sqlx-macros-core-0.8.3/.cargo_vcs_info.json0000644000000001560000000000100142450ustar { "git": { "sha1": "28cfdbb40c4fe535721c9ee5e1583409e0cac27e" }, "path_in_vcs": "sqlx-macros-core" }sqlx-macros-core-0.8.3/Cargo.toml0000644000000074450000000000100122530ustar # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO # # When uploading crates to the registry Cargo will automatically # "normalize" Cargo.toml files for maximal compatibility # with all versions of Cargo and also rewrite `path` dependencies # to registry (e.g., crates.io) dependencies. # # If you are reading this file be aware that the original Cargo.toml # will likely look very different (and much more reasonable). # See Cargo.toml.orig for the original contents. [package] edition = "2021" name = "sqlx-macros-core" version = "0.8.3" authors = [ "Ryan Leckey ", "Austin Bonander ", "Chloe Ross ", "Daniel Akhterov ", ] description = "Macro support core for SQLx, the Rust SQL toolkit. Not intended to be used directly." license = "MIT OR Apache-2.0" repository = "https://github.com/launchbadge/sqlx" [dependencies.async-std] version = "1.12" optional = true [dependencies.dotenvy] version = "0.15.0" default-features = false [dependencies.either] version = "1.6.1" [dependencies.heck] version = "0.5" [dependencies.hex] version = "0.4.3" [dependencies.once_cell] version = "1.9.0" [dependencies.proc-macro2] version = "1.0.79" default-features = false [dependencies.quote] version = "1.0.26" default-features = false [dependencies.serde] version = "1.0.132" features = ["derive"] [dependencies.serde_json] version = "1.0.73" [dependencies.sha2] version = "0.10.0" [dependencies.sqlx-core] version = "=0.8.3" features = ["offline"] [dependencies.sqlx-mysql] version = "=0.8.3" features = [ "offline", "migrate", ] optional = true [dependencies.sqlx-postgres] version = "=0.8.3" features = [ "offline", "migrate", ] optional = true [dependencies.sqlx-sqlite] version = "=0.8.3" features = [ "offline", "migrate", ] optional = true [dependencies.syn] version = "2.0.52" features = [ "full", "derive", "parsing", "printing", "clone-impls", ] default-features = false [dependencies.tempfile] version = "3.10.1" [dependencies.tokio] version = "1" features = [ "time", "net", "sync", "fs", "io-util", "rt", ] optional = true default-features = false [dependencies.url] version = "2.2.2" [features] _rt-async-std = [ "async-std", "sqlx-core/_rt-async-std", ] _rt-tokio = [ "tokio", "sqlx-core/_rt-tokio", ] _sqlite = [] _tls-native-tls = ["sqlx-core/_tls-native-tls"] _tls-rustls-aws-lc-rs = ["sqlx-core/_tls-rustls-aws-lc-rs"] _tls-rustls-ring-native-roots = ["sqlx-core/_tls-rustls-ring-native-roots"] _tls-rustls-ring-webpki = ["sqlx-core/_tls-rustls-ring-webpki"] bigdecimal = [ "sqlx-core/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal", ] bit-vec = [ "sqlx-core/bit-vec", "sqlx-postgres?/bit-vec", ] chrono = [ "sqlx-core/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono", ] default = [] derive = [] ipnetwork = [ "sqlx-core/ipnetwork", "sqlx-postgres?/ipnetwork", ] json = [ "sqlx-core/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json", ] mac_address = [ "sqlx-core/mac_address", "sqlx-postgres?/mac_address", ] macros = [] migrate = ["sqlx-core/migrate"] mysql = ["sqlx-mysql"] postgres = ["sqlx-postgres"] rust_decimal = [ "sqlx-core/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal", ] sqlite = [ "_sqlite", "sqlx-sqlite/bundled", ] sqlite-unbundled = [ "_sqlite", "sqlx-sqlite/unbundled", ] time = [ "sqlx-core/time", "sqlx-mysql?/time", "sqlx-postgres?/time", "sqlx-sqlite?/time", ] uuid = [ "sqlx-core/uuid", "sqlx-mysql?/uuid", "sqlx-postgres?/uuid", "sqlx-sqlite?/uuid", ] [lints.rust.unexpected_cfgs] level = "warn" priority = 0 sqlx-macros-core-0.8.3/Cargo.toml.orig000064400000000000000000000054231046102023000157260ustar 00000000000000[package] name = "sqlx-macros-core" description = "Macro support core for SQLx, the Rust SQL toolkit. Not intended to be used directly." version.workspace = true license.workspace = true edition.workspace = true authors.workspace = true repository.workspace = true [features] default = [] # for conditional compilation _rt-async-std = ["async-std", "sqlx-core/_rt-async-std"] _rt-tokio = ["tokio", "sqlx-core/_rt-tokio"] _tls-native-tls = ["sqlx-core/_tls-native-tls"] _tls-rustls-aws-lc-rs = ["sqlx-core/_tls-rustls-aws-lc-rs"] _tls-rustls-ring-webpki = ["sqlx-core/_tls-rustls-ring-webpki"] _tls-rustls-ring-native-roots = ["sqlx-core/_tls-rustls-ring-native-roots"] _sqlite = [] # SQLx features derive = [] macros = [] migrate = ["sqlx-core/migrate"] # database mysql = ["sqlx-mysql"] postgres = ["sqlx-postgres"] sqlite = ["_sqlite", "sqlx-sqlite/bundled"] sqlite-unbundled = ["_sqlite", "sqlx-sqlite/unbundled"] # type integrations json = ["sqlx-core/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json"] bigdecimal = ["sqlx-core/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"] bit-vec = ["sqlx-core/bit-vec", "sqlx-postgres?/bit-vec"] chrono = ["sqlx-core/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"] ipnetwork = ["sqlx-core/ipnetwork", "sqlx-postgres?/ipnetwork"] mac_address = ["sqlx-core/mac_address", "sqlx-postgres?/mac_address"] rust_decimal = ["sqlx-core/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] time = ["sqlx-core/time", "sqlx-mysql?/time", "sqlx-postgres?/time", "sqlx-sqlite?/time"] uuid = ["sqlx-core/uuid", "sqlx-mysql?/uuid", "sqlx-postgres?/uuid", "sqlx-sqlite?/uuid"] [dependencies] sqlx-core = { workspace = true, features = ["offline"] } sqlx-mysql = { workspace = true, features = ["offline", "migrate"], optional = true } sqlx-postgres = { workspace = true, features = ["offline", "migrate"], optional = true } sqlx-sqlite = { workspace = true, features = ["offline", "migrate"], optional = true } async-std = { workspace = true, optional = true } tokio = { workspace = true, optional = true } dotenvy = { workspace = true } hex = { version = "0.4.3" } heck = { version = "0.5" } either = "1.6.1" once_cell = "1.9.0" proc-macro2 = { version = "1.0.79", default-features = false } serde = { version = "1.0.132", features = ["derive"] } serde_json = { version = "1.0.73" } sha2 = { version = "0.10.0" } syn = { version = "2.0.52", default-features = false, features = ["full", "derive", "parsing", "printing", "clone-impls"] } tempfile = { version = "3.10.1" } quote = { version = "1.0.26", default-features = false } url = { version = "2.2.2" } [lints.rust.unexpected_cfgs] level = "warn" # 1.80 will warn without this check-cfg = ['cfg(sqlx_macros_unstable)', 'cfg(procmacro2_semver_exempt)'] sqlx-macros-core-0.8.3/LICENSE-APACHE000064400000000000000000000240031046102023000147560ustar 00000000000000Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2020 LaunchBadge, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.sqlx-macros-core-0.8.3/LICENSE-MIT000064400000000000000000000020441046102023000144670ustar 00000000000000Copyright (c) 2020 LaunchBadge, LLC Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. sqlx-macros-core-0.8.3/src/common.rs000064400000000000000000000020611046102023000154570ustar 00000000000000use proc_macro2::Span; use std::env; use std::path::{Path, PathBuf}; pub(crate) fn resolve_path(path: impl AsRef, err_span: Span) -> syn::Result { let path = path.as_ref(); if path.is_absolute() { return Err(syn::Error::new( err_span, "absolute paths will only work on the current machine", )); } // requires `proc_macro::SourceFile::path()` to be stable // https://github.com/rust-lang/rust/issues/54725 if path.is_relative() && !path .parent() .map_or(false, |parent| !parent.as_os_str().is_empty()) { return Err(syn::Error::new( err_span, "paths relative to the current file's directory are not currently supported", )); } let base_dir = env::var("CARGO_MANIFEST_DIR").map_err(|_| { syn::Error::new( err_span, "CARGO_MANIFEST_DIR is not set; please use Cargo to build", ) })?; let base_dir_path = Path::new(&base_dir); Ok(base_dir_path.join(path)) } sqlx-macros-core-0.8.3/src/database/impls.rs000064400000000000000000000042261046102023000170640ustar 00000000000000macro_rules! impl_database_ext { ( $database:path, row: $row:path, $(describe-blocking: $describe:path,)? ) => { impl $crate::database::DatabaseExt for $database { const DATABASE_PATH: &'static str = stringify!($database); const ROW_PATH: &'static str = stringify!($row); impl_describe_blocking!($database, $($describe)?); } } } macro_rules! impl_describe_blocking { ($database:path $(,)?) => { fn describe_blocking( query: &str, database_url: &str, ) -> sqlx_core::Result> { use $crate::database::CachingDescribeBlocking; // This can't be a provided method because the `static` can't reference `Self`. static CACHE: CachingDescribeBlocking<$database> = CachingDescribeBlocking::new(); CACHE.describe(query, database_url) } }; ($database:path, $describe:path) => { fn describe_blocking( query: &str, database_url: &str, ) -> sqlx_core::Result> { $describe(query, database_url) } }; } // The paths below will also be emitted from the macros, so they need to match the final facade. mod sqlx { #[cfg(feature = "mysql")] pub use sqlx_mysql as mysql; #[cfg(feature = "postgres")] pub use sqlx_postgres as postgres; #[cfg(feature = "_sqlite")] pub use sqlx_sqlite as sqlite; } // NOTE: type mappings have been moved to `src/type_checking.rs` in their respective driver crates. #[cfg(feature = "mysql")] impl_database_ext! { sqlx::mysql::MySql, row: sqlx::mysql::MySqlRow, } #[cfg(feature = "postgres")] impl_database_ext! { sqlx::postgres::Postgres, row: sqlx::postgres::PgRow, } #[cfg(feature = "_sqlite")] impl_database_ext! { sqlx::sqlite::Sqlite, row: sqlx::sqlite::SqliteRow, // Since proc-macros don't benefit from async, we can make a describe call directly // which also ensures that the database is closed afterwards, regardless of errors. describe-blocking: sqlx_sqlite::describe_blocking, } sqlx-macros-core-0.8.3/src/database/mod.rs000064400000000000000000000034751046102023000165240ustar 00000000000000use std::collections::hash_map; use std::collections::HashMap; use std::sync::Mutex; use once_cell::sync::Lazy; use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::type_checking::TypeChecking; #[cfg(any(feature = "postgres", feature = "mysql", feature = "_sqlite"))] mod impls; pub trait DatabaseExt: Database + TypeChecking { const DATABASE_PATH: &'static str; const ROW_PATH: &'static str; fn db_path() -> syn::Path { syn::parse_str(Self::DATABASE_PATH).unwrap() } fn row_path() -> syn::Path { syn::parse_str(Self::ROW_PATH).unwrap() } fn describe_blocking(query: &str, database_url: &str) -> sqlx_core::Result>; } #[allow(dead_code)] pub struct CachingDescribeBlocking { connections: Lazy>>, } #[allow(dead_code)] impl CachingDescribeBlocking { pub const fn new() -> Self { CachingDescribeBlocking { connections: Lazy::new(|| Mutex::new(HashMap::new())), } } pub fn describe(&self, query: &str, database_url: &str) -> sqlx_core::Result> where for<'a> &'a mut DB::Connection: Executor<'a, Database = DB>, { let mut cache = self .connections .lock() .expect("previous panic in describe call"); crate::block_on(async { let conn = match cache.entry(database_url.to_string()) { hash_map::Entry::Occupied(hit) => hit.into_mut(), hash_map::Entry::Vacant(miss) => { miss.insert(DB::Connection::connect(database_url).await?) } }; conn.describe(query).await }) } } sqlx-macros-core-0.8.3/src/derives/attributes.rs000064400000000000000000000201531046102023000200200ustar 00000000000000use proc_macro2::{Ident, Span, TokenStream}; use quote::quote_spanned; use syn::{ punctuated::Punctuated, token::Comma, Attribute, DeriveInput, Field, LitStr, Meta, Token, Type, Variant, }; macro_rules! assert_attribute { ($e:expr, $err:expr, $input:expr) => { if !$e { return Err(syn::Error::new_spanned($input, $err)); } }; } macro_rules! fail { ($t:expr, $m:expr) => { return Err(syn::Error::new_spanned($t, $m)) }; } macro_rules! try_set { ($i:ident, $v:expr, $t:expr) => { match $i { None => $i = Some($v), Some(_) => fail!($t, "duplicate attribute"), } }; } pub struct TypeName { pub val: String, pub span: Span, } impl TypeName { pub fn get(&self) -> TokenStream { let val = &self.val; quote_spanned! { self.span => #val } } } #[derive(Copy, Clone)] #[allow(clippy::enum_variant_names)] pub enum RenameAll { LowerCase, SnakeCase, UpperCase, ScreamingSnakeCase, KebabCase, CamelCase, PascalCase, } pub struct SqlxContainerAttributes { pub transparent: bool, pub type_name: Option, pub rename_all: Option, pub repr: Option, pub no_pg_array: bool, pub default: bool, } pub struct SqlxChildAttributes { pub rename: Option, pub default: bool, pub flatten: bool, pub try_from: Option, pub skip: bool, pub json: bool, } pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result { let mut transparent = None; let mut repr = None; let mut type_name = None; let mut rename_all = None; let mut no_pg_array = None; let mut default = None; for attr in input { if attr.path().is_ident("sqlx") { attr.parse_nested_meta(|meta| { if meta.path.is_ident("transparent") { try_set!(transparent, true, attr); } else if meta.path.is_ident("no_pg_array") { try_set!(no_pg_array, true, attr); } else if meta.path.is_ident("default") { try_set!(default, true, attr); } else if meta.path.is_ident("rename_all") { meta.input.parse::()?; let lit: LitStr = meta.input.parse()?; let val = match lit.value().as_str() { "lowercase" => RenameAll::LowerCase, "snake_case" => RenameAll::SnakeCase, "UPPERCASE" => RenameAll::UpperCase, "SCREAMING_SNAKE_CASE" => RenameAll::ScreamingSnakeCase, "kebab-case" => RenameAll::KebabCase, "camelCase" => RenameAll::CamelCase, "PascalCase" => RenameAll::PascalCase, _ => fail!(lit, "unexpected value for rename_all"), }; try_set!(rename_all, val, lit) } else if meta.path.is_ident("type_name") { meta.input.parse::()?; let lit: LitStr = meta.input.parse()?; let name = TypeName { val: lit.value(), span: lit.span(), }; try_set!(type_name, name, lit) } else { fail!(meta.path, "unexpected attribute") } Ok(()) })?; } else if attr.path().is_ident("repr") { let list: Punctuated = attr.parse_args_with(>::parse_terminated)?; if let Some(path) = list.iter().find_map(|f| f.require_path_only().ok()) { try_set!(repr, path.get_ident().unwrap().clone(), list); } } } Ok(SqlxContainerAttributes { transparent: transparent.unwrap_or(false), repr, type_name, rename_all, no_pg_array: no_pg_array.unwrap_or(false), default: default.unwrap_or(false), }) } pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result { let mut rename = None; let mut default = false; let mut try_from = None; let mut flatten = false; let mut skip: bool = false; let mut json = false; for attr in input.iter().filter(|a| a.path().is_ident("sqlx")) { attr.parse_nested_meta(|meta| { if meta.path.is_ident("rename") { meta.input.parse::()?; let val: LitStr = meta.input.parse()?; try_set!(rename, val.value(), val); } else if meta.path.is_ident("try_from") { meta.input.parse::()?; let val: LitStr = meta.input.parse()?; try_set!(try_from, val.parse()?, val); } else if meta.path.is_ident("default") { default = true; } else if meta.path.is_ident("flatten") { flatten = true; } else if meta.path.is_ident("skip") { skip = true; } else if meta.path.is_ident("json") { json = true; } Ok(()) })?; if json && flatten { fail!( attr, "Cannot use `json` and `flatten` together on the same field" ); } } Ok(SqlxChildAttributes { rename, default, flatten, try_from, skip, json, }) } pub fn check_transparent_attributes( input: &DeriveInput, field: &Field, ) -> syn::Result { let attributes = parse_container_attributes(&input.attrs)?; assert_attribute!( attributes.rename_all.is_none(), "unexpected #[sqlx(rename_all = ..)]", field ); let ch_attributes = parse_child_attributes(&field.attrs)?; assert_attribute!( ch_attributes.rename.is_none(), "unexpected #[sqlx(rename = ..)]", field ); Ok(attributes) } pub fn check_enum_attributes(input: &DeriveInput) -> syn::Result { let attributes = parse_container_attributes(&input.attrs)?; assert_attribute!( !attributes.transparent, "unexpected #[sqlx(transparent)]", input ); Ok(attributes) } pub fn check_weak_enum_attributes( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { let attributes = check_enum_attributes(input)?; assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input); assert_attribute!( attributes.rename_all.is_none(), "unexpected #[sqlx(c = ..)]", input ); for variant in variants { let attributes = parse_child_attributes(&variant.attrs)?; assert_attribute!( attributes.rename.is_none(), "unexpected #[sqlx(rename = ..)]", variant ); } Ok(attributes) } pub fn check_strong_enum_attributes( input: &DeriveInput, _variants: &Punctuated, ) -> syn::Result { let attributes = check_enum_attributes(input)?; assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); Ok(attributes) } pub fn check_struct_attributes( input: &DeriveInput, fields: &Punctuated, ) -> syn::Result { let attributes = parse_container_attributes(&input.attrs)?; assert_attribute!( !attributes.transparent, "unexpected #[sqlx(transparent)]", input ); assert_attribute!( attributes.rename_all.is_none(), "unexpected #[sqlx(rename_all = ..)]", input ); assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); for field in fields { let attributes = parse_child_attributes(&field.attrs)?; assert_attribute!( attributes.rename.is_none(), "unexpected #[sqlx(rename = ..)]", field ); } Ok(attributes) } sqlx-macros-core-0.8.3/src/derives/decode.rs000064400000000000000000000246601046102023000170640ustar 00000000000000use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, check_weak_enum_attributes, parse_child_attributes, parse_container_attributes, }; use super::rename_all; use proc_macro2::TokenStream; use quote::quote; use syn::punctuated::Punctuated; use syn::token::Comma; use syn::{ parse_quote, Arm, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, FieldsUnnamed, Stmt, TypeParamBound, Variant, }; pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result { let attrs = parse_container_attributes(&input.attrs)?; match &input.data { Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), .. }) if unnamed.len() == 1 => { expand_derive_decode_transparent(input, unnamed.first().unwrap()) } Data::Enum(DataEnum { variants, .. }) => match attrs.repr { Some(_) => expand_derive_decode_weak_enum(input, variants), None => expand_derive_decode_strong_enum(input, variants), }, Data::Struct(DataStruct { fields: Fields::Named(FieldsNamed { named, .. }), .. }) => expand_derive_decode_struct(input, named), Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), Data::Struct(DataStruct { fields: Fields::Unnamed(..), .. }) => Err(syn::Error::new_spanned( input, "structs with zero or more than one unnamed field are not supported", )), Data::Struct(DataStruct { fields: Fields::Unit, .. }) => Err(syn::Error::new_spanned( input, "unit structs are not supported", )), } } fn expand_derive_decode_transparent( input: &DeriveInput, field: &Field, ) -> syn::Result { check_transparent_attributes(input, field)?; let ident = &input.ident; let ty = &field.ty; // extract type generics let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); // add db type for impl generics & where clause let mut generics = generics.clone(); generics .params .insert(0, parse_quote!(DB: ::sqlx::Database)); generics.params.insert(0, parse_quote!('r)); generics .make_where_clause() .predicates .push(parse_quote!(#ty: ::sqlx::decode::Decode<'r, DB>)); let (impl_generics, _, where_clause) = generics.split_for_impl(); let tts = quote!( #[automatically_derived] impl #impl_generics ::sqlx::decode::Decode<'r, DB> for #ident #ty_generics #where_clause { fn decode( value: ::ValueRef<'r>, ) -> ::std::result::Result< Self, ::std::boxed::Box< dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync, >, > { <#ty as ::sqlx::decode::Decode<'r, DB>>::decode(value).map(Self) } } ); Ok(tts) } fn expand_derive_decode_weak_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { let attr = check_weak_enum_attributes(input, variants)?; let repr = attr.repr.unwrap(); let ident = &input.ident; let ident_s = ident.to_string(); let arms = variants .iter() .map(|v| { let id = &v.ident; parse_quote! { _ if (#ident::#id as #repr) == value => ::std::result::Result::Ok(#ident::#id), } }) .collect::>(); Ok(quote!( #[automatically_derived] impl<'r, DB: ::sqlx::Database> ::sqlx::decode::Decode<'r, DB> for #ident where #repr: ::sqlx::decode::Decode<'r, DB>, { fn decode( value: ::ValueRef<'r>, ) -> ::std::result::Result< Self, ::std::boxed::Box< dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync, >, > { let value = <#repr as ::sqlx::decode::Decode<'r, DB>>::decode(value)?; match value { #(#arms)* _ => ::std::result::Result::Err(::std::boxed::Box::new(::sqlx::Error::Decode( ::std::format!("invalid value {:?} for enum {}", value, #ident_s).into(), ))) } } } )) } fn expand_derive_decode_strong_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { let cattr = check_strong_enum_attributes(input, variants)?; let ident = &input.ident; let ident_s = ident.to_string(); let value_arms = variants.iter().map(|v| -> Arm { let id = &v.ident; let attributes = parse_child_attributes(&v.attrs).unwrap(); if let Some(rename) = attributes.rename { parse_quote!(#rename => ::std::result::Result::Ok(#ident :: #id),) } else if let Some(pattern) = cattr.rename_all { let name = rename_all(&id.to_string(), pattern); parse_quote!(#name => ::std::result::Result::Ok(#ident :: #id),) } else { let name = id.to_string(); parse_quote!(#name => ::std::result::Result::Ok(#ident :: #id),) } }); let values = quote! { match value { #(#value_arms)* _ => Err(format!("invalid value {:?} for enum {}", value, #ident_s).into()) } }; let mut tts = TokenStream::new(); if cfg!(feature = "mysql") { tts.extend(quote!( #[automatically_derived] impl<'r> ::sqlx::decode::Decode<'r, ::sqlx::mysql::MySql> for #ident { fn decode( value: ::sqlx::mysql::MySqlValueRef<'r>, ) -> ::std::result::Result< Self, ::std::boxed::Box< dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync, >, > { let value = <&'r ::std::primitive::str as ::sqlx::decode::Decode< 'r, ::sqlx::mysql::MySql, >>::decode(value)?; #values } } )); } if cfg!(feature = "postgres") { tts.extend(quote!( #[automatically_derived] impl<'r> ::sqlx::decode::Decode<'r, ::sqlx::postgres::Postgres> for #ident { fn decode( value: ::sqlx::postgres::PgValueRef<'r>, ) -> ::std::result::Result< Self, ::std::boxed::Box< dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync, >, > { let value = <&'r ::std::primitive::str as ::sqlx::decode::Decode< 'r, ::sqlx::postgres::Postgres, >>::decode(value)?; #values } } )); } if cfg!(feature = "_sqlite") { tts.extend(quote!( #[automatically_derived] impl<'r> ::sqlx::decode::Decode<'r, ::sqlx::sqlite::Sqlite> for #ident { fn decode( value: ::sqlx::sqlite::SqliteValueRef<'r>, ) -> ::std::result::Result< Self, ::std::boxed::Box< dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync, >, > { let value = <&'r ::std::primitive::str as ::sqlx::decode::Decode< 'r, ::sqlx::sqlite::Sqlite, >>::decode(value)?; #values } } )); } Ok(tts) } fn expand_derive_decode_struct( input: &DeriveInput, fields: &Punctuated, ) -> syn::Result { check_struct_attributes(input, fields)?; let mut tts = TokenStream::new(); if cfg!(feature = "postgres") { let ident = &input.ident; let (_, ty_generics, where_clause) = input.generics.split_for_impl(); let mut generics = input.generics.clone(); // add db type for impl generics & where clause for type_param in &mut generics.type_params_mut() { type_param.bounds.extend::<[TypeParamBound; 2]>([ parse_quote!(for<'decode> ::sqlx::decode::Decode<'decode, ::sqlx::Postgres>), parse_quote!(::sqlx::types::Type<::sqlx::Postgres>), ]); } generics.params.push(parse_quote!('r)); let (impl_generics, _, _) = generics.split_for_impl(); let reads = fields.iter().map(|field| -> Stmt { let id = &field.ident; let ty = &field.ty; parse_quote!( let #id = decoder.try_decode::<#ty>()?; ) }); let names = fields.iter().map(|field| &field.ident); tts.extend(quote!( #[automatically_derived] impl #impl_generics ::sqlx::decode::Decode<'r, ::sqlx::Postgres> for #ident #ty_generics #where_clause { fn decode( value: ::sqlx::postgres::PgValueRef<'r>, ) -> ::std::result::Result< Self, ::std::boxed::Box< dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync, >, > { let mut decoder = ::sqlx::postgres::types::PgRecordDecoder::new(value)?; #(#reads)* ::std::result::Result::Ok(#ident { #(#names),* }) } } )); } Ok(tts) } sqlx-macros-core-0.8.3/src/derives/encode.rs000064400000000000000000000213121046102023000170650ustar 00000000000000use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, check_weak_enum_attributes, parse_child_attributes, parse_container_attributes, }; use super::rename_all; use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::punctuated::Punctuated; use syn::token::Comma; use syn::{ parse_quote, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, FieldsUnnamed, Lifetime, LifetimeParam, Stmt, TypeParamBound, Variant, }; pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result { let args = parse_container_attributes(&input.attrs)?; match &input.data { Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), .. }) if unnamed.len() == 1 => { expand_derive_encode_transparent(input, unnamed.first().unwrap()) } Data::Enum(DataEnum { variants, .. }) => match args.repr { Some(_) => expand_derive_encode_weak_enum(input, variants), None => expand_derive_encode_strong_enum(input, variants), }, Data::Struct(DataStruct { fields: Fields::Named(FieldsNamed { named, .. }), .. }) => expand_derive_encode_struct(input, named), Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), Data::Struct(DataStruct { fields: Fields::Unnamed(..), .. }) => Err(syn::Error::new_spanned( input, "structs with zero or more than one unnamed field are not supported", )), Data::Struct(DataStruct { fields: Fields::Unit, .. }) => Err(syn::Error::new_spanned( input, "unit structs are not supported", )), } } fn expand_derive_encode_transparent( input: &DeriveInput, field: &Field, ) -> syn::Result { check_transparent_attributes(input, field)?; let ident = &input.ident; let ty = &field.ty; // extract type generics let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); // add db type for impl generics & where clause let lifetime = Lifetime::new("'q", Span::call_site()); let mut generics = generics.clone(); generics .params .insert(0, LifetimeParam::new(lifetime.clone()).into()); generics .params .insert(0, parse_quote!(DB: ::sqlx::Database)); generics .make_where_clause() .predicates .push(parse_quote!(#ty: ::sqlx::encode::Encode<#lifetime, DB>)); let (impl_generics, _, where_clause) = generics.split_for_impl(); Ok(quote!( #[automatically_derived] impl #impl_generics ::sqlx::encode::Encode<#lifetime, DB> for #ident #ty_generics #where_clause { fn encode_by_ref( &self, buf: &mut ::ArgumentBuffer<#lifetime>, ) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> { <#ty as ::sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.0, buf) } fn produces(&self) -> Option { <#ty as ::sqlx::encode::Encode<#lifetime, DB>>::produces(&self.0) } fn size_hint(&self) -> usize { <#ty as ::sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.0) } } )) } fn expand_derive_encode_weak_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { let attr = check_weak_enum_attributes(input, variants)?; let repr = attr.repr.unwrap(); let ident = &input.ident; let mut values = Vec::new(); for v in variants { let id = &v.ident; values.push(quote!(#ident :: #id => (#ident :: #id as #repr),)); } Ok(quote!( #[automatically_derived] impl<'q, DB: ::sqlx::Database> ::sqlx::encode::Encode<'q, DB> for #ident where #repr: ::sqlx::encode::Encode<'q, DB>, { fn encode_by_ref( &self, buf: &mut ::ArgumentBuffer<'q>, ) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> { let value = match self { #(#values)* }; <#repr as ::sqlx::encode::Encode>::encode_by_ref(&value, buf) } fn size_hint(&self) -> usize { <#repr as ::sqlx::encode::Encode>::size_hint(&Default::default()) } } )) } fn expand_derive_encode_strong_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { let cattr = check_strong_enum_attributes(input, variants)?; let ident = &input.ident; let mut value_arms = Vec::new(); for v in variants { let id = &v.ident; let attributes = parse_child_attributes(&v.attrs)?; if let Some(rename) = attributes.rename { value_arms.push(quote!(#ident :: #id => #rename,)); } else if let Some(pattern) = cattr.rename_all { let name = rename_all(&id.to_string(), pattern); value_arms.push(quote!(#ident :: #id => #name,)); } else { let name = id.to_string(); value_arms.push(quote!(#ident :: #id => #name,)); } } Ok(quote!( #[automatically_derived] impl<'q, DB: ::sqlx::Database> ::sqlx::encode::Encode<'q, DB> for #ident where &'q ::std::primitive::str: ::sqlx::encode::Encode<'q, DB>, { fn encode_by_ref( &self, buf: &mut ::ArgumentBuffer<'q>, ) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> { let val = match self { #(#value_arms)* }; <&::std::primitive::str as ::sqlx::encode::Encode<'q, DB>>::encode(val, buf) } fn size_hint(&self) -> ::std::primitive::usize { let val = match self { #(#value_arms)* }; <&::std::primitive::str as ::sqlx::encode::Encode<'q, DB>>::size_hint(&val) } } )) } fn expand_derive_encode_struct( input: &DeriveInput, fields: &Punctuated, ) -> syn::Result { check_struct_attributes(input, fields)?; let mut tts = TokenStream::new(); if cfg!(feature = "postgres") { let ident = &input.ident; let column_count = fields.len(); let (_, ty_generics, where_clause) = input.generics.split_for_impl(); let mut generics = input.generics.clone(); // add db type for impl generics & where clause for type_param in &mut generics.type_params_mut() { type_param.bounds.extend::<[TypeParamBound; 2]>([ parse_quote!(for<'encode> ::sqlx::encode::Encode<'encode, ::sqlx::Postgres>), parse_quote!(::sqlx::types::Type<::sqlx::Postgres>), ]); } generics.params.push(parse_quote!('q)); let (impl_generics, _, _) = generics.split_for_impl(); let writes = fields.iter().map(|field| -> Stmt { let id = &field.ident; parse_quote!( encoder.encode(&self. #id)?; ) }); let sizes = fields.iter().map(|field| -> Expr { let id = &field.ident; let ty = &field.ty; parse_quote!( <#ty as ::sqlx::encode::Encode<::sqlx::Postgres>>::size_hint(&self. #id) ) }); tts.extend(quote!( #[automatically_derived] impl #impl_generics ::sqlx::encode::Encode<'_, ::sqlx::Postgres> for #ident #ty_generics #where_clause { fn encode_by_ref( &self, buf: &mut ::sqlx::postgres::PgArgumentBuffer, ) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> { let mut encoder = ::sqlx::postgres::types::PgRecordEncoder::new(buf); #(#writes)* encoder.finish(); ::std::result::Result::Ok(::sqlx::encode::IsNull::No) } fn size_hint(&self) -> ::std::primitive::usize { #column_count * (4 + 4) // oid (int) and length (int) for each column + #(#sizes)+* // sum of the size hints for each column } } )); } Ok(tts) } sqlx-macros-core-0.8.3/src/derives/mod.rs000064400000000000000000000023201046102023000164050ustar 00000000000000mod attributes; mod decode; mod encode; mod row; mod r#type; pub use decode::expand_derive_decode; pub use encode::expand_derive_encode; pub use r#type::expand_derive_type; pub use row::expand_derive_from_row; use self::attributes::RenameAll; use heck::{ToKebabCase, ToLowerCamelCase, ToShoutySnakeCase, ToSnakeCase, ToUpperCamelCase}; use proc_macro2::TokenStream; use syn::DeriveInput; pub fn expand_derive_type_encode_decode(input: &DeriveInput) -> syn::Result { let encode_tts = expand_derive_encode(input)?; let decode_tts = expand_derive_decode(input)?; let type_tts = expand_derive_type(input)?; let combined = TokenStream::from_iter(encode_tts.into_iter().chain(decode_tts).chain(type_tts)); Ok(combined) } pub(crate) fn rename_all(s: &str, pattern: RenameAll) -> String { match pattern { RenameAll::LowerCase => s.to_lowercase(), RenameAll::SnakeCase => s.to_snake_case(), RenameAll::UpperCase => s.to_uppercase(), RenameAll::ScreamingSnakeCase => s.to_shouty_snake_case(), RenameAll::KebabCase => s.to_kebab_case(), RenameAll::CamelCase => s.to_lower_camel_case(), RenameAll::PascalCase => s.to_upper_camel_case(), } } sqlx-macros-core-0.8.3/src/derives/row.rs000064400000000000000000000252001046102023000164370ustar 00000000000000use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::{ parse_quote, punctuated::Punctuated, token::Comma, Data, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, FieldsUnnamed, Lifetime, Stmt, }; use super::{ attributes::{parse_child_attributes, parse_container_attributes}, rename_all, }; pub fn expand_derive_from_row(input: &DeriveInput) -> syn::Result { match &input.data { Data::Struct(DataStruct { fields: Fields::Named(FieldsNamed { named, .. }), .. }) => expand_derive_from_row_struct(input, named), Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), .. }) => expand_derive_from_row_struct_unnamed(input, unnamed), Data::Struct(DataStruct { fields: Fields::Unit, .. }) => Err(syn::Error::new_spanned( input, "unit structs are not supported", )), Data::Enum(_) => Err(syn::Error::new_spanned(input, "enums are not supported")), Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), } } fn expand_derive_from_row_struct( input: &DeriveInput, fields: &Punctuated, ) -> syn::Result { let ident = &input.ident; let generics = &input.generics; let (lifetime, provided) = generics .lifetimes() .next() .map(|def| (def.lifetime.clone(), false)) .unwrap_or_else(|| (Lifetime::new("'a", Span::call_site()), true)); let (_, ty_generics, _) = generics.split_for_impl(); let mut generics = generics.clone(); generics.params.insert(0, parse_quote!(R: ::sqlx::Row)); if provided { generics.params.insert(0, parse_quote!(#lifetime)); } let predicates = &mut generics.make_where_clause().predicates; predicates.push(parse_quote!(&#lifetime ::std::primitive::str: ::sqlx::ColumnIndex)); let container_attributes = parse_container_attributes(&input.attrs)?; let default_instance: Option = if container_attributes.default { predicates.push(parse_quote!(#ident: ::std::default::Default)); Some(parse_quote!( let __default = #ident::default(); )) } else { None }; let reads: Vec = fields .iter() .filter_map(|field| -> Option { let id = &field.ident.as_ref()?; let attributes = parse_child_attributes(&field.attrs).unwrap(); let ty = &field.ty; if attributes.skip { return Some(parse_quote!( let #id: #ty = Default::default(); )); } let id_s = if let Some(s) = attributes.rename { s } else { let s = id.to_string().trim_start_matches("r#").to_owned(); match container_attributes.rename_all { Some(pattern) => rename_all(&s, pattern), None => s } }; let expr: Expr = match (attributes.flatten, attributes.try_from, attributes.json) { // (false, None, false) => { predicates .push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>)); predicates.push(parse_quote!(#ty: ::sqlx::types::Type)); parse_quote!(__row.try_get(#id_s)) } // Flatten (true, None, false) => { predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>)); parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(__row)) } // Flatten + Try from (true, Some(try_from), false) => { predicates.push(parse_quote!(#try_from: ::sqlx::FromRow<#lifetime, R>)); parse_quote!( <#try_from as ::sqlx::FromRow<#lifetime, R>>::from_row(__row) .and_then(|v| { <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v) .map_err(|e| { // Triggers a lint warning if `TryFrom::Err = Infallible` #[allow(unreachable_code)] ::sqlx::Error::ColumnDecode { index: #id_s.to_string(), source: sqlx::__spec_error!(e), } }) }) ) } // Flatten + Json (true, _, true) => { panic!("Cannot use both flatten and json") } // Try from (false, Some(try_from), false) => { predicates .push(parse_quote!(#try_from: ::sqlx::decode::Decode<#lifetime, R::Database>)); predicates.push(parse_quote!(#try_from: ::sqlx::types::Type)); parse_quote!( __row.try_get(#id_s) .and_then(|v| { <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v) .map_err(|e| { // Triggers a lint warning if `TryFrom::Err = Infallible` #[allow(unreachable_code)] ::sqlx::Error::ColumnDecode { index: #id_s.to_string(), source: sqlx::__spec_error!(e), } }) }) ) } // Try from + Json (false, Some(try_from), true) => { predicates .push(parse_quote!(::sqlx::types::Json<#try_from>: ::sqlx::decode::Decode<#lifetime, R::Database>)); predicates.push(parse_quote!(::sqlx::types::Json<#try_from>: ::sqlx::types::Type)); parse_quote!( __row.try_get::<::sqlx::types::Json<_>, _>(#id_s) .and_then(|v| { <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v.0) .map_err(|e| { // Triggers a lint warning if `TryFrom::Err = Infallible` #[allow(unreachable_code)] ::sqlx::Error::ColumnDecode { index: #id_s.to_string(), source: sqlx::__spec_error!(e), } }) }) ) }, // Json (false, None, true) => { predicates .push(parse_quote!(::sqlx::types::Json<#ty>: ::sqlx::decode::Decode<#lifetime, R::Database>)); predicates.push(parse_quote!(::sqlx::types::Json<#ty>: ::sqlx::types::Type)); parse_quote!(__row.try_get::<::sqlx::types::Json<_>, _>(#id_s).map(|x| x.0)) }, }; if attributes.default { Some(parse_quote!( let #id: #ty = #expr.or_else(|e| match e { ::sqlx::Error::ColumnNotFound(_) => { ::std::result::Result::Ok(Default::default()) }, e => ::std::result::Result::Err(e) })?; )) } else if container_attributes.default { Some(parse_quote!( let #id: #ty = #expr.or_else(|e| match e { ::sqlx::Error::ColumnNotFound(_) => { ::std::result::Result::Ok(__default.#id) }, e => ::std::result::Result::Err(e) })?; )) } else { Some(parse_quote!( let #id: #ty = #expr?; )) } }) .collect(); let (impl_generics, _, where_clause) = generics.split_for_impl(); let names = fields.iter().map(|field| &field.ident); Ok(quote!( #[automatically_derived] impl #impl_generics ::sqlx::FromRow<#lifetime, R> for #ident #ty_generics #where_clause { fn from_row(__row: &#lifetime R) -> ::sqlx::Result { #default_instance #(#reads)* ::std::result::Result::Ok(#ident { #(#names),* }) } } )) } fn expand_derive_from_row_struct_unnamed( input: &DeriveInput, fields: &Punctuated, ) -> syn::Result { let ident = &input.ident; let generics = &input.generics; let (lifetime, provided) = generics .lifetimes() .next() .map(|def| (def.lifetime.clone(), false)) .unwrap_or_else(|| (Lifetime::new("'a", Span::call_site()), true)); let (_, ty_generics, _) = generics.split_for_impl(); let mut generics = generics.clone(); generics.params.insert(0, parse_quote!(R: ::sqlx::Row)); if provided { generics.params.insert(0, parse_quote!(#lifetime)); } let predicates = &mut generics.make_where_clause().predicates; predicates.push(parse_quote!( ::std::primitive::usize: ::sqlx::ColumnIndex )); for field in fields { let ty = &field.ty; predicates.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>)); predicates.push(parse_quote!(#ty: ::sqlx::types::Type)); } let (impl_generics, _, where_clause) = generics.split_for_impl(); let gets = fields .iter() .enumerate() .map(|(idx, _)| quote!(row.try_get(#idx)?)); Ok(quote!( #[automatically_derived] impl #impl_generics ::sqlx::FromRow<#lifetime, R> for #ident #ty_generics #where_clause { fn from_row(row: &#lifetime R) -> ::sqlx::Result { ::std::result::Result::Ok(#ident ( #(#gets),* )) } } )) } sqlx-macros-core-0.8.3/src/derives/type.rs000064400000000000000000000214521046102023000166160ustar 00000000000000use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, check_weak_enum_attributes, parse_container_attributes, TypeName, }; use proc_macro2::{Ident, TokenStream}; use quote::{quote, quote_spanned}; use syn::punctuated::Punctuated; use syn::token::Comma; use syn::{ parse_quote, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, FieldsUnnamed, Variant, }; pub fn expand_derive_type(input: &DeriveInput) -> syn::Result { let attrs = parse_container_attributes(&input.attrs)?; match &input.data { // Newtype structs: // struct Foo(i32); Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), .. }) => { if unnamed.len() == 1 { expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap()) } else { Err(syn::Error::new_spanned( input, "structs with zero or more than one unnamed field are not supported", )) } } // Record types // struct Foo { foo: i32, bar: String } Data::Struct(DataStruct { fields: Fields::Named(FieldsNamed { named, .. }), .. }) => expand_derive_has_sql_type_struct(input, named), Data::Struct(DataStruct { fields: Fields::Unit, .. }) => Err(syn::Error::new_spanned( input, "unit structs are not supported", )), Data::Enum(DataEnum { variants, .. }) => match attrs.repr { // Enums that encode to/from integers (weak enums) Some(_) => expand_derive_has_sql_type_weak_enum(input, variants), // Enums that decode to/from strings (strong enums) None => expand_derive_has_sql_type_strong_enum(input, variants), }, Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), } } fn expand_derive_has_sql_type_transparent( input: &DeriveInput, field: &Field, ) -> syn::Result { let attr = check_transparent_attributes(input, field)?; let ident = &input.ident; let ty = &field.ty; let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); if attr.transparent { let mut generics = generics.clone(); let mut array_generics = generics.clone(); generics .params .insert(0, parse_quote!(DB: ::sqlx::Database)); generics .make_where_clause() .predicates .push(parse_quote!(#ty: ::sqlx::Type)); let (impl_generics, _, where_clause) = generics.split_for_impl(); array_generics .make_where_clause() .predicates .push(parse_quote!(#ty: ::sqlx::postgres::PgHasArrayType)); let (array_impl_generics, _, array_where_clause) = array_generics.split_for_impl(); let mut tokens = quote!( #[automatically_derived] impl #impl_generics ::sqlx::Type< DB > for #ident #ty_generics #where_clause { fn type_info() -> DB::TypeInfo { <#ty as ::sqlx::Type>::type_info() } fn compatible(ty: &DB::TypeInfo) -> ::std::primitive::bool { <#ty as ::sqlx::Type>::compatible(ty) } } ); if cfg!(feature = "postgres") && !attr.no_pg_array { tokens.extend(quote!( #[automatically_derived] impl #array_impl_generics ::sqlx::postgres::PgHasArrayType for #ident #ty_generics #array_where_clause { fn array_type_info() -> ::sqlx::postgres::PgTypeInfo { <#ty as ::sqlx::postgres::PgHasArrayType>::array_type_info() } } )); } return Ok(tokens); } let mut tts = TokenStream::new(); if cfg!(feature = "postgres") { let ty_name = type_name(ident, attr.type_name.as_ref()); tts.extend(quote!( #[automatically_derived] impl ::sqlx::Type<::sqlx::postgres::Postgres> for #ident #ty_generics { fn type_info() -> ::sqlx::postgres::PgTypeInfo { ::sqlx::postgres::PgTypeInfo::with_name(#ty_name) } } )); } Ok(tts) } fn expand_derive_has_sql_type_weak_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { let attrs = check_weak_enum_attributes(input, variants)?; let repr = attrs.repr.unwrap(); let ident = &input.ident; let mut ts = quote!( #[automatically_derived] impl ::sqlx::Type for #ident where #repr: ::sqlx::Type, { fn type_info() -> DB::TypeInfo { <#repr as ::sqlx::Type>::type_info() } fn compatible(ty: &DB::TypeInfo) -> bool { <#repr as ::sqlx::Type>::compatible(ty) } } ); if cfg!(feature = "postgres") && !attrs.no_pg_array { ts.extend(quote!( #[automatically_derived] impl ::sqlx::postgres::PgHasArrayType for #ident { fn array_type_info() -> ::sqlx::postgres::PgTypeInfo { <#repr as ::sqlx::postgres::PgHasArrayType>::array_type_info() } } )); } Ok(ts) } fn expand_derive_has_sql_type_strong_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { let attributes = check_strong_enum_attributes(input, variants)?; let ident = &input.ident; let mut tts = TokenStream::new(); if cfg!(feature = "mysql") { tts.extend(quote!( #[automatically_derived] impl ::sqlx::Type<::sqlx::MySql> for #ident { fn type_info() -> ::sqlx::mysql::MySqlTypeInfo { ::sqlx::mysql::MySqlTypeInfo::__enum() } } )); } if cfg!(feature = "postgres") { let ty_name = type_name(ident, attributes.type_name.as_ref()); tts.extend(quote!( #[automatically_derived] impl ::sqlx::Type<::sqlx::Postgres> for #ident { fn type_info() -> ::sqlx::postgres::PgTypeInfo { ::sqlx::postgres::PgTypeInfo::with_name(#ty_name) } } )); if !attributes.no_pg_array { tts.extend(quote!( #[automatically_derived] impl ::sqlx::postgres::PgHasArrayType for #ident { fn array_type_info() -> ::sqlx::postgres::PgTypeInfo { ::sqlx::postgres::PgTypeInfo::array_of(#ty_name) } } )); } } if cfg!(feature = "_sqlite") { tts.extend(quote!( #[automatically_derived] impl sqlx::Type<::sqlx::Sqlite> for #ident { fn type_info() -> ::sqlx::sqlite::SqliteTypeInfo { <::std::primitive::str as ::sqlx::Type>::type_info() } fn compatible(ty: &::sqlx::sqlite::SqliteTypeInfo) -> ::std::primitive::bool { <&::std::primitive::str as ::sqlx::types::Type>::compatible(ty) } } )); } Ok(tts) } fn expand_derive_has_sql_type_struct( input: &DeriveInput, fields: &Punctuated, ) -> syn::Result { let attributes = check_struct_attributes(input, fields)?; let ident = &input.ident; let mut tts = TokenStream::new(); if cfg!(feature = "postgres") { let ty_name = type_name(ident, attributes.type_name.as_ref()); tts.extend(quote!( #[automatically_derived] impl ::sqlx::Type<::sqlx::Postgres> for #ident { fn type_info() -> ::sqlx::postgres::PgTypeInfo { ::sqlx::postgres::PgTypeInfo::with_name(#ty_name) } } )); if !attributes.no_pg_array { tts.extend(quote!( #[automatically_derived] impl ::sqlx::postgres::PgHasArrayType for #ident { fn array_type_info() -> ::sqlx::postgres::PgTypeInfo { ::sqlx::postgres::PgTypeInfo::array_of(#ty_name) } } )); } } Ok(tts) } fn type_name(ident: &Ident, explicit_name: Option<&TypeName>) -> TokenStream { explicit_name.map(|tn| tn.get()).unwrap_or_else(|| { let s = ident.to_string(); quote_spanned!(ident.span()=> #s) }) } sqlx-macros-core-0.8.3/src/lib.rs000064400000000000000000000045051046102023000147420ustar 00000000000000//! Support crate for SQLx's proc macros. //! //! ### Note: Semver Exempt API //! The API of this crate is not meant for general use and does *not* follow Semantic Versioning. //! The only crate that follows Semantic Versioning in the project is the `sqlx` crate itself. //! If you are building a custom SQLx driver, you should pin an exact version of this and //! `sqlx-core` to avoid breakages: //! //! ```toml //! sqlx-core = "=0.6.2" //! sqlx-macros-core = "=0.6.2" //! ``` //! //! And then make releases in lockstep with `sqlx-core`. We recommend all driver crates, in-tree //! or otherwise, use the same version numbers as `sqlx-core` to avoid confusion. #![cfg_attr( any(sqlx_macros_unstable, procmacro2_semver_exempt), feature(track_path) )] #[cfg(feature = "macros")] use crate::query::QueryDriver; pub type Error = Box; pub type Result = std::result::Result; mod common; mod database; #[cfg(feature = "derive")] pub mod derives; #[cfg(feature = "macros")] pub mod query; #[cfg(feature = "macros")] // The compiler gives misleading help messages about `#[cfg(test)]` when this is just named `test`. pub mod test_attr; #[cfg(feature = "migrate")] pub mod migrate; #[cfg(feature = "macros")] pub const FOSS_DRIVERS: &[QueryDriver] = &[ #[cfg(feature = "mysql")] QueryDriver::new::(), #[cfg(feature = "postgres")] QueryDriver::new::(), #[cfg(feature = "_sqlite")] QueryDriver::new::(), ]; pub fn block_on(f: F) -> F::Output where F: std::future::Future, { #[cfg(feature = "_rt-tokio")] { use once_cell::sync::Lazy; use tokio::runtime::{self, Runtime}; // We need a single, persistent Tokio runtime since we're caching connections, // otherwise we'll get "IO driver has terminated" errors. static TOKIO_RT: Lazy = Lazy::new(|| { runtime::Builder::new_current_thread() .enable_all() .build() .expect("failed to start Tokio runtime") }); TOKIO_RT.block_on(f) } #[cfg(all(feature = "_rt-async-std", not(feature = "tokio")))] { async_std::task::block_on(f) } #[cfg(not(any(feature = "_rt-async-std", feature = "tokio")))] sqlx_core::rt::missing_rt(f) } sqlx-macros-core-0.8.3/src/migrate.rs000064400000000000000000000074351046102023000156310ustar 00000000000000#[cfg(any(sqlx_macros_unstable, procmacro2_semver_exempt))] extern crate proc_macro; use std::path::{Path, PathBuf}; use proc_macro2::TokenStream; use quote::{quote, ToTokens, TokenStreamExt}; use syn::LitStr; use sqlx_core::migrate::{Migration, MigrationType}; pub struct QuoteMigrationType(MigrationType); impl ToTokens for QuoteMigrationType { fn to_tokens(&self, tokens: &mut TokenStream) { let ts = match self.0 { MigrationType::Simple => quote! { ::sqlx::migrate::MigrationType::Simple }, MigrationType::ReversibleUp => quote! { ::sqlx::migrate::MigrationType::ReversibleUp }, MigrationType::ReversibleDown => { quote! { ::sqlx::migrate::MigrationType::ReversibleDown } } }; tokens.append_all(ts); } } struct QuoteMigration { migration: Migration, path: PathBuf, } impl ToTokens for QuoteMigration { fn to_tokens(&self, tokens: &mut TokenStream) { let Migration { version, description, migration_type, checksum, no_tx, .. } = &self.migration; let migration_type = QuoteMigrationType(*migration_type); let sql = self .path .canonicalize() .map_err(|e| { format!( "error canonicalizing migration path {}: {e}", self.path.display() ) }) .and_then(|path| { let path_str = path.to_str().ok_or_else(|| { format!( "migration path cannot be represented as a string: {}", self.path.display() ) })?; // this tells the compiler to watch this path for changes Ok(quote! { include_str!(#path_str) }) }) .unwrap_or_else(|e| quote! { compile_error!(#e) }); let ts = quote! { ::sqlx::migrate::Migration { version: #version, description: ::std::borrow::Cow::Borrowed(#description), migration_type: #migration_type, sql: ::std::borrow::Cow::Borrowed(#sql), no_tx: #no_tx, checksum: ::std::borrow::Cow::Borrowed(&[ #(#checksum),* ]), } }; tokens.append_all(ts); } } pub fn expand_migrator_from_lit_dir(dir: LitStr) -> crate::Result { expand_migrator_from_dir(&dir.value(), dir.span()) } pub(crate) fn expand_migrator_from_dir( dir: &str, err_span: proc_macro2::Span, ) -> crate::Result { let path = crate::common::resolve_path(dir, err_span)?; expand_migrator(&path) } pub(crate) fn expand_migrator(path: &Path) -> crate::Result { let path = path.canonicalize().map_err(|e| { format!( "error canonicalizing migration directory {}: {e}", path.display() ) })?; // Use the same code path to resolve migrations at compile time and runtime. let migrations = sqlx_core::migrate::resolve_blocking(&path)? .into_iter() .map(|(migration, path)| QuoteMigration { migration, path }); #[cfg(any(sqlx_macros_unstable, procmacro2_semver_exempt))] { let path = path.to_str().ok_or_else(|| { format!( "migration directory path cannot be represented as a string: {:?}", path ) })?; proc_macro::tracked_path::path(path); } Ok(quote! { ::sqlx::migrate::Migrator { migrations: ::std::borrow::Cow::Borrowed(&[ #(#migrations),* ]), ..::sqlx::migrate::Migrator::DEFAULT } }) } sqlx-macros-core-0.8.3/src/query/args.rs000064400000000000000000000131321046102023000162710ustar 00000000000000use crate::database::DatabaseExt; use crate::query::QueryMacroInput; use either::Either; use proc_macro2::TokenStream; use quote::{format_ident, quote, quote_spanned}; use sqlx_core::describe::Describe; use syn::spanned::Spanned; use syn::{Expr, ExprCast, ExprGroup, Type}; /// Returns a tokenstream which typechecks the arguments passed to the macro /// and binds them to `DB::Arguments` with the ident `query_args`. pub fn quote_args( input: &QueryMacroInput, info: &Describe, ) -> crate::Result { let db_path = DB::db_path(); if input.arg_exprs.is_empty() { return Ok(quote! { let query_args = ::core::result::Result::<_, ::sqlx::error::BoxDynError>::Ok(<#db_path as ::sqlx::database::Database>::Arguments::<'_>::default()); }); } let arg_names = (0..input.arg_exprs.len()) .map(|i| format_ident!("arg{}", i)) .collect::>(); let arg_name = &arg_names; let arg_expr = input.arg_exprs.iter().cloned().map(strip_wildcard); let arg_bindings = quote! { #(let #arg_name = &(#arg_expr);)* }; let args_check = match info.parameters() { None | Some(Either::Right(_)) => { // all we can do is check arity which we did TokenStream::new() } Some(Either::Left(_)) if !input.checked => { // this is an `*_unchecked!()` macro invocation TokenStream::new() } Some(Either::Left(params)) => { params .iter() .zip(arg_names.iter().zip(&input.arg_exprs)) .enumerate() .map(|(i, (param_ty, (name, expr)))| -> crate::Result<_> { if get_type_override(expr).is_some() { // cast will fail to compile if the type does not match // and we strip casts to wildcard return Ok(quote!()); } let param_ty = DB::param_type_for_id(param_ty) .ok_or_else(|| { if let Some(feature_gate) = DB::get_feature_gate(param_ty) { format!( "optional sqlx feature `{}` required for type {} of param #{}", feature_gate, param_ty, i + 1, ) } else { format!("unsupported type {} for param #{}", param_ty, i + 1) } })? .parse::() .map_err(|_| format!("Rust type mapping for {param_ty} not parsable"))?; Ok(quote_spanned!(expr.span() => // this shouldn't actually run #[allow(clippy::missing_panics_doc, clippy::unreachable)] if false { use ::sqlx::ty_match::{WrapSameExt as _, MatchBorrowExt as _}; // evaluate the expression only once in case it contains moves let expr = ::sqlx::ty_match::dupe_value(#name); // if `expr` is `Option`, get `Option<$ty>`, otherwise `$ty` let ty_check = ::sqlx::ty_match::WrapSame::<#param_ty, _>::new(&expr).wrap_same(); // if `expr` is `&str`, convert `String` to `&str` let (mut _ty_check, match_borrow) = ::sqlx::ty_match::MatchBorrow::new(ty_check, &expr); _ty_check = match_borrow.match_borrow(); // this causes move-analysis to effectively ignore this block ::std::unreachable!(); } )) }) .collect::>()? } }; let args_count = input.arg_exprs.len(); Ok(quote! { #arg_bindings #args_check let mut query_args = <#db_path as ::sqlx::database::Database>::Arguments::<'_>::default(); query_args.reserve( #args_count, 0 #(+ ::sqlx::encode::Encode::<#db_path>::size_hint(#arg_name))* ); let query_args = ::core::result::Result::<_, ::sqlx::error::BoxDynError>::Ok(query_args) #(.and_then(move |mut query_args| query_args.add(#arg_name).map(move |()| query_args) ))*; }) } fn get_type_override(expr: &Expr) -> Option<&Type> { match expr { Expr::Group(group) => get_type_override(&group.expr), Expr::Cast(cast) => Some(&cast.ty), _ => None, } } fn strip_wildcard(expr: Expr) -> Expr { match expr { Expr::Group(ExprGroup { attrs, group_token, expr, }) => Expr::Group(ExprGroup { attrs, group_token, expr: Box::new(strip_wildcard(*expr)), }), // we want to retain casts if they semantically matter Expr::Cast(ExprCast { attrs, expr, as_token, ty, }) => match *ty { // cast to wildcard `_` will produce weird errors; we interpret it as taking the value as-is Type::Infer(_) => *expr, _ => Expr::Cast(ExprCast { attrs, expr, as_token, ty, }), }, _ => expr, } } sqlx-macros-core-0.8.3/src/query/data.rs000064400000000000000000000141701046102023000162510ustar 00000000000000use std::collections::HashMap; use std::fmt::{Debug, Display, Formatter}; use std::fs; use std::io::Write as _; use std::marker::PhantomData; use std::path::{Path, PathBuf}; use std::sync::Mutex; use once_cell::sync::Lazy; use serde::{Serialize, Serializer}; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use crate::database::DatabaseExt; #[derive(serde::Serialize)] #[serde(bound(serialize = "Describe: serde::Serialize"))] #[derive(Debug)] pub struct QueryData { db_name: SerializeDbName, #[allow(dead_code)] pub(super) query: String, pub(super) describe: Describe, pub(super) hash: String, } impl QueryData { pub fn from_describe(query: &str, describe: Describe) -> Self { QueryData { db_name: SerializeDbName::default(), query: query.into(), describe, hash: hash_string(query), } } } struct SerializeDbName(PhantomData); impl Default for SerializeDbName { fn default() -> Self { SerializeDbName(PhantomData) } } impl Debug for SerializeDbName { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_tuple("SerializeDbName").field(&DB::NAME).finish() } } impl Display for SerializeDbName { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.pad(DB::NAME) } } impl Serialize for SerializeDbName { fn serialize(&self, serializer: S) -> Result where S: Serializer, { serializer.serialize_str(DB::NAME) } } static OFFLINE_DATA_CACHE: Lazy>> = Lazy::new(Default::default); /// Offline query data #[derive(Clone, serde::Deserialize)] pub struct DynQueryData { pub db_name: String, pub query: String, pub describe: serde_json::Value, pub hash: String, } impl DynQueryData { /// Loads a query given the path to its "query-.json" file. Subsequent calls for the same /// path are retrieved from an in-memory cache. pub fn from_data_file(path: impl AsRef, query: &str) -> crate::Result { let path = path.as_ref(); let mut cache = OFFLINE_DATA_CACHE .lock() // Just reset the cache on error .unwrap_or_else(|posion_err| { let mut guard = posion_err.into_inner(); *guard = Default::default(); guard }); if let Some(cached) = cache.get(path).cloned() { if query != cached.query { return Err("hash collision for saved query data".into()); } return Ok(cached); } #[cfg(procmacro2_semver_exempt)] { let path = path.as_ref().canonicalize()?; let path = path.to_str().ok_or_else(|| { format!( "query-.json path cannot be represented as a string: {:?}", path ) })?; proc_macro::tracked_path::path(path); } let offline_data_contents = fs::read_to_string(path) .map_err(|e| format!("failed to read saved query path {}: {}", path.display(), e))?; let dyn_data: DynQueryData = serde_json::from_str(&offline_data_contents)?; if query != dyn_data.query { return Err("hash collision for saved query data".into()); } let _ = cache.insert(path.to_owned(), dyn_data.clone()); Ok(dyn_data) } } impl QueryData where Describe: serde::Serialize + serde::de::DeserializeOwned, { pub fn from_dyn_data(dyn_data: DynQueryData) -> crate::Result { assert!(!dyn_data.db_name.is_empty()); assert!(!dyn_data.hash.is_empty()); if DB::NAME == dyn_data.db_name { let describe: Describe = serde_json::from_value(dyn_data.describe)?; Ok(QueryData { db_name: SerializeDbName::default(), query: dyn_data.query, describe, hash: dyn_data.hash, }) } else { Err(format!( "expected query data for {}, got data for {}", DB::NAME, dyn_data.db_name ) .into()) } } pub(super) fn save_in(&self, dir: impl AsRef) -> crate::Result<()> { use std::io::ErrorKind; let path = dir.as_ref().join(format!("query-{}.json", self.hash)); match std::fs::remove_file(&path) { Ok(()) => {} Err(err) if matches!( err.kind(), ErrorKind::NotFound | ErrorKind::PermissionDenied, ) => {} Err(err) => return Err(format!("failed to delete {path:?}: {err:?}").into()), } let mut file = match std::fs::OpenOptions::new() .write(true) .create_new(true) .open(&path) { Ok(file) => file, // We overlapped with a concurrent invocation and the other one succeeded. Err(err) if matches!(err.kind(), ErrorKind::AlreadyExists) => return Ok(()), Err(err) => { return Err(format!("failed to exclusively create {path:?}: {err:?}").into()) } }; let data = serde_json::to_string_pretty(self) .map_err(|err| format!("failed to serialize query data: {err:?}"))?; file.write_all(data.as_bytes()) .map_err(|err| format!("failed to write query data to file: {err:?}"))?; // Ensure there is a newline at the end of the JSON file to avoid // accidental modification by IDE and make github diff tool happier. file.write_all(b"\n") .map_err(|err| format!("failed to append a newline to file: {err:?}"))?; Ok(()) } } pub(super) fn hash_string(query: &str) -> String { // picked `sha2` because it's already in the dependency tree for both MySQL and Postgres use sha2::{Digest, Sha256}; hex::encode(Sha256::digest(query.as_bytes())) } sqlx-macros-core-0.8.3/src/query/input.rs000064400000000000000000000114201046102023000164720ustar 00000000000000use std::fs; use proc_macro2::{Ident, Span}; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::{Expr, LitBool, LitStr, Token}; use syn::{ExprArray, Type}; /// Macro input shared by `query!()` and `query_file!()` pub struct QueryMacroInput { pub(super) sql: String, pub(super) src_span: Span, pub(super) record_type: RecordType, pub(super) arg_exprs: Vec, pub(super) checked: bool, pub(super) file_path: Option, } enum QuerySrc { String(String), File(String), } pub enum RecordType { Given(Type), Scalar, Generated, } impl Parse for QueryMacroInput { fn parse(input: ParseStream) -> syn::Result { let mut query_src: Option<(QuerySrc, Span)> = None; let mut args: Option> = None; let mut record_type = RecordType::Generated; let mut checked = true; let mut expect_comma = false; while !input.is_empty() { if expect_comma { let _ = input.parse::()?; } let key: Ident = input.parse()?; let _ = input.parse::()?; if key == "source" { let span = input.span(); let query_str = Punctuated::::parse_separated_nonempty(input)? .iter() .map(LitStr::value) .collect(); query_src = Some((QuerySrc::String(query_str), span)); } else if key == "source_file" { let lit_str = input.parse::()?; query_src = Some((QuerySrc::File(lit_str.value()), lit_str.span())); } else if key == "args" { let exprs = input.parse::()?; args = Some(exprs.elems.into_iter().collect()) } else if key == "record" { if !matches!(record_type, RecordType::Generated) { return Err(input.error("colliding `scalar` or `record` key")); } record_type = RecordType::Given(input.parse()?); } else if key == "scalar" { if !matches!(record_type, RecordType::Generated) { return Err(input.error("colliding `scalar` or `record` key")); } // we currently expect only `scalar = _` // a `query_as_scalar!()` variant seems less useful than just overriding the type // of the column in SQL input.parse::()?; record_type = RecordType::Scalar; } else if key == "checked" { let lit_bool = input.parse::()?; checked = lit_bool.value; } else { let message = format!("unexpected input key: {key}"); return Err(syn::Error::new_spanned(key, message)); } expect_comma = true; } let (src, src_span) = query_src.ok_or_else(|| input.error("expected `source` or `source_file` key"))?; let arg_exprs = args.unwrap_or_default(); let file_path = src.file_path(src_span)?; Ok(QueryMacroInput { sql: src.resolve(src_span)?, src_span, record_type, arg_exprs, checked, file_path, }) } } impl QuerySrc { /// If the query source is a file, read it to a string. Otherwise return the query string. fn resolve(self, source_span: Span) -> syn::Result { match self { QuerySrc::String(string) => Ok(string), QuerySrc::File(file) => read_file_src(&file, source_span), } } fn file_path(&self, source_span: Span) -> syn::Result> { if let QuerySrc::File(ref file) = *self { let path = crate::common::resolve_path(file, source_span)? .canonicalize() .map_err(|e| syn::Error::new(source_span, e))?; Ok(Some( path.to_str() .ok_or_else(|| { syn::Error::new( source_span, "query file path cannot be represented as a string", ) })? .to_string(), )) } else { Ok(None) } } } fn read_file_src(source: &str, source_span: Span) -> syn::Result { let file_path = crate::common::resolve_path(source, source_span)?; fs::read_to_string(&file_path).map_err(|e| { syn::Error::new( source_span, format!( "failed to read query file at {}: {}", file_path.display(), e ), ) }) } sqlx-macros-core-0.8.3/src/query/mod.rs000064400000000000000000000300031046102023000161100ustar 00000000000000use std::path::PathBuf; use std::sync::{Arc, Mutex}; use std::{fs, io}; use once_cell::sync::Lazy; use proc_macro2::TokenStream; use syn::Type; pub use input::QueryMacroInput; use quote::{format_ident, quote}; use sqlx_core::database::Database; use sqlx_core::{column::Column, describe::Describe, type_info::TypeInfo}; use crate::database::DatabaseExt; use crate::query::data::{hash_string, DynQueryData, QueryData}; use crate::query::input::RecordType; use either::Either; use url::Url; mod args; mod data; mod input; mod output; #[derive(Copy, Clone)] pub struct QueryDriver { db_name: &'static str, url_schemes: &'static [&'static str], expand: fn(QueryMacroInput, QueryDataSource) -> crate::Result, } impl QueryDriver { pub const fn new() -> Self where Describe: serde::Serialize + serde::de::DeserializeOwned, { QueryDriver { db_name: DB::NAME, url_schemes: DB::URL_SCHEMES, expand: expand_with::, } } } pub enum QueryDataSource<'a> { Live { database_url: &'a str, database_url_parsed: Url, }, Cached(DynQueryData), } impl<'a> QueryDataSource<'a> { pub fn live(database_url: &'a str) -> crate::Result { Ok(QueryDataSource::Live { database_url, database_url_parsed: database_url.parse()?, }) } pub fn matches_driver(&self, driver: &QueryDriver) -> bool { match self { Self::Live { database_url_parsed, .. } => driver.url_schemes.contains(&database_url_parsed.scheme()), Self::Cached(dyn_data) => dyn_data.db_name == driver.db_name, } } } struct Metadata { #[allow(unused)] manifest_dir: PathBuf, offline: bool, database_url: Option, workspace_root: Arc>>, } impl Metadata { pub fn workspace_root(&self) -> PathBuf { let mut root = self.workspace_root.lock().unwrap(); if root.is_none() { use serde::Deserialize; use std::process::Command; let cargo = env("CARGO").expect("`CARGO` must be set"); let output = Command::new(cargo) .args(["metadata", "--format-version=1", "--no-deps"]) .current_dir(&self.manifest_dir) .env_remove("__CARGO_FIX_PLZ") .output() .expect("Could not fetch metadata"); #[derive(Deserialize)] struct CargoMetadata { workspace_root: PathBuf, } let metadata: CargoMetadata = serde_json::from_slice(&output.stdout).expect("Invalid `cargo metadata` output"); *root = Some(metadata.workspace_root); } root.clone().unwrap() } } // If we are in a workspace, lookup `workspace_root` since `CARGO_MANIFEST_DIR` won't // reflect the workspace dir: https://github.com/rust-lang/cargo/issues/3946 static METADATA: Lazy = Lazy::new(|| { let manifest_dir: PathBuf = env("CARGO_MANIFEST_DIR") .expect("`CARGO_MANIFEST_DIR` must be set") .into(); // If a .env file exists at CARGO_MANIFEST_DIR, load environment variables from this, // otherwise fallback to default dotenv behaviour. let env_path = manifest_dir.join(".env"); #[cfg_attr(not(procmacro2_semver_exempt), allow(unused_variables))] let env_path = if env_path.exists() { let res = dotenvy::from_path(&env_path); if let Err(e) = res { panic!("failed to load environment from {env_path:?}, {e}"); } Some(env_path) } else { dotenvy::dotenv().ok() }; // tell the compiler to watch the `.env` for changes, if applicable #[cfg(procmacro2_semver_exempt)] if let Some(env_path) = env_path.as_ref().and_then(|path| path.to_str()) { proc_macro::tracked_path::path(env_path); } let offline = env("SQLX_OFFLINE") .map(|s| s.eq_ignore_ascii_case("true") || s == "1") .unwrap_or(false); let database_url = env("DATABASE_URL").ok(); Metadata { manifest_dir, offline, database_url, workspace_root: Arc::new(Mutex::new(None)), } }); pub fn expand_input<'a>( input: QueryMacroInput, drivers: impl IntoIterator, ) -> crate::Result { let data_source = match &*METADATA { Metadata { offline: false, database_url: Some(db_url), .. } => QueryDataSource::live(db_url)?, Metadata { offline, .. } => { // Try load the cached query metadata file. let filename = format!("query-{}.json", hash_string(&input.sql)); // Check SQLX_OFFLINE_DIR, then local .sqlx, then workspace .sqlx. let dirs = [ || env("SQLX_OFFLINE_DIR").ok().map(PathBuf::from), || Some(METADATA.manifest_dir.join(".sqlx")), || Some(METADATA.workspace_root().join(".sqlx")), ]; let Some(data_file_path) = dirs .iter() .filter_map(|path| path()) .map(|path| path.join(&filename)) .find(|path| path.exists()) else { return Err( if *offline { "`SQLX_OFFLINE=true` but there is no cached data for this query, run `cargo sqlx prepare` to update the query cache or unset `SQLX_OFFLINE`" } else { "set `DATABASE_URL` to use query macros online, or run `cargo sqlx prepare` to update the query cache" }.into() ); }; QueryDataSource::Cached(DynQueryData::from_data_file(&data_file_path, &input.sql)?) } }; for driver in drivers { if data_source.matches_driver(driver) { return (driver.expand)(input, data_source); } } match data_source { QueryDataSource::Live { database_url_parsed, .. } => Err(format!( "no database driver found matching URL scheme {:?}; the corresponding Cargo feature may need to be enabled", database_url_parsed.scheme() ).into()), QueryDataSource::Cached(data) => { Err(format!( "found cached data for database {:?} but no matching driver; the corresponding Cargo feature may need to be enabled", data.db_name ).into()) } } } fn expand_with( input: QueryMacroInput, data_source: QueryDataSource, ) -> crate::Result where Describe: DescribeExt, { let (query_data, offline): (QueryData, bool) = match data_source { QueryDataSource::Cached(dyn_data) => (QueryData::from_dyn_data(dyn_data)?, true), QueryDataSource::Live { database_url, .. } => { let describe = DB::describe_blocking(&input.sql, database_url)?; (QueryData::from_describe(&input.sql, describe), false) } }; expand_with_data(input, query_data, offline) } // marker trait for `Describe` that lets us conditionally require it to be `Serialize + Deserialize` trait DescribeExt: serde::Serialize + serde::de::DeserializeOwned {} impl DescribeExt for Describe where Describe: serde::Serialize + serde::de::DeserializeOwned { } fn expand_with_data( input: QueryMacroInput, data: QueryData, offline: bool, ) -> crate::Result where Describe: DescribeExt, { // validate at the minimum that our args match the query's input parameters let num_parameters = match data.describe.parameters() { Some(Either::Left(params)) => Some(params.len()), Some(Either::Right(num)) => Some(num), None => None, }; if let Some(num) = num_parameters { if num != input.arg_exprs.len() { return Err( format!("expected {} parameters, got {}", num, input.arg_exprs.len()).into(), ); } } let args_tokens = args::quote_args(&input, &data.describe)?; let query_args = format_ident!("query_args"); let output = if data .describe .columns() .iter() .all(|it| it.type_info().is_void()) { let db_path = DB::db_path(); let sql = &input.sql; quote! { ::sqlx::__query_with_result::<#db_path, _>(#sql, #query_args) } } else { match input.record_type { RecordType::Generated => { let columns = output::columns_to_rust::(&data.describe)?; let record_name: Type = syn::parse_str("Record").unwrap(); for rust_col in &columns { if rust_col.type_.is_wildcard() { return Err( "wildcard overrides are only allowed with an explicit record type, \ e.g. `query_as!()` and its variants" .into(), ); } } let record_fields = columns .iter() .map(|output::RustColumn { ident, type_, .. }| quote!(#ident: #type_,)); let mut record_tokens = quote! { #[derive(Debug)] #[allow(non_snake_case)] struct #record_name { #(#record_fields)* } }; record_tokens.extend(output::quote_query_as::( &input, &record_name, &query_args, &columns, )); record_tokens } RecordType::Given(ref out_ty) => { let columns = output::columns_to_rust::(&data.describe)?; output::quote_query_as::(&input, out_ty, &query_args, &columns) } RecordType::Scalar => { output::quote_query_scalar::(&input, &query_args, &data.describe)? } } }; let ret_tokens = quote! { { #[allow(clippy::all)] { use ::sqlx::Arguments as _; #args_tokens #output } } }; // Store query metadata only if offline support is enabled but the current build is online. // If the build is offline, the cache is our input so it's pointless to also write data for it. if !offline { // Only save query metadata if SQLX_OFFLINE_DIR is set manually or by `cargo sqlx prepare`. // Note: in a cargo workspace this path is relative to the root. if let Ok(dir) = env("SQLX_OFFLINE_DIR") { let path = PathBuf::from(&dir); match fs::metadata(&path) { Err(e) => { if e.kind() != io::ErrorKind::NotFound { // Can't obtain information about .sqlx return Err(format!("{e}: {dir}").into()); } // .sqlx doesn't exist. return Err(format!("sqlx offline path does not exist: {dir}").into()); } Ok(meta) => { if !meta.is_dir() { return Err(format!( "sqlx offline path exists, but is not a directory: {dir}" ) .into()); } // .sqlx exists and is a directory, store data. data.save_in(path)?; } } } } Ok(ret_tokens) } /// Get the value of an environment variable, telling the compiler about it if applicable. fn env(name: &str) -> Result { #[cfg(procmacro2_semver_exempt)] { proc_macro::tracked_env::var(name) } #[cfg(not(procmacro2_semver_exempt))] { std::env::var(name) } } sqlx-macros-core-0.8.3/src/query/output.rs000064400000000000000000000247261046102023000167100ustar 00000000000000use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::Type; use sqlx_core::column::Column; use sqlx_core::describe::Describe; use crate::database::DatabaseExt; use crate::query::QueryMacroInput; use sqlx_core::type_checking::TypeChecking; use std::fmt::{self, Display, Formatter}; use syn::parse::{Parse, ParseStream}; use syn::Token; pub struct RustColumn { pub(super) ident: Ident, pub(super) var_name: Ident, pub(super) type_: ColumnType, } pub(super) enum ColumnType { Exact(TokenStream), Wildcard, OptWildcard, } impl ColumnType { pub(super) fn is_wildcard(&self) -> bool { !matches!(self, ColumnType::Exact(_)) } } impl ToTokens for ColumnType { fn to_tokens(&self, tokens: &mut TokenStream) { tokens.append_all(match &self { ColumnType::Exact(type_) => type_.clone().into_iter(), ColumnType::Wildcard => quote! { _ }.into_iter(), ColumnType::OptWildcard => quote! { ::std::option::Option<_> }.into_iter(), }) } } struct DisplayColumn<'a> { // zero-based index, converted to 1-based number idx: usize, name: &'a str, } struct ColumnDecl { ident: Ident, r#override: ColumnOverride, } struct ColumnOverride { nullability: ColumnNullabilityOverride, type_: ColumnTypeOverride, } #[derive(PartialEq)] enum ColumnNullabilityOverride { NonNull, Nullable, None, } enum ColumnTypeOverride { Exact(Type), Wildcard, None, } impl Display for DisplayColumn<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "column #{} ({:?})", self.idx + 1, self.name) } } pub fn columns_to_rust(describe: &Describe) -> crate::Result> { (0..describe.columns().len()) .map(|i| column_to_rust(describe, i)) .collect::>>() } fn column_to_rust(describe: &Describe, i: usize) -> crate::Result { let column = &describe.columns()[i]; // add raw prefix to all identifiers let decl = ColumnDecl::parse(column.name()) .map_err(|e| format!("column name {:?} is invalid: {}", column.name(), e))?; let ColumnOverride { nullability, type_ } = decl.r#override; let nullable = match nullability { ColumnNullabilityOverride::NonNull => false, ColumnNullabilityOverride::Nullable => true, ColumnNullabilityOverride::None => describe.nullable(i).unwrap_or(true), }; let type_ = match (type_, nullable) { (ColumnTypeOverride::Exact(type_), false) => ColumnType::Exact(type_.to_token_stream()), (ColumnTypeOverride::Exact(type_), true) => { ColumnType::Exact(quote! { ::std::option::Option<#type_> }) } (ColumnTypeOverride::Wildcard, false) => ColumnType::Wildcard, (ColumnTypeOverride::Wildcard, true) => ColumnType::OptWildcard, (ColumnTypeOverride::None, _) => { let type_ = get_column_type::(i, column); if !nullable { ColumnType::Exact(type_) } else { ColumnType::Exact(quote! { ::std::option::Option<#type_> }) } } }; Ok(RustColumn { // prefix the variable name we use in `quote_query_as!()` so it doesn't conflict // https://github.com/launchbadge/sqlx/issues/1322 var_name: quote::format_ident!("sqlx_query_as_{}", decl.ident), ident: decl.ident, type_, }) } pub fn quote_query_as( input: &QueryMacroInput, out_ty: &Type, bind_args: &Ident, columns: &[RustColumn], ) -> TokenStream { let instantiations = columns.iter().enumerate().map( |( i, RustColumn { var_name, type_, .. }, )| { match (input.checked, type_) { // we guarantee the type is valid so we can skip the runtime check (true, ColumnType::Exact(type_)) => quote! { // binding to a `let` avoids confusing errors about // "try expression alternatives have incompatible types" // it doesn't seem to hurt inference in the other branches #[allow(non_snake_case)] let #var_name = row.try_get_unchecked::<#type_, _>(#i)?.into(); }, // type was overridden to be a wildcard so we fallback to the runtime check (true, ColumnType::Wildcard) => quote! ( #[allow(non_snake_case)] let #var_name = row.try_get(#i)?; ), (true, ColumnType::OptWildcard) => { quote! ( #[allow(non_snake_case)] let #var_name = row.try_get::<::std::option::Option<_>, _>(#i)?; ) } // macro is the `_unchecked!()` variant so this will die in decoding if it's wrong (false, _) => quote!( #[allow(non_snake_case)] let #var_name = row.try_get_unchecked(#i)?; ), } }, ); let ident = columns.iter().map(|col| &col.ident); let var_name = columns.iter().map(|col| &col.var_name); let db_path = DB::db_path(); let row_path = DB::row_path(); // if this query came from a file, use `include_str!()` to tell the compiler where it came from let sql = if let Some(ref path) = &input.file_path { quote::quote_spanned! { input.src_span => include_str!(#path) } } else { let sql = &input.sql; quote! { #sql } }; quote! { ::sqlx::__query_with_result::<#db_path, _>(#sql, #bind_args).try_map(|row: #row_path| { use ::sqlx::Row as _; #(#instantiations)* ::std::result::Result::Ok(#out_ty { #(#ident: #var_name),* }) }) } } pub fn quote_query_scalar( input: &QueryMacroInput, bind_args: &Ident, describe: &Describe, ) -> crate::Result { let columns = describe.columns(); if columns.len() != 1 { return Err(syn::Error::new( input.src_span, format!("expected exactly 1 column, got {}", columns.len()), ) .into()); } // attempt to parse a column override, otherwise fall back to the inferred type of the column let ty = if let Ok(rust_col) = column_to_rust(describe, 0) { rust_col.type_.to_token_stream() } else if input.checked { let ty = get_column_type::(0, &columns[0]); if describe.nullable(0).unwrap_or(true) { quote! { ::std::option::Option<#ty> } } else { ty } } else { quote! { _ } }; let db = DB::db_path(); let query = &input.sql; Ok(quote! { ::sqlx::__query_scalar_with_result::<#db, #ty, _>(#query, #bind_args) }) } fn get_column_type(i: usize, column: &DB::Column) -> TokenStream { let type_info = column.type_info(); ::return_type_for_id(type_info).map_or_else( || { let message = if let Some(feature_gate) = ::get_feature_gate(type_info) { format!( "optional sqlx feature `{feat}` required for type {ty} of {col}", ty = &type_info, feat = feature_gate, col = DisplayColumn { idx: i, name: column.name() } ) } else { format!( "unsupported type {ty} of {col}", ty = type_info, col = DisplayColumn { idx: i, name: column.name() } ) }; syn::Error::new(Span::call_site(), message).to_compile_error() }, |t| t.parse().unwrap(), ) } impl ColumnDecl { fn parse(col_name: &str) -> crate::Result { // find the end of the identifier because we want to use our own logic to parse it // if we tried to feed this into `syn::parse_str()` we might get an un-great error // for some kinds of invalid identifiers let (ident, remainder) = if let Some(i) = col_name.find(&[':', '!', '?'][..]) { let (ident, remainder) = col_name.split_at(i); (parse_ident(ident)?, remainder) } else { (parse_ident(col_name)?, "") }; Ok(ColumnDecl { ident, r#override: if !remainder.is_empty() { syn::parse_str(remainder)? } else { ColumnOverride { nullability: ColumnNullabilityOverride::None, type_: ColumnTypeOverride::None, } }, }) } } impl Parse for ColumnOverride { fn parse(input: ParseStream) -> syn::Result { let lookahead = input.lookahead1(); let nullability = if lookahead.peek(Token![!]) { input.parse::()?; ColumnNullabilityOverride::NonNull } else if lookahead.peek(Token![?]) { input.parse::()?; ColumnNullabilityOverride::Nullable } else { ColumnNullabilityOverride::None }; let type_ = if input.lookahead1().peek(Token![:]) { input.parse::()?; let ty = Type::parse(input)?; if let Type::Infer(_) = ty { ColumnTypeOverride::Wildcard } else { ColumnTypeOverride::Exact(ty) } } else { ColumnTypeOverride::None }; Ok(Self { nullability, type_ }) } } fn parse_ident(name: &str) -> crate::Result { // workaround for the following issue (it's semi-fixed but still spits out extra diagnostics) // https://github.com/dtolnay/syn/issues/749#issuecomment-575451318 let is_valid_ident = !name.is_empty() && name.starts_with(|c: char| c.is_alphabetic() || c == '_') && name.chars().all(|c| c.is_alphanumeric() || c == '_'); if is_valid_ident { let ident = String::from("r#") + name; if let Ok(ident) = syn::parse_str(&ident) { return Ok(ident); } } Err(format!("{name:?} is not a valid Rust identifier").into()) } sqlx-macros-core-0.8.3/src/test_attr.rs000064400000000000000000000333011046102023000162010ustar 00000000000000use proc_macro2::TokenStream; use quote::quote; use syn::parse::Parser; #[cfg(feature = "migrate")] struct Args { fixtures: Vec<(FixturesType, Vec)>, migrations: MigrationsOpt, } #[cfg(feature = "migrate")] enum FixturesType { None, RelativePath, CustomRelativePath(syn::LitStr), ExplicitPath, } #[cfg(feature = "migrate")] enum MigrationsOpt { InferredPath, ExplicitPath(syn::LitStr), ExplicitMigrator(syn::Path), Disabled, } type AttributeArgs = syn::punctuated::Punctuated; pub fn expand(args: TokenStream, input: syn::ItemFn) -> crate::Result { let parser = AttributeArgs::parse_terminated; let args = parser.parse2(args)?; if input.sig.inputs.is_empty() { if !args.is_empty() { if cfg!(not(feature = "migrate")) { return Err(syn::Error::new_spanned( args.first().unwrap(), "control attributes are not allowed unless \ the `migrate` feature is enabled and \ automatic test DB management is used; see docs", ) .into()); } return Err(syn::Error::new_spanned( args.first().unwrap(), "control attributes are not allowed unless \ automatic test DB management is used; see docs", ) .into()); } return Ok(expand_simple(input)); } #[cfg(feature = "migrate")] return expand_advanced(args, input); #[cfg(not(feature = "migrate"))] return Err(syn::Error::new_spanned(input, "`migrate` feature required").into()); } fn expand_simple(input: syn::ItemFn) -> TokenStream { let ret = &input.sig.output; let name = &input.sig.ident; let body = &input.block; let attrs = &input.attrs; quote! { #[::core::prelude::v1::test] #(#attrs)* fn #name() #ret { ::sqlx::test_block_on(async { #body }) } } } #[cfg(feature = "migrate")] fn expand_advanced(args: AttributeArgs, input: syn::ItemFn) -> crate::Result { let ret = &input.sig.output; let name = &input.sig.ident; let inputs = &input.sig.inputs; let body = &input.block; let attrs = &input.attrs; let args = parse_args(args)?; let fn_arg_types = inputs.iter().map(|_| quote! { _ }); let mut fixtures = Vec::new(); for (fixture_type, fixtures_local) in args.fixtures { let mut res = match fixture_type { FixturesType::None => vec![], FixturesType::RelativePath => fixtures_local .into_iter() .map(|fixture| { let mut fixture_str = fixture.value(); add_sql_extension_if_missing(&mut fixture_str); let path = format!("fixtures/{}", fixture_str); quote! { ::sqlx::testing::TestFixture { path: #path, contents: include_str!(#path), } } }) .collect(), FixturesType::CustomRelativePath(path) => fixtures_local .into_iter() .map(|fixture| { let mut fixture_str = fixture.value(); add_sql_extension_if_missing(&mut fixture_str); let path = format!("{}/{}", path.value(), fixture_str); quote! { ::sqlx::testing::TestFixture { path: #path, contents: include_str!(#path), } } }) .collect(), FixturesType::ExplicitPath => fixtures_local .into_iter() .map(|fixture| { let path = fixture.value(); quote! { ::sqlx::testing::TestFixture { path: #path, contents: include_str!(#path), } } }) .collect(), }; fixtures.append(&mut res) } let migrations = match args.migrations { MigrationsOpt::ExplicitPath(path) => { let migrator = crate::migrate::expand_migrator_from_lit_dir(path)?; quote! { args.migrator(&#migrator); } } MigrationsOpt::InferredPath if !inputs.is_empty() => { let migrations_path = crate::common::resolve_path("./migrations", proc_macro2::Span::call_site())?; if migrations_path.is_dir() { let migrator = crate::migrate::expand_migrator(&migrations_path)?; quote! { args.migrator(&#migrator); } } else { quote! {} } } MigrationsOpt::ExplicitMigrator(path) => { quote! { args.migrator(&#path); } } _ => quote! {}, }; Ok(quote! { #(#attrs)* #[::core::prelude::v1::test] fn #name() #ret { async fn #name(#inputs) #ret { #body } let mut args = ::sqlx::testing::TestArgs::new(concat!(module_path!(), "::", stringify!(#name))); #migrations args.fixtures(&[#(#fixtures),*]); // We need to give a coercion site or else we get "unimplemented trait" errors. let f: fn(#(#fn_arg_types),*) -> _ = #name; ::sqlx::testing::TestFn::run_test(f, args) } }) } #[cfg(feature = "migrate")] fn parse_args(attr_args: AttributeArgs) -> syn::Result { use syn::{ parenthesized, parse::Parse, punctuated::Punctuated, token::Comma, Expr, Lit, LitStr, Meta, MetaNameValue, Token, }; let mut fixtures = Vec::new(); let mut migrations = MigrationsOpt::InferredPath; for arg in attr_args { let path = arg.path().clone(); match arg { syn::Meta::List(list) if list.path.is_ident("fixtures") => { let mut fixtures_local = vec![]; let mut fixtures_type = FixturesType::None; let parse_nested = list.parse_nested_meta(|meta| { if meta.path.is_ident("path") { // fixtures(path = "", scripts("","")) checking `path` argument meta.input.parse::()?; let val: LitStr = meta.input.parse()?; parse_fixtures_path_args(&mut fixtures_type, val)?; } else if meta.path.is_ident("scripts") { // fixtures(path = "", scripts("","")) checking `scripts` argument let content; parenthesized!(content in meta.input); let list = content.parse_terminated(::parse, Comma)?; parse_fixtures_scripts_args(&mut fixtures_type, list, &mut fixtures_local)?; } else { return Err(syn::Error::new_spanned( meta.path, "unexpected fixture meta", )); } Ok(()) }); if parse_nested.is_err() { // fixtures("","") or fixtures("","") let args = list.parse_args_with(>::parse_terminated)?; for arg in args { parse_fixtures_args(&mut fixtures_type, arg, &mut fixtures_local)?; } } fixtures.push((fixtures_type, fixtures_local)); } syn::Meta::NameValue(value) if value.path.is_ident("migrations") => { if !matches!(migrations, MigrationsOpt::InferredPath) { return Err(syn::Error::new_spanned( value, "cannot have more than one `migrations` or `migrator` arg", )); } fn recurse_lit_lookup(expr: Expr) -> Option { match expr { Expr::Lit(syn::ExprLit { lit, .. }) => { return Some(lit); } Expr::Group(syn::ExprGroup { expr, .. }) => { return recurse_lit_lookup(*expr); } _ => return None, } } let Some(lit) = recurse_lit_lookup(value.value) else { return Err(syn::Error::new_spanned(path, "expected string or `false`")); }; migrations = match lit { // migrations = false Lit::Bool(b) if !b.value => MigrationsOpt::Disabled, // migrations = true Lit::Bool(b) => { return Err(syn::Error::new_spanned( b, "`migrations = true` is redundant", )); } // migrations = "path" Lit::Str(s) => MigrationsOpt::ExplicitPath(s), lit => return Err(syn::Error::new_spanned(lit, "expected string or `false`")), }; } // migrator = "" Meta::NameValue(MetaNameValue { value, .. }) if path.is_ident("migrator") => { if !matches!(migrations, MigrationsOpt::InferredPath) { return Err(syn::Error::new_spanned( path, "cannot have more than one `migrations` or `migrator` arg", )); } let Expr::Lit(syn::ExprLit { lit: Lit::Str(lit), .. }) = value else { return Err(syn::Error::new_spanned(path, "expected string")); }; migrations = MigrationsOpt::ExplicitMigrator(lit.parse()?); } arg => { return Err(syn::Error::new_spanned( arg, r#"expected `fixtures("", ...)` or `migrations = "" | false` or `migrator = ""`"#, )) } } } Ok(Args { fixtures, migrations, }) } #[cfg(feature = "migrate")] fn parse_fixtures_args( fixtures_type: &mut FixturesType, litstr: syn::LitStr, fixtures_local: &mut Vec, ) -> syn::Result<()> { // fixtures(path = "", scripts("","")) checking `path` argument let path_str = litstr.value(); let path = std::path::Path::new(&path_str); // This will be `true` if there's at least one path separator (`/` or `\`) // It's also true for all absolute paths, even e.g. `/foo.sql` as the root directory is counted as a component. let is_explicit_path = path.components().count() > 1; match fixtures_type { FixturesType::None => { if is_explicit_path { *fixtures_type = FixturesType::ExplicitPath; } else { *fixtures_type = FixturesType::RelativePath; } } FixturesType::RelativePath => { if is_explicit_path { return Err(syn::Error::new_spanned( litstr, "expected only relative path fixtures", )); } } FixturesType::ExplicitPath => { if !is_explicit_path { return Err(syn::Error::new_spanned( litstr, "expected only explicit path fixtures", )); } } FixturesType::CustomRelativePath(_) => { return Err(syn::Error::new_spanned( litstr, "custom relative path fixtures must be defined in `scripts` argument", )) } } if (matches!(fixtures_type, FixturesType::ExplicitPath) && !is_explicit_path) { return Err(syn::Error::new_spanned( litstr, "expected explicit path fixtures to have `.sql` extension", )); } fixtures_local.push(litstr); Ok(()) } #[cfg(feature = "migrate")] fn parse_fixtures_path_args( fixtures_type: &mut FixturesType, namevalue: syn::LitStr, ) -> syn::Result<()> { if !matches!(fixtures_type, FixturesType::None) { return Err(syn::Error::new_spanned( namevalue, "`path` must be the first argument of `fixtures`", )); } *fixtures_type = FixturesType::CustomRelativePath(namevalue); Ok(()) } #[cfg(feature = "migrate")] fn parse_fixtures_scripts_args( fixtures_type: &mut FixturesType, list: syn::punctuated::Punctuated, fixtures_local: &mut Vec, ) -> syn::Result<()> { // fixtures(path = "", scripts("","")) checking `scripts` argument if !matches!(fixtures_type, FixturesType::CustomRelativePath(_)) { return Err(syn::Error::new_spanned( list, "`scripts` must be the second argument of `fixtures` and used together with `path`", )); } fixtures_local.extend(list); Ok(()) } #[cfg(feature = "migrate")] fn add_sql_extension_if_missing(fixture: &mut String) { let has_extension = std::path::Path::new(&fixture).extension().is_some(); if !has_extension { fixture.push_str(".sql") } }