sqlx-macros-core-0.7.3/.cargo_vcs_info.json0000644000000001560000000000100142440ustar { "git": { "sha1": "c55aba0dc14f33b8a26cab6af565fcc4c8af8962" }, "path_in_vcs": "sqlx-macros-core" }sqlx-macros-core-0.7.3/Cargo.toml0000644000000070130000000000100122410ustar # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO # # When uploading crates to the registry Cargo will automatically # "normalize" Cargo.toml files for maximal compatibility # with all versions of Cargo and also rewrite `path` dependencies # to registry (e.g., crates.io) dependencies. # # If you are reading this file be aware that the original Cargo.toml # will likely look very different (and much more reasonable). # See Cargo.toml.orig for the original contents. [package] edition = "2021" name = "sqlx-macros-core" version = "0.7.3" authors = [ "Ryan Leckey ", "Austin Bonander ", "Chloe Ross ", "Daniel Akhterov ", ] description = "Macro support core for SQLx, the Rust SQL toolkit. Not intended to be used directly." license = "MIT OR Apache-2.0" repository = "https://github.com/launchbadge/sqlx" [dependencies.async-std] version = "1.12" optional = true [dependencies.atomic-write-file] version = "0.1" [dependencies.dotenvy] version = "0.15.0" default-features = false [dependencies.either] version = "1.6.1" [dependencies.heck] version = "0.4" features = ["unicode"] [dependencies.hex] version = "0.4.3" [dependencies.once_cell] version = "1.9.0" [dependencies.proc-macro2] version = "1.0.36" default-features = false [dependencies.quote] version = "1.0.14" default-features = false [dependencies.serde] version = "1.0.132" features = ["derive"] [dependencies.serde_json] version = "1.0.73" [dependencies.sha2] version = "0.10.0" [dependencies.sqlx-core] version = "=0.7.3" features = ["offline"] [dependencies.sqlx-mysql] version = "=0.7.3" features = [ "offline", "migrate", ] optional = true [dependencies.sqlx-postgres] version = "=0.7.3" features = [ "offline", "migrate", ] optional = true [dependencies.sqlx-sqlite] version = "=0.7.3" features = [ "offline", "migrate", ] optional = true [dependencies.syn] version = "1.0.84" features = [ "full", "derive", "parsing", "printing", "clone-impls", ] default-features = false [dependencies.tempfile] version = "3.3.0" [dependencies.tokio] version = "1" features = [ "time", "net", "sync", "fs", "io-util", "rt", ] optional = true default-features = false [dependencies.url] version = "2.2.2" default-features = false [features] _rt-async-std = [ "async-std", "sqlx-core/_rt-async-std", ] _rt-tokio = [ "tokio", "sqlx-core/_rt-tokio", ] _tls-native-tls = ["sqlx-core/_tls-native-tls"] _tls-rustls = ["sqlx-core/_tls-rustls"] bigdecimal = [ "sqlx-core/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal", ] bit-vec = [ "sqlx-core/bit-vec", "sqlx-postgres?/bit-vec", ] chrono = [ "sqlx-core/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono", ] default = [] ipnetwork = [ "sqlx-core/ipnetwork", "sqlx-postgres?/ipnetwork", ] json = [ "sqlx-core/json", "sqlx-mysql?/json", "sqlx-sqlite?/json", ] mac_address = [ "sqlx-core/mac_address", "sqlx-postgres?/mac_address", ] migrate = ["sqlx-core/migrate"] mysql = ["sqlx-mysql"] postgres = ["sqlx-postgres"] rust_decimal = [ "sqlx-core/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal", ] sqlite = ["sqlx-sqlite"] time = [ "sqlx-core/time", "sqlx-mysql?/time", "sqlx-postgres?/time", "sqlx-sqlite?/time", ] uuid = [ "sqlx-core/uuid", "sqlx-mysql?/uuid", "sqlx-postgres?/uuid", "sqlx-sqlite?/uuid", ] sqlx-macros-core-0.7.3/Cargo.toml.orig000064400000000000000000000046570072674642500157650ustar 00000000000000[package] name = "sqlx-macros-core" description = "Macro support core for SQLx, the Rust SQL toolkit. Not intended to be used directly." version.workspace = true license.workspace = true edition.workspace = true authors.workspace = true repository.workspace = true [features] default = [] # for conditional compilation _rt-async-std = ["async-std", "sqlx-core/_rt-async-std"] _rt-tokio = ["tokio", "sqlx-core/_rt-tokio"] _tls-native-tls = ["sqlx-core/_tls-native-tls"] _tls-rustls = ["sqlx-core/_tls-rustls"] # SQLx features migrate = ["sqlx-core/migrate"] # database mysql = ["sqlx-mysql"] postgres = ["sqlx-postgres"] sqlite = ["sqlx-sqlite"] # type integrations json = ["sqlx-core/json", "sqlx-mysql?/json", "sqlx-sqlite?/json"] bigdecimal = ["sqlx-core/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"] bit-vec = ["sqlx-core/bit-vec", "sqlx-postgres?/bit-vec"] chrono = ["sqlx-core/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"] ipnetwork = ["sqlx-core/ipnetwork", "sqlx-postgres?/ipnetwork"] mac_address = ["sqlx-core/mac_address", "sqlx-postgres?/mac_address"] rust_decimal = ["sqlx-core/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] time = ["sqlx-core/time", "sqlx-mysql?/time", "sqlx-postgres?/time", "sqlx-sqlite?/time"] uuid = ["sqlx-core/uuid", "sqlx-mysql?/uuid", "sqlx-postgres?/uuid", "sqlx-sqlite?/uuid"] [dependencies] sqlx-core = { workspace = true, features = ["offline"] } sqlx-mysql = { workspace = true, features = ["offline", "migrate"], optional = true } sqlx-postgres = { workspace = true, features = ["offline", "migrate"], optional = true } sqlx-sqlite = { workspace = true, features = ["offline", "migrate"], optional = true } async-std = { workspace = true, optional = true } tokio = { workspace = true, optional = true } dotenvy = { workspace = true } atomic-write-file = { version = "0.1" } hex = { version = "0.4.3" } heck = { version = "0.4", features = ["unicode"] } either = "1.6.1" once_cell = "1.9.0" proc-macro2 = { version = "1.0.36", default-features = false } serde = { version = "1.0.132", features = ["derive"] } serde_json = { version = "1.0.73" } sha2 = { version = "0.10.0" } syn = { version = "1.0.84", default-features = false, features = ["full", "derive", "parsing", "printing", "clone-impls"] } tempfile = { version = "3.3.0" } quote = { version = "1.0.14", default-features = false } url = { version = "2.2.2", default-features = false } sqlx-macros-core-0.7.3/src/common.rs000064400000000000000000000020610072674642500155060ustar 00000000000000use proc_macro2::Span; use std::env; use std::path::{Path, PathBuf}; pub(crate) fn resolve_path(path: impl AsRef, err_span: Span) -> syn::Result { let path = path.as_ref(); if path.is_absolute() { return Err(syn::Error::new( err_span, "absolute paths will only work on the current machine", )); } // requires `proc_macro::SourceFile::path()` to be stable // https://github.com/rust-lang/rust/issues/54725 if path.is_relative() && !path .parent() .map_or(false, |parent| !parent.as_os_str().is_empty()) { return Err(syn::Error::new( err_span, "paths relative to the current file's directory are not currently supported", )); } let base_dir = env::var("CARGO_MANIFEST_DIR").map_err(|_| { syn::Error::new( err_span, "CARGO_MANIFEST_DIR is not set; please use Cargo to build", ) })?; let base_dir_path = Path::new(&base_dir); Ok(base_dir_path.join(path)) } sqlx-macros-core-0.7.3/src/database/mod.rs000064400000000000000000000130070072674642500165430ustar 00000000000000use std::collections::hash_map; use std::collections::HashMap; use std::sync::Mutex; use once_cell::sync::Lazy; use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; #[derive(PartialEq, Eq)] #[allow(dead_code)] pub enum ParamChecking { Strong, Weak, } pub trait DatabaseExt: Database { const DATABASE_PATH: &'static str; const ROW_PATH: &'static str; const PARAM_CHECKING: ParamChecking; fn db_path() -> syn::Path { syn::parse_str(Self::DATABASE_PATH).unwrap() } fn row_path() -> syn::Path { syn::parse_str(Self::ROW_PATH).unwrap() } fn param_type_for_id(id: &Self::TypeInfo) -> Option<&'static str>; fn return_type_for_id(id: &Self::TypeInfo) -> Option<&'static str>; fn get_feature_gate(info: &Self::TypeInfo) -> Option<&'static str>; fn describe_blocking(query: &str, database_url: &str) -> sqlx_core::Result>; } #[allow(dead_code)] pub struct CachingDescribeBlocking { connections: Lazy>>, } #[allow(dead_code)] impl CachingDescribeBlocking { pub const fn new() -> Self { CachingDescribeBlocking { connections: Lazy::new(|| Mutex::new(HashMap::new())), } } pub fn describe(&self, query: &str, database_url: &str) -> sqlx_core::Result> where for<'a> &'a mut DB::Connection: Executor<'a, Database = DB>, { crate::block_on(async { let mut cache = self .connections .lock() .expect("previous panic in describe call"); let conn = match cache.entry(database_url.to_string()) { hash_map::Entry::Occupied(hit) => hit.into_mut(), hash_map::Entry::Vacant(miss) => { miss.insert(DB::Connection::connect(&database_url).await?) } }; conn.describe(query).await }) } } #[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))] macro_rules! impl_database_ext { ( $database:path { $($(#[$meta:meta])? $ty:ty $(| $input:ty)?),*$(,)? }, ParamChecking::$param_checking:ident, feature-types: $ty_info:ident => $get_gate:expr, row: $row:path, $(describe-blocking: $describe:path,)? ) => { impl $crate::database::DatabaseExt for $database { const DATABASE_PATH: &'static str = stringify!($database); const ROW_PATH: &'static str = stringify!($row); const PARAM_CHECKING: $crate::database::ParamChecking = $crate::database::ParamChecking::$param_checking; fn param_type_for_id(info: &Self::TypeInfo) -> Option<&'static str> { match () { $( $(#[$meta])? _ if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info => Some(input_ty!($ty $(, $input)?)), )* $( $(#[$meta])? _ if <$ty as sqlx_core::types::Type<$database>>::compatible(info) => Some(input_ty!($ty $(, $input)?)), )* _ => None } } fn return_type_for_id(info: &Self::TypeInfo) -> Option<&'static str> { match () { $( $(#[$meta])? _ if <$ty as sqlx_core::types::Type<$database>>::type_info() == *info => return Some(stringify!($ty)), )* $( $(#[$meta])? _ if <$ty as sqlx_core::types::Type<$database>>::compatible(info) => return Some(stringify!($ty)), )* _ => None } } fn get_feature_gate($ty_info: &Self::TypeInfo) -> Option<&'static str> { $get_gate } impl_describe_blocking!($database, $($describe)?); } } } #[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))] macro_rules! impl_describe_blocking { ($database:path $(,)?) => { fn describe_blocking( query: &str, database_url: &str, ) -> sqlx_core::Result> { use $crate::database::CachingDescribeBlocking; // This can't be a provided method because the `static` can't reference `Self`. static CACHE: CachingDescribeBlocking<$database> = CachingDescribeBlocking::new(); CACHE.describe(query, database_url) } }; ($database:path, $describe:path) => { fn describe_blocking( query: &str, database_url: &str, ) -> sqlx_core::Result> { $describe(query, database_url) } }; } #[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))] macro_rules! input_ty { ($ty:ty, $input:ty) => { stringify!($input) }; ($ty:ty) => { stringify!($ty) }; } #[cfg(feature = "postgres")] mod postgres; #[cfg(feature = "mysql")] mod mysql; #[cfg(feature = "sqlite")] mod sqlite; mod fake_sqlx { pub use sqlx_core::*; #[cfg(feature = "mysql")] pub use sqlx_mysql as mysql; #[cfg(feature = "postgres")] pub use sqlx_postgres as postgres; #[cfg(feature = "sqlite")] pub use sqlx_sqlite as sqlite; } sqlx-macros-core-0.7.3/src/database/mysql.rs000064400000000000000000000026670072674642500171430ustar 00000000000000use super::fake_sqlx as sqlx; impl_database_ext! { sqlx::mysql::MySql { u8, u16, u32, u64, i8, i16, i32, i64, f32, f64, // ordering is important here as otherwise we might infer strings to be binary // CHAR, VAR_CHAR, TEXT String, // BINARY, VAR_BINARY, BLOB Vec, #[cfg(all(feature = "chrono", not(feature = "time")))] sqlx::types::chrono::NaiveTime, #[cfg(all(feature = "chrono", not(feature = "time")))] sqlx::types::chrono::NaiveDate, #[cfg(all(feature = "chrono", not(feature = "time")))] sqlx::types::chrono::NaiveDateTime, #[cfg(all(feature = "chrono", not(feature = "time")))] sqlx::types::chrono::DateTime, #[cfg(feature = "time")] sqlx::types::time::Time, #[cfg(feature = "time")] sqlx::types::time::Date, #[cfg(feature = "time")] sqlx::types::time::PrimitiveDateTime, #[cfg(feature = "time")] sqlx::types::time::OffsetDateTime, #[cfg(feature = "bigdecimal")] sqlx::types::BigDecimal, #[cfg(feature = "rust_decimal")] sqlx::types::Decimal, #[cfg(feature = "json")] sqlx::types::JsonValue, }, ParamChecking::Weak, feature-types: info => info.__type_feature_gate(), row: sqlx::mysql::MySqlRow, } sqlx-macros-core-0.7.3/src/database/postgres.rs000064400000000000000000000156670072674642500176500ustar 00000000000000use super::fake_sqlx as sqlx; impl_database_ext! { sqlx::postgres::Postgres { (), bool, String | &str, i8, i16, i32, i64, f32, f64, Vec | &[u8], sqlx::postgres::types::Oid, sqlx::postgres::types::PgInterval, sqlx::postgres::types::PgMoney, sqlx::postgres::types::PgLTree, sqlx::postgres::types::PgLQuery, #[cfg(feature = "uuid")] sqlx::types::Uuid, #[cfg(feature = "chrono")] sqlx::types::chrono::NaiveTime, #[cfg(feature = "chrono")] sqlx::types::chrono::NaiveDate, #[cfg(feature = "chrono")] sqlx::types::chrono::NaiveDateTime, #[cfg(feature = "chrono")] sqlx::types::chrono::DateTime | sqlx::types::chrono::DateTime<_>, #[cfg(feature = "chrono")] sqlx::postgres::types::PgTimeTz, #[cfg(feature = "time")] sqlx::types::time::Time, #[cfg(feature = "time")] sqlx::types::time::Date, #[cfg(feature = "time")] sqlx::types::time::PrimitiveDateTime, #[cfg(feature = "time")] sqlx::types::time::OffsetDateTime, #[cfg(feature = "time")] sqlx::postgres::types::PgTimeTz, #[cfg(feature = "bigdecimal")] sqlx::types::BigDecimal, #[cfg(feature = "rust_decimal")] sqlx::types::Decimal, #[cfg(feature = "ipnetwork")] sqlx::types::ipnetwork::IpNetwork, #[cfg(feature = "mac_address")] sqlx::types::mac_address::MacAddress, #[cfg(feature = "json")] sqlx::types::JsonValue, #[cfg(feature = "bit-vec")] sqlx::types::BitVec, // Arrays Vec | &[bool], Vec | &[String], Vec> | &[Vec], Vec | &[i8], Vec | &[i16], Vec | &[i32], Vec | &[i64], Vec | &[f32], Vec | &[f64], Vec | &[sqlx::postgres::types::Oid], Vec | &[sqlx::postgres::types::PgMoney], #[cfg(feature = "uuid")] Vec | &[sqlx::types::Uuid], #[cfg(feature = "chrono")] Vec | &[sqlx::types::chrono::NaiveTime], #[cfg(feature = "chrono")] Vec | &[sqlx::types::chrono::NaiveDate], #[cfg(feature = "chrono")] Vec | &[sqlx::types::chrono::NaiveDateTime], #[cfg(feature = "chrono")] Vec> | &[sqlx::types::chrono::DateTime<_>], #[cfg(feature = "time")] Vec | &[sqlx::types::time::Time], #[cfg(feature = "time")] Vec | &[sqlx::types::time::Date], #[cfg(feature = "time")] Vec | &[sqlx::types::time::PrimitiveDateTime], #[cfg(feature = "time")] Vec | &[sqlx::types::time::OffsetDateTime], #[cfg(feature = "bigdecimal")] Vec | &[sqlx::types::BigDecimal], #[cfg(feature = "rust_decimal")] Vec | &[sqlx::types::Decimal], #[cfg(feature = "ipnetwork")] Vec | &[sqlx::types::ipnetwork::IpNetwork], #[cfg(feature = "mac_address")] Vec | &[sqlx::types::mac_address::MacAddress], #[cfg(feature = "json")] Vec | &[sqlx::types::JsonValue], // Ranges sqlx::postgres::types::PgRange, sqlx::postgres::types::PgRange, #[cfg(feature = "bigdecimal")] sqlx::postgres::types::PgRange, #[cfg(feature = "rust_decimal")] sqlx::postgres::types::PgRange, #[cfg(feature = "chrono")] sqlx::postgres::types::PgRange, #[cfg(feature = "chrono")] sqlx::postgres::types::PgRange, #[cfg(feature = "chrono")] sqlx::postgres::types::PgRange> | sqlx::postgres::types::PgRange>, #[cfg(feature = "time")] sqlx::postgres::types::PgRange, #[cfg(feature = "time")] sqlx::postgres::types::PgRange, #[cfg(feature = "time")] sqlx::postgres::types::PgRange, // Range arrays Vec> | &[sqlx::postgres::types::PgRange], Vec> | &[sqlx::postgres::types::PgRange], #[cfg(feature = "bigdecimal")] Vec> | &[sqlx::postgres::types::PgRange], #[cfg(feature = "rust_decimal")] Vec> | &[sqlx::postgres::types::PgRange], #[cfg(feature = "chrono")] Vec> | &[sqlx::postgres::types::PgRange], #[cfg(feature = "chrono")] Vec> | &[sqlx::postgres::types::PgRange], #[cfg(feature = "chrono")] Vec>> | Vec>>, #[cfg(feature = "chrono")] &[sqlx::postgres::types::PgRange>] | &[sqlx::postgres::types::PgRange>], #[cfg(feature = "time")] Vec> | &[sqlx::postgres::types::PgRange], #[cfg(feature = "time")] Vec> | &[sqlx::postgres::types::PgRange], #[cfg(feature = "time")] Vec> | &[sqlx::postgres::types::PgRange], }, ParamChecking::Strong, feature-types: info => info.__type_feature_gate(), row: sqlx::postgres::PgRow, } sqlx-macros-core-0.7.3/src/database/sqlite.rs000064400000000000000000000024060072674642500172660ustar 00000000000000use super::fake_sqlx as sqlx; // f32 is not included below as REAL represents a floating point value // stored as an 8-byte IEEE floating point number // For more info see: https://www.sqlite.org/datatype3.html#storage_classes_and_datatypes impl_database_ext! { sqlx::sqlite::Sqlite { bool, i32, i64, f64, String, Vec, #[cfg(feature = "chrono")] sqlx::types::chrono::NaiveDate, #[cfg(feature = "chrono")] sqlx::types::chrono::NaiveDateTime, #[cfg(feature = "chrono")] sqlx::types::chrono::DateTime | sqlx::types::chrono::DateTime<_>, #[cfg(feature = "time")] sqlx::types::time::OffsetDateTime, #[cfg(feature = "time")] sqlx::types::time::PrimitiveDateTime, #[cfg(feature = "time")] sqlx::types::time::Date, #[cfg(feature = "uuid")] sqlx::types::Uuid, }, ParamChecking::Weak, feature-types: _info => None, row: sqlx::sqlite::SqliteRow, // Since proc-macros don't benefit from async, we can make a describe call directly // which also ensures that the database is closed afterwards, regardless of errors. describe-blocking: sqlx_sqlite::describe_blocking, } sqlx-macros-core-0.7.3/src/derives/attributes.rs000064400000000000000000000242540072674642500200550ustar 00000000000000use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; use syn::{ punctuated::Punctuated, spanned::Spanned, token::Comma, Attribute, DeriveInput, Field, Lit, Meta, MetaNameValue, NestedMeta, Type, Variant, }; macro_rules! assert_attribute { ($e:expr, $err:expr, $input:expr) => { if !$e { return Err(syn::Error::new_spanned($input, $err)); } }; } macro_rules! fail { ($t:expr, $m:expr) => { return Err(syn::Error::new_spanned($t, $m)) }; } macro_rules! try_set { ($i:ident, $v:expr, $t:expr) => { match $i { None => $i = Some($v), Some(_) => fail!($t, "duplicate attribute"), } }; } pub struct TypeName { pub val: String, pub span: Span, } impl TypeName { pub fn get(&self) -> TokenStream { let val = &self.val; quote! { #val } } } #[derive(Copy, Clone)] pub enum RenameAll { LowerCase, SnakeCase, UpperCase, ScreamingSnakeCase, KebabCase, CamelCase, PascalCase, } pub struct SqlxContainerAttributes { pub transparent: bool, pub type_name: Option, pub rename_all: Option, pub repr: Option, pub no_pg_array: bool, pub default: bool, } pub struct SqlxChildAttributes { pub rename: Option, pub default: bool, pub flatten: bool, pub try_from: Option, pub skip: bool, pub json: bool, } pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result { let mut transparent = None; let mut repr = None; let mut type_name = None; let mut rename_all = None; let mut no_pg_array = None; let mut default = None; for attr in input .iter() .filter(|a| a.path.is_ident("sqlx") || a.path.is_ident("repr")) { let meta = attr .parse_meta() .map_err(|e| syn::Error::new_spanned(attr, e))?; match meta { Meta::List(list) if list.path.is_ident("sqlx") => { for value in list.nested.iter() { match value { NestedMeta::Meta(meta) => match meta { Meta::Path(p) if p.is_ident("transparent") => { try_set!(transparent, true, value) } Meta::Path(p) if p.is_ident("no_pg_array") => { try_set!(no_pg_array, true, value); } Meta::NameValue(MetaNameValue { path, lit: Lit::Str(val), .. }) if path.is_ident("rename_all") => { let val = match &*val.value() { "lowercase" => RenameAll::LowerCase, "snake_case" => RenameAll::SnakeCase, "UPPERCASE" => RenameAll::UpperCase, "SCREAMING_SNAKE_CASE" => RenameAll::ScreamingSnakeCase, "kebab-case" => RenameAll::KebabCase, "camelCase" => RenameAll::CamelCase, "PascalCase" => RenameAll::PascalCase, _ => fail!(meta, "unexpected value for rename_all"), }; try_set!(rename_all, val, value) } Meta::NameValue(MetaNameValue { path, lit: Lit::Str(val), .. }) if path.is_ident("type_name") => { try_set!( type_name, TypeName { val: val.value(), span: value.span(), }, value ) } Meta::Path(p) if p.is_ident("default") => { try_set!(default, true, value) } u => fail!(u, "unexpected attribute"), }, u => fail!(u, "unexpected attribute"), } } } Meta::List(list) if list.path.is_ident("repr") => { if list.nested.len() != 1 { fail!(&list.nested, "expected one value") } match list.nested.first().unwrap() { NestedMeta::Meta(Meta::Path(p)) if p.get_ident().is_some() => { try_set!(repr, p.get_ident().unwrap().clone(), list); } u => fail!(u, "unexpected value"), } } _ => {} } } Ok(SqlxContainerAttributes { transparent: transparent.unwrap_or(false), repr, type_name, rename_all, no_pg_array: no_pg_array.unwrap_or(false), default: default.unwrap_or(false), }) } pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result { let mut rename = None; let mut default = false; let mut try_from = None; let mut flatten = false; let mut skip: bool = false; let mut json = false; for attr in input.iter().filter(|a| a.path.is_ident("sqlx")) { let meta = attr .parse_meta() .map_err(|e| syn::Error::new_spanned(attr, e))?; if let Meta::List(list) = meta { for value in list.nested.iter() { match value { NestedMeta::Meta(meta) => match meta { Meta::NameValue(MetaNameValue { path, lit: Lit::Str(val), .. }) if path.is_ident("rename") => try_set!(rename, val.value(), value), Meta::NameValue(MetaNameValue { path, lit: Lit::Str(val), .. }) if path.is_ident("try_from") => try_set!(try_from, val.parse()?, value), Meta::Path(path) if path.is_ident("default") => default = true, Meta::Path(path) if path.is_ident("flatten") => flatten = true, Meta::Path(path) if path.is_ident("skip") => skip = true, Meta::Path(path) if path.is_ident("json") => json = true, u => fail!(u, "unexpected attribute"), }, u => fail!(u, "unexpected attribute"), } } } if json && flatten { fail!( attr, "Cannot use `json` and `flatten` together on the same field" ); } } Ok(SqlxChildAttributes { rename, default, flatten, try_from, skip, json, }) } pub fn check_transparent_attributes( input: &DeriveInput, field: &Field, ) -> syn::Result { let attributes = parse_container_attributes(&input.attrs)?; assert_attribute!( attributes.rename_all.is_none(), "unexpected #[sqlx(rename_all = ..)]", field ); let ch_attributes = parse_child_attributes(&field.attrs)?; assert_attribute!( ch_attributes.rename.is_none(), "unexpected #[sqlx(rename = ..)]", field ); Ok(attributes) } pub fn check_enum_attributes(input: &DeriveInput) -> syn::Result { let attributes = parse_container_attributes(&input.attrs)?; assert_attribute!( !attributes.transparent, "unexpected #[sqlx(transparent)]", input ); assert_attribute!( !attributes.no_pg_array, "unused #[sqlx(no_pg_array)]; derive does not emit `PgHasArrayType` impls for enums", input ); Ok(attributes) } pub fn check_weak_enum_attributes( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { let attributes = check_enum_attributes(input)?; assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input); assert_attribute!( attributes.rename_all.is_none(), "unexpected #[sqlx(c = ..)]", input ); for variant in variants { let attributes = parse_child_attributes(&variant.attrs)?; assert_attribute!( attributes.rename.is_none(), "unexpected #[sqlx(rename = ..)]", variant ); } Ok(attributes) } pub fn check_strong_enum_attributes( input: &DeriveInput, _variants: &Punctuated, ) -> syn::Result { let attributes = check_enum_attributes(input)?; assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); Ok(attributes) } pub fn check_struct_attributes<'a>( input: &'a DeriveInput, fields: &Punctuated, ) -> syn::Result { let attributes = parse_container_attributes(&input.attrs)?; assert_attribute!( !attributes.transparent, "unexpected #[sqlx(transparent)]", input ); assert_attribute!( attributes.rename_all.is_none(), "unexpected #[sqlx(rename_all = ..)]", input ); assert_attribute!( !attributes.no_pg_array, "unused #[sqlx(no_pg_array)]; derive does not emit `PgHasArrayType` impls for custom structs", input ); assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); for field in fields { let attributes = parse_child_attributes(&field.attrs)?; assert_attribute!( attributes.rename.is_none(), "unexpected #[sqlx(rename = ..)]", field ); } Ok(attributes) } sqlx-macros-core-0.7.3/src/derives/decode.rs000064400000000000000000000247700072674642500171150ustar 00000000000000use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, check_weak_enum_attributes, parse_child_attributes, parse_container_attributes, }; use super::rename_all; use proc_macro2::TokenStream; use quote::quote; use syn::punctuated::Punctuated; use syn::token::Comma; use syn::{ parse_quote, Arm, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, FieldsUnnamed, Stmt, Variant, }; pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result { let attrs = parse_container_attributes(&input.attrs)?; match &input.data { Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), .. }) if unnamed.len() == 1 => { expand_derive_decode_transparent(input, unnamed.first().unwrap()) } Data::Enum(DataEnum { variants, .. }) => match attrs.repr { Some(_) => expand_derive_decode_weak_enum(input, variants), None => expand_derive_decode_strong_enum(input, variants), }, Data::Struct(DataStruct { fields: Fields::Named(FieldsNamed { named, .. }), .. }) => expand_derive_decode_struct(input, named), Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), Data::Struct(DataStruct { fields: Fields::Unnamed(..), .. }) => Err(syn::Error::new_spanned( input, "structs with zero or more than one unnamed field are not supported", )), Data::Struct(DataStruct { fields: Fields::Unit, .. }) => Err(syn::Error::new_spanned( input, "unit structs are not supported", )), } } fn expand_derive_decode_transparent( input: &DeriveInput, field: &Field, ) -> syn::Result { check_transparent_attributes(input, field)?; let ident = &input.ident; let ty = &field.ty; // extract type generics let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); // add db type for impl generics & where clause let mut generics = generics.clone(); generics .params .insert(0, parse_quote!(DB: ::sqlx::Database)); generics.params.insert(0, parse_quote!('r)); generics .make_where_clause() .predicates .push(parse_quote!(#ty: ::sqlx::decode::Decode<'r, DB>)); let (impl_generics, _, where_clause) = generics.split_for_impl(); let tts = quote!( #[automatically_derived] impl #impl_generics ::sqlx::decode::Decode<'r, DB> for #ident #ty_generics #where_clause { fn decode( value: >::ValueRef, ) -> ::std::result::Result< Self, ::std::boxed::Box< dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync, >, > { <#ty as ::sqlx::decode::Decode<'r, DB>>::decode(value).map(Self) } } ); Ok(tts) } fn expand_derive_decode_weak_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { let attr = check_weak_enum_attributes(input, &variants)?; let repr = attr.repr.unwrap(); let ident = &input.ident; let ident_s = ident.to_string(); let arms = variants .iter() .map(|v| { let id = &v.ident; parse_quote! { _ if (#ident::#id as #repr) == value => ::std::result::Result::Ok(#ident::#id), } }) .collect::>(); Ok(quote!( #[automatically_derived] impl<'r, DB: ::sqlx::Database> ::sqlx::decode::Decode<'r, DB> for #ident where #repr: ::sqlx::decode::Decode<'r, DB>, { fn decode( value: >::ValueRef, ) -> ::std::result::Result< Self, ::std::boxed::Box< dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync, >, > { let value = <#repr as ::sqlx::decode::Decode<'r, DB>>::decode(value)?; match value { #(#arms)* _ => ::std::result::Result::Err(::std::boxed::Box::new(::sqlx::Error::Decode( ::std::format!("invalid value {:?} for enum {}", value, #ident_s).into(), ))) } } } )) } fn expand_derive_decode_strong_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { let cattr = check_strong_enum_attributes(input, &variants)?; let ident = &input.ident; let ident_s = ident.to_string(); let value_arms = variants.iter().map(|v| -> Arm { let id = &v.ident; let attributes = parse_child_attributes(&v.attrs).unwrap(); if let Some(rename) = attributes.rename { parse_quote!(#rename => ::std::result::Result::Ok(#ident :: #id),) } else if let Some(pattern) = cattr.rename_all { let name = rename_all(&*id.to_string(), pattern); parse_quote!(#name => ::std::result::Result::Ok(#ident :: #id),) } else { let name = id.to_string(); parse_quote!(#name => ::std::result::Result::Ok(#ident :: #id),) } }); let values = quote! { match value { #(#value_arms)* _ => Err(format!("invalid value {:?} for enum {}", value, #ident_s).into()) } }; let mut tts = TokenStream::new(); if cfg!(feature = "mysql") { tts.extend(quote!( #[automatically_derived] impl<'r> ::sqlx::decode::Decode<'r, ::sqlx::mysql::MySql> for #ident { fn decode( value: ::sqlx::mysql::MySqlValueRef<'r>, ) -> ::std::result::Result< Self, ::std::boxed::Box< dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync, >, > { let value = <&'r ::std::primitive::str as ::sqlx::decode::Decode< 'r, ::sqlx::mysql::MySql, >>::decode(value)?; #values } } )); } if cfg!(feature = "postgres") { tts.extend(quote!( #[automatically_derived] impl<'r> ::sqlx::decode::Decode<'r, ::sqlx::postgres::Postgres> for #ident { fn decode( value: ::sqlx::postgres::PgValueRef<'r>, ) -> ::std::result::Result< Self, ::std::boxed::Box< dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync, >, > { let value = <&'r ::std::primitive::str as ::sqlx::decode::Decode< 'r, ::sqlx::postgres::Postgres, >>::decode(value)?; #values } } )); } if cfg!(feature = "sqlite") { tts.extend(quote!( #[automatically_derived] impl<'r> ::sqlx::decode::Decode<'r, ::sqlx::sqlite::Sqlite> for #ident { fn decode( value: ::sqlx::sqlite::SqliteValueRef<'r>, ) -> ::std::result::Result< Self, ::std::boxed::Box< dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync, >, > { let value = <&'r ::std::primitive::str as ::sqlx::decode::Decode< 'r, ::sqlx::sqlite::Sqlite, >>::decode(value)?; #values } } )); } Ok(tts) } fn expand_derive_decode_struct( input: &DeriveInput, fields: &Punctuated, ) -> syn::Result { check_struct_attributes(input, fields)?; let mut tts = TokenStream::new(); if cfg!(feature = "postgres") { let ident = &input.ident; // extract type generics let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); // add db type for impl generics & where clause let mut generics = generics.clone(); generics.params.insert(0, parse_quote!('r)); let predicates = &mut generics.make_where_clause().predicates; for field in fields { let ty = &field.ty; predicates.push(parse_quote!(#ty: ::sqlx::decode::Decode<'r, ::sqlx::Postgres>)); predicates.push(parse_quote!(#ty: ::sqlx::types::Type<::sqlx::Postgres>)); } let (impl_generics, _, where_clause) = generics.split_for_impl(); let reads = fields.iter().map(|field| -> Stmt { let id = &field.ident; let ty = &field.ty; parse_quote!( let #id = decoder.try_decode::<#ty>()?; ) }); let names = fields.iter().map(|field| &field.ident); tts.extend(quote!( #[automatically_derived] impl #impl_generics ::sqlx::decode::Decode<'r, ::sqlx::Postgres> for #ident #ty_generics #where_clause { fn decode( value: ::sqlx::postgres::PgValueRef<'r>, ) -> ::std::result::Result< Self, ::std::boxed::Box< dyn ::std::error::Error + 'static + ::std::marker::Send + ::std::marker::Sync, >, > { let mut decoder = ::sqlx::postgres::types::PgRecordDecoder::new(value)?; #(#reads)* ::std::result::Result::Ok(#ident { #(#names),* }) } } )); } Ok(tts) } sqlx-macros-core-0.7.3/src/derives/encode.rs000064400000000000000000000210230072674642500171130ustar 00000000000000use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, check_weak_enum_attributes, parse_child_attributes, parse_container_attributes, }; use super::rename_all; use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::punctuated::Punctuated; use syn::token::Comma; use syn::{ parse_quote, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, FieldsUnnamed, Lifetime, LifetimeDef, Stmt, Variant, }; pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result { let args = parse_container_attributes(&input.attrs)?; match &input.data { Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), .. }) if unnamed.len() == 1 => { expand_derive_encode_transparent(&input, unnamed.first().unwrap()) } Data::Enum(DataEnum { variants, .. }) => match args.repr { Some(_) => expand_derive_encode_weak_enum(input, variants), None => expand_derive_encode_strong_enum(input, variants), }, Data::Struct(DataStruct { fields: Fields::Named(FieldsNamed { named, .. }), .. }) => expand_derive_encode_struct(input, named), Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), Data::Struct(DataStruct { fields: Fields::Unnamed(..), .. }) => Err(syn::Error::new_spanned( input, "structs with zero or more than one unnamed field are not supported", )), Data::Struct(DataStruct { fields: Fields::Unit, .. }) => Err(syn::Error::new_spanned( input, "unit structs are not supported", )), } } fn expand_derive_encode_transparent( input: &DeriveInput, field: &Field, ) -> syn::Result { check_transparent_attributes(input, field)?; let ident = &input.ident; let ty = &field.ty; // extract type generics let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); // add db type for impl generics & where clause let lifetime = Lifetime::new("'q", Span::call_site()); let mut generics = generics.clone(); generics .params .insert(0, LifetimeDef::new(lifetime.clone()).into()); generics .params .insert(0, parse_quote!(DB: ::sqlx::Database)); generics .make_where_clause() .predicates .push(parse_quote!(#ty: ::sqlx::encode::Encode<#lifetime, DB>)); let (impl_generics, _, where_clause) = generics.split_for_impl(); Ok(quote!( #[automatically_derived] impl #impl_generics ::sqlx::encode::Encode<#lifetime, DB> for #ident #ty_generics #where_clause { fn encode_by_ref( &self, buf: &mut >::ArgumentBuffer, ) -> ::sqlx::encode::IsNull { <#ty as ::sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.0, buf) } fn produces(&self) -> Option { <#ty as ::sqlx::encode::Encode<#lifetime, DB>>::produces(&self.0) } fn size_hint(&self) -> usize { <#ty as ::sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.0) } } )) } fn expand_derive_encode_weak_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { let attr = check_weak_enum_attributes(input, &variants)?; let repr = attr.repr.unwrap(); let ident = &input.ident; let mut values = Vec::new(); for v in variants { let id = &v.ident; values.push(quote!(#ident :: #id => (#ident :: #id as #repr),)); } Ok(quote!( #[automatically_derived] impl<'q, DB: ::sqlx::Database> ::sqlx::encode::Encode<'q, DB> for #ident where #repr: ::sqlx::encode::Encode<'q, DB>, { fn encode_by_ref( &self, buf: &mut >::ArgumentBuffer, ) -> ::sqlx::encode::IsNull { let value = match self { #(#values)* }; <#repr as ::sqlx::encode::Encode>::encode_by_ref(&value, buf) } fn size_hint(&self) -> usize { <#repr as ::sqlx::encode::Encode>::size_hint(&Default::default()) } } )) } fn expand_derive_encode_strong_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { let cattr = check_strong_enum_attributes(input, &variants)?; let ident = &input.ident; let mut value_arms = Vec::new(); for v in variants { let id = &v.ident; let attributes = parse_child_attributes(&v.attrs)?; if let Some(rename) = attributes.rename { value_arms.push(quote!(#ident :: #id => #rename,)); } else if let Some(pattern) = cattr.rename_all { let name = rename_all(&*id.to_string(), pattern); value_arms.push(quote!(#ident :: #id => #name,)); } else { let name = id.to_string(); value_arms.push(quote!(#ident :: #id => #name,)); } } Ok(quote!( #[automatically_derived] impl<'q, DB: ::sqlx::Database> ::sqlx::encode::Encode<'q, DB> for #ident where &'q ::std::primitive::str: ::sqlx::encode::Encode<'q, DB>, { fn encode_by_ref( &self, buf: &mut >::ArgumentBuffer, ) -> ::sqlx::encode::IsNull { let val = match self { #(#value_arms)* }; <&::std::primitive::str as ::sqlx::encode::Encode<'q, DB>>::encode(val, buf) } fn size_hint(&self) -> ::std::primitive::usize { let val = match self { #(#value_arms)* }; <&::std::primitive::str as ::sqlx::encode::Encode<'q, DB>>::size_hint(&val) } } )) } fn expand_derive_encode_struct( input: &DeriveInput, fields: &Punctuated, ) -> syn::Result { check_struct_attributes(input, &fields)?; let mut tts = TokenStream::new(); if cfg!(feature = "postgres") { let ident = &input.ident; let column_count = fields.len(); // extract type generics let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); // add db type for impl generics & where clause let mut generics = generics.clone(); let predicates = &mut generics.make_where_clause().predicates; for field in fields { let ty = &field.ty; predicates .push(parse_quote!(#ty: for<'q> ::sqlx::encode::Encode<'q, ::sqlx::Postgres>)); predicates.push(parse_quote!(#ty: ::sqlx::types::Type<::sqlx::Postgres>)); } let (impl_generics, _, where_clause) = generics.split_for_impl(); let writes = fields.iter().map(|field| -> Stmt { let id = &field.ident; parse_quote!( encoder.encode(&self. #id); ) }); let sizes = fields.iter().map(|field| -> Expr { let id = &field.ident; let ty = &field.ty; parse_quote!( <#ty as ::sqlx::encode::Encode<::sqlx::Postgres>>::size_hint(&self. #id) ) }); tts.extend(quote!( #[automatically_derived] impl #impl_generics ::sqlx::encode::Encode<'_, ::sqlx::Postgres> for #ident #ty_generics #where_clause { fn encode_by_ref( &self, buf: &mut ::sqlx::postgres::PgArgumentBuffer, ) -> ::sqlx::encode::IsNull { let mut encoder = ::sqlx::postgres::types::PgRecordEncoder::new(buf); #(#writes)* encoder.finish(); ::sqlx::encode::IsNull::No } fn size_hint(&self) -> ::std::primitive::usize { #column_count * (4 + 4) // oid (int) and length (int) for each column + #(#sizes)+* // sum of the size hints for each column } } )); } Ok(tts) } sqlx-macros-core-0.7.3/src/derives/mod.rs000064400000000000000000000023550072674642500164440ustar 00000000000000mod attributes; mod decode; mod encode; mod row; mod r#type; pub use decode::expand_derive_decode; pub use encode::expand_derive_encode; pub use r#type::expand_derive_type; pub use row::expand_derive_from_row; use self::attributes::RenameAll; use heck::{ToKebabCase, ToLowerCamelCase, ToShoutySnakeCase, ToSnakeCase, ToUpperCamelCase}; use proc_macro2::TokenStream; use std::iter::FromIterator; use syn::DeriveInput; pub fn expand_derive_type_encode_decode(input: &DeriveInput) -> syn::Result { let encode_tts = expand_derive_encode(input)?; let decode_tts = expand_derive_decode(input)?; let type_tts = expand_derive_type(input)?; let combined = TokenStream::from_iter(encode_tts.into_iter().chain(decode_tts).chain(type_tts)); Ok(combined) } pub(crate) fn rename_all(s: &str, pattern: RenameAll) -> String { match pattern { RenameAll::LowerCase => s.to_lowercase(), RenameAll::SnakeCase => s.to_snake_case(), RenameAll::UpperCase => s.to_uppercase(), RenameAll::ScreamingSnakeCase => s.to_shouty_snake_case(), RenameAll::KebabCase => s.to_kebab_case(), RenameAll::CamelCase => s.to_lower_camel_case(), RenameAll::PascalCase => s.to_upper_camel_case(), } } sqlx-macros-core-0.7.3/src/derives/row.rs000064400000000000000000000217430072674642500164760ustar 00000000000000use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::{ parse_quote, punctuated::Punctuated, token::Comma, Data, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, FieldsUnnamed, Lifetime, Stmt, }; use super::{ attributes::{parse_child_attributes, parse_container_attributes}, rename_all, }; pub fn expand_derive_from_row(input: &DeriveInput) -> syn::Result { match &input.data { Data::Struct(DataStruct { fields: Fields::Named(FieldsNamed { named, .. }), .. }) => expand_derive_from_row_struct(input, named), Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), .. }) => expand_derive_from_row_struct_unnamed(input, unnamed), Data::Struct(DataStruct { fields: Fields::Unit, .. }) => Err(syn::Error::new_spanned( input, "unit structs are not supported", )), Data::Enum(_) => Err(syn::Error::new_spanned(input, "enums are not supported")), Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), } } fn expand_derive_from_row_struct( input: &DeriveInput, fields: &Punctuated, ) -> syn::Result { let ident = &input.ident; let generics = &input.generics; let (lifetime, provided) = generics .lifetimes() .next() .map(|def| (def.lifetime.clone(), false)) .unwrap_or_else(|| (Lifetime::new("'a", Span::call_site()), true)); let (_, ty_generics, _) = generics.split_for_impl(); let mut generics = generics.clone(); generics.params.insert(0, parse_quote!(R: ::sqlx::Row)); if provided { generics.params.insert(0, parse_quote!(#lifetime)); } let predicates = &mut generics.make_where_clause().predicates; predicates.push(parse_quote!(&#lifetime ::std::primitive::str: ::sqlx::ColumnIndex)); let container_attributes = parse_container_attributes(&input.attrs)?; let default_instance: Option; if container_attributes.default { predicates.push(parse_quote!(#ident: ::std::default::Default)); default_instance = Some(parse_quote!( let __default = #ident::default(); )); } else { default_instance = None; } let reads: Vec = fields .iter() .filter_map(|field| -> Option { let id = &field.ident.as_ref()?; let attributes = parse_child_attributes(&field.attrs).unwrap(); let ty = &field.ty; if attributes.skip { return Some(parse_quote!( let #id: #ty = Default::default(); )); } let id_s = attributes .rename .or_else(|| Some(id.to_string().trim_start_matches("r#").to_owned())) .map(|s| match container_attributes.rename_all { Some(pattern) => rename_all(&s, pattern), None => s, }) .unwrap(); let expr: Expr = match (attributes.flatten, attributes.try_from, attributes.json) { // (false, None, false) => { predicates .push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>)); predicates.push(parse_quote!(#ty: ::sqlx::types::Type)); parse_quote!(row.try_get(#id_s)) } // Flatten (true, None, false) => { predicates.push(parse_quote!(#ty: ::sqlx::FromRow<#lifetime, R>)); parse_quote!(<#ty as ::sqlx::FromRow<#lifetime, R>>::from_row(row)) } // Flatten + Try from (true, Some(try_from), false) => { predicates.push(parse_quote!(#try_from: ::sqlx::FromRow<#lifetime, R>)); parse_quote!(<#try_from as ::sqlx::FromRow<#lifetime, R>>::from_row(row).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string())))) } // Flatten + Json (true, _, true) => { panic!("Cannot use both flatten and json") } // Try from (false, Some(try_from), false) => { predicates .push(parse_quote!(#try_from: ::sqlx::decode::Decode<#lifetime, R::Database>)); predicates.push(parse_quote!(#try_from: ::sqlx::types::Type)); parse_quote!(row.try_get(#id_s).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v).map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string())))) } // Try from + Json (false, Some(try_from), true) => { predicates .push(parse_quote!(::sqlx::types::Json<#try_from>: ::sqlx::decode::Decode<#lifetime, R::Database>)); predicates.push(parse_quote!(::sqlx::types::Json<#try_from>: ::sqlx::types::Type)); parse_quote!( row.try_get::<::sqlx::types::Json<_>, _>(#id_s).and_then(|v| <#ty as ::std::convert::TryFrom::<#try_from>>::try_from(v.0) .map_err(|e| ::sqlx::Error::ColumnNotFound("FromRow: try_from failed".to_string())) ) ) }, // Json (false, None, true) => { predicates .push(parse_quote!(::sqlx::types::Json<#ty>: ::sqlx::decode::Decode<#lifetime, R::Database>)); predicates.push(parse_quote!(::sqlx::types::Json<#ty>: ::sqlx::types::Type)); parse_quote!(row.try_get::<::sqlx::types::Json<_>, _>(#id_s).map(|x| x.0)) }, }; if attributes.default { Some(parse_quote!(let #id: #ty = #expr.or_else(|e| match e { ::sqlx::Error::ColumnNotFound(_) => { ::std::result::Result::Ok(Default::default()) }, e => ::std::result::Result::Err(e) })?;)) } else if container_attributes.default { Some(parse_quote!(let #id: #ty = #expr.or_else(|e| match e { ::sqlx::Error::ColumnNotFound(_) => { ::std::result::Result::Ok(__default.#id) }, e => ::std::result::Result::Err(e) })?;)) } else { Some(parse_quote!( let #id: #ty = #expr?; )) } }) .collect(); let (impl_generics, _, where_clause) = generics.split_for_impl(); let names = fields.iter().map(|field| &field.ident); Ok(quote!( #[automatically_derived] impl #impl_generics ::sqlx::FromRow<#lifetime, R> for #ident #ty_generics #where_clause { fn from_row(row: &#lifetime R) -> ::sqlx::Result { #default_instance #(#reads)* ::std::result::Result::Ok(#ident { #(#names),* }) } } )) } fn expand_derive_from_row_struct_unnamed( input: &DeriveInput, fields: &Punctuated, ) -> syn::Result { let ident = &input.ident; let generics = &input.generics; let (lifetime, provided) = generics .lifetimes() .next() .map(|def| (def.lifetime.clone(), false)) .unwrap_or_else(|| (Lifetime::new("'a", Span::call_site()), true)); let (_, ty_generics, _) = generics.split_for_impl(); let mut generics = generics.clone(); generics.params.insert(0, parse_quote!(R: ::sqlx::Row)); if provided { generics.params.insert(0, parse_quote!(#lifetime)); } let predicates = &mut generics.make_where_clause().predicates; predicates.push(parse_quote!( ::std::primitive::usize: ::sqlx::ColumnIndex )); for field in fields { let ty = &field.ty; predicates.push(parse_quote!(#ty: ::sqlx::decode::Decode<#lifetime, R::Database>)); predicates.push(parse_quote!(#ty: ::sqlx::types::Type)); } let (impl_generics, _, where_clause) = generics.split_for_impl(); let gets = fields .iter() .enumerate() .map(|(idx, _)| quote!(row.try_get(#idx)?)); Ok(quote!( #[automatically_derived] impl #impl_generics ::sqlx::FromRow<#lifetime, R> for #ident #ty_generics #where_clause { fn from_row(row: &#lifetime R) -> ::sqlx::Result { ::std::result::Result::Ok(#ident ( #(#gets),* )) } } )) } sqlx-macros-core-0.7.3/src/derives/type.rs000064400000000000000000000171130072674642500166440ustar 00000000000000use super::attributes::{ check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes, check_weak_enum_attributes, parse_container_attributes, TypeName, }; use proc_macro2::{Ident, TokenStream}; use quote::{quote, quote_spanned}; use syn::punctuated::Punctuated; use syn::token::Comma; use syn::{ parse_quote, Data, DataEnum, DataStruct, DeriveInput, Field, Fields, FieldsNamed, FieldsUnnamed, Variant, }; pub fn expand_derive_type(input: &DeriveInput) -> syn::Result { let attrs = parse_container_attributes(&input.attrs)?; match &input.data { Data::Struct(DataStruct { fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), .. }) if unnamed.len() == 1 => { expand_derive_has_sql_type_transparent(input, unnamed.first().unwrap()) } Data::Enum(DataEnum { variants, .. }) => match attrs.repr { Some(_) => expand_derive_has_sql_type_weak_enum(input, variants), None => expand_derive_has_sql_type_strong_enum(input, variants), }, Data::Struct(DataStruct { fields: Fields::Named(FieldsNamed { named, .. }), .. }) => expand_derive_has_sql_type_struct(input, named), Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")), Data::Struct(DataStruct { fields: Fields::Unnamed(..), .. }) => Err(syn::Error::new_spanned( input, "structs with zero or more than one unnamed field are not supported", )), Data::Struct(DataStruct { fields: Fields::Unit, .. }) => Err(syn::Error::new_spanned( input, "unit structs are not supported", )), } } fn expand_derive_has_sql_type_transparent( input: &DeriveInput, field: &Field, ) -> syn::Result { let attr = check_transparent_attributes(input, field)?; let ident = &input.ident; let ty = &field.ty; let generics = &input.generics; let (_, ty_generics, _) = generics.split_for_impl(); if attr.transparent { let mut generics = generics.clone(); let mut array_generics = generics.clone(); generics .params .insert(0, parse_quote!(DB: ::sqlx::Database)); generics .make_where_clause() .predicates .push(parse_quote!(#ty: ::sqlx::Type)); let (impl_generics, _, where_clause) = generics.split_for_impl(); array_generics .make_where_clause() .predicates .push(parse_quote!(#ty: ::sqlx::postgres::PgHasArrayType)); let (array_impl_generics, _, array_where_clause) = array_generics.split_for_impl(); let mut tokens = quote!( #[automatically_derived] impl #impl_generics ::sqlx::Type< DB > for #ident #ty_generics #where_clause { fn type_info() -> DB::TypeInfo { <#ty as ::sqlx::Type>::type_info() } fn compatible(ty: &DB::TypeInfo) -> ::std::primitive::bool { <#ty as ::sqlx::Type>::compatible(ty) } } ); if cfg!(feature = "postgres") && !attr.no_pg_array { tokens.extend(quote!( #[automatically_derived] impl #array_impl_generics ::sqlx::postgres::PgHasArrayType for #ident #ty_generics #array_where_clause { fn array_type_info() -> ::sqlx::postgres::PgTypeInfo { <#ty as ::sqlx::postgres::PgHasArrayType>::array_type_info() } } )); } return Ok(tokens); } let mut tts = TokenStream::new(); if cfg!(feature = "postgres") { let ty_name = type_name(ident, attr.type_name.as_ref()); tts.extend(quote!( #[automatically_derived] impl ::sqlx::Type<::sqlx::postgres::Postgres> for #ident #ty_generics { fn type_info() -> ::sqlx::postgres::PgTypeInfo { ::sqlx::postgres::PgTypeInfo::with_name(#ty_name) } } )); } Ok(tts) } fn expand_derive_has_sql_type_weak_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { let attr = check_weak_enum_attributes(input, variants)?; let repr = attr.repr.unwrap(); let ident = &input.ident; let ts = quote!( #[automatically_derived] impl ::sqlx::Type for #ident where #repr: ::sqlx::Type, { fn type_info() -> DB::TypeInfo { <#repr as ::sqlx::Type>::type_info() } fn compatible(ty: &DB::TypeInfo) -> bool { <#repr as ::sqlx::Type>::compatible(ty) } } ); Ok(ts) } fn expand_derive_has_sql_type_strong_enum( input: &DeriveInput, variants: &Punctuated, ) -> syn::Result { let attributes = check_strong_enum_attributes(input, variants)?; let ident = &input.ident; let mut tts = TokenStream::new(); if cfg!(feature = "mysql") { tts.extend(quote!( #[automatically_derived] impl ::sqlx::Type<::sqlx::MySql> for #ident { fn type_info() -> ::sqlx::mysql::MySqlTypeInfo { ::sqlx::mysql::MySqlTypeInfo::__enum() } fn compatible(ty: &::sqlx::mysql::MySqlTypeInfo) -> ::std::primitive::bool { *ty == ::sqlx::mysql::MySqlTypeInfo::__enum() } } )); } if cfg!(feature = "postgres") { let ty_name = type_name(ident, attributes.type_name.as_ref()); tts.extend(quote!( #[automatically_derived] impl ::sqlx::Type<::sqlx::Postgres> for #ident { fn type_info() -> ::sqlx::postgres::PgTypeInfo { ::sqlx::postgres::PgTypeInfo::with_name(#ty_name) } } )); } if cfg!(feature = "sqlite") { tts.extend(quote!( #[automatically_derived] impl sqlx::Type<::sqlx::Sqlite> for #ident { fn type_info() -> ::sqlx::sqlite::SqliteTypeInfo { <::std::primitive::str as ::sqlx::Type>::type_info() } fn compatible(ty: &::sqlx::sqlite::SqliteTypeInfo) -> ::std::primitive::bool { <&::std::primitive::str as ::sqlx::types::Type>::compatible(ty) } } )); } Ok(tts) } fn expand_derive_has_sql_type_struct( input: &DeriveInput, fields: &Punctuated, ) -> syn::Result { let attributes = check_struct_attributes(input, fields)?; let ident = &input.ident; let mut tts = TokenStream::new(); if cfg!(feature = "postgres") { let ty_name = type_name(ident, attributes.type_name.as_ref()); tts.extend(quote!( #[automatically_derived] impl ::sqlx::Type<::sqlx::Postgres> for #ident { fn type_info() -> ::sqlx::postgres::PgTypeInfo { ::sqlx::postgres::PgTypeInfo::with_name(#ty_name) } } )); } Ok(tts) } fn type_name(ident: &Ident, explicit_name: Option<&TypeName>) -> TokenStream { explicit_name.map(|tn| tn.get()).unwrap_or_else(|| { let s = ident.to_string(); quote_spanned!(ident.span()=> #s) }) } sqlx-macros-core-0.7.3/src/lib.rs000064400000000000000000000042750072674642500147750ustar 00000000000000//! Support crate for SQLx's proc macros. //! //! ### Note: Semver Exempt API //! The API of this crate is not meant for general use and does *not* follow Semantic Versioning. //! The only crate that follows Semantic Versioning in the project is the `sqlx` crate itself. //! If you are building a custom SQLx driver, you should pin an exact version of this and //! `sqlx-core` to avoid breakages: //! //! ```toml //! sqlx-core = "=0.6.2" //! sqlx-macros-core = "=0.6.2" //! ``` //! //! And then make releases in lockstep with `sqlx-core`. We recommend all driver crates, in-tree //! or otherwise, use the same version numbers as `sqlx-core` to avoid confusion. #![cfg_attr( any(sqlx_macros_unstable, procmacro2_semver_exempt), feature(track_path) )] use crate::query::QueryDriver; pub type Error = Box; pub type Result = std::result::Result; mod common; mod database; pub mod derives; pub mod query; // The compiler gives misleading help messages about `#[cfg(test)]` when this is just named `test`. pub mod test_attr; #[cfg(feature = "migrate")] pub mod migrate; pub const FOSS_DRIVERS: &[QueryDriver] = &[ #[cfg(feature = "mysql")] QueryDriver::new::(), #[cfg(feature = "postgres")] QueryDriver::new::(), #[cfg(feature = "sqlite")] QueryDriver::new::(), ]; pub fn block_on(f: F) -> F::Output where F: std::future::Future, { #[cfg(feature = "_rt-tokio")] { use once_cell::sync::Lazy; use tokio::runtime::{self, Runtime}; // We need a single, persistent Tokio runtime since we're caching connections, // otherwise we'll get "IO driver has terminated" errors. static TOKIO_RT: Lazy = Lazy::new(|| { runtime::Builder::new_current_thread() .enable_all() .build() .expect("failed to start Tokio runtime") }); return TOKIO_RT.block_on(f); } #[cfg(all(feature = "_rt-async-std", not(feature = "tokio")))] return async_std::task::block_on(f); #[cfg(not(any(feature = "_rt-async-std", feature = "tokio")))] sqlx_core::rt::missing_rt(f) } sqlx-macros-core-0.7.3/src/migrate.rs000064400000000000000000000107470072674642500156600ustar 00000000000000#[cfg(any(sqlx_macros_unstable, procmacro2_semver_exempt))] extern crate proc_macro; use proc_macro2::TokenStream; use quote::{quote, ToTokens, TokenStreamExt}; use sha2::{Digest, Sha384}; use sqlx_core::migrate::MigrationType; use std::fs; use std::path::Path; use syn::LitStr; pub struct QuotedMigrationType(MigrationType); impl ToTokens for QuotedMigrationType { fn to_tokens(&self, tokens: &mut TokenStream) { let ts = match self.0 { MigrationType::Simple => quote! { ::sqlx::migrate::MigrationType::Simple }, MigrationType::ReversibleUp => quote! { ::sqlx::migrate::MigrationType::ReversibleUp }, MigrationType::ReversibleDown => { quote! { ::sqlx::migrate::MigrationType::ReversibleDown } } }; tokens.append_all(ts.into_iter()); } } struct QuotedMigration { version: i64, description: String, migration_type: QuotedMigrationType, path: String, checksum: Vec, } impl ToTokens for QuotedMigration { fn to_tokens(&self, tokens: &mut TokenStream) { let QuotedMigration { version, description, migration_type, path, checksum, } = &self; let ts = quote! { ::sqlx::migrate::Migration { version: #version, description: ::std::borrow::Cow::Borrowed(#description), migration_type: #migration_type, // this tells the compiler to watch this path for changes sql: ::std::borrow::Cow::Borrowed(include_str!(#path)), checksum: ::std::borrow::Cow::Borrowed(&[ #(#checksum),* ]), } }; tokens.append_all(ts.into_iter()); } } // mostly copied from sqlx-core/src/migrate/source.rs pub fn expand_migrator_from_lit_dir(dir: LitStr) -> crate::Result { expand_migrator_from_dir(&dir.value(), dir.span()) } pub(crate) fn expand_migrator_from_dir( dir: &str, err_span: proc_macro2::Span, ) -> crate::Result { let path = crate::common::resolve_path(dir, err_span)?; expand_migrator(&path) } pub(crate) fn expand_migrator(path: &Path) -> crate::Result { let mut migrations = Vec::new(); for entry in fs::read_dir(&path)? { let entry = entry?; if !fs::metadata(entry.path())?.is_file() { // not a file; ignore continue; } let file_name = entry.file_name(); let file_name = file_name.to_string_lossy(); let parts = file_name.splitn(2, '_').collect::>(); if parts.len() != 2 || !parts[1].ends_with(".sql") { // not of the format: _.sql; ignore continue; } let version: i64 = parts[0].parse()?; let migration_type = MigrationType::from_filename(parts[1]); // remove the `.sql` and replace `_` with ` ` let description = parts[1] .trim_end_matches(migration_type.suffix()) .replace('_', " ") .to_owned(); let sql = fs::read_to_string(&entry.path())?; let checksum = Vec::from(Sha384::digest(sql.as_bytes()).as_slice()); // canonicalize the path so we can pass it to `include_str!()` let path = entry.path().canonicalize()?; let path = path .to_str() .ok_or_else(|| { format!( "migration path cannot be represented as a string: {:?}", path ) })? .to_owned(); migrations.push(QuotedMigration { version, description, migration_type: QuotedMigrationType(migration_type), path, checksum, }) } // ensure that we are sorted by `VERSION ASC` migrations.sort_by_key(|m| m.version); #[cfg(any(sqlx_macros_unstable, procmacro2_semver_exempt))] { let path = path.canonicalize()?; let path = path.to_str().ok_or_else(|| { format!( "migration directory path cannot be represented as a string: {:?}", path ) })?; proc_macro::tracked_path::path(path); } Ok(quote! { ::sqlx::migrate::Migrator { migrations: ::std::borrow::Cow::Borrowed(&[ #(#migrations),* ]), ignore_missing: false, locking: true, } }) } sqlx-macros-core-0.7.3/src/query/args.rs000064400000000000000000000151610072674642500163240ustar 00000000000000use crate::database::DatabaseExt; use crate::query::QueryMacroInput; use either::Either; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote, quote_spanned}; use sqlx_core::describe::Describe; use syn::spanned::Spanned; use syn::{Expr, ExprCast, ExprGroup, ExprType, Type}; /// Returns a tokenstream which typechecks the arguments passed to the macro /// and binds them to `DB::Arguments` with the ident `query_args`. pub fn quote_args( input: &QueryMacroInput, info: &Describe, ) -> crate::Result { let db_path = DB::db_path(); if input.arg_exprs.is_empty() { return Ok(quote! { let query_args = <#db_path as ::sqlx::database::HasArguments>::Arguments::default(); }); } let arg_names = (0..input.arg_exprs.len()) .map(|i| format_ident!("arg{}", i)) .collect::>(); let arg_name = &arg_names; let arg_expr = input.arg_exprs.iter().cloned().map(strip_wildcard); let arg_bindings = quote! { #(let #arg_name = &(#arg_expr);)* }; let args_check = match info.parameters() { None | Some(Either::Right(_)) => { // all we can do is check arity which we did TokenStream::new() } Some(Either::Left(_)) if !input.checked => { // this is an `*_unchecked!()` macro invocation TokenStream::new() } Some(Either::Left(params)) => { params .iter() .zip(arg_names.iter().zip(&input.arg_exprs)) .enumerate() .map(|(i, (param_ty, (name, expr)))| -> crate::Result<_> { let param_ty = match get_type_override(expr) { // cast will fail to compile if the type does not match // and we strip casts to wildcard Some((_, false)) => return Ok(quote!()), // type ascription is deprecated Some((ty, true)) => return Ok(create_warning(name.clone(), &ty, &expr)), None => { DB::param_type_for_id(¶m_ty) .ok_or_else(|| { if let Some(feature_gate) = ::get_feature_gate(¶m_ty) { format!( "optional sqlx feature `{}` required for type {} of param #{}", feature_gate, param_ty, i + 1, ) } else { format!("unsupported type {} for param #{}", param_ty, i + 1) } })? .parse::() .map_err(|_| format!("Rust type mapping for {param_ty} not parsable"))? } }; Ok(quote_spanned!(expr.span() => // this shouldn't actually run if false { use ::sqlx::ty_match::{WrapSameExt as _, MatchBorrowExt as _}; // evaluate the expression only once in case it contains moves let expr = ::sqlx::ty_match::dupe_value(#name); // if `expr` is `Option`, get `Option<$ty>`, otherwise `$ty` let ty_check = ::sqlx::ty_match::WrapSame::<#param_ty, _>::new(&expr).wrap_same(); // if `expr` is `&str`, convert `String` to `&str` let (mut _ty_check, match_borrow) = ::sqlx::ty_match::MatchBorrow::new(ty_check, &expr); _ty_check = match_borrow.match_borrow(); // this causes move-analysis to effectively ignore this block ::std::panic!(); } )) }) .collect::>()? } }; let args_count = input.arg_exprs.len(); Ok(quote! { #arg_bindings #args_check let mut query_args = <#db_path as ::sqlx::database::HasArguments>::Arguments::default(); query_args.reserve( #args_count, 0 #(+ ::sqlx::encode::Encode::<#db_path>::size_hint(#arg_name))* ); #(query_args.add(#arg_name);)* }) } fn create_warning(name: Ident, ty: &Type, expr: &Expr) -> TokenStream { let Expr::Type(ExprType { expr: stripped, .. }) = expr else { return quote!(); }; let current = quote!(#stripped: #ty).to_string(); let fix = quote!(#stripped as #ty).to_string(); let name = Ident::new(&format!("warning_{name}"), expr.span()); let message = format!( " \t\tType ascription pattern is deprecated, prefer casting \t\tTry changing from \t\t\t`{current}` \t\tto \t\t\t`{fix}` \t\tSee for more information " ); quote_spanned!(expr.span() => // this shouldn't actually run if false { #[deprecated(note = #message)] #[allow(non_upper_case_globals)] const #name: () = (); let _ = #name; } ) } fn get_type_override(expr: &Expr) -> Option<(&Type, bool)> { match expr { Expr::Group(group) => get_type_override(&group.expr), Expr::Cast(cast) => Some((&cast.ty, false)), Expr::Type(ascription) => Some((&ascription.ty, true)), _ => None, } } fn strip_wildcard(expr: Expr) -> Expr { match expr { Expr::Group(ExprGroup { attrs, group_token, expr, }) => Expr::Group(ExprGroup { attrs, group_token, expr: Box::new(strip_wildcard(*expr)), }), // type ascription syntax is experimental so we always strip it Expr::Type(ExprType { expr, .. }) => *expr, // we want to retain casts if they semantically matter Expr::Cast(ExprCast { attrs, expr, as_token, ty, }) => match *ty { // cast to wildcard `_` will produce weird errors; we interpret it as taking the value as-is Type::Infer(_) => *expr, _ => Expr::Cast(ExprCast { attrs, expr, as_token, ty, }), }, _ => expr, } } sqlx-macros-core-0.7.3/src/query/data.rs000064400000000000000000000127120072674642500163000ustar 00000000000000use std::collections::HashMap; use std::fmt::{Debug, Display, Formatter}; use std::fs; use std::io::Write as _; use std::marker::PhantomData; use std::path::{Path, PathBuf}; use std::sync::Mutex; use once_cell::sync::Lazy; use serde::{Serialize, Serializer}; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use crate::database::DatabaseExt; #[derive(serde::Serialize)] #[serde(bound(serialize = "Describe: serde::Serialize"))] #[derive(Debug)] pub struct QueryData { db_name: SerializeDbName, #[allow(dead_code)] pub(super) query: String, pub(super) describe: Describe, pub(super) hash: String, } impl QueryData { pub fn from_describe(query: &str, describe: Describe) -> Self { QueryData { db_name: SerializeDbName::default(), query: query.into(), describe, hash: hash_string(query), } } } struct SerializeDbName(PhantomData); impl Default for SerializeDbName { fn default() -> Self { SerializeDbName(PhantomData) } } impl Debug for SerializeDbName { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_tuple("SerializeDbName").field(&DB::NAME).finish() } } impl Display for SerializeDbName { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.pad(DB::NAME) } } impl Serialize for SerializeDbName { fn serialize(&self, serializer: S) -> Result where S: Serializer, { serializer.serialize_str(DB::NAME) } } static OFFLINE_DATA_CACHE: Lazy>> = Lazy::new(Default::default); /// Offline query data #[derive(Clone, serde::Deserialize)] pub struct DynQueryData { pub db_name: String, pub query: String, pub describe: serde_json::Value, pub hash: String, } impl DynQueryData { /// Loads a query given the path to its "query-.json" file. Subsequent calls for the same /// path are retrieved from an in-memory cache. pub fn from_data_file(path: impl AsRef, query: &str) -> crate::Result { let path = path.as_ref(); let mut cache = OFFLINE_DATA_CACHE .lock() // Just reset the cache on error .unwrap_or_else(|posion_err| { let mut guard = posion_err.into_inner(); *guard = Default::default(); guard }); if let Some(cached) = cache.get(path).cloned() { if query != cached.query { return Err("hash collision for saved query data".into()); } return Ok(cached); } #[cfg(procmacr2_semver_exempt)] { let path = path.as_ref().canonicalize()?; let path = path.to_str().ok_or_else(|| { format!( "query-.json path cannot be represented as a string: {:?}", path ) })?; proc_macro::tracked_path::path(path); } let offline_data_contents = fs::read_to_string(path) .map_err(|e| format!("failed to read saved query path {}: {}", path.display(), e))?; let dyn_data: DynQueryData = serde_json::from_str(&offline_data_contents)?; if query != dyn_data.query { return Err("hash collision for saved query data".into()); } let _ = cache.insert(path.to_owned(), dyn_data.clone()); Ok(dyn_data) } } impl QueryData where Describe: serde::Serialize + serde::de::DeserializeOwned, { pub fn from_dyn_data(dyn_data: DynQueryData) -> crate::Result { assert!(!dyn_data.db_name.is_empty()); assert!(!dyn_data.hash.is_empty()); if DB::NAME == dyn_data.db_name { let describe: Describe = serde_json::from_value(dyn_data.describe)?; Ok(QueryData { db_name: SerializeDbName::default(), query: dyn_data.query, describe, hash: dyn_data.hash, }) } else { Err(format!( "expected query data for {}, got data for {}", DB::NAME, dyn_data.db_name ) .into()) } } pub(super) fn save_in(&self, dir: impl AsRef) -> crate::Result<()> { let path = dir.as_ref().join(format!("query-{}.json", self.hash)); let mut file = atomic_write_file::AtomicWriteFile::open(&path) .map_err(|err| format!("failed to open the temporary file: {err:?}"))?; serde_json::to_writer_pretty(file.as_file_mut(), self) .map_err(|err| format!("failed to serialize query data to file: {err:?}"))?; // Ensure there is a newline at the end of the JSON file to avoid // accidental modification by IDE and make github diff tool happier. file.as_file_mut() .write_all(b"\n") .map_err(|err| format!("failed to append a newline to file: {err:?}"))?; file.commit() .map_err(|err| format!("failed to commit the query data to {path:?}: {err:?}"))?; Ok(()) } } pub(super) fn hash_string(query: &str) -> String { // picked `sha2` because it's already in the dependency tree for both MySQL and Postgres use sha2::{Digest, Sha256}; hex::encode(Sha256::digest(query.as_bytes())) } sqlx-macros-core-0.7.3/src/query/input.rs000064400000000000000000000114200072674642500165210ustar 00000000000000use std::fs; use proc_macro2::{Ident, Span}; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::{Expr, LitBool, LitStr, Token}; use syn::{ExprArray, Type}; /// Macro input shared by `query!()` and `query_file!()` pub struct QueryMacroInput { pub(super) sql: String, pub(super) src_span: Span, pub(super) record_type: RecordType, pub(super) arg_exprs: Vec, pub(super) checked: bool, pub(super) file_path: Option, } enum QuerySrc { String(String), File(String), } pub enum RecordType { Given(Type), Scalar, Generated, } impl Parse for QueryMacroInput { fn parse(input: ParseStream) -> syn::Result { let mut query_src: Option<(QuerySrc, Span)> = None; let mut args: Option> = None; let mut record_type = RecordType::Generated; let mut checked = true; let mut expect_comma = false; while !input.is_empty() { if expect_comma { let _ = input.parse::()?; } let key: Ident = input.parse()?; let _ = input.parse::()?; if key == "source" { let span = input.span(); let query_str = Punctuated::::parse_separated_nonempty(input)? .iter() .map(LitStr::value) .collect(); query_src = Some((QuerySrc::String(query_str), span)); } else if key == "source_file" { let lit_str = input.parse::()?; query_src = Some((QuerySrc::File(lit_str.value()), lit_str.span())); } else if key == "args" { let exprs = input.parse::()?; args = Some(exprs.elems.into_iter().collect()) } else if key == "record" { if !matches!(record_type, RecordType::Generated) { return Err(input.error("colliding `scalar` or `record` key")); } record_type = RecordType::Given(input.parse()?); } else if key == "scalar" { if !matches!(record_type, RecordType::Generated) { return Err(input.error("colliding `scalar` or `record` key")); } // we currently expect only `scalar = _` // a `query_as_scalar!()` variant seems less useful than just overriding the type // of the column in SQL input.parse::()?; record_type = RecordType::Scalar; } else if key == "checked" { let lit_bool = input.parse::()?; checked = lit_bool.value; } else { let message = format!("unexpected input key: {key}"); return Err(syn::Error::new_spanned(key, message)); } expect_comma = true; } let (src, src_span) = query_src.ok_or_else(|| input.error("expected `source` or `source_file` key"))?; let arg_exprs = args.unwrap_or_default(); let file_path = src.file_path(src_span)?; Ok(QueryMacroInput { sql: src.resolve(src_span)?, src_span, record_type, arg_exprs, checked, file_path, }) } } impl QuerySrc { /// If the query source is a file, read it to a string. Otherwise return the query string. fn resolve(self, source_span: Span) -> syn::Result { match self { QuerySrc::String(string) => Ok(string), QuerySrc::File(file) => read_file_src(&file, source_span), } } fn file_path(&self, source_span: Span) -> syn::Result> { if let QuerySrc::File(ref file) = *self { let path = crate::common::resolve_path(file, source_span)? .canonicalize() .map_err(|e| syn::Error::new(source_span, e))?; Ok(Some( path.to_str() .ok_or_else(|| { syn::Error::new( source_span, "query file path cannot be represented as a string", ) })? .to_string(), )) } else { Ok(None) } } } fn read_file_src(source: &str, source_span: Span) -> syn::Result { let file_path = crate::common::resolve_path(source, source_span)?; fs::read_to_string(&file_path).map_err(|e| { syn::Error::new( source_span, format!( "failed to read query file at {}: {}", file_path.display(), e ), ) }) } sqlx-macros-core-0.7.3/src/query/mod.rs000064400000000000000000000300650072674642500161470ustar 00000000000000use std::path::PathBuf; use std::sync::{Arc, Mutex}; use std::{fs, io}; use once_cell::sync::Lazy; use proc_macro2::TokenStream; use syn::Type; pub use input::QueryMacroInput; use quote::{format_ident, quote}; use sqlx_core::database::Database; use sqlx_core::{column::Column, describe::Describe, type_info::TypeInfo}; use crate::database::DatabaseExt; use crate::query::data::{hash_string, DynQueryData, QueryData}; use crate::query::input::RecordType; use either::Either; use url::Url; mod args; mod data; mod input; mod output; #[derive(Copy, Clone)] pub struct QueryDriver { db_name: &'static str, url_schemes: &'static [&'static str], expand: fn(QueryMacroInput, QueryDataSource) -> crate::Result, } impl QueryDriver { pub const fn new() -> Self where Describe: serde::Serialize + serde::de::DeserializeOwned, { QueryDriver { db_name: DB::NAME, url_schemes: DB::URL_SCHEMES, expand: expand_with::, } } } pub enum QueryDataSource<'a> { Live { database_url: &'a str, database_url_parsed: Url, }, Cached(DynQueryData), } impl<'a> QueryDataSource<'a> { pub fn live(database_url: &'a str) -> crate::Result { Ok(QueryDataSource::Live { database_url, database_url_parsed: database_url.parse()?, }) } pub fn matches_driver(&self, driver: &QueryDriver) -> bool { match self { Self::Live { database_url_parsed, .. } => driver.url_schemes.contains(&database_url_parsed.scheme()), Self::Cached(dyn_data) => dyn_data.db_name == driver.db_name, } } } struct Metadata { #[allow(unused)] manifest_dir: PathBuf, offline: bool, database_url: Option, workspace_root: Arc>>, } impl Metadata { pub fn workspace_root(&self) -> PathBuf { let mut root = self.workspace_root.lock().unwrap(); if root.is_none() { use serde::Deserialize; use std::process::Command; let cargo = env("CARGO").expect("`CARGO` must be set"); let output = Command::new(&cargo) .args(&["metadata", "--format-version=1", "--no-deps"]) .current_dir(&self.manifest_dir) .env_remove("__CARGO_FIX_PLZ") .output() .expect("Could not fetch metadata"); #[derive(Deserialize)] struct CargoMetadata { workspace_root: PathBuf, } let metadata: CargoMetadata = serde_json::from_slice(&output.stdout).expect("Invalid `cargo metadata` output"); *root = Some(metadata.workspace_root); } root.clone().unwrap() } } // If we are in a workspace, lookup `workspace_root` since `CARGO_MANIFEST_DIR` won't // reflect the workspace dir: https://github.com/rust-lang/cargo/issues/3946 static METADATA: Lazy = Lazy::new(|| { let manifest_dir: PathBuf = env("CARGO_MANIFEST_DIR") .expect("`CARGO_MANIFEST_DIR` must be set") .into(); // If a .env file exists at CARGO_MANIFEST_DIR, load environment variables from this, // otherwise fallback to default dotenv behaviour. let env_path = manifest_dir.join(".env"); #[cfg_attr(not(procmacro2_semver_exempt), allow(unused_variables))] let env_path = if env_path.exists() { let res = dotenvy::from_path(&env_path); if let Err(e) = res { panic!("failed to load environment from {env_path:?}, {e}"); } Some(env_path) } else { dotenvy::dotenv().ok() }; // tell the compiler to watch the `.env` for changes, if applicable #[cfg(procmacro2_semver_exempt)] if let Some(env_path) = env_path.as_ref().and_then(|path| path.to_str()) { proc_macro::tracked_path::path(env_path); } let offline = env("SQLX_OFFLINE") .map(|s| s.eq_ignore_ascii_case("true") || s == "1") .unwrap_or(false); let database_url = env("DATABASE_URL").ok(); Metadata { manifest_dir, offline, database_url, workspace_root: Arc::new(Mutex::new(None)), } }); pub fn expand_input<'a>( input: QueryMacroInput, drivers: impl IntoIterator, ) -> crate::Result { let data_source = match &*METADATA { Metadata { offline: false, database_url: Some(db_url), .. } => QueryDataSource::live(db_url)?, Metadata { offline, .. } => { // Try load the cached query metadata file. let filename = format!("query-{}.json", hash_string(&input.sql)); // Check SQLX_OFFLINE_DIR, then local .sqlx, then workspace .sqlx. let dirs = [ env("SQLX_OFFLINE_DIR").ok().map(PathBuf::from), Some(METADATA.manifest_dir.join(".sqlx")), Some(METADATA.workspace_root().join(".sqlx")), ]; let Some(data_file_path) = dirs .iter() .filter_map(|path| path.as_ref()) .map(|path| path.join(&filename)) .find(|path| path.exists()) else { return Err( if *offline { "`SQLX_OFFLINE=true` but there is no cached data for this query, run `cargo sqlx prepare` to update the query cache or unset `SQLX_OFFLINE`" } else { "set `DATABASE_URL` to use query macros online, or run `cargo sqlx prepare` to update the query cache" }.into() ); }; QueryDataSource::Cached(DynQueryData::from_data_file(&data_file_path, &input.sql)?) } }; for driver in drivers { if data_source.matches_driver(&driver) { return (driver.expand)(input, data_source); } } match data_source { QueryDataSource::Live { database_url_parsed, .. } => Err(format!( "no database driver found matching URL scheme {:?}; the corresponding Cargo feature may need to be enabled", database_url_parsed.scheme() ).into()), QueryDataSource::Cached(data) => { Err(format!( "found cached data for database {:?} but no matching driver; the corresponding Cargo feature may need to be enabled", data.db_name ).into()) } } } fn expand_with( input: QueryMacroInput, data_source: QueryDataSource, ) -> crate::Result where Describe: DescribeExt, { let (query_data, offline): (QueryData, bool) = match data_source { QueryDataSource::Cached(dyn_data) => (QueryData::from_dyn_data(dyn_data)?, true), QueryDataSource::Live { database_url, .. } => { let describe = DB::describe_blocking(&input.sql, &database_url)?; (QueryData::from_describe(&input.sql, describe), false) } }; expand_with_data(input, query_data, offline) } // marker trait for `Describe` that lets us conditionally require it to be `Serialize + Deserialize` trait DescribeExt: serde::Serialize + serde::de::DeserializeOwned {} impl DescribeExt for Describe where Describe: serde::Serialize + serde::de::DeserializeOwned { } fn expand_with_data( input: QueryMacroInput, data: QueryData, offline: bool, ) -> crate::Result where Describe: DescribeExt, { // validate at the minimum that our args match the query's input parameters let num_parameters = match data.describe.parameters() { Some(Either::Left(params)) => Some(params.len()), Some(Either::Right(num)) => Some(num), None => None, }; if let Some(num) = num_parameters { if num != input.arg_exprs.len() { return Err( format!("expected {} parameters, got {}", num, input.arg_exprs.len()).into(), ); } } let args_tokens = args::quote_args(&input, &data.describe)?; let query_args = format_ident!("query_args"); let output = if data .describe .columns() .iter() .all(|it| it.type_info().is_void()) { let db_path = DB::db_path(); let sql = &input.sql; quote! { ::sqlx::query_with::<#db_path, _>(#sql, #query_args) } } else { match input.record_type { RecordType::Generated => { let columns = output::columns_to_rust::(&data.describe)?; let record_name: Type = syn::parse_str("Record").unwrap(); for rust_col in &columns { if rust_col.type_.is_wildcard() { return Err( "wildcard overrides are only allowed with an explicit record type, \ e.g. `query_as!()` and its variants" .into(), ); } } let record_fields = columns.iter().map( |&output::RustColumn { ref ident, ref type_, .. }| quote!(#ident: #type_,), ); let mut record_tokens = quote! { #[derive(Debug)] struct #record_name { #(#record_fields)* } }; record_tokens.extend(output::quote_query_as::( &input, &record_name, &query_args, &columns, )); record_tokens } RecordType::Given(ref out_ty) => { let columns = output::columns_to_rust::(&data.describe)?; output::quote_query_as::(&input, out_ty, &query_args, &columns) } RecordType::Scalar => { output::quote_query_scalar::(&input, &query_args, &data.describe)? } } }; let ret_tokens = quote! { { #[allow(clippy::all)] { use ::sqlx::Arguments as _; #args_tokens #output } } }; // Store query metadata only if offline support is enabled but the current build is online. // If the build is offline, the cache is our input so it's pointless to also write data for it. if !offline { // Only save query metadata if SQLX_OFFLINE_DIR is set manually or by `cargo sqlx prepare`. // Note: in a cargo workspace this path is relative to the root. if let Ok(dir) = env("SQLX_OFFLINE_DIR") { let path = PathBuf::from(&dir); match fs::metadata(&path) { Err(e) => { if e.kind() != io::ErrorKind::NotFound { // Can't obtain information about .sqlx return Err(format!("{e}: {dir}").into()); } // .sqlx doesn't exist. return Err(format!("sqlx offline path does not exist: {dir}").into()); } Ok(meta) => { if !meta.is_dir() { return Err(format!( "sqlx offline path exists, but is not a directory: {dir}" ) .into()); } // .sqlx exists and is a directory, store data. data.save_in(path)?; } } } } Ok(ret_tokens) } /// Get the value of an environment variable, telling the compiler about it if applicable. fn env(name: &str) -> Result { #[cfg(procmacro2_semver_exempt)] { proc_macro::tracked_env::var(name) } #[cfg(not(procmacro2_semver_exempt))] { std::env::var(name) } } sqlx-macros-core-0.7.3/src/query/output.rs000064400000000000000000000242640072674642500167340ustar 00000000000000use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens, TokenStreamExt}; use syn::Type; use sqlx_core::column::Column; use sqlx_core::describe::Describe; use crate::database::DatabaseExt; use crate::query::QueryMacroInput; use std::fmt::{self, Display, Formatter}; use syn::parse::{Parse, ParseStream}; use syn::Token; pub struct RustColumn { pub(super) ident: Ident, pub(super) var_name: Ident, pub(super) type_: ColumnType, } pub(super) enum ColumnType { Exact(TokenStream), Wildcard, OptWildcard, } impl ColumnType { pub(super) fn is_wildcard(&self) -> bool { !matches!(self, ColumnType::Exact(_)) } } impl ToTokens for ColumnType { fn to_tokens(&self, tokens: &mut TokenStream) { tokens.append_all(match &self { ColumnType::Exact(type_) => type_.clone().into_iter(), ColumnType::Wildcard => quote! { _ }.into_iter(), ColumnType::OptWildcard => quote! { ::std::option::Option<_> }.into_iter(), }) } } struct DisplayColumn<'a> { // zero-based index, converted to 1-based number idx: usize, name: &'a str, } struct ColumnDecl { ident: Ident, r#override: ColumnOverride, } struct ColumnOverride { nullability: ColumnNullabilityOverride, type_: ColumnTypeOverride, } #[derive(PartialEq)] enum ColumnNullabilityOverride { NonNull, Nullable, None, } enum ColumnTypeOverride { Exact(Type), Wildcard, None, } impl Display for DisplayColumn<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "column #{} ({:?})", self.idx + 1, self.name) } } pub fn columns_to_rust(describe: &Describe) -> crate::Result> { (0..describe.columns().len()) .map(|i| column_to_rust(describe, i)) .collect::>>() } fn column_to_rust(describe: &Describe, i: usize) -> crate::Result { let column = &describe.columns()[i]; // add raw prefix to all identifiers let decl = ColumnDecl::parse(&column.name()) .map_err(|e| format!("column name {:?} is invalid: {}", column.name(), e))?; let ColumnOverride { nullability, type_ } = decl.r#override; let nullable = match nullability { ColumnNullabilityOverride::NonNull => false, ColumnNullabilityOverride::Nullable => true, ColumnNullabilityOverride::None => describe.nullable(i).unwrap_or(true), }; let type_ = match (type_, nullable) { (ColumnTypeOverride::Exact(type_), false) => ColumnType::Exact(type_.to_token_stream()), (ColumnTypeOverride::Exact(type_), true) => { ColumnType::Exact(quote! { ::std::option::Option<#type_> }) } (ColumnTypeOverride::Wildcard, false) => ColumnType::Wildcard, (ColumnTypeOverride::Wildcard, true) => ColumnType::OptWildcard, (ColumnTypeOverride::None, _) => { let type_ = get_column_type::(i, column); if !nullable { ColumnType::Exact(type_) } else { ColumnType::Exact(quote! { ::std::option::Option<#type_> }) } } }; Ok(RustColumn { // prefix the variable name we use in `quote_query_as!()` so it doesn't conflict // https://github.com/launchbadge/sqlx/issues/1322 var_name: quote::format_ident!("sqlx_query_as_{}", decl.ident), ident: decl.ident, type_, }) } pub fn quote_query_as( input: &QueryMacroInput, out_ty: &Type, bind_args: &Ident, columns: &[RustColumn], ) -> TokenStream { let instantiations = columns.iter().enumerate().map( |( i, &RustColumn { ref var_name, ref type_, .. }, )| { match (input.checked, type_) { // we guarantee the type is valid so we can skip the runtime check (true, ColumnType::Exact(type_)) => quote! { // binding to a `let` avoids confusing errors about // "try expression alternatives have incompatible types" // it doesn't seem to hurt inference in the other branches let #var_name = row.try_get_unchecked::<#type_, _>(#i)?.into(); }, // type was overridden to be a wildcard so we fallback to the runtime check (true, ColumnType::Wildcard) => quote! ( let #var_name = row.try_get(#i)?; ), (true, ColumnType::OptWildcard) => { quote! ( let #var_name = row.try_get::<::std::option::Option<_>, _>(#i)?; ) } // macro is the `_unchecked!()` variant so this will die in decoding if it's wrong (false, _) => quote!( let #var_name = row.try_get_unchecked(#i)?; ), } }, ); let ident = columns.iter().map(|col| &col.ident); let var_name = columns.iter().map(|col| &col.var_name); let db_path = DB::db_path(); let row_path = DB::row_path(); // if this query came from a file, use `include_str!()` to tell the compiler where it came from let sql = if let Some(ref path) = &input.file_path { quote::quote_spanned! { input.src_span => include_str!(#path) } } else { let sql = &input.sql; quote! { #sql } }; quote! { ::sqlx::query_with::<#db_path, _>(#sql, #bind_args).try_map(|row: #row_path| { use ::sqlx::Row as _; #(#instantiations)* ::std::result::Result::Ok(#out_ty { #(#ident: #var_name),* }) }) } } pub fn quote_query_scalar( input: &QueryMacroInput, bind_args: &Ident, describe: &Describe, ) -> crate::Result { let columns = describe.columns(); if columns.len() != 1 { return Err(syn::Error::new( input.src_span, format!("expected exactly 1 column, got {}", columns.len()), ) .into()); } // attempt to parse a column override, otherwise fall back to the inferred type of the column let ty = if let Ok(rust_col) = column_to_rust(describe, 0) { rust_col.type_.to_token_stream() } else if input.checked { let ty = get_column_type::(0, &columns[0]); if describe.nullable(0).unwrap_or(true) { quote! { ::std::option::Option<#ty> } } else { ty } } else { quote! { _ } }; let db = DB::db_path(); let query = &input.sql; Ok(quote! { ::sqlx::query_scalar_with::<#db, #ty, _>(#query, #bind_args) }) } fn get_column_type(i: usize, column: &DB::Column) -> TokenStream { let type_info = &*column.type_info(); ::return_type_for_id(&type_info).map_or_else( || { let message = if let Some(feature_gate) = ::get_feature_gate(&type_info) { format!( "optional sqlx feature `{feat}` required for type {ty} of {col}", ty = &type_info, feat = feature_gate, col = DisplayColumn { idx: i, name: &*column.name() } ) } else { format!( "unsupported type {ty} of {col}", ty = type_info, col = DisplayColumn { idx: i, name: &*column.name() } ) }; syn::Error::new(Span::call_site(), message).to_compile_error() }, |t| t.parse().unwrap(), ) } impl ColumnDecl { fn parse(col_name: &str) -> crate::Result { // find the end of the identifier because we want to use our own logic to parse it // if we tried to feed this into `syn::parse_str()` we might get an un-great error // for some kinds of invalid identifiers let (ident, remainder) = if let Some(i) = col_name.find(&[':', '!', '?'][..]) { let (ident, remainder) = col_name.split_at(i); (parse_ident(ident)?, remainder) } else { (parse_ident(col_name)?, "") }; Ok(ColumnDecl { ident, r#override: if !remainder.is_empty() { syn::parse_str(remainder)? } else { ColumnOverride { nullability: ColumnNullabilityOverride::None, type_: ColumnTypeOverride::None, } }, }) } } impl Parse for ColumnOverride { fn parse(input: ParseStream) -> syn::Result { let lookahead = input.lookahead1(); let nullability = if lookahead.peek(Token![!]) { input.parse::()?; ColumnNullabilityOverride::NonNull } else if lookahead.peek(Token![?]) { input.parse::()?; ColumnNullabilityOverride::Nullable } else { ColumnNullabilityOverride::None }; let type_ = if input.lookahead1().peek(Token![:]) { input.parse::()?; let ty = Type::parse(input)?; if let Type::Infer(_) = ty { ColumnTypeOverride::Wildcard } else { ColumnTypeOverride::Exact(ty) } } else { ColumnTypeOverride::None }; Ok(Self { nullability, type_ }) } } fn parse_ident(name: &str) -> crate::Result { // workaround for the following issue (it's semi-fixed but still spits out extra diagnostics) // https://github.com/dtolnay/syn/issues/749#issuecomment-575451318 let is_valid_ident = !name.is_empty() && name.starts_with(|c: char| c.is_alphabetic() || c == '_') && name.chars().all(|c| c.is_alphanumeric() || c == '_'); if is_valid_ident { let ident = String::from("r#") + name; if let Ok(ident) = syn::parse_str(&ident) { return Ok(ident); } } Err(format!("{name:?} is not a valid Rust identifier").into()) } sqlx-macros-core-0.7.3/src/test_attr.rs000064400000000000000000000326760072674642500162460ustar 00000000000000use proc_macro2::TokenStream; use quote::quote; #[cfg(feature = "migrate")] struct Args { fixtures: Vec<(FixturesType, Vec)>, migrations: MigrationsOpt, } #[cfg(feature = "migrate")] enum FixturesType { None, RelativePath, CustomRelativePath(syn::LitStr), ExplicitPath, } #[cfg(feature = "migrate")] enum MigrationsOpt { InferredPath, ExplicitPath(syn::LitStr), ExplicitMigrator(syn::Path), Disabled, } pub fn expand(args: syn::AttributeArgs, input: syn::ItemFn) -> crate::Result { if input.sig.inputs.is_empty() { if !args.is_empty() { if cfg!(feature = "migrate") { return Err(syn::Error::new_spanned( args.first().unwrap(), "control attributes are not allowed unless \ the `migrate` feature is enabled and \ automatic test DB management is used; see docs", ) .into()); } return Err(syn::Error::new_spanned( args.first().unwrap(), "control attributes are not allowed unless \ automatic test DB management is used; see docs", ) .into()); } return Ok(expand_simple(input)); } #[cfg(feature = "migrate")] return expand_advanced(args, input); #[cfg(not(feature = "migrate"))] return Err(syn::Error::new_spanned(input, "`migrate` feature required").into()); } fn expand_simple(input: syn::ItemFn) -> TokenStream { let ret = &input.sig.output; let name = &input.sig.ident; let body = &input.block; let attrs = &input.attrs; quote! { #[::core::prelude::v1::test] #(#attrs)* fn #name() #ret { ::sqlx::test_block_on(async { #body }) } } } #[cfg(feature = "migrate")] fn expand_advanced(args: syn::AttributeArgs, input: syn::ItemFn) -> crate::Result { let ret = &input.sig.output; let name = &input.sig.ident; let inputs = &input.sig.inputs; let body = &input.block; let attrs = &input.attrs; let args = parse_args(args)?; let fn_arg_types = inputs.iter().map(|_| quote! { _ }); let mut fixtures = Vec::new(); for (fixture_type, fixtures_local) in args.fixtures { let mut res = match fixture_type { FixturesType::None => vec![], FixturesType::RelativePath => fixtures_local .into_iter() .map(|fixture| { let mut fixture_str = fixture.value(); add_sql_extension_if_missing(&mut fixture_str); let path = format!("fixtures/{}", fixture_str); quote! { ::sqlx::testing::TestFixture { path: #path, contents: include_str!(#path), } } }) .collect(), FixturesType::CustomRelativePath(path) => fixtures_local .into_iter() .map(|fixture| { let mut fixture_str = fixture.value(); add_sql_extension_if_missing(&mut fixture_str); let path = format!("{}/{}", path.value(), fixture_str); quote! { ::sqlx::testing::TestFixture { path: #path, contents: include_str!(#path), } } }) .collect(), FixturesType::ExplicitPath => fixtures_local .into_iter() .map(|fixture| { let path = fixture.value(); quote! { ::sqlx::testing::TestFixture { path: #path, contents: include_str!(#path), } } }) .collect(), }; fixtures.append(&mut res) } let migrations = match args.migrations { MigrationsOpt::ExplicitPath(path) => { let migrator = crate::migrate::expand_migrator_from_lit_dir(path)?; quote! { args.migrator(&#migrator); } } MigrationsOpt::InferredPath if !inputs.is_empty() => { let migrations_path = crate::common::resolve_path("./migrations", proc_macro2::Span::call_site())?; if migrations_path.is_dir() { let migrator = crate::migrate::expand_migrator(&migrations_path)?; quote! { args.migrator(&#migrator); } } else { quote! {} } } MigrationsOpt::ExplicitMigrator(path) => { quote! { args.migrator(&#path); } } _ => quote! {}, }; Ok(quote! { #[::core::prelude::v1::test] #(#attrs)* fn #name() #ret { async fn inner(#inputs) #ret { #body } let mut args = ::sqlx::testing::TestArgs::new(concat!(module_path!(), "::", stringify!(#name))); #migrations args.fixtures(&[#(#fixtures),*]); // We need to give a coercion site or else we get "unimplemented trait" errors. let f: fn(#(#fn_arg_types),*) -> _ = inner; ::sqlx::testing::TestFn::run_test(f, args) } }) } #[cfg(feature = "migrate")] fn parse_args(attr_args: syn::AttributeArgs) -> syn::Result { let mut fixtures = Vec::new(); let mut migrations = MigrationsOpt::InferredPath; for arg in attr_args { match arg { syn::NestedMeta::Meta(syn::Meta::List(list)) if list.path.is_ident("fixtures") => { let mut fixtures_local = vec![]; let mut fixtures_type = FixturesType::None; for nested in list.nested { match nested { syn::NestedMeta::Lit(syn::Lit::Str(litstr)) => { // fixtures("","") or fixtures("","") parse_fixtures_args(&mut fixtures_type, litstr, &mut fixtures_local)?; }, syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue)) if namevalue.path.is_ident("path") => { // fixtures(path = "", scripts("","")) checking `path` argument parse_fixtures_path_args(&mut fixtures_type, namevalue)?; }, syn::NestedMeta::Meta(syn::Meta::List(list)) if list.path.is_ident("scripts") => { // fixtures(path = "", scripts("","")) checking `scripts` argument parse_fixtures_scripts_args(&mut fixtures_type, list, &mut fixtures_local)?; } other => { return Err(syn::Error::new_spanned(other, "expected string literal")) } }; } fixtures.push((fixtures_type, fixtures_local)); } syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue)) if namevalue.path.is_ident("migrations") => { if !matches!(migrations, MigrationsOpt::InferredPath) { return Err(syn::Error::new_spanned( namevalue, "cannot have more than one `migrations` or `migrator` arg", )); } migrations = match namevalue.lit { syn::Lit::Bool(litbool) => { if !litbool.value { // migrations = false MigrationsOpt::Disabled } else { // migrations = true return Err(syn::Error::new_spanned( litbool, "`migrations = true` is redundant", )); } } // migrations = "" syn::Lit::Str(litstr) => MigrationsOpt::ExplicitPath(litstr), _ => { return Err(syn::Error::new_spanned( namevalue, "expected string or `false`", )) } }; } syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue)) if namevalue.path.is_ident("migrator") => { if !matches!(migrations, MigrationsOpt::InferredPath) { return Err(syn::Error::new_spanned( namevalue, "cannot have more than one `migrations` or `migrator` arg", )); } migrations = match namevalue.lit { // migrator = "" syn::Lit::Str(litstr) => MigrationsOpt::ExplicitMigrator(litstr.parse()?), _ => { return Err(syn::Error::new_spanned( namevalue, "expected string", )) } }; } other => { return Err(syn::Error::new_spanned( other, "expected `fixtures(\"\", ...)` or `migrations = \"\" | false` or `migrator = \"\"`", )) } } } Ok(Args { fixtures, migrations, }) } #[cfg(feature = "migrate")] fn parse_fixtures_args( fixtures_type: &mut FixturesType, litstr: syn::LitStr, fixtures_local: &mut Vec, ) -> syn::Result<()> { // fixtures(path = "", scripts("","")) checking `path` argument let path_str = litstr.value(); let path = std::path::Path::new(&path_str); // This will be `true` if there's at least one path separator (`/` or `\`) // It's also true for all absolute paths, even e.g. `/foo.sql` as the root directory is counted as a component. let is_explicit_path = path.components().count() > 1; match fixtures_type { FixturesType::None => { if is_explicit_path { *fixtures_type = FixturesType::ExplicitPath; } else { *fixtures_type = FixturesType::RelativePath; } } FixturesType::RelativePath => { if is_explicit_path { return Err(syn::Error::new_spanned( litstr, "expected only relative path fixtures", )); } } FixturesType::ExplicitPath => { if !is_explicit_path { return Err(syn::Error::new_spanned( litstr, "expected only explicit path fixtures", )); } } FixturesType::CustomRelativePath(_) => { return Err(syn::Error::new_spanned( litstr, "custom relative path fixtures must be defined in `scripts` argument", )) } } if (matches!(fixtures_type, FixturesType::ExplicitPath) && !is_explicit_path) { return Err(syn::Error::new_spanned( litstr, "expected explicit path fixtures to have `.sql` extension", )); } fixtures_local.push(litstr); Ok(()) } #[cfg(feature = "migrate")] fn parse_fixtures_path_args( fixtures_type: &mut FixturesType, namevalue: syn::MetaNameValue, ) -> syn::Result<()> { // fixtures(path = "", scripts("","")) checking `path` argument if !matches!(fixtures_type, FixturesType::None) { return Err(syn::Error::new_spanned( namevalue, "`path` must be the first argument of `fixtures`", )); } *fixtures_type = match namevalue.lit { // path = "" syn::Lit::Str(litstr) => FixturesType::CustomRelativePath(litstr), _ => return Err(syn::Error::new_spanned(namevalue, "expected string")), }; Ok(()) } #[cfg(feature = "migrate")] fn parse_fixtures_scripts_args( fixtures_type: &mut FixturesType, list: syn::MetaList, fixtures_local: &mut Vec, ) -> syn::Result<()> { // fixtures(path = "", scripts("","")) checking `scripts` argument if !matches!(fixtures_type, FixturesType::CustomRelativePath(_)) { return Err(syn::Error::new_spanned( list, "`scripts` must be the second argument of `fixtures` and used together with `path`", )); } for nested in list.nested { let litstr = match nested { syn::NestedMeta::Lit(syn::Lit::Str(litstr)) => litstr, other => return Err(syn::Error::new_spanned(other, "expected string literal")), }; fixtures_local.push(litstr); } Ok(()) } #[cfg(feature = "migrate")] fn add_sql_extension_if_missing(fixture: &mut String) { let has_extension = std::path::Path::new(&fixture).extension().is_some(); if !has_extension { fixture.push_str(".sql") } }