sqlx-postgres-0.7.3/.cargo_vcs_info.json0000644000000001530000000000100136750ustar { "git": { "sha1": "c55aba0dc14f33b8a26cab6af565fcc4c8af8962" }, "path_in_vcs": "sqlx-postgres" }sqlx-postgres-0.7.3/Cargo.toml0000644000000073360000000000100117050ustar # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO # # When uploading crates to the registry Cargo will automatically # "normalize" Cargo.toml files for maximal compatibility # with all versions of Cargo and also rewrite `path` dependencies # to registry (e.g., crates.io) dependencies. # # If you are reading this file be aware that the original Cargo.toml # will likely look very different (and much more reasonable). # See Cargo.toml.orig for the original contents. [package] edition = "2021" name = "sqlx-postgres" version = "0.7.3" authors = [ "Ryan Leckey ", "Austin Bonander ", "Chloe Ross ", "Daniel Akhterov ", ] description = "PostgreSQL driver implementation for SQLx. Not for direct use; see the `sqlx` crate for details." documentation = "https://docs.rs/sqlx" license = "MIT OR Apache-2.0" repository = "https://github.com/launchbadge/sqlx" [dependencies.atoi] version = "2.0" [dependencies.base64] version = "0.21.0" features = ["std"] default-features = false [dependencies.bigdecimal] version = "0.3.0" optional = true [dependencies.bit-vec] version = "0.6.3" optional = true [dependencies.bitflags] version = "2" default-features = false [dependencies.byteorder] version = "1.4.3" features = ["std"] default-features = false [dependencies.chrono] version = "0.4.22" optional = true default-features = false [dependencies.crc] version = "3.0.0" [dependencies.dotenvy] version = "0.15.0" default-features = false [dependencies.futures-channel] version = "0.3.19" features = [ "sink", "alloc", "std", ] default-features = false [dependencies.futures-core] version = "0.3.19" default-features = false [dependencies.futures-io] version = "0.3.24" [dependencies.futures-util] version = "0.3.19" features = [ "alloc", "sink", "io", ] default-features = false [dependencies.hex] version = "0.4.3" [dependencies.hkdf] version = "0.12.0" [dependencies.hmac] version = "0.12.0" default-features = false [dependencies.home] version = "0.5.5" [dependencies.ipnetwork] version = "0.20.0" optional = true [dependencies.itoa] version = "1.0.1" [dependencies.log] version = "0.4.17" [dependencies.mac_address] version = "1.1.5" optional = true [dependencies.md-5] version = "0.10.0" default-features = false [dependencies.memchr] version = "2.4.1" default-features = false [dependencies.num-bigint] version = "0.4.3" optional = true [dependencies.once_cell] version = "1.9.0" [dependencies.rand] version = "0.8.4" features = [ "std", "std_rng", ] default-features = false [dependencies.rust_decimal] version = "1.26.1" optional = true [dependencies.serde] version = "1.0.144" features = ["derive"] [dependencies.serde_json] version = "1.0.85" features = ["raw_value"] [dependencies.sha1] version = "0.10.1" default-features = false [dependencies.sha2] version = "0.10.0" default-features = false [dependencies.smallvec] version = "1.7.0" [dependencies.sqlx-core] version = "=0.7.3" features = ["json"] [dependencies.stringprep] version = "0.1.2" [dependencies.thiserror] version = "1.0.35" [dependencies.time] version = "0.3.14" features = [ "formatting", "parsing", "macros", ] optional = true [dependencies.tracing] version = "0.1.37" features = ["log"] [dependencies.uuid] version = "1.1.2" optional = true [dependencies.whoami] version = "1.2.1" default-features = false [features] any = ["sqlx-core/any"] bigdecimal = [ "dep:bigdecimal", "dep:num-bigint", ] json = ["sqlx-core/json"] migrate = ["sqlx-core/migrate"] offline = ["sqlx-core/offline"] rust_decimal = [ "dep:rust_decimal", "rust_decimal/maths", ] [target."cfg(target_os = \"windows\")".dependencies.etcetera] version = "0.8.0" sqlx-postgres-0.7.3/Cargo.toml.orig000064400000000000000000000054250072674642500154130ustar 00000000000000[package] name = "sqlx-postgres" documentation = "https://docs.rs/sqlx" description = "PostgreSQL driver implementation for SQLx. Not for direct use; see the `sqlx` crate for details." version.workspace = true license.workspace = true edition.workspace = true authors.workspace = true repository.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] any = ["sqlx-core/any"] json = ["sqlx-core/json"] migrate = ["sqlx-core/migrate"] offline = ["sqlx-core/offline"] # Type integration features which require additional dependencies rust_decimal = ["dep:rust_decimal", "rust_decimal/maths"] bigdecimal = ["dep:bigdecimal", "dep:num-bigint"] [dependencies] # Futures crates futures-channel = { version = "0.3.19", default-features = false, features = ["sink", "alloc", "std"] } futures-core = { version = "0.3.19", default-features = false } futures-io = "0.3.24" futures-util = { version = "0.3.19", default-features = false, features = ["alloc", "sink", "io"] } # Cryptographic Primitives crc = "3.0.0" hkdf = "0.12.0" hmac = { version = "0.12.0", default-features = false } md-5 = { version = "0.10.0", default-features = false } rand = { version = "0.8.4", default-features = false, features = ["std", "std_rng"] } sha1 = { version = "0.10.1", default-features = false } sha2 = { version = "0.10.0", default-features = false } # Type Integrations (versions inherited from `[workspace.dependencies]`) bigdecimal = { workspace = true, optional = true } bit-vec = { workspace = true, optional = true } chrono = { workspace = true, optional = true } ipnetwork = { workspace = true, optional = true } mac_address = { workspace = true, optional = true } rust_decimal = { workspace = true, optional = true } time = { workspace = true, optional = true } uuid = { workspace = true, optional = true } # Misc atoi = "2.0" base64 = { version = "0.21.0", default-features = false, features = ["std"] } bitflags = { version = "2", default-features = false } byteorder = { version = "1.4.3", default-features = false, features = ["std"] } dotenvy = { workspace = true } hex = "0.4.3" home = "0.5.5" itoa = "1.0.1" log = "0.4.17" memchr = { version = "2.4.1", default-features = false } num-bigint = { version = "0.4.3", optional = true } once_cell = "1.9.0" smallvec = "1.7.0" stringprep = "0.1.2" thiserror = "1.0.35" tracing = { version = "0.1.37", features = ["log"] } whoami = { version = "1.2.1", default-features = false } serde = { version = "1.0.144", features = ["derive"] } serde_json = { version = "1.0.85", features = ["raw_value"] } [dependencies.sqlx-core] workspace = true # We use JSON in the driver implementation itself so there's no reason not to enable it here. features = ["json"] [target.'cfg(target_os = "windows")'.dependencies] etcetera = "0.8.0" sqlx-postgres-0.7.3/src/advisory_lock.rs000064400000000000000000000450170072674642500165320ustar 00000000000000use crate::error::Result; use crate::Either; use crate::PgConnection; use hkdf::Hkdf; use once_cell::sync::OnceCell; use sha2::Sha256; use std::ops::{Deref, DerefMut}; /// A mutex-like type utilizing [Postgres advisory locks]. /// /// Advisory locks are a mechanism provided by Postgres to have mutually exclusive or shared /// locks tracked in the database with application-defined semantics, as opposed to the standard /// row-level or table-level locks which may not fit all use-cases. /// /// This API provides a convenient wrapper for generating and storing the integer keys that /// advisory locks use, as well as RAII guards for releasing advisory locks when they fall out /// of scope. /// /// This API only handles session-scoped advisory locks (explicitly locked and unlocked, or /// automatically released when a connection is closed). /// /// It is also possible to use transaction-scoped locks but those can be used by beginning a /// transaction and calling the appropriate lock functions (e.g. `SELECT pg_advisory_xact_lock()`) /// manually, and cannot be explicitly released, but are automatically released when a transaction /// ends (is committed or rolled back). /// /// Session-level locks can be acquired either inside or outside a transaction and are not /// tied to transaction semantics; a lock acquired inside a transaction is still held when that /// transaction is committed or rolled back, until explicitly released or the connection is closed. /// /// Locks can be acquired in either shared or exclusive modes, which can be thought of as read locks /// and write locks, respectively. Multiple shared locks are allowed for the same key, but a single /// exclusive lock prevents any other lock being taken for a given key until it is released. /// /// [Postgres advisory locks]: https://www.postgresql.org/docs/current/explicit-locking.html#ADVISORY-LOCKS #[derive(Debug, Clone)] pub struct PgAdvisoryLock { key: PgAdvisoryLockKey, /// The query to execute to release this lock. release_query: OnceCell, } /// A key type natively used by Postgres advisory locks. /// /// Currently, Postgres advisory locks have two different key spaces: one keyed by a single /// 64-bit integer, and one keyed by a pair of two 32-bit integers. The Postgres docs /// specify that these key spaces "do not overlap": /// /// https://www.postgresql.org/docs/current/functions-admin.html#FUNCTIONS-ADVISORY-LOCKS /// /// The documentation for the `pg_locks` system view explains further how advisory locks /// are treated in Postgres: /// /// https://www.postgresql.org/docs/current/view-pg-locks.html #[derive(Debug, Clone, PartialEq, Eq)] #[non_exhaustive] pub enum PgAdvisoryLockKey { /// The keyspace designated by a single 64-bit integer. /// /// When [PgAdvisoryLock] is constructed with [::new()][PgAdvisoryLock::new()], /// this is the keyspace used. BigInt(i64), /// The keyspace designated by two 32-bit integers. IntPair(i32, i32), } /// A wrapper for `PgConnection` (or a similar type) that represents a held Postgres advisory lock. /// /// Can be acquired by [`PgAdvisoryLock::acquire()`] or [`PgAdvisoryLock::try_acquire()`]. /// Released on-drop or via [`Self::release_now()`]. /// /// ### Note: Release-on-drop is not immediate! /// On drop, this guard queues a `pg_advisory_unlock()` call on the connection which will be /// flushed to the server the next time it is used, or when it is returned to /// a [`PgPool`][crate::PgPool] in the case of /// [`PoolConnection`][crate::pool::PoolConnection]. /// /// This means the lock is not actually released as soon as the guard is dropped. To ensure the /// lock is eagerly released, you can call [`.release_now().await`][Self::release_now()]. pub struct PgAdvisoryLockGuard<'lock, C: AsMut> { lock: &'lock PgAdvisoryLock, conn: Option, } impl PgAdvisoryLock { /// Construct a `PgAdvisoryLock` using the given string as a key. /// /// This is intended to make it easier to use an advisory lock by using a human-readable string /// for a key as opposed to manually generating a unique integer key. The generated integer key /// is guaranteed to be stable and in the single 64-bit integer keyspace /// (see [`PgAdvisoryLockKey`] for details). /// /// This is done by applying the [Hash-based Key Derivation Function (HKDF; IETF RFC 5869)][hkdf] /// to the bytes of the input string, but in a way that the calculated integer is unlikely /// to collide with any similar implementations (although we don't currently know of any). /// See the source of this method for details. /// /// [hkdf]: https://datatracker.ietf.org/doc/html/rfc5869 /// ### Example /// ```rust /// # extern crate sqlx_core as sqlx; /// use sqlx::postgres::{PgAdvisoryLock, PgAdvisoryLockKey}; /// /// let lock = PgAdvisoryLock::new("my first Postgres advisory lock!"); /// // Negative values are fine because of how Postgres treats advisory lock keys. /// // See the documentation for the `pg_locks` system view for details. /// assert_eq!(lock.key(), &PgAdvisoryLockKey::BigInt(-5560419505042474287)); /// ``` pub fn new(key_string: impl AsRef) -> Self { let input_key_material = key_string.as_ref(); // HKDF was chosen because it is designed to concentrate the entropy in a variable-length // input key and produce a higher quality but reduced-length output key with a // well-specified and reproducible algorithm. // // Granted, the input key is usually meant to be pseudorandom and not human readable, // but we're not trying to produce an unguessable value by any means; just one that's as // unlikely to already be in use as possible, but still deterministic. // // SHA-256 was chosen as the hash function because it's already used in the Postgres driver, // which should save on codegen and optimization. // We don't supply a salt as that is intended to be random, but we want a deterministic key. let hkdf = Hkdf::::new(None, input_key_material.as_bytes()); let mut output_key_material = [0u8; 8]; // The first string is the "info" string of the HKDF which is intended to tie the output // exclusively to SQLx. This should avoid collisions with implementations using a similar // strategy. If you _want_ this to match some other implementation then you should get // the calculated integer key from it and use that directly. // // Do *not* change this string as it will affect the output! hkdf.expand( b"SQLx (Rust) Postgres advisory lock", &mut output_key_material, ) // `Hkdf::expand()` only returns an error if you ask for more than 255 times the digest size. // This is specified by RFC 5869 but not elaborated upon: // https://datatracker.ietf.org/doc/html/rfc5869#section-2.3 // Since we're only asking for 8 bytes, this error shouldn't be returned. .expect("BUG: `output_key_material` should be of acceptable length"); // For ease of use, this method assumes the user doesn't care which keyspace is used. // // It doesn't seem likely that someone would care about using the `(int, int)` keyspace // specifically unless they already had keys to use, in which case they wouldn't // care about this method. That's why we also provide `with_key()`. // // The choice of `from_le_bytes()` is mostly due to x86 being the most popular // architecture for server software, so it should be a no-op there. let key = PgAdvisoryLockKey::BigInt(i64::from_le_bytes(output_key_material)); tracing::trace!( ?key, key_string = ?input_key_material, "generated key from key string", ); Self::with_key(key) } /// Construct a `PgAdvisoryLock` with a manually supplied key. pub fn with_key(key: PgAdvisoryLockKey) -> Self { Self { key, release_query: OnceCell::new(), } } /// Returns the current key. pub fn key(&self) -> &PgAdvisoryLockKey { &self.key } // Why doesn't this use `Acquire`? Well, I tried it and got really useless errors // about "cannot project lifetimes to parent scope". // // It has something to do with how lifetimes work on the `Acquire` trait, I couldn't // be bothered to figure it out. Probably another issue with a lack of `async fn` in traits // or lazy normalization. /// Acquires an exclusive lock using `pg_advisory_lock()`, waiting until the lock is acquired. /// /// For a version that returns immediately instead of waiting, see [`Self::try_acquire()`]. /// /// A connection-like type is required to execute the call. Allowed types include `PgConnection`, /// `PoolConnection` and `Transaction`, as well as mutable references to /// any of these. /// /// The returned guard queues a `pg_advisory_unlock()` call on the connection when dropped, /// which will be executed the next time the connection is used, or when returned to a /// [`PgPool`][crate::PgPool] in the case of `PoolConnection`. /// /// Postgres allows a single connection to acquire a given lock more than once without releasing /// it first, so in that sense the lock is re-entrant. However, the number of unlock operations /// must match the number of lock operations for the lock to actually be released. /// /// See [Postgres' documentation for the Advisory Lock Functions][advisory-funcs] for details. /// /// [advisory-funcs]: https://www.postgresql.org/docs/current/functions-admin.html#FUNCTIONS-ADVISORY-LOCKS pub async fn acquire>( &self, mut conn: C, ) -> Result> { match &self.key { PgAdvisoryLockKey::BigInt(key) => { crate::query::query("SELECT pg_advisory_lock($1)") .bind(key) .execute(conn.as_mut()) .await?; } PgAdvisoryLockKey::IntPair(key1, key2) => { crate::query::query("SELECT pg_advisory_lock($1, $2)") .bind(key1) .bind(key2) .execute(conn.as_mut()) .await?; } } Ok(PgAdvisoryLockGuard::new(self, conn)) } /// Acquires an exclusive lock using `pg_try_advisory_lock()`, returning immediately /// if the lock could not be acquired. /// /// For a version that waits until the lock is acquired, see [`Self::acquire()`]. /// /// A connection-like type is required to execute the call. Allowed types include `PgConnection`, /// `PoolConnection` and `Transaction`, as well as mutable references to /// any of these. The connection is returned if the lock could not be acquired. /// /// The returned guard queues a `pg_advisory_unlock()` call on the connection when dropped, /// which will be executed the next time the connection is used, or when returned to a /// [`PgPool`][crate::PgPool] in the case of `PoolConnection`. /// /// Postgres allows a single connection to acquire a given lock more than once without releasing /// it first, so in that sense the lock is re-entrant. However, the number of unlock operations /// must match the number of lock operations for the lock to actually be released. /// /// See [Postgres' documentation for the Advisory Lock Functions][advisory-funcs] for details. /// /// [advisory-funcs]: https://www.postgresql.org/docs/current/functions-admin.html#FUNCTIONS-ADVISORY-LOCKS pub async fn try_acquire>( &self, mut conn: C, ) -> Result, C>> { let locked: bool = match &self.key { PgAdvisoryLockKey::BigInt(key) => { crate::query_scalar::query_scalar("SELECT pg_try_advisory_lock($1)") .bind(key) .fetch_one(conn.as_mut()) .await? } PgAdvisoryLockKey::IntPair(key1, key2) => { crate::query_scalar::query_scalar("SELECT pg_try_advisory_lock($1, $2)") .bind(key1) .bind(key2) .fetch_one(conn.as_mut()) .await? } }; if locked { Ok(Either::Left(PgAdvisoryLockGuard::new(self, conn))) } else { Ok(Either::Right(conn)) } } /// Execute `pg_advisory_unlock()` for this lock's key on the given connection. /// /// This is used by [`PgAdvisoryLockGuard::release_now()`] and is also provided for manually /// releasing the lock from connections returned by [`PgAdvisoryLockGuard::leak()`]. /// /// An error should only be returned if there is something wrong with the connection, /// in which case the lock will be automatically released by the connection closing anyway. /// /// The `boolean` value is that returned by `pg_advisory_lock()`. If it is `false`, it /// indicates that the lock was not actually held by the given connection and that a warning /// has been logged by the Postgres server. pub async fn force_release>(&self, mut conn: C) -> Result<(C, bool)> { let released: bool = match &self.key { PgAdvisoryLockKey::BigInt(key) => { crate::query_scalar::query_scalar("SELECT pg_advisory_unlock($1)") .bind(key) .fetch_one(conn.as_mut()) .await? } PgAdvisoryLockKey::IntPair(key1, key2) => { crate::query_scalar::query_scalar("SELECT pg_advisory_unlock($1, $2)") .bind(key1) .bind(key2) .fetch_one(conn.as_mut()) .await? } }; Ok((conn, released)) } fn get_release_query(&self) -> &str { self.release_query.get_or_init(|| match &self.key { PgAdvisoryLockKey::BigInt(key) => format!("SELECT pg_advisory_unlock({key})"), PgAdvisoryLockKey::IntPair(key1, key2) => { format!("SELECT pg_advisory_unlock({key1}, {key2})") } }) } } impl PgAdvisoryLockKey { /// Converts `Self::Bigint(bigint)` to `Some(bigint)` and all else to `None`. pub fn as_bigint(&self) -> Option { if let Self::BigInt(bigint) = self { Some(*bigint) } else { None } } } const NONE_ERR: &str = "BUG: PgAdvisoryLockGuard.conn taken"; impl<'lock, C: AsMut> PgAdvisoryLockGuard<'lock, C> { fn new(lock: &'lock PgAdvisoryLock, conn: C) -> Self { PgAdvisoryLockGuard { lock, conn: Some(conn), } } /// Immediately release the held advisory lock instead of when the connection is next used. /// /// An error should only be returned if there is something wrong with the connection, /// in which case the lock will be automatically released by the connection closing anyway. /// /// If `pg_advisory_unlock()` returns `false`, a warning will be logged, both by SQLx as /// well as the Postgres server. This would only happen if the lock was released without /// using this guard, or the connection was swapped using [`std::mem::replace()`]. pub async fn release_now(mut self) -> Result { let (conn, released) = self .lock .force_release(self.conn.take().expect(NONE_ERR)) .await?; if !released { tracing::warn!( lock = ?self.lock.key, "PgAdvisoryLockGuard: advisory lock was not held by the contained connection", ); } Ok(conn) } /// Cancel the release of the advisory lock, keeping it held until the connection is closed. /// /// To manually release the lock later, see [`PgAdvisoryLock::force_release()`]. pub fn leak(mut self) -> C { self.conn.take().expect(NONE_ERR) } } impl<'lock, C: AsMut + AsRef> Deref for PgAdvisoryLockGuard<'lock, C> { type Target = PgConnection; fn deref(&self) -> &Self::Target { self.conn.as_ref().expect(NONE_ERR).as_ref() } } /// Mutable access to the underlying connection is provided so it can still be used like normal, /// even allowing locks to be taken recursively. /// /// However, replacing the connection with a different one using, e.g. [`std::mem::replace()`] /// is a logic error and will cause a warning to be logged by the PostgreSQL server when this /// guard attempts to release the lock. impl<'lock, C: AsMut + AsRef> DerefMut for PgAdvisoryLockGuard<'lock, C> { fn deref_mut(&mut self) -> &mut Self::Target { self.conn.as_mut().expect(NONE_ERR).as_mut() } } impl<'lock, C: AsMut + AsRef> AsRef for PgAdvisoryLockGuard<'lock, C> { fn as_ref(&self) -> &PgConnection { self.conn.as_ref().expect(NONE_ERR).as_ref() } } /// Mutable access to the underlying connection is provided so it can still be used like normal, /// even allowing locks to be taken recursively. /// /// However, replacing the connection with a different one using, e.g. [`std::mem::replace()`] /// is a logic error and will cause a warning to be logged by the PostgreSQL server when this /// guard attempts to release the lock. impl<'lock, C: AsMut> AsMut for PgAdvisoryLockGuard<'lock, C> { fn as_mut(&mut self) -> &mut PgConnection { self.conn.as_mut().expect(NONE_ERR).as_mut() } } /// Queues a `pg_advisory_unlock()` call on the wrapped connection which will be flushed /// to the server the next time it is used, or when it is returned to [`PgPool`][crate::PgPool] /// in the case of [`PoolConnection`][crate::pool::PoolConnection]. impl<'lock, C: AsMut> Drop for PgAdvisoryLockGuard<'lock, C> { fn drop(&mut self) { if let Some(mut conn) = self.conn.take() { // Queue a simple query message to execute next time the connection is used. // The `async fn` versions can safely use the prepared statement protocol, // but this is the safest way to queue a query to execute on the next opportunity. conn.as_mut() .queue_simple_query(self.lock.get_release_query()); } } } sqlx-postgres-0.7.3/src/any.rs000064400000000000000000000166130072674642500144510ustar 00000000000000use crate::{ Either, PgColumn, PgConnectOptions, PgConnection, PgQueryResult, PgRow, PgTransactionManager, PgTypeInfo, Postgres, }; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::{StreamExt, TryFutureExt, TryStreamExt}; pub use sqlx_core::any::*; use crate::type_info::PgType; use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; use sqlx_core::ext::ustr::UStr; use sqlx_core::transaction::TransactionManager; sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Postgres); impl AnyConnectionBackend for PgConnection { fn name(&self) -> &str { ::NAME } fn close(self: Box) -> BoxFuture<'static, sqlx_core::Result<()>> { Connection::close(*self) } fn close_hard(self: Box) -> BoxFuture<'static, sqlx_core::Result<()>> { Connection::close_hard(*self) } fn ping(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { Connection::ping(self) } fn begin(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { PgTransactionManager::begin(self) } fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { PgTransactionManager::commit(self) } fn rollback(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { PgTransactionManager::rollback(self) } fn start_rollback(&mut self) { PgTransactionManager::start_rollback(self) } fn shrink_buffers(&mut self) { Connection::shrink_buffers(self); } fn flush(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { Connection::flush(self) } fn should_flush(&self) -> bool { Connection::should_flush(self) } #[cfg(feature = "migrate")] fn as_migrate( &mut self, ) -> sqlx_core::Result<&mut (dyn sqlx_core::migrate::Migrate + Send + 'static)> { Ok(self) } fn fetch_many<'q>( &'q mut self, query: &'q str, arguments: Option>, ) -> BoxStream<'q, sqlx_core::Result>> { let persistent = arguments.is_some(); let args = arguments.as_ref().map(AnyArguments::convert_to); Box::pin( self.run(query, args, 0, persistent, None) .try_flatten_stream() .map( move |res: sqlx_core::Result>| match res? { Either::Left(result) => Ok(Either::Left(map_result(result))), Either::Right(row) => Ok(Either::Right(AnyRow::try_from(&row)?)), }, ), ) } fn fetch_optional<'q>( &'q mut self, query: &'q str, arguments: Option>, ) -> BoxFuture<'q, sqlx_core::Result>> { let persistent = arguments.is_some(); let args = arguments.as_ref().map(AnyArguments::convert_to); Box::pin(async move { let stream = self.run(query, args, 1, persistent, None).await?; futures_util::pin_mut!(stream); if let Some(Either::Right(row)) = stream.try_next().await? { return Ok(Some(AnyRow::try_from(&row)?)); } Ok(None) }) } fn prepare_with<'c, 'q: 'c>( &'c mut self, sql: &'q str, _parameters: &[AnyTypeInfo], ) -> BoxFuture<'c, sqlx_core::Result>> { Box::pin(async move { let statement = Executor::prepare_with(self, sql, &[]).await?; AnyStatement::try_from_statement( sql, &statement, statement.metadata.column_names.clone(), ) }) } fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, sqlx_core::Result>> { Box::pin(async move { let describe = Executor::describe(self, sql).await?; let columns = describe .columns .iter() .map(AnyColumn::try_from) .collect::, _>>()?; let parameters = match describe.parameters { Some(Either::Left(parameters)) => Some(Either::Left( parameters .iter() .enumerate() .map(|(i, type_info)| { AnyTypeInfo::try_from(type_info).map_err(|_| { sqlx_core::Error::AnyDriverError( format!( "Any driver does not support type {type_info} of parameter {i}" ) .into(), ) }) }) .collect::, _>>()?, )), Some(Either::Right(count)) => Some(Either::Right(count)), None => None, }; Ok(Describe { columns, parameters, nullable: describe.nullable, }) }) } } impl<'a> TryFrom<&'a PgTypeInfo> for AnyTypeInfo { type Error = sqlx_core::Error; fn try_from(pg_type: &'a PgTypeInfo) -> Result { Ok(AnyTypeInfo { kind: match &pg_type.0 { PgType::Void => AnyTypeInfoKind::Null, PgType::Int2 => AnyTypeInfoKind::SmallInt, PgType::Int4 => AnyTypeInfoKind::Integer, PgType::Int8 => AnyTypeInfoKind::BigInt, PgType::Float4 => AnyTypeInfoKind::Real, PgType::Float8 => AnyTypeInfoKind::Double, PgType::Bytea => AnyTypeInfoKind::Blob, PgType::Text => AnyTypeInfoKind::Text, PgType::DeclareWithName(UStr::Static("citext")) => AnyTypeInfoKind::Text, _ => { return Err(sqlx_core::Error::AnyDriverError( format!("Any driver does not support the Postgres type {pg_type:?}").into(), )) } }, }) } } impl<'a> TryFrom<&'a PgColumn> for AnyColumn { type Error = sqlx_core::Error; fn try_from(col: &'a PgColumn) -> Result { let type_info = AnyTypeInfo::try_from(&col.type_info).map_err(|e| sqlx_core::Error::ColumnDecode { index: col.name.to_string(), source: e.into(), })?; Ok(AnyColumn { ordinal: col.ordinal, name: col.name.clone(), type_info, }) } } impl<'a> TryFrom<&'a PgRow> for AnyRow { type Error = sqlx_core::Error; fn try_from(row: &'a PgRow) -> Result { AnyRow::map_from(row, row.metadata.column_names.clone()) } } impl<'a> TryFrom<&'a AnyConnectOptions> for PgConnectOptions { type Error = sqlx_core::Error; fn try_from(value: &'a AnyConnectOptions) -> Result { let mut opts = PgConnectOptions::parse_from_url(&value.database_url)?; opts.log_settings = value.log_settings.clone(); Ok(opts) } } fn map_result(res: PgQueryResult) -> AnyQueryResult { AnyQueryResult { rows_affected: res.rows_affected(), last_insert_id: None, } } sqlx-postgres-0.7.3/src/arguments.rs000064400000000000000000000132130072674642500156600ustar 00000000000000use std::fmt::{self, Write}; use std::ops::{Deref, DerefMut}; use crate::encode::{Encode, IsNull}; use crate::error::Error; use crate::ext::ustr::UStr; use crate::types::Type; use crate::{PgConnection, PgTypeInfo, Postgres}; pub(crate) use sqlx_core::arguments::Arguments; // TODO: buf.patch(|| ...) is a poor name, can we think of a better name? Maybe `buf.lazy(||)` ? // TODO: Extend the patch system to support dynamic lengths // Considerations: // - The prefixed-len offset needs to be back-tracked and updated // - message::Bind needs to take a &PgArguments and use a `write` method instead of // referencing a buffer directly // - The basic idea is that we write bytes for the buffer until we get somewhere // that has a patch, we then apply the patch which should write to &mut Vec, // backtrack and update the prefixed-len, then write until the next patch offset #[derive(Default)] pub struct PgArgumentBuffer { buffer: Vec, // Number of arguments count: usize, // Whenever an `Encode` impl needs to defer some work until after we resolve parameter types // it can use `patch`. // // This currently is only setup to be useful if there is a *fixed-size* slot that needs to be // tweaked from the input type. However, that's the only use case we currently have. // patches: Vec<( usize, // offset usize, // argument index Box, )>, // Whenever an `Encode` impl encounters a `PgTypeInfo` object that does not have an OID // It pushes a "hole" that must be patched later. // // The hole is a `usize` offset into the buffer with the type name that should be resolved // This is done for Records and Arrays as the OID is needed well before we are in an async // function and can just ask postgres. // type_holes: Vec<(usize, UStr)>, // Vec<{ offset, type_name }> } /// Implementation of [`Arguments`] for PostgreSQL. #[derive(Default)] pub struct PgArguments { // Types of each bind parameter pub(crate) types: Vec, // Buffer of encoded bind parameters pub(crate) buffer: PgArgumentBuffer, } impl PgArguments { pub(crate) fn add<'q, T>(&mut self, value: T) where T: Encode<'q, Postgres> + Type, { // remember the type information for this value self.types .push(value.produces().unwrap_or_else(T::type_info)); // encode the value into our buffer self.buffer.encode(value); // increment the number of arguments we are tracking self.buffer.count += 1; } // Apply patches // This should only go out and ask postgres if we have not seen the type name yet pub(crate) async fn apply_patches( &mut self, conn: &mut PgConnection, parameters: &[PgTypeInfo], ) -> Result<(), Error> { let PgArgumentBuffer { ref patches, ref type_holes, ref mut buffer, .. } = self.buffer; for (offset, ty, callback) in patches { let buf = &mut buffer[*offset..]; let ty = ¶meters[*ty]; callback(buf, ty); } for (offset, name) in type_holes { let oid = conn.fetch_type_id_by_name(&*name).await?; buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes()); } Ok(()) } } impl<'q> Arguments<'q> for PgArguments { type Database = Postgres; fn reserve(&mut self, additional: usize, size: usize) { self.types.reserve(additional); self.buffer.reserve(size); } fn add(&mut self, value: T) where T: Encode<'q, Self::Database> + Type, { self.add(value) } fn format_placeholder(&self, writer: &mut W) -> fmt::Result { write!(writer, "${}", self.buffer.count) } } impl PgArgumentBuffer { pub(crate) fn encode<'q, T>(&mut self, value: T) where T: Encode<'q, Postgres>, { // reserve space to write the prefixed length of the value let offset = self.len(); self.extend(&[0; 4]); // encode the value into our buffer let len = if let IsNull::No = value.encode(self) { (self.len() - offset - 4) as i32 } else { // Write a -1 to indicate NULL // NOTE: It is illegal for [encode] to write any data debug_assert_eq!(self.len(), offset + 4); -1_i32 }; // write the len to the beginning of the value self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes()); } // Adds a callback to be invoked later when we know the parameter type #[allow(dead_code)] pub(crate) fn patch(&mut self, callback: F) where F: Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync, { let offset = self.len(); let index = self.count; self.patches.push((offset, index, Box::new(callback))); } // Extends the inner buffer by enough space to have an OID // Remembers where the OID goes and type name for the OID pub(crate) fn patch_type_by_name(&mut self, type_name: &UStr) { let offset = self.len(); self.extend_from_slice(&0_u32.to_be_bytes()); self.type_holes.push((offset, type_name.clone())); } } impl Deref for PgArgumentBuffer { type Target = Vec; #[inline] fn deref(&self) -> &Self::Target { &self.buffer } } impl DerefMut for PgArgumentBuffer { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { &mut self.buffer } } sqlx-postgres-0.7.3/src/column.rs000064400000000000000000000014220072674642500151470ustar 00000000000000use crate::ext::ustr::UStr; use crate::{PgTypeInfo, Postgres}; pub(crate) use sqlx_core::column::{Column, ColumnIndex}; #[derive(Debug, Clone)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct PgColumn { pub(crate) ordinal: usize, pub(crate) name: UStr, pub(crate) type_info: PgTypeInfo, #[cfg_attr(feature = "offline", serde(skip))] pub(crate) relation_id: Option, #[cfg_attr(feature = "offline", serde(skip))] pub(crate) relation_attribute_no: Option, } impl Column for PgColumn { type Database = Postgres; fn ordinal(&self) -> usize { self.ordinal } fn name(&self) -> &str { &*self.name } fn type_info(&self) -> &PgTypeInfo { &self.type_info } } sqlx-postgres-0.7.3/src/connection/describe.rs000064400000000000000000000437460072674642500176100ustar 00000000000000use crate::error::Error; use crate::ext::ustr::UStr; use crate::message::{ParameterDescription, RowDescription}; use crate::query_as::query_as; use crate::query_scalar::{query_scalar, query_scalar_with}; use crate::statement::PgStatementMetadata; use crate::type_info::{PgCustomType, PgType, PgTypeKind}; use crate::types::Json; use crate::types::Oid; use crate::HashMap; use crate::{PgArguments, PgColumn, PgConnection, PgTypeInfo}; use futures_core::future::BoxFuture; use std::fmt::Write; use std::sync::Arc; /// Describes the type of the `pg_type.typtype` column /// /// See #[derive(Copy, Clone, Debug, Eq, PartialEq)] enum TypType { Base, Composite, Domain, Enum, Pseudo, Range, } impl TryFrom for TypType { type Error = (); fn try_from(t: u8) -> Result { let t = match t { b'b' => Self::Base, b'c' => Self::Composite, b'd' => Self::Domain, b'e' => Self::Enum, b'p' => Self::Pseudo, b'r' => Self::Range, _ => return Err(()), }; Ok(t) } } /// Describes the type of the `pg_type.typcategory` column /// /// See #[derive(Copy, Clone, Debug, Eq, PartialEq)] enum TypCategory { Array, Boolean, Composite, DateTime, Enum, Geometric, Network, Numeric, Pseudo, Range, String, Timespan, User, BitString, Unknown, } impl TryFrom for TypCategory { type Error = (); fn try_from(c: u8) -> Result { let c = match c { b'A' => Self::Array, b'B' => Self::Boolean, b'C' => Self::Composite, b'D' => Self::DateTime, b'E' => Self::Enum, b'G' => Self::Geometric, b'I' => Self::Network, b'N' => Self::Numeric, b'P' => Self::Pseudo, b'R' => Self::Range, b'S' => Self::String, b'T' => Self::Timespan, b'U' => Self::User, b'V' => Self::BitString, b'X' => Self::Unknown, _ => return Err(()), }; Ok(c) } } impl PgConnection { pub(super) async fn handle_row_description( &mut self, desc: Option, should_fetch: bool, ) -> Result<(Vec, HashMap), Error> { let mut columns = Vec::new(); let mut column_names = HashMap::new(); let desc = if let Some(desc) = desc { desc } else { // no rows return Ok((columns, column_names)); }; columns.reserve(desc.fields.len()); column_names.reserve(desc.fields.len()); for (index, field) in desc.fields.into_iter().enumerate() { let name = UStr::from(field.name); let type_info = self .maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch) .await?; let column = PgColumn { ordinal: index, name: name.clone(), type_info, relation_id: field.relation_id, relation_attribute_no: field.relation_attribute_no, }; columns.push(column); column_names.insert(name, index); } Ok((columns, column_names)) } pub(super) async fn handle_parameter_description( &mut self, desc: ParameterDescription, ) -> Result, Error> { let mut params = Vec::with_capacity(desc.types.len()); for ty in desc.types { params.push(self.maybe_fetch_type_info_by_oid(ty, true).await?); } Ok(params) } async fn maybe_fetch_type_info_by_oid( &mut self, oid: Oid, should_fetch: bool, ) -> Result { // first we check if this is a built-in type // in the average application, the vast majority of checks should flow through this if let Some(info) = PgTypeInfo::try_from_oid(oid) { return Ok(info); } // next we check a local cache for user-defined type names <-> object id if let Some(info) = self.cache_type_info.get(&oid) { return Ok(info.clone()); } // fallback to asking the database directly for a type name if should_fetch { let info = self.fetch_type_by_oid(oid).await?; // cache the type name <-> oid relationship in a paired hashmap // so we don't come down this road again self.cache_type_info.insert(oid, info.clone()); self.cache_type_oid .insert(info.0.name().to_string().into(), oid); Ok(info) } else { // we are not in a place that *can* run a query // this generally means we are in the middle of another query // this _should_ only happen for complex types sent through the TEXT protocol // we're open to ideas to correct this.. but it'd probably be more efficient to figure // out a way to "prime" the type cache for connections rather than make this // fallback work correctly for complex user-defined types for the TEXT protocol Ok(PgTypeInfo(PgType::DeclareWithOid(oid))) } } fn fetch_type_by_oid(&mut self, oid: Oid) -> BoxFuture<'_, Result> { Box::pin(async move { let (name, typ_type, category, relation_id, element, base_type): (String, i8, i8, Oid, Oid, Oid) = query_as( "SELECT typname, typtype, typcategory, typrelid, typelem, typbasetype FROM pg_catalog.pg_type WHERE oid = $1", ) .bind(oid) .fetch_one(&mut *self) .await?; let typ_type = TypType::try_from(typ_type as u8); let category = TypCategory::try_from(category as u8); match (typ_type, category) { (Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await, (Ok(TypType::Base), Ok(TypCategory::Array)) => { Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { kind: PgTypeKind::Array( self.maybe_fetch_type_info_by_oid(element, true).await?, ), name: name.into(), oid, })))) } (Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => { Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { kind: PgTypeKind::Pseudo, name: name.into(), oid, })))) } (Ok(TypType::Range), Ok(TypCategory::Range)) => { self.fetch_range_by_oid(oid, name).await } (Ok(TypType::Enum), Ok(TypCategory::Enum)) => { self.fetch_enum_by_oid(oid, name).await } (Ok(TypType::Composite), Ok(TypCategory::Composite)) => { self.fetch_composite_by_oid(oid, relation_id, name).await } _ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { kind: PgTypeKind::Simple, name: name.into(), oid, })))), } }) } async fn fetch_enum_by_oid(&mut self, oid: Oid, name: String) -> Result { let variants: Vec = query_scalar( r#" SELECT enumlabel FROM pg_catalog.pg_enum WHERE enumtypid = $1 ORDER BY enumsortorder "#, ) .bind(oid) .fetch_all(self) .await?; Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { oid, name: name.into(), kind: PgTypeKind::Enum(Arc::from(variants)), })))) } fn fetch_composite_by_oid( &mut self, oid: Oid, relation_id: Oid, name: String, ) -> BoxFuture<'_, Result> { Box::pin(async move { let raw_fields: Vec<(String, Oid)> = query_as( r#" SELECT attname, atttypid FROM pg_catalog.pg_attribute WHERE attrelid = $1 AND NOT attisdropped AND attnum > 0 ORDER BY attnum "#, ) .bind(relation_id) .fetch_all(&mut *self) .await?; let mut fields = Vec::new(); for (field_name, field_oid) in raw_fields.into_iter() { let field_type = self.maybe_fetch_type_info_by_oid(field_oid, true).await?; fields.push((field_name, field_type)); } Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { oid, name: name.into(), kind: PgTypeKind::Composite(Arc::from(fields)), })))) }) } fn fetch_domain_by_oid( &mut self, oid: Oid, base_type: Oid, name: String, ) -> BoxFuture<'_, Result> { Box::pin(async move { let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?; Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { oid, name: name.into(), kind: PgTypeKind::Domain(base_type), })))) }) } fn fetch_range_by_oid( &mut self, oid: Oid, name: String, ) -> BoxFuture<'_, Result> { Box::pin(async move { let element_oid: Oid = query_scalar( r#" SELECT rngsubtype FROM pg_catalog.pg_range WHERE rngtypid = $1 "#, ) .bind(oid) .fetch_one(&mut *self) .await?; let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?; Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType { kind: PgTypeKind::Range(element), name: name.into(), oid, })))) }) } pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result { if let Some(oid) = self.cache_type_oid.get(name) { return Ok(*oid); } // language=SQL let (oid,): (Oid,) = query_as("SELECT $1::regtype::oid") .bind(name) .fetch_optional(&mut *self) .await? .ok_or_else(|| Error::TypeNotFound { type_name: String::from(name), })?; self.cache_type_oid.insert(name.to_string().into(), oid); Ok(oid) } pub(crate) async fn get_nullable_for_columns( &mut self, stmt_id: Oid, meta: &PgStatementMetadata, ) -> Result>, Error> { if meta.columns.is_empty() { return Ok(vec![]); } let mut nullable_query = String::from("SELECT NOT pg_attribute.attnotnull FROM (VALUES "); let mut args = PgArguments::default(); for (i, (column, bind)) in meta.columns.iter().zip((1..).step_by(3)).enumerate() { if !args.buffer.is_empty() { nullable_query += ", "; } let _ = write!( nullable_query, "(${}::int4, ${}::int4, ${}::int2)", bind, bind + 1, bind + 2 ); args.add(i as i32); args.add(column.relation_id); args.add(column.relation_attribute_no); } nullable_query.push_str( ") as col(idx, table_id, col_idx) \ LEFT JOIN pg_catalog.pg_attribute \ ON table_id IS NOT NULL \ AND attrelid = table_id \ AND attnum = col_idx \ ORDER BY col.idx", ); let mut nullables = query_scalar_with::<_, Option, _>(&nullable_query, args) .fetch_all(&mut *self) .await?; // If the server is CockroachDB or Materialize, skip this step (#1248). if !self.stream.parameter_statuses.contains_key("crdb_version") && !self.stream.parameter_statuses.contains_key("mz_version") { // patch up our null inference with data from EXPLAIN let nullable_patch = self .nullables_from_explain(stmt_id, meta.parameters.len()) .await?; for (nullable, patch) in nullables.iter_mut().zip(nullable_patch) { *nullable = patch.or(*nullable); } } Ok(nullables) } /// Infer nullability for columns of this statement using EXPLAIN VERBOSE. /// /// This currently only marks columns that are on the inner half of an outer join /// and returns `None` for all others. async fn nullables_from_explain( &mut self, stmt_id: Oid, params_len: usize, ) -> Result>, Error> { let mut explain = format!( "EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE sqlx_s_{}", stmt_id.0 ); let mut comma = false; if params_len > 0 { explain += "("; // fill the arguments list with NULL, which should theoretically be valid for _ in 0..params_len { if comma { explain += ", "; } explain += "NULL"; comma = true; } explain += ")"; } let (Json([explain]),): (Json<[Explain; 1]>,) = query_as(&explain).fetch_one(self).await?; let mut nullables = Vec::new(); if let Explain::Plan { plan: plan @ Plan { output: Some(ref outputs), .. }, } = &explain { nullables.resize(outputs.len(), None); visit_plan(plan, outputs, &mut nullables); } Ok(nullables) } } fn visit_plan(plan: &Plan, outputs: &[String], nullables: &mut Vec>) { if let Some(plan_outputs) = &plan.output { // all outputs of a Full Join must be marked nullable // otherwise, all outputs of the inner half of an outer join must be marked nullable if plan.join_type.as_deref() == Some("Full") || plan.parent_relation.as_deref() == Some("Inner") { for output in plan_outputs { if let Some(i) = outputs.iter().position(|o| o == output) { // N.B. this may produce false positives but those don't cause runtime errors nullables[i] = Some(true); } } } } if let Some(plans) = &plan.plans { if let Some("Left") | Some("Right") = plan.join_type.as_deref() { for plan in plans { visit_plan(plan, outputs, nullables); } } } } #[derive(serde::Deserialize, Debug)] #[serde(untagged)] enum Explain { // NOTE: the returned JSON may not contain a `plan` field, for example, with `CALL` statements: // https://github.com/launchbadge/sqlx/issues/1449 // // In this case, we should just fall back to assuming all is nullable. // // It may also contain additional fields we don't care about, which should not break parsing: // https://github.com/launchbadge/sqlx/issues/2587 // https://github.com/launchbadge/sqlx/issues/2622 Plan { #[serde(rename = "Plan")] plan: Plan, }, // This ensures that parsing never technically fails. // // We don't want to specifically expect `"Utility Statement"` because there might be other cases // and we don't care unless it contains a query plan anyway. Other(serde::de::IgnoredAny), } #[derive(serde::Deserialize, Debug)] struct Plan { #[serde(rename = "Join Type")] join_type: Option, #[serde(rename = "Parent Relationship")] parent_relation: Option, #[serde(rename = "Output")] output: Option>, #[serde(rename = "Plans")] plans: Option>, } #[test] fn explain_parsing() { let normal_plan = r#"[ { "Plan": { "Node Type": "Result", "Parallel Aware": false, "Async Capable": false, "Startup Cost": 0.00, "Total Cost": 0.01, "Plan Rows": 1, "Plan Width": 4, "Output": ["1"] } } ]"#; // https://github.com/launchbadge/sqlx/issues/2622 let extra_field = r#"[ { "Plan": { "Node Type": "Result", "Parallel Aware": false, "Async Capable": false, "Startup Cost": 0.00, "Total Cost": 0.01, "Plan Rows": 1, "Plan Width": 4, "Output": ["1"] }, "Query Identifier": 1147616880456321454 } ]"#; // https://github.com/launchbadge/sqlx/issues/1449 let utility_statement = r#"["Utility Statement"]"#; let normal_plan_parsed = serde_json::from_str::<[Explain; 1]>(normal_plan).unwrap(); let extra_field_parsed = serde_json::from_str::<[Explain; 1]>(extra_field).unwrap(); let utility_statement_parsed = serde_json::from_str::<[Explain; 1]>(utility_statement).unwrap(); assert!( matches!(normal_plan_parsed, [Explain::Plan { plan: Plan { .. } }]), "unexpected parse from {normal_plan:?}: {normal_plan_parsed:?}" ); assert!( matches!(extra_field_parsed, [Explain::Plan { plan: Plan { .. } }]), "unexpected parse from {extra_field:?}: {extra_field_parsed:?}" ); assert!( matches!(utility_statement_parsed, [Explain::Other(_)]), "unexpected parse from {utility_statement:?}: {utility_statement_parsed:?}" ) } sqlx-postgres-0.7.3/src/connection/establish.rs000064400000000000000000000133220072674642500177710ustar 00000000000000use crate::HashMap; use crate::common::StatementCache; use crate::connection::{sasl, stream::PgStream}; use crate::error::Error; use crate::io::Decode; use crate::message::{ Authentication, BackendKeyData, MessageFormat, Password, ReadyForQuery, Startup, }; use crate::types::Oid; use crate::{PgConnectOptions, PgConnection}; // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3 // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.11 impl PgConnection { pub(crate) async fn establish(options: &PgConnectOptions) -> Result { // Upgrade to TLS if we were asked to and the server supports it let mut stream = PgStream::connect(options).await?; // To begin a session, a frontend opens a connection to the server // and sends a startup message. let mut params = vec![ // Sets the display format for date and time values, // as well as the rules for interpreting ambiguous date input values. ("DateStyle", "ISO, MDY"), // Sets the client-side encoding (character set). // ("client_encoding", "UTF8"), // Sets the time zone for displaying and interpreting time stamps. ("TimeZone", "UTC"), ]; if let Some(ref extra_float_digits) = options.extra_float_digits { params.push(("extra_float_digits", extra_float_digits)); } if let Some(ref application_name) = options.application_name { params.push(("application_name", application_name)); } if let Some(ref options) = options.options { params.push(("options", options)); } stream .send(Startup { username: Some(&options.username), database: options.database.as_deref(), params: ¶ms, }) .await?; // The server then uses this information and the contents of // its configuration files (such as pg_hba.conf) to determine whether the connection is // provisionally acceptable, and what additional // authentication is required (if any). let mut process_id = 0; let mut secret_key = 0; let transaction_status; loop { let message = stream.recv().await?; match message.format { MessageFormat::Authentication => match message.decode()? { Authentication::Ok => { // the authentication exchange is successfully completed // do nothing; no more information is required to continue } Authentication::CleartextPassword => { // The frontend must now send a [PasswordMessage] containing the // password in clear-text form. stream .send(Password::Cleartext( options.password.as_deref().unwrap_or_default(), )) .await?; } Authentication::Md5Password(body) => { // The frontend must now send a [PasswordMessage] containing the // password (with user name) encrypted via MD5, then encrypted again // using the 4-byte random salt specified in the // [AuthenticationMD5Password] message. stream .send(Password::Md5 { username: &options.username, password: options.password.as_deref().unwrap_or_default(), salt: body.salt, }) .await?; } Authentication::Sasl(body) => { sasl::authenticate(&mut stream, options, body).await?; } method => { return Err(err_protocol!( "unsupported authentication method: {:?}", method )); } }, MessageFormat::BackendKeyData => { // provides secret-key data that the frontend must save if it wants to be // able to issue cancel requests later let data: BackendKeyData = message.decode()?; process_id = data.process_id; secret_key = data.secret_key; } MessageFormat::ReadyForQuery => { // start-up is completed. The frontend can now issue commands transaction_status = ReadyForQuery::decode(message.contents)?.transaction_status; break; } _ => { return Err(err_protocol!( "establish: unexpected message: {:?}", message.format )) } } } Ok(PgConnection { stream, process_id, secret_key, transaction_status, transaction_depth: 0, pending_ready_for_query_count: 0, next_statement_id: Oid(1), cache_statement: StatementCache::new(options.statement_cache_capacity), cache_type_oid: HashMap::new(), cache_type_info: HashMap::new(), log_settings: options.log_settings.clone(), }) } } sqlx-postgres-0.7.3/src/connection/executor.rs000064400000000000000000000370030072674642500176530ustar 00000000000000use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::logger::QueryLogger; use crate::message::{ self, Bind, Close, CommandComplete, DataRow, MessageFormat, ParameterDescription, Parse, Query, RowDescription, }; use crate::statement::PgStatementMetadata; use crate::type_info::PgType; use crate::types::Oid; use crate::{ statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo, PgValueFormat, Postgres, }; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::{pin_mut, TryStreamExt}; use sqlx_core::Either; use std::{borrow::Cow, sync::Arc}; async fn prepare( conn: &mut PgConnection, sql: &str, parameters: &[PgTypeInfo], metadata: Option>, ) -> Result<(Oid, Arc), Error> { let id = conn.next_statement_id; conn.next_statement_id.incr_one(); // build a list of type OIDs to send to the database in the PARSE command // we have not yet started the query sequence, so we are *safe* to cleanly make // additional queries here to get any missing OIDs let mut param_types = Vec::with_capacity(parameters.len()); for ty in parameters { param_types.push(if let PgType::DeclareWithName(name) = &ty.0 { conn.fetch_type_id_by_name(name).await? } else { ty.0.oid() }); } // flush and wait until we are re-ready conn.wait_until_ready().await?; // next we send the PARSE command to the server conn.stream.write(Parse { param_types: &*param_types, query: sql, statement: id, }); if metadata.is_none() { // get the statement columns and parameters conn.stream.write(message::Describe::Statement(id)); } // we ask for the server to immediately send us the result of the PARSE command conn.write_sync(); conn.stream.flush().await?; // indicates that the SQL query string is now successfully parsed and has semantic validity let _ = conn .stream .recv_expect(MessageFormat::ParseComplete) .await?; let metadata = if let Some(metadata) = metadata { // each SYNC produces one READY FOR QUERY conn.recv_ready_for_query().await?; // we already have metadata metadata } else { let parameters = recv_desc_params(conn).await?; let rows = recv_desc_rows(conn).await?; // each SYNC produces one READY FOR QUERY conn.recv_ready_for_query().await?; let parameters = conn.handle_parameter_description(parameters).await?; let (columns, column_names) = conn.handle_row_description(rows, true).await?; // ensure that if we did fetch custom data, we wait until we are fully ready before // continuing conn.wait_until_ready().await?; Arc::new(PgStatementMetadata { parameters, columns, column_names: Arc::new(column_names), }) }; Ok((id, metadata)) } async fn recv_desc_params(conn: &mut PgConnection) -> Result { conn.stream .recv_expect(MessageFormat::ParameterDescription) .await } async fn recv_desc_rows(conn: &mut PgConnection) -> Result, Error> { let rows: Option = match conn.stream.recv().await? { // describes the rows that will be returned when the statement is eventually executed message if message.format == MessageFormat::RowDescription => Some(message.decode()?), // no data would be returned if this statement was executed message if message.format == MessageFormat::NoData => None, message => { return Err(err_protocol!( "expecting RowDescription or NoData but received {:?}", message.format )); } }; Ok(rows) } impl PgConnection { // wait for CloseComplete to indicate a statement was closed pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> { // we need to wait for the [CloseComplete] to be returned from the server while count > 0 { match self.stream.recv().await? { message if message.format == MessageFormat::PortalSuspended => { // there was an open portal // this can happen if the last time a statement was used it was not fully executed } message if message.format == MessageFormat::CloseComplete => { // successfully closed the statement (and freed up the server resources) count -= 1; } message => { return Err(err_protocol!( "expecting PortalSuspended or CloseComplete but received {:?}", message.format )); } } } Ok(()) } pub(crate) fn write_sync(&mut self) { self.stream.write(message::Sync); // all SYNC messages will return a ReadyForQuery self.pending_ready_for_query_count += 1; } async fn get_or_prepare<'a>( &mut self, sql: &str, parameters: &[PgTypeInfo], // should we store the result of this prepare to the cache store_to_cache: bool, // optional metadata that was provided by the user, this means they are reusing // a statement object metadata: Option>, ) -> Result<(Oid, Arc), Error> { if let Some(statement) = self.cache_statement.get_mut(sql) { return Ok((*statement).clone()); } let statement = prepare(self, sql, parameters, metadata).await?; if store_to_cache && self.cache_statement.is_enabled() { if let Some((id, _)) = self.cache_statement.insert(sql, statement.clone()) { self.stream.write(Close::Statement(id)); self.write_sync(); self.stream.flush().await?; self.wait_for_close_complete(1).await?; self.recv_ready_for_query().await?; } } Ok(statement) } pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>( &'c mut self, query: &'q str, arguments: Option, limit: u8, persistent: bool, metadata_opt: Option>, ) -> Result, Error>> + 'e, Error> { let mut logger = QueryLogger::new(query, self.log_settings.clone()); // before we continue, wait until we are "ready" to accept more queries self.wait_until_ready().await?; let mut metadata: Arc; let format = if let Some(mut arguments) = arguments { // prepare the statement if this our first time executing it // always return the statement ID here let (statement, metadata_) = self .get_or_prepare(query, &arguments.types, persistent, metadata_opt) .await?; metadata = metadata_; // patch holes created during encoding arguments.apply_patches(self, &metadata.parameters).await?; // consume messages till `ReadyForQuery` before bind and execute self.wait_until_ready().await?; // bind to attach the arguments to the statement and create a portal self.stream.write(Bind { portal: None, statement, formats: &[PgValueFormat::Binary], num_params: arguments.types.len() as i16, params: &*arguments.buffer, result_formats: &[PgValueFormat::Binary], }); // executes the portal up to the passed limit // the protocol-level limit acts nearly identically to the `LIMIT` in SQL self.stream.write(message::Execute { portal: None, limit: limit.into(), }); // From https://www.postgresql.org/docs/current/protocol-flow.html: // // "An unnamed portal is destroyed at the end of the transaction, or as // soon as the next Bind statement specifying the unnamed portal as // destination is issued. (Note that a simple Query message also // destroys the unnamed portal." // we ask the database server to close the unnamed portal and free the associated resources // earlier - after the execution of the current query. self.stream.write(message::Close::Portal(None)); // finally, [Sync] asks postgres to process the messages that we sent and respond with // a [ReadyForQuery] message when it's completely done. Theoretically, we could send // dozens of queries before a [Sync] and postgres can handle that. Execution on the server // is still serial but it would reduce round-trips. Some kind of builder pattern that is // termed batching might suit this. self.write_sync(); // prepared statements are binary PgValueFormat::Binary } else { // Query will trigger a ReadyForQuery self.stream.write(Query(query)); self.pending_ready_for_query_count += 1; // metadata starts out as "nothing" metadata = Arc::new(PgStatementMetadata::default()); // and unprepared statements are text PgValueFormat::Text }; self.stream.flush().await?; Ok(try_stream! { loop { let message = self.stream.recv().await?; match message.format { MessageFormat::BindComplete | MessageFormat::ParseComplete | MessageFormat::ParameterDescription | MessageFormat::NoData // unnamed portal has been closed | MessageFormat::CloseComplete => { // harmless messages to ignore } // "Execute phase is always terminated by the appearance of // exactly one of these messages: CommandComplete, // EmptyQueryResponse (if the portal was created from an // empty query string), ErrorResponse, or PortalSuspended" MessageFormat::CommandComplete => { // a SQL command completed normally let cc: CommandComplete = message.decode()?; let rows_affected = cc.rows_affected(); logger.increase_rows_affected(rows_affected); r#yield!(Either::Left(PgQueryResult { rows_affected, })); } MessageFormat::EmptyQueryResponse => { // empty query string passed to an unprepared execute } // Message::ErrorResponse is handled in self.stream.recv() // incomplete query execution has finished MessageFormat::PortalSuspended => {} MessageFormat::RowDescription => { // indicates that a *new* set of rows are about to be returned let (columns, column_names) = self .handle_row_description(Some(message.decode()?), false) .await?; metadata = Arc::new(PgStatementMetadata { column_names: Arc::new(column_names), columns, parameters: Vec::default(), }); } MessageFormat::DataRow => { logger.increment_rows_returned(); // one of the set of rows returned by a SELECT, FETCH, etc query let data: DataRow = message.decode()?; let row = PgRow { data, format, metadata: Arc::clone(&metadata), }; r#yield!(Either::Right(row)); } MessageFormat::ReadyForQuery => { // processing of the query string is complete self.handle_ready_for_query(message)?; break; } _ => { return Err(err_protocol!( "execute: unexpected message: {:?}", message.format )); } } } Ok(()) }) } } impl<'c> Executor<'c> for &'c mut PgConnection { type Database = Postgres; fn fetch_many<'e, 'q: 'e, E: 'q>( self, mut query: E, ) -> BoxStream<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, { let sql = query.sql(); let metadata = query.statement().map(|s| Arc::clone(&s.metadata)); let arguments = query.take_arguments(); let persistent = query.persistent(); Box::pin(try_stream! { let s = self.run(sql, arguments, 0, persistent, metadata).await?; pin_mut!(s); while let Some(v) = s.try_next().await? { r#yield!(v); } Ok(()) }) } fn fetch_optional<'e, 'q: 'e, E: 'q>( self, mut query: E, ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, { let sql = query.sql(); let metadata = query.statement().map(|s| Arc::clone(&s.metadata)); let arguments = query.take_arguments(); let persistent = query.persistent(); Box::pin(async move { let s = self.run(sql, arguments, 1, persistent, metadata).await?; pin_mut!(s); while let Some(s) = s.try_next().await? { if let Either::Right(r) = s { return Ok(Some(r)); } } Ok(None) }) } fn prepare_with<'e, 'q: 'e>( self, sql: &'q str, parameters: &'e [PgTypeInfo], ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { Box::pin(async move { self.wait_until_ready().await?; let (_, metadata) = self.get_or_prepare(sql, parameters, true, None).await?; Ok(PgStatement { sql: Cow::Borrowed(sql), metadata, }) }) } fn describe<'e, 'q: 'e>( self, sql: &'q str, ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { Box::pin(async move { self.wait_until_ready().await?; let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None).await?; let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?; Ok(Describe { columns: metadata.columns.clone(), nullable, parameters: Some(Either::Left(metadata.parameters.clone())), }) }) } } sqlx-postgres-0.7.3/src/connection/mod.rs000064400000000000000000000144250072674642500165770ustar 00000000000000use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; use crate::HashMap; use futures_core::future::BoxFuture; use futures_util::FutureExt; use crate::common::StatementCache; use crate::error::Error; use crate::ext::ustr::UStr; use crate::io::Decode; use crate::message::{ Close, Message, MessageFormat, Query, ReadyForQuery, Terminate, TransactionStatus, }; use crate::statement::PgStatementMetadata; use crate::transaction::Transaction; use crate::types::Oid; use crate::{PgConnectOptions, PgTypeInfo, Postgres}; pub(crate) use sqlx_core::connection::*; pub use self::stream::PgStream; pub(crate) mod describe; mod establish; mod executor; mod sasl; mod stream; mod tls; /// A connection to a PostgreSQL database. pub struct PgConnection { // underlying TCP or UDS stream, // wrapped in a potentially TLS stream, // wrapped in a buffered stream pub(crate) stream: PgStream, // process id of this backend // used to send cancel requests #[allow(dead_code)] process_id: u32, // secret key of this backend // used to send cancel requests #[allow(dead_code)] secret_key: u32, // sequence of statement IDs for use in preparing statements // in PostgreSQL, the statement is prepared to a user-supplied identifier next_statement_id: Oid, // cache statement by query string to the id and columns cache_statement: StatementCache<(Oid, Arc)>, // cache user-defined types by id <-> info cache_type_info: HashMap, cache_type_oid: HashMap, // number of ReadyForQuery messages that we are currently expecting pub(crate) pending_ready_for_query_count: usize, // current transaction status transaction_status: TransactionStatus, pub(crate) transaction_depth: usize, log_settings: LogSettings, } impl PgConnection { /// the version number of the server in `libpq` format pub fn server_version_num(&self) -> Option { self.stream.server_version_num } // will return when the connection is ready for another query pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> { if !self.stream.write_buffer_mut().is_empty() { self.stream.flush().await?; } while self.pending_ready_for_query_count > 0 { let message = self.stream.recv().await?; if let MessageFormat::ReadyForQuery = message.format { self.handle_ready_for_query(message)?; } } Ok(()) } async fn recv_ready_for_query(&mut self) -> Result<(), Error> { let r: ReadyForQuery = self .stream .recv_expect(MessageFormat::ReadyForQuery) .await?; self.pending_ready_for_query_count -= 1; self.transaction_status = r.transaction_status; Ok(()) } fn handle_ready_for_query(&mut self, message: Message) -> Result<(), Error> { self.pending_ready_for_query_count -= 1; self.transaction_status = ReadyForQuery::decode(message.contents)?.transaction_status; Ok(()) } /// Queue a simple query (not prepared) to execute the next time this connection is used. /// /// Used for rolling back transactions and releasing advisory locks. pub(crate) fn queue_simple_query(&mut self, query: &str) { self.pending_ready_for_query_count += 1; self.stream.write(Query(query)); } } impl Debug for PgConnection { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("PgConnection").finish() } } impl Connection for PgConnection { type Database = Postgres; type Options = PgConnectOptions; fn close(mut self) -> BoxFuture<'static, Result<(), Error>> { // The normal, graceful termination procedure is that the frontend sends a Terminate // message and immediately closes the connection. // On receipt of this message, the backend closes the // connection and terminates. Box::pin(async move { self.stream.send(Terminate).await?; self.stream.shutdown().await?; Ok(()) }) } fn close_hard(mut self) -> BoxFuture<'static, Result<(), Error>> { Box::pin(async move { self.stream.shutdown().await?; Ok(()) }) } fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> { // Users were complaining about this showing up in query statistics on the server. // By sending a comment we avoid an error if the connection was in the middle of a rowset // self.execute("/* SQLx ping */").map_ok(|_| ()).boxed() Box::pin(async move { // The simplest call-and-response that's possible. self.write_sync(); self.wait_until_ready().await }) } fn begin(&mut self) -> BoxFuture<'_, Result, Error>> where Self: Sized, { Transaction::begin(self) } fn cached_statements_size(&self) -> usize { self.cache_statement.len() } fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { self.cache_type_oid.clear(); let mut cleared = 0_usize; self.wait_until_ready().await?; while let Some((id, _)) = self.cache_statement.remove_lru() { self.stream.write(Close::Statement(id)); cleared += 1; } if cleared > 0 { self.write_sync(); self.stream.flush().await?; self.wait_for_close_complete(cleared).await?; self.recv_ready_for_query().await?; } Ok(()) }) } fn shrink_buffers(&mut self) { self.stream.shrink_buffers(); } #[doc(hidden)] fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { self.wait_until_ready().boxed() } #[doc(hidden)] fn should_flush(&self) -> bool { !self.stream.write_buffer().is_empty() } } // Implement `AsMut` so that `PgConnection` can be wrapped in // a `PgAdvisoryLockGuard`. // // See: https://github.com/launchbadge/sqlx/issues/2520 impl AsMut for PgConnection { fn as_mut(&mut self) -> &mut PgConnection { self } } sqlx-postgres-0.7.3/src/connection/sasl.rs000064400000000000000000000150040072674642500167540ustar 00000000000000use crate::connection::stream::PgStream; use crate::error::Error; use crate::message::{ Authentication, AuthenticationSasl, MessageFormat, SaslInitialResponse, SaslResponse, }; use crate::PgConnectOptions; use hmac::{Hmac, Mac}; use rand::Rng; use sha2::{Digest, Sha256}; use stringprep::saslprep; use base64::prelude::{Engine as _, BASE64_STANDARD}; const GS2_HEADER: &str = "n,,"; const CHANNEL_ATTR: &str = "c"; const USERNAME_ATTR: &str = "n"; const CLIENT_PROOF_ATTR: &str = "p"; const NONCE_ATTR: &str = "r"; pub(crate) async fn authenticate( stream: &mut PgStream, options: &PgConnectOptions, data: AuthenticationSasl, ) -> Result<(), Error> { let mut has_sasl = false; let mut has_sasl_plus = false; let mut unknown = Vec::new(); for mechanism in data.mechanisms() { match mechanism { "SCRAM-SHA-256" => { has_sasl = true; } "SCRAM-SHA-256-PLUS" => { has_sasl_plus = true; } _ => { unknown.push(mechanism.to_owned()); } } } if !has_sasl_plus && !has_sasl { return Err(err_protocol!( "unsupported SASL authentication mechanisms: {}", unknown.join(", ") )); } // channel-binding = "c=" base64 let mut channel_binding = format!("{CHANNEL_ATTR}="); BASE64_STANDARD.encode_string(GS2_HEADER, &mut channel_binding); // "n=" saslname ;; Usernames are prepared using SASLprep. let username = format!("{}={}", USERNAME_ATTR, options.username); let username = match saslprep(&username) { Ok(v) => v, // TODO(danielakhterov): Remove panic when we have proper support for configuration errors Err(_) => panic!("Failed to saslprep username"), }; // nonce = "r=" c-nonce [s-nonce] ;; Second part provided by server. let nonce = gen_nonce(); // client-first-message-bare = [reserved-mext ","] username "," nonce ["," extensions] let client_first_message_bare = format!("{username},{nonce}"); let client_first_message = format!("{GS2_HEADER}{client_first_message_bare}"); stream .send(SaslInitialResponse { response: &client_first_message, plus: false, }) .await?; let cont = match stream.recv_expect(MessageFormat::Authentication).await? { Authentication::SaslContinue(data) => data, auth => { return Err(err_protocol!( "expected SASLContinue but received {:?}", auth )); } }; // SaltedPassword := Hi(Normalize(password), salt, i) let salted_password = hi( options.password.as_deref().unwrap_or_default(), &cont.salt, cont.iterations, )?; // ClientKey := HMAC(SaltedPassword, "Client Key") let mut mac = Hmac::::new_from_slice(&salted_password).map_err(Error::protocol)?; mac.update(b"Client Key"); let client_key = mac.finalize().into_bytes(); // StoredKey := H(ClientKey) let stored_key = Sha256::digest(&client_key); // client-final-message-without-proof let client_final_message_wo_proof = format!( "{channel_binding},r={nonce}", channel_binding = channel_binding, nonce = &cont.nonce ); // AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof let auth_message = format!( "{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}", client_first_message_bare = client_first_message_bare, server_first_message = cont.message, client_final_message_wo_proof = client_final_message_wo_proof ); // ClientSignature := HMAC(StoredKey, AuthMessage) let mut mac = Hmac::::new_from_slice(&stored_key).map_err(Error::protocol)?; mac.update(&auth_message.as_bytes()); let client_signature = mac.finalize().into_bytes(); // ClientProof := ClientKey XOR ClientSignature let client_proof: Vec = client_key .iter() .zip(client_signature.iter()) .map(|(&a, &b)| a ^ b) .collect(); // ServerKey := HMAC(SaltedPassword, "Server Key") let mut mac = Hmac::::new_from_slice(&salted_password).map_err(Error::protocol)?; mac.update(b"Server Key"); let server_key = mac.finalize().into_bytes(); // ServerSignature := HMAC(ServerKey, AuthMessage) let mut mac = Hmac::::new_from_slice(&server_key).map_err(Error::protocol)?; mac.update(&auth_message.as_bytes()); // client-final-message = client-final-message-without-proof "," proof let mut client_final_message = format!("{client_final_message_wo_proof},{CLIENT_PROOF_ATTR}="); BASE64_STANDARD.encode_string(client_proof, &mut client_final_message); stream.send(SaslResponse(&client_final_message)).await?; let data = match stream.recv_expect(MessageFormat::Authentication).await? { Authentication::SaslFinal(data) => data, auth => { return Err(err_protocol!("expected SASLFinal but received {:?}", auth)); } }; // authentication is only considered valid if this verification passes mac.verify_slice(&data.verifier).map_err(Error::protocol)?; Ok(()) } // nonce is a sequence of random printable bytes fn gen_nonce() -> String { let mut rng = rand::thread_rng(); let count = rng.gen_range(64..128); // printable = %x21-2B / %x2D-7E // ;; Printable ASCII except ",". // ;; Note that any "printable" is also // ;; a valid "value". let nonce: String = std::iter::repeat(()) .map(|()| { let mut c = rng.gen_range(0x21..0x7F) as u8; while c == 0x2C { c = rng.gen_range(0x21..0x7F) as u8; } c }) .take(count) .map(|c| c as char) .collect(); rng.gen_range(32..128); format!("{NONCE_ATTR}={nonce}") } // Hi(str, salt, i): fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32], Error> { let mut mac = Hmac::::new_from_slice(s.as_bytes()).map_err(Error::protocol)?; mac.update(&salt); mac.update(&1u32.to_be_bytes()); let mut u = mac.finalize().into_bytes(); let mut hi = u; for _ in 1..iter_count { let mut mac = Hmac::::new_from_slice(s.as_bytes()).map_err(Error::protocol)?; mac.update(u.as_slice()); u = mac.finalize().into_bytes(); hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect(); } Ok(hi.into()) } sqlx-postgres-0.7.3/src/connection/stream.rs000064400000000000000000000214040072674642500173060ustar 00000000000000use std::collections::BTreeMap; use std::ops::{Deref, DerefMut}; use std::str::FromStr; use futures_channel::mpsc::UnboundedSender; use futures_util::SinkExt; use log::Level; use sqlx_core::bytes::{Buf, Bytes}; use crate::connection::tls::MaybeUpgradeTls; use crate::error::Error; use crate::io::{Decode, Encode}; use crate::message::{Message, MessageFormat, Notice, Notification, ParameterStatus}; use crate::net::{self, BufferedSocket, Socket}; use crate::{PgConnectOptions, PgDatabaseError, PgSeverity}; // the stream is a separate type from the connection to uphold the invariant where an instantiated // [PgConnection] is a **valid** connection to postgres // when a new connection is asked for, we work directly on the [PgStream] type until the // connection is fully established // in other words, `self` in any PgConnection method is a live connection to postgres that // is fully prepared to receive queries pub struct PgStream { // A trait object is okay here as the buffering amortizes the overhead of both the dynamic // function call as well as the syscall. inner: BufferedSocket>, // buffer of unreceived notification messages from `PUBLISH` // this is set when creating a PgListener and only written to if that listener is // re-used for query execution in-between receiving messages pub(crate) notifications: Option>, pub(crate) parameter_statuses: BTreeMap, pub(crate) server_version_num: Option, } impl PgStream { pub(super) async fn connect(options: &PgConnectOptions) -> Result { let socket_future = match options.fetch_socket() { Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?, None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?, }; let socket = socket_future.await?; Ok(Self { inner: BufferedSocket::new(socket), notifications: None, parameter_statuses: BTreeMap::default(), server_version_num: None, }) } pub(crate) async fn send<'en, T>(&mut self, message: T) -> Result<(), Error> where T: Encode<'en>, { self.write(message); self.flush().await?; Ok(()) } // Expect a specific type and format pub(crate) async fn recv_expect<'de, T: Decode<'de>>( &mut self, format: MessageFormat, ) -> Result { let message = self.recv().await?; if message.format != format { return Err(err_protocol!( "expecting {:?} but received {:?}", format, message.format )); } message.decode() } pub(crate) async fn recv_unchecked(&mut self) -> Result { // all packets in postgres start with a 5-byte header // this header contains the message type and the total length of the message let mut header: Bytes = self.inner.read(5).await?; let format = MessageFormat::try_from_u8(header.get_u8())?; let size = (header.get_u32() - 4) as usize; let contents = self.inner.read(size).await?; Ok(Message { format, contents }) } // Get the next message from the server // May wait for more data from the server pub(crate) async fn recv(&mut self) -> Result { loop { let message = self.recv_unchecked().await?; match message.format { MessageFormat::ErrorResponse => { // An error returned from the database server. return Err(PgDatabaseError(message.decode()?).into()); } MessageFormat::NotificationResponse => { if let Some(buffer) = &mut self.notifications { let notification: Notification = message.decode()?; let _ = buffer.send(notification).await; continue; } } MessageFormat::ParameterStatus => { // informs the frontend about the current (initial) // setting of backend parameters let ParameterStatus { name, value } = message.decode()?; // TODO: handle `client_encoding`, `DateStyle` change match name.as_str() { "server_version" => { self.server_version_num = parse_server_version(&value); } _ => { self.parameter_statuses.insert(name, value); } } continue; } MessageFormat::NoticeResponse => { // do we need this to be more configurable? // if you are reading this comment and think so, open an issue let notice: Notice = message.decode()?; let (log_level, tracing_level) = match notice.severity() { PgSeverity::Fatal | PgSeverity::Panic | PgSeverity::Error => { (Level::Error, tracing::Level::ERROR) } PgSeverity::Warning => (Level::Warn, tracing::Level::WARN), PgSeverity::Notice => (Level::Info, tracing::Level::INFO), PgSeverity::Debug => (Level::Debug, tracing::Level::DEBUG), PgSeverity::Info | PgSeverity::Log => (Level::Trace, tracing::Level::TRACE), }; let log_is_enabled = log::log_enabled!( target: "sqlx::postgres::notice", log_level ) || sqlx_core::private_tracing_dynamic_enabled!( target: "sqlx::postgres::notice", tracing_level ); if log_is_enabled { let message = format!("{}", notice.message()); sqlx_core::private_tracing_dynamic_event!( target: "sqlx::postgres::notice", tracing_level, message ); } continue; } _ => {} } return Ok(message); } } } impl Deref for PgStream { type Target = BufferedSocket>; #[inline] fn deref(&self) -> &Self::Target { &self.inner } } impl DerefMut for PgStream { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { &mut self.inner } } // reference: // https://github.com/postgres/postgres/blob/6feebcb6b44631c3dc435e971bd80c2dd218a5ab/src/interfaces/libpq/fe-exec.c#L1030-L1065 fn parse_server_version(s: &str) -> Option { let mut parts = Vec::::with_capacity(3); let mut from = 0; let mut chs = s.char_indices().peekable(); while let Some((i, ch)) = chs.next() { match ch { '.' => { if let Ok(num) = u32::from_str(&s[from..i]) { parts.push(num); from = i + 1; } else { break; } } _ if ch.is_digit(10) => { if chs.peek().is_none() { if let Ok(num) = u32::from_str(&s[from..]) { parts.push(num); } break; } } _ => { if let Ok(num) = u32::from_str(&s[from..i]) { parts.push(num); } break; } }; } let version_num = match parts.as_slice() { [major, minor, rev] => (100 * major + minor) * 100 + rev, [major, minor] if *major >= 10 => 100 * 100 * major + minor, [major, minor] => (100 * major + minor) * 100, [major] => 100 * 100 * major, _ => return None, }; Some(version_num) } #[cfg(test)] mod tests { use super::parse_server_version; #[test] fn test_parse_server_version_num() { // old style assert_eq!(parse_server_version("9.6.1"), Some(90601)); // new style assert_eq!(parse_server_version("10.1"), Some(100001)); // old style without minor version assert_eq!(parse_server_version("9.6devel"), Some(90600)); // new style without minor version, e.g. */ assert_eq!(parse_server_version("10devel"), Some(100000)); assert_eq!(parse_server_version("13devel87"), Some(130000)); // unknown assert_eq!(parse_server_version("unknown"), None); } } sqlx-postgres-0.7.3/src/connection/tls.rs000064400000000000000000000060370072674642500166220ustar 00000000000000use futures_core::future::BoxFuture; use crate::error::Error; use crate::net::tls::{self, TlsConfig}; use crate::net::{Socket, SocketIntoBox, WithSocket}; use crate::message::SslRequest; use crate::{PgConnectOptions, PgSslMode}; pub struct MaybeUpgradeTls<'a>(pub &'a PgConnectOptions); impl<'a> WithSocket for MaybeUpgradeTls<'a> { type Output = BoxFuture<'a, crate::Result>>; fn with_socket(self, socket: S) -> Self::Output { Box::pin(maybe_upgrade(socket, self.0)) } } async fn maybe_upgrade( mut socket: S, options: &PgConnectOptions, ) -> Result, Error> { // https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS match options.ssl_mode { // FIXME: Implement ALLOW PgSslMode::Allow | PgSslMode::Disable => return Ok(Box::new(socket)), PgSslMode::Prefer => { if !tls::available() { return Ok(Box::new(socket)); } // try upgrade, but its okay if we fail if !request_upgrade(&mut socket, options).await? { return Ok(Box::new(socket)); } } PgSslMode::Require | PgSslMode::VerifyFull | PgSslMode::VerifyCa => { tls::error_if_unavailable()?; if !request_upgrade(&mut socket, options).await? { // upgrade failed, die return Err(Error::Tls("server does not support TLS".into())); } } } let accept_invalid_certs = !matches!( options.ssl_mode, PgSslMode::VerifyCa | PgSslMode::VerifyFull ); let accept_invalid_hostnames = !matches!(options.ssl_mode, PgSslMode::VerifyFull); let config = TlsConfig { accept_invalid_certs, accept_invalid_hostnames, hostname: &options.host, root_cert_path: options.ssl_root_cert.as_ref(), client_cert_path: options.ssl_client_cert.as_ref(), client_key_path: options.ssl_client_key.as_ref(), }; tls::handshake(socket, config, SocketIntoBox).await } async fn request_upgrade( socket: &mut impl Socket, _options: &PgConnectOptions, ) -> Result { // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.11 // To initiate an SSL-encrypted connection, the frontend initially sends an // SSLRequest message rather than a StartupMessage socket.write(SslRequest::BYTES).await?; // The server then responds with a single byte containing S or N, indicating that // it is willing or unwilling to perform SSL, respectively. let mut response = [0u8]; socket.read(&mut &mut response[..]).await?; match response[0] { b'S' => { // The server is ready and willing to accept an SSL connection Ok(true) } b'N' => { // The server is _unwilling_ to perform SSL Ok(false) } other => Err(err_protocol!( "unexpected response from SSLRequest: 0x{:02x}", other )), } } sqlx-postgres-0.7.3/src/copy.rs000064400000000000000000000324220072674642500146300ustar 00000000000000use futures_core::future::BoxFuture; use std::borrow::Cow; use std::ops::{Deref, DerefMut}; use futures_core::stream::BoxStream; use sqlx_core::bytes::{BufMut, Bytes}; use crate::connection::PgConnection; use crate::error::{Error, Result}; use crate::ext::async_stream::TryAsyncStream; use crate::io::{AsyncRead, AsyncReadExt}; use crate::message::{ CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query, }; use crate::pool::{Pool, PoolConnection}; use crate::Postgres; impl PgConnection { /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data /// to Postgres. This is a more efficient way to import data into Postgres as compared to /// `INSERT` but requires one of a few specific data formats (text/CSV/binary). /// /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is /// returned. /// /// Command examples and accepted formats for `COPY` data are shown here: /// https://www.postgresql.org/docs/current/sql-copy.html /// /// ### Note /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection /// will return an error the next time it is used. pub async fn copy_in_raw(&mut self, statement: &str) -> Result> { PgCopyIn::begin(self, statement).await } /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data /// from Postgres. This is a more efficient way to export data from Postgres but /// arrives in chunks of one of a few data formats (text/CSV/binary). /// /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command, /// an error is returned. /// /// Note that once this process has begun, unless you read the stream to completion, /// it can only be canceled in two ways: /// /// 1. by closing the connection, or: /// 2. by using another connection to kill the server process that is sending the data as shown /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598). /// /// If you don't read the stream to completion, the next time the connection is used it will /// need to read and discard all the remaining queued data, which could take some time. /// /// Command examples and accepted formats for `COPY` data are shown here: /// https://www.postgresql.org/docs/current/sql-copy.html #[allow(clippy::needless_lifetimes)] pub async fn copy_out_raw<'c>( &'c mut self, statement: &str, ) -> Result>> { pg_begin_copy_out(self, statement).await } } /// Implements methods for directly executing `COPY FROM/TO STDOUT` on a [`PgPool`]. /// /// This is a replacement for the inherent methods on `PgPool` which could not exist /// once the Postgres driver was moved out into its own crate. pub trait PgPoolCopyExt { /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres. /// This is a more efficient way to import data into Postgres as compared to /// `INSERT` but requires one of a few specific data formats (text/CSV/binary). /// /// A single connection will be checked out for the duration. /// /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is /// returned. /// /// Command examples and accepted formats for `COPY` data are shown here: /// https://www.postgresql.org/docs/current/sql-copy.html /// /// ### Note /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection /// will return an error the next time it is used. fn copy_in_raw<'a>( &'a self, statement: &'a str, ) -> BoxFuture<'a, Result>>>; /// Issue a `COPY TO STDOUT` statement and begin streaming data /// from Postgres. This is a more efficient way to export data from Postgres but /// arrives in chunks of one of a few data formats (text/CSV/binary). /// /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command, /// an error is returned. /// /// Note that once this process has begun, unless you read the stream to completion, /// it can only be canceled in two ways: /// /// 1. by closing the connection, or: /// 2. by using another connection to kill the server process that is sending the data as shown /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598). /// /// If you don't read the stream to completion, the next time the connection is used it will /// need to read and discard all the remaining queued data, which could take some time. /// /// Command examples and accepted formats for `COPY` data are shown here: /// https://www.postgresql.org/docs/current/sql-copy.html fn copy_out_raw<'a>( &'a self, statement: &'a str, ) -> BoxFuture<'a, Result>>>; } impl PgPoolCopyExt for Pool { fn copy_in_raw<'a>( &'a self, statement: &'a str, ) -> BoxFuture<'a, Result>>> { Box::pin(async { PgCopyIn::begin(self.acquire().await?, statement).await }) } fn copy_out_raw<'a>( &'a self, statement: &'a str, ) -> BoxFuture<'a, Result>>> { Box::pin(async { pg_begin_copy_out(self.acquire().await?, statement).await }) } } /// A connection in streaming `COPY FROM STDIN` mode. /// /// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw]. /// /// ### Note /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection /// will return an error the next time it is used. #[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"] pub struct PgCopyIn> { conn: Option, response: CopyResponse, } impl> PgCopyIn { async fn begin(mut conn: C, statement: &str) -> Result { conn.wait_until_ready().await?; conn.stream.send(Query(statement)).await?; let response = match conn.stream.recv_expect(MessageFormat::CopyInResponse).await { Ok(res) => res, Err(e) => { conn.stream.recv().await?; return Err(e); } }; Ok(PgCopyIn { conn: Some(conn), response, }) } /// Returns `true` if Postgres is expecting data in text or CSV format. pub fn is_textual(&self) -> bool { self.response.format == 0 } /// Returns the number of columns expected in the input. pub fn num_columns(&self) -> usize { assert_eq!( self.response.num_columns as usize, self.response.format_codes.len(), "num_columns does not match format_codes.len()" ); self.response.format_codes.len() } /// Check if a column is expecting data in text format (`true`) or binary format (`false`). /// /// ### Panics /// If `column` is out of range according to [`.num_columns()`][Self::num_columns]. pub fn column_is_textual(&self, column: usize) -> bool { self.response.format_codes[column] == 0 } /// Send a chunk of `COPY` data. /// /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead. pub async fn send(&mut self, data: impl Deref) -> Result<&mut Self> { self.conn .as_deref_mut() .expect("send_data: conn taken") .stream .send(CopyData(data)) .await?; Ok(self) } /// Copy data directly from `source` to the database without requiring an intermediate buffer. /// /// `source` will be read to the end. /// /// ### Note: Completion Step Required /// You must still call either [Self::finish] or [Self::abort] to complete the process. /// /// ### Note: Runtime Features /// This method uses the `AsyncRead` trait which is re-exported from either Tokio or `async-std` /// depending on which runtime feature is used. /// /// The runtime features _used_ to be mutually exclusive, but are no longer. /// If both `runtime-async-std` and `runtime-tokio` features are enabled, the Tokio version /// takes precedent. pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> { // this is a separate guard from WriteAndFlush so we can reuse the buffer without zeroing struct BufGuard<'s>(&'s mut Vec); impl Drop for BufGuard<'_> { fn drop(&mut self) { self.0.clear() } } let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken"); // flush any existing messages in the buffer and clear it conn.stream.flush().await?; loop { let buf = conn.stream.write_buffer_mut(); // CopyData format code and reserved space for length buf.put_slice(b"d\0\0\0\x04"); let read = match () { // Tokio lets us read into the buffer without zeroing first #[cfg(feature = "_rt-tokio")] _ => source.read_buf(buf.buf_mut()).await?, #[cfg(not(feature = "_rt-tokio"))] _ => source.read(buf.init_remaining_mut()).await?, }; if read == 0 { // This will end up sending an empty `CopyData` packet but that should be fine. break; } buf.advance(read); // Write the length let read32 = u32::try_from(read) .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?; (&mut buf.get_mut()[1..]).put_u32(read32 + 4); conn.stream.flush().await?; } Ok(self) } /// Signal that the `COPY` process should be aborted and any data received should be discarded. /// /// The given message can be used for indicating the reason for the abort in the database logs. /// /// The server is expected to respond with an error, so only _unexpected_ errors are returned. pub async fn abort(mut self, msg: impl Into) -> Result<()> { let mut conn = self .conn .take() .expect("PgCopyIn::fail_with: conn taken illegally"); conn.stream.send(CopyFail::new(msg)).await?; match conn.stream.recv().await { Ok(msg) => Err(err_protocol!( "fail_with: expected ErrorResponse, got: {:?}", msg.format )), Err(Error::Database(e)) => { match e.code() { Some(Cow::Borrowed("57014")) => { // postgres abort received error code conn.stream .recv_expect(MessageFormat::ReadyForQuery) .await?; Ok(()) } _ => Err(Error::Database(e)), } } Err(e) => Err(e), } } /// Signal that the `COPY` process is complete. /// /// The number of rows affected is returned. pub async fn finish(mut self) -> Result { let mut conn = self .conn .take() .expect("CopyWriter::finish: conn taken illegally"); conn.stream.send(CopyDone).await?; let cc: CommandComplete = match conn .stream .recv_expect(MessageFormat::CommandComplete) .await { Ok(cc) => cc, Err(e) => { conn.stream.recv().await?; return Err(e); } }; conn.stream .recv_expect(MessageFormat::ReadyForQuery) .await?; Ok(cc.rows_affected()) } } impl> Drop for PgCopyIn { fn drop(&mut self) { if let Some(mut conn) = self.conn.take() { conn.stream.write(CopyFail::new( "PgCopyIn dropped without calling finish() or fail()", )); } } } async fn pg_begin_copy_out<'c, C: DerefMut + Send + 'c>( mut conn: C, statement: &str, ) -> Result>> { conn.wait_until_ready().await?; conn.stream.send(Query(statement)).await?; let _: CopyResponse = conn .stream .recv_expect(MessageFormat::CopyOutResponse) .await?; let stream: TryAsyncStream<'c, Bytes> = try_stream! { loop { let msg = conn.stream.recv().await?; match msg.format { MessageFormat::CopyData => r#yield!(msg.decode::>()?.0), MessageFormat::CopyDone => { let _ = msg.decode::()?; conn.stream.recv_expect(MessageFormat::CommandComplete).await?; conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?; return Ok(()) }, _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format)) } } }; Ok(Box::pin(stream)) } sqlx-postgres-0.7.3/src/database.rs000064400000000000000000000022700072674642500154200ustar 00000000000000use crate::arguments::PgArgumentBuffer; use crate::value::{PgValue, PgValueRef}; use crate::{ PgArguments, PgColumn, PgConnection, PgQueryResult, PgRow, PgStatement, PgTransactionManager, PgTypeInfo, }; pub(crate) use sqlx_core::database::{ Database, HasArguments, HasStatement, HasStatementCache, HasValueRef, }; /// PostgreSQL database driver. #[derive(Debug)] pub struct Postgres; impl Database for Postgres { type Connection = PgConnection; type TransactionManager = PgTransactionManager; type Row = PgRow; type QueryResult = PgQueryResult; type Column = PgColumn; type TypeInfo = PgTypeInfo; type Value = PgValue; const NAME: &'static str = "PostgreSQL"; const URL_SCHEMES: &'static [&'static str] = &["postgres", "postgresql"]; } impl<'r> HasValueRef<'r> for Postgres { type Database = Postgres; type ValueRef = PgValueRef<'r>; } impl HasArguments<'_> for Postgres { type Database = Postgres; type Arguments = PgArguments; type ArgumentBuffer = PgArgumentBuffer; } impl<'q> HasStatement<'q> for Postgres { type Database = Postgres; type Statement = PgStatement<'q>; } impl HasStatementCache for Postgres {} sqlx-postgres-0.7.3/src/error.rs000064400000000000000000000162750072674642500150170ustar 00000000000000use std::error::Error as StdError; use std::fmt::{self, Debug, Display, Formatter}; use atoi::atoi; use smallvec::alloc::borrow::Cow; pub(crate) use sqlx_core::error::*; use crate::message::{Notice, PgSeverity}; /// An error returned from the PostgreSQL database. pub struct PgDatabaseError(pub(crate) Notice); // Error message fields are documented: // https://www.postgresql.org/docs/current/protocol-error-fields.html impl PgDatabaseError { #[inline] pub fn severity(&self) -> PgSeverity { self.0.severity() } /// The [SQLSTATE](https://www.postgresql.org/docs/current/errcodes-appendix.html) code for /// this error. #[inline] pub fn code(&self) -> &str { self.0.code() } /// The primary human-readable error message. This should be accurate but /// terse (typically one line). #[inline] pub fn message(&self) -> &str { self.0.message() } /// An optional secondary error message carrying more detail about the problem. /// Might run to multiple lines. #[inline] pub fn detail(&self) -> Option<&str> { self.0.get(b'D') } /// An optional suggestion what to do about the problem. This is intended to differ from /// `detail` in that it offers advice (potentially inappropriate) rather than hard facts. /// Might run to multiple lines. #[inline] pub fn hint(&self) -> Option<&str> { self.0.get(b'H') } /// Indicates an error cursor position as an index into the original query string; or, /// a position into an internally generated query. #[inline] pub fn position(&self) -> Option> { self.0 .get_raw(b'P') .and_then(atoi) .map(PgErrorPosition::Original) .or_else(|| { let position = self.0.get_raw(b'p').and_then(atoi)?; let query = self.0.get(b'q')?; Some(PgErrorPosition::Internal { position, query }) }) } /// An indication of the context in which the error occurred. Presently this includes a call /// stack traceback of active procedural language functions and internally-generated queries. /// The trace is one entry per line, most recent first. pub fn r#where(&self) -> Option<&str> { self.0.get(b'W') } /// If this error is with a specific database object, the /// name of the schema containing that object, if any. pub fn schema(&self) -> Option<&str> { self.0.get(b's') } /// If this error is with a specific table, the name of the table. pub fn table(&self) -> Option<&str> { self.0.get(b't') } /// If the error is with a specific table column, the name of the column. pub fn column(&self) -> Option<&str> { self.0.get(b'c') } /// If the error is with a specific data type, the name of the data type. pub fn data_type(&self) -> Option<&str> { self.0.get(b'd') } /// If the error is with a specific constraint, the name of the constraint. /// For this purpose, indexes are constraints, even if they weren't created /// with constraint syntax. pub fn constraint(&self) -> Option<&str> { self.0.get(b'n') } /// The file name of the source-code location where this error was reported. pub fn file(&self) -> Option<&str> { self.0.get(b'F') } /// The line number of the source-code location where this error was reported. pub fn line(&self) -> Option { self.0.get_raw(b'L').and_then(atoi) } /// The name of the source-code routine reporting this error. pub fn routine(&self) -> Option<&str> { self.0.get(b'R') } } #[derive(Debug, Eq, PartialEq)] pub enum PgErrorPosition<'a> { /// A position (in characters) into the original query. Original(usize), /// A position into the internally-generated query. Internal { /// The position in characters. position: usize, /// The text of a failed internally-generated command. This could be, for example, /// the SQL query issued by a PL/pgSQL function. query: &'a str, }, } impl Debug for PgDatabaseError { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("PgDatabaseError") .field("severity", &self.severity()) .field("code", &self.code()) .field("message", &self.message()) .field("detail", &self.detail()) .field("hint", &self.hint()) .field("position", &self.position()) .field("where", &self.r#where()) .field("schema", &self.schema()) .field("table", &self.table()) .field("column", &self.column()) .field("data_type", &self.data_type()) .field("constraint", &self.constraint()) .field("file", &self.file()) .field("line", &self.line()) .field("routine", &self.routine()) .finish() } } impl Display for PgDatabaseError { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.write_str(self.message()) } } impl StdError for PgDatabaseError {} impl DatabaseError for PgDatabaseError { fn message(&self) -> &str { self.message() } fn code(&self) -> Option> { Some(Cow::Borrowed(self.code())) } #[doc(hidden)] fn as_error(&self) -> &(dyn StdError + Send + Sync + 'static) { self } #[doc(hidden)] fn as_error_mut(&mut self) -> &mut (dyn StdError + Send + Sync + 'static) { self } #[doc(hidden)] fn into_error(self: Box) -> BoxDynError { self } fn is_transient_in_connect_phase(&self) -> bool { // https://www.postgresql.org/docs/current/errcodes-appendix.html [ // too_many_connections // This may be returned if we just un-gracefully closed a connection, // give the database a chance to notice it and clean it up. "53300", // cannot_connect_now // Returned if the database is still starting up. "57P03", ] .contains(&self.code()) } fn constraint(&self) -> Option<&str> { self.constraint() } fn table(&self) -> Option<&str> { self.table() } fn kind(&self) -> ErrorKind { match self.code() { error_codes::UNIQUE_VIOLATION => ErrorKind::UniqueViolation, error_codes::FOREIGN_KEY_VIOLATION => ErrorKind::ForeignKeyViolation, error_codes::NOT_NULL_VIOLATION => ErrorKind::NotNullViolation, error_codes::CHECK_VIOLATION => ErrorKind::CheckViolation, _ => ErrorKind::Other, } } } /// For reference: pub(crate) mod error_codes { /// Caused when a unique or primary key is violated. pub const UNIQUE_VIOLATION: &str = "23505"; /// Caused when a foreign key is violated. pub const FOREIGN_KEY_VIOLATION: &str = "23503"; /// Caused when a column marked as NOT NULL received a null value. pub const NOT_NULL_VIOLATION: &str = "23502"; /// Caused when a check constraint is violated. pub const CHECK_VIOLATION: &str = "23514"; } sqlx-postgres-0.7.3/src/io/buf_mut.rs000064400000000000000000000027660072674642500157360ustar 00000000000000use crate::types::Oid; pub trait PgBufMutExt { fn put_length_prefixed(&mut self, f: F) where F: FnOnce(&mut Vec); fn put_statement_name(&mut self, id: Oid); fn put_portal_name(&mut self, id: Option); } impl PgBufMutExt for Vec { // writes a length-prefixed message, this is used when encoding nearly all messages as postgres // wants us to send the length of the often-variable-sized messages up front fn put_length_prefixed(&mut self, f: F) where F: FnOnce(&mut Vec), { // reserve space to write the prefixed length let offset = self.len(); self.extend(&[0; 4]); // write the main body of the message f(self); // now calculate the size of what we wrote and set the length value let size = (self.len() - offset) as i32; self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes()); } // writes a statement name by ID #[inline] fn put_statement_name(&mut self, id: Oid) { // N.B. if you change this don't forget to update it in ../describe.rs self.extend(b"sqlx_s_"); self.extend(itoa::Buffer::new().format(id.0).as_bytes()); self.push(0); } // writes a portal name by ID #[inline] fn put_portal_name(&mut self, id: Option) { if let Some(id) = id { self.extend(b"sqlx_p_"); self.extend(itoa::Buffer::new().format(id.0).as_bytes()); } self.push(0); } } sqlx-postgres-0.7.3/src/io/mod.rs000064400000000000000000000001160072674642500150370ustar 00000000000000mod buf_mut; pub use buf_mut::PgBufMutExt; pub(crate) use sqlx_core::io::*; sqlx-postgres-0.7.3/src/lib.rs000064400000000000000000000035260072674642500144270ustar 00000000000000//! **PostgreSQL** database driver. #[macro_use] extern crate sqlx_core; use crate::executor::Executor; mod advisory_lock; mod arguments; mod column; mod connection; mod copy; mod database; mod error; mod io; mod listener; mod message; mod options; mod query_result; mod row; mod statement; mod transaction; mod type_info; pub mod types; mod value; #[cfg(feature = "any")] pub mod any; #[cfg(feature = "migrate")] mod migrate; #[cfg(feature = "migrate")] mod testing; pub(crate) use sqlx_core::driver_prelude::*; pub use advisory_lock::{PgAdvisoryLock, PgAdvisoryLockGuard, PgAdvisoryLockKey}; pub use arguments::{PgArgumentBuffer, PgArguments}; pub use column::PgColumn; pub use connection::PgConnection; pub use copy::PgCopyIn; pub use database::Postgres; pub use error::{PgDatabaseError, PgErrorPosition}; pub use listener::{PgListener, PgNotification}; pub use message::PgSeverity; pub use options::{PgConnectOptions, PgSslMode}; pub use query_result::PgQueryResult; pub use row::PgRow; pub use statement::PgStatement; pub use transaction::PgTransactionManager; pub use type_info::{PgTypeInfo, PgTypeKind}; pub use types::PgHasArrayType; pub use value::{PgValue, PgValueFormat, PgValueRef}; /// An alias for [`Pool`][crate::pool::Pool], specialized for Postgres. pub type PgPool = crate::pool::Pool; /// An alias for [`PoolOptions`][crate::pool::PoolOptions], specialized for Postgres. pub type PgPoolOptions = crate::pool::PoolOptions; /// An alias for [`Executor<'_, Database = Postgres>`][Executor]. pub trait PgExecutor<'c>: Executor<'c, Database = Postgres> {} impl<'c, T: Executor<'c, Database = Postgres>> PgExecutor<'c> for T {} impl_into_arguments_for_arguments!(PgArguments); impl_acquire!(Postgres, PgConnection); impl_column_index_for_row!(PgRow); impl_column_index_for_statement!(PgStatement); impl_encode_for_option!(Postgres); sqlx-postgres-0.7.3/src/listener.rs000064400000000000000000000365240072674642500155120ustar 00000000000000use std::fmt::{self, Debug}; use std::io; use std::str::from_utf8; use futures_channel::mpsc; use futures_core::future::BoxFuture; use futures_core::stream::{BoxStream, Stream}; use futures_util::{FutureExt, StreamExt, TryStreamExt}; use sqlx_core::Either; use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::message::{MessageFormat, Notification}; use crate::pool::PoolOptions; use crate::pool::{Pool, PoolConnection}; use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres}; /// A stream of asynchronous notifications from Postgres. /// /// This listener will auto-reconnect. If the active /// connection being used ever dies, this listener will detect that event, create a /// new connection, will re-subscribe to all of the originally specified channels, and will resume /// operations as normal. pub struct PgListener { pool: Pool, connection: Option>, buffer_rx: mpsc::UnboundedReceiver, buffer_tx: Option>, channels: Vec, ignore_close_event: bool, } /// An asynchronous notification from Postgres. pub struct PgNotification(Notification); impl PgListener { pub async fn connect(url: &str) -> Result { // Create a pool of 1 without timeouts (as they don't apply here) // We only use the pool to handle re-connections let pool = PoolOptions::::new() .max_connections(1) .max_lifetime(None) .idle_timeout(None) .connect(url) .await?; let mut this = Self::connect_with(&pool).await?; // We don't need to handle close events this.ignore_close_event = true; Ok(this) } pub async fn connect_with(pool: &Pool) -> Result { // Pull out an initial connection let mut connection = pool.acquire().await?; // Setup a notification buffer let (sender, receiver) = mpsc::unbounded(); connection.stream.notifications = Some(sender); Ok(Self { pool: pool.clone(), connection: Some(connection), buffer_rx: receiver, buffer_tx: None, channels: Vec::new(), ignore_close_event: false, }) } /// Set whether or not to ignore [`Pool::close_event()`]. Defaults to `false`. /// /// By default, when [`Pool::close()`] is called on the pool this listener is using /// while [`Self::recv()`] or [`Self::try_recv()`] are waiting for a message, the wait is /// cancelled and `Err(PoolClosed)` is returned. /// /// This is because `Pool::close()` will wait until _all_ connections are returned and closed, /// including the one being used by this listener. /// /// Otherwise, `pool.close().await` would have to wait until `PgListener` encountered a /// need to acquire a new connection (timeout, error, etc.) and dropped the one it was /// currently holding, at which point `.recv()` or `.try_recv()` would return `Err(PoolClosed)` /// on the attempt to acquire a new connection anyway. /// /// However, if you want `PgListener` to ignore the close event and continue waiting for a /// message as long as it can, set this to `true`. /// /// Does nothing if this was constructed with [`PgListener::connect()`], as that creates an /// internal pool just for the new instance of `PgListener` which cannot be closed manually. pub fn ignore_pool_close_event(&mut self, val: bool) { self.ignore_close_event = val; } /// Starts listening for notifications on a channel. /// The channel name is quoted here to ensure case sensitivity. pub async fn listen(&mut self, channel: &str) -> Result<(), Error> { self.connection() .await? .execute(&*format!(r#"LISTEN "{}""#, ident(channel))) .await?; self.channels.push(channel.to_owned()); Ok(()) } /// Starts listening for notifications on all channels. pub async fn listen_all( &mut self, channels: impl IntoIterator, ) -> Result<(), Error> { let beg = self.channels.len(); self.channels.extend(channels.into_iter().map(|s| s.into())); let query = build_listen_all_query(&self.channels[beg..]); self.connection().await?.execute(&*query).await?; Ok(()) } /// Stops listening for notifications on a channel. /// The channel name is quoted here to ensure case sensitivity. pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> { // use RAW connection and do NOT re-connect automatically, since this is not required for // UNLISTEN (we've disconnected anyways) if let Some(connection) = self.connection.as_mut() { connection .execute(&*format!(r#"UNLISTEN "{}""#, ident(channel))) .await?; } if let Some(pos) = self.channels.iter().position(|s| s == channel) { self.channels.remove(pos); } Ok(()) } /// Stops listening for notifications on all channels. pub async fn unlisten_all(&mut self) -> Result<(), Error> { // use RAW connection and do NOT re-connect automatically, since this is not required for // UNLISTEN (we've disconnected anyways) if let Some(connection) = self.connection.as_mut() { connection.execute("UNLISTEN *").await?; } self.channels.clear(); Ok(()) } #[inline] async fn connect_if_needed(&mut self) -> Result<(), Error> { if self.connection.is_none() { let mut connection = self.pool.acquire().await?; connection.stream.notifications = self.buffer_tx.take(); connection .execute(&*build_listen_all_query(&self.channels)) .await?; self.connection = Some(connection); } Ok(()) } #[inline] async fn connection(&mut self) -> Result<&mut PgConnection, Error> { // Ensure we have an active connection to work with. self.connect_if_needed().await?; Ok(self.connection.as_mut().unwrap()) } /// Receives the next notification available from any of the subscribed channels. /// /// If the connection to PostgreSQL is lost, it is automatically reconnected on the next /// call to `recv()`, and should be entirely transparent (as long as it was just an /// intermittent network failure or long-lived connection reaper). /// /// As notifications are transient, any received while the connection was lost, will not /// be returned. If you'd prefer the reconnection to be explicit and have a chance to /// do something before, please see [`try_recv`](Self::try_recv). /// /// # Example /// /// ```rust,no_run /// # use sqlx_core::postgres::PgListener; /// # use sqlx_core::error::Error; /// # /// # #[cfg(feature = "_rt")] /// # sqlx::__rt::test_block_on(async move { /// # let mut listener = PgListener::connect("postgres:// ...").await?; /// loop { /// // ask for next notification, re-connecting (transparently) if needed /// let notification = listener.recv().await?; /// /// // handle notification, do something interesting /// } /// # Result::<(), Error>::Ok(()) /// # }).unwrap(); /// ``` pub async fn recv(&mut self) -> Result { loop { if let Some(notification) = self.try_recv().await? { return Ok(notification); } } } /// Receives the next notification available from any of the subscribed channels. /// /// If the connection to PostgreSQL is lost, `None` is returned, and the connection is /// reconnected on the next call to `try_recv()`. /// /// # Example /// /// ```rust,no_run /// # use sqlx_core::postgres::PgListener; /// # use sqlx_core::error::Error; /// # /// # #[cfg(feature = "_rt")] /// # sqlx::__rt::test_block_on(async move { /// # let mut listener = PgListener::connect("postgres:// ...").await?; /// loop { /// // start handling notifications, connecting if needed /// while let Some(notification) = listener.try_recv().await? { /// // handle notification /// } /// /// // connection lost, do something interesting /// } /// # Result::<(), Error>::Ok(()) /// # }).unwrap(); /// ``` pub async fn try_recv(&mut self) -> Result, Error> { // Flush the buffer first, if anything // This would only fill up if this listener is used as a connection if let Ok(Some(notification)) = self.buffer_rx.try_next() { return Ok(Some(PgNotification(notification))); } // Fetch our `CloseEvent` listener, if applicable. let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event()); loop { let next_message = self.connection().await?.stream.recv_unchecked(); let res = if let Some(ref mut close_event) = close_event { // cancels the wait and returns `Err(PoolClosed)` if the pool is closed // before `next_message` returns, or if the pool was already closed close_event.do_until(next_message).await? } else { next_message.await }; let message = match res { Ok(message) => message, // The connection is dead, ensure that it is dropped, // update self state, and loop to try again. Err(Error::Io(err)) if (err.kind() == io::ErrorKind::ConnectionAborted || err.kind() == io::ErrorKind::UnexpectedEof) => { self.buffer_tx = self.connection().await?.stream.notifications.take(); self.connection = None; // lost connection return Ok(None); } // Forward other errors Err(error) => { return Err(error); } }; match message.format { // We've received an async notification, return it. MessageFormat::NotificationResponse => { return Ok(Some(PgNotification(message.decode()?))); } // Mark the connection as ready for another query MessageFormat::ReadyForQuery => { self.connection().await?.pending_ready_for_query_count -= 1; } // Ignore unexpected messages _ => {} } } } /// Consume this listener, returning a `Stream` of notifications. /// /// The backing connection will be automatically reconnected should it be lost. /// /// This has the same potential drawbacks as [`recv`](PgListener::recv). /// pub fn into_stream(mut self) -> impl Stream> + Unpin { Box::pin(try_stream! { loop { r#yield!(self.recv().await?); } }) } } impl Drop for PgListener { fn drop(&mut self) { if let Some(mut conn) = self.connection.take() { let fut = async move { let _ = conn.execute("UNLISTEN *").await; // inline the drop handler from `PoolConnection` so it doesn't try to spawn another task // otherwise, it may trigger a panic if this task is dropped because the runtime is going away: // https://github.com/launchbadge/sqlx/issues/1389 conn.return_to_pool().await; }; // Unregister any listeners before returning the connection to the pool. crate::rt::spawn(fut); } } } impl<'c> Executor<'c> for &'c mut PgListener { type Database = Postgres; fn fetch_many<'e, 'q: 'e, E: 'q>( self, query: E, ) -> BoxStream<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, { futures_util::stream::once(async move { // need some basic type annotation to help the compiler a bit let res: Result<_, Error> = Ok(self.connection().await?.fetch_many(query)); res }) .try_flatten() .boxed() } fn fetch_optional<'e, 'q: 'e, E: 'q>( self, query: E, ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, { async move { self.connection().await?.fetch_optional(query).await }.boxed() } fn prepare_with<'e, 'q: 'e>( self, query: &'q str, parameters: &'e [PgTypeInfo], ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { async move { self.connection() .await? .prepare_with(query, parameters) .await } .boxed() } #[doc(hidden)] fn describe<'e, 'q: 'e>( self, query: &'q str, ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { async move { self.connection().await?.describe(query).await }.boxed() } } impl PgNotification { /// The process ID of the notifying backend process. #[inline] pub fn process_id(&self) -> u32 { self.0.process_id } /// The channel that the notify has been raised on. This can be thought /// of as the message topic. #[inline] pub fn channel(&self) -> &str { from_utf8(&self.0.channel).unwrap() } /// The payload of the notification. An empty payload is received as an /// empty string. #[inline] pub fn payload(&self) -> &str { from_utf8(&self.0.payload).unwrap() } } impl Debug for PgListener { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PgListener").finish() } } impl Debug for PgNotification { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PgNotification") .field("process_id", &self.process_id()) .field("channel", &self.channel()) .field("payload", &self.payload()) .finish() } } fn ident(mut name: &str) -> String { // If the input string contains a NUL byte, we should truncate the // identifier. if let Some(index) = name.find('\0') { name = &name[..index]; } // Any double quotes must be escaped name.replace('"', "\"\"") } fn build_listen_all_query(channels: impl IntoIterator>) -> String { channels.into_iter().fold(String::new(), |mut acc, chan| { acc.push_str(r#"LISTEN ""#); acc.push_str(&ident(chan.as_ref())); acc.push_str(r#"";"#); acc }) } #[test] fn test_build_listen_all_query_with_single_channel() { let output = build_listen_all_query(&["test"]); assert_eq!(output.as_str(), r#"LISTEN "test";"#); } #[test] fn test_build_listen_all_query_with_multiple_channels() { let output = build_listen_all_query(&["channel.0", "channel.1"]); assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#); } sqlx-postgres-0.7.3/src/message/authentication.rs000064400000000000000000000131170072674642500203210ustar 00000000000000use std::str::from_utf8; use memchr::memchr; use sqlx_core::bytes::{Buf, Bytes}; use crate::error::Error; use crate::io::Decode; use base64::prelude::{Engine as _, BASE64_STANDARD}; // On startup, the server sends an appropriate authentication request message, // to which the frontend must reply with an appropriate authentication // response message (such as a password). // For all authentication methods except GSSAPI, SSPI and SASL, there is at // most one request and one response. In some methods, no response at all is // needed from the frontend, and so no authentication request occurs. // For GSSAPI, SSPI and SASL, multiple exchanges of packets may // be needed to complete the authentication. // // #[derive(Debug)] pub enum Authentication { /// The authentication exchange is successfully completed. Ok, /// The frontend must now send a [PasswordMessage] containing the /// password in clear-text form. CleartextPassword, /// The frontend must now send a [PasswordMessage] containing the /// password (with user name) encrypted via MD5, then encrypted /// again using the 4-byte random salt. Md5Password(AuthenticationMd5Password), /// The frontend must now initiate a SASL negotiation, /// using one of the SASL mechanisms listed in the message. /// /// The frontend will send a [SaslInitialResponse] with the name /// of the selected mechanism, and the first part of the SASL /// data stream in response to this. /// /// If further messages are needed, the server will /// respond with [Authentication::SaslContinue]. Sasl(AuthenticationSasl), /// This message contains challenge data from the previous step of SASL negotiation. /// /// The frontend must respond with a [SaslResponse] message. SaslContinue(AuthenticationSaslContinue), /// SASL authentication has completed with additional mechanism-specific /// data for the client. /// /// The server will next send [Authentication::Ok] to /// indicate successful authentication. SaslFinal(AuthenticationSaslFinal), } impl Decode<'_> for Authentication { fn decode_with(mut buf: Bytes, _: ()) -> Result { Ok(match buf.get_u32() { 0 => Authentication::Ok, 3 => Authentication::CleartextPassword, 5 => { let mut salt = [0; 4]; buf.copy_to_slice(&mut salt); Authentication::Md5Password(AuthenticationMd5Password { salt }) } 10 => Authentication::Sasl(AuthenticationSasl(buf)), 11 => Authentication::SaslContinue(AuthenticationSaslContinue::decode(buf)?), 12 => Authentication::SaslFinal(AuthenticationSaslFinal::decode(buf)?), ty => { return Err(err_protocol!("unknown authentication method: {}", ty)); } }) } } /// Body of [Authentication::Md5Password]. #[derive(Debug)] pub struct AuthenticationMd5Password { pub salt: [u8; 4], } /// Body of [Authentication::Sasl]. #[derive(Debug)] pub struct AuthenticationSasl(Bytes); impl AuthenticationSasl { #[inline] pub fn mechanisms(&self) -> SaslMechanisms<'_> { SaslMechanisms(&self.0) } } /// An iterator over the SASL authentication mechanisms provided by the server. pub struct SaslMechanisms<'a>(&'a [u8]); impl<'a> Iterator for SaslMechanisms<'a> { type Item = &'a str; fn next(&mut self) -> Option { if !self.0.is_empty() && self.0[0] == b'\0' { return None; } let mechanism = memchr(b'\0', self.0).and_then(|nul| from_utf8(&self.0[..nul]).ok())?; self.0 = &self.0[(mechanism.len() + 1)..]; Some(mechanism) } } #[derive(Debug)] pub struct AuthenticationSaslContinue { pub salt: Vec, pub iterations: u32, pub nonce: String, pub message: String, } impl Decode<'_> for AuthenticationSaslContinue { fn decode_with(buf: Bytes, _: ()) -> Result { let mut iterations: u32 = 4096; let mut salt = Vec::new(); let mut nonce = Bytes::new(); // [Example] // r=/z+giZiTxAH7r8sNAeHr7cvpqV3uo7G/bJBIJO3pjVM7t3ng,s=4UV68bIkC8f9/X8xH7aPhg==,i=4096 for item in buf.split(|b| *b == b',') { let key = item[0]; let value = &item[2..]; match key { b'r' => { nonce = buf.slice_ref(value); } b'i' => { iterations = atoi::atoi(value).unwrap_or(4096); } b's' => { salt = BASE64_STANDARD.decode(value).map_err(Error::protocol)?; } _ => {} } } Ok(Self { iterations, salt, nonce: from_utf8(&*nonce).map_err(Error::protocol)?.to_owned(), message: from_utf8(&*buf).map_err(Error::protocol)?.to_owned(), }) } } #[derive(Debug)] pub struct AuthenticationSaslFinal { pub verifier: Vec, } impl Decode<'_> for AuthenticationSaslFinal { fn decode_with(buf: Bytes, _: ()) -> Result { let mut verifier = Vec::new(); for item in buf.split(|b| *b == b',') { let key = item[0]; let value = &item[2..]; if let b'v' = key { verifier = BASE64_STANDARD.decode(value).map_err(Error::protocol)?; } } Ok(Self { verifier }) } } sqlx-postgres-0.7.3/src/message/backend_key_data.rs000064400000000000000000000023250072674642500205310ustar 00000000000000use byteorder::{BigEndian, ByteOrder}; use sqlx_core::bytes::Bytes; use crate::error::Error; use crate::io::Decode; /// Contains cancellation key data. The frontend must save these values if it /// wishes to be able to issue `CancelRequest` messages later. #[derive(Debug)] pub struct BackendKeyData { /// The process ID of this database. pub process_id: u32, /// The secret key of this database. pub secret_key: u32, } impl Decode<'_> for BackendKeyData { fn decode_with(buf: Bytes, _: ()) -> Result { let process_id = BigEndian::read_u32(&buf); let secret_key = BigEndian::read_u32(&buf[4..]); Ok(Self { process_id, secret_key, }) } } #[test] fn test_decode_backend_key_data() { const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; let m = BackendKeyData::decode(DATA.into()).unwrap(); assert_eq!(m.process_id, 10182); assert_eq!(m.secret_key, 2303903019); } #[cfg(all(test, not(debug_assertions)))] #[bench] fn bench_decode_backend_key_data(b: &mut test::Bencher) { const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; b.iter(|| { BackendKeyData::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); }); } sqlx-postgres-0.7.3/src/message/bind.rs000064400000000000000000000041360072674642500162170ustar 00000000000000use crate::io::Encode; use crate::io::PgBufMutExt; use crate::types::Oid; use crate::PgValueFormat; #[derive(Debug)] pub struct Bind<'a> { /// The ID of the destination portal (`None` selects the unnamed portal). pub portal: Option, /// The id of the source prepared statement. pub statement: Oid, /// The parameter format codes. Each must presently be zero (text) or one (binary). /// /// There can be zero to indicate that there are no parameters or that the parameters all use the /// default format (text); or one, in which case the specified format code is applied to all /// parameters; or it can equal the actual number of parameters. pub formats: &'a [PgValueFormat], /// The number of parameters. pub num_params: i16, /// The value of each parameter, in the indicated format. pub params: &'a [u8], /// The result-column format codes. Each must presently be zero (text) or one (binary). /// /// There can be zero to indicate that there are no result columns or that the /// result columns should all use the default format (text); or one, in which /// case the specified format code is applied to all result columns (if any); /// or it can equal the actual number of result columns of the query. pub result_formats: &'a [PgValueFormat], } impl Encode<'_> for Bind<'_> { fn encode_with(&self, buf: &mut Vec, _: ()) { buf.push(b'B'); buf.put_length_prefixed(|buf| { buf.put_portal_name(self.portal); buf.put_statement_name(self.statement); buf.extend(&(self.formats.len() as i16).to_be_bytes()); for &format in self.formats { buf.extend(&(format as i16).to_be_bytes()); } buf.extend(&self.num_params.to_be_bytes()); buf.extend(self.params); buf.extend(&(self.result_formats.len() as i16).to_be_bytes()); for &format in self.result_formats { buf.extend(&(format as i16).to_be_bytes()); } }); } } // TODO: Unit Test Bind // TODO: Benchmark Bind sqlx-postgres-0.7.3/src/message/close.rs000064400000000000000000000014510072674642500164050ustar 00000000000000use crate::io::Encode; use crate::io::PgBufMutExt; use crate::types::Oid; const CLOSE_PORTAL: u8 = b'P'; const CLOSE_STATEMENT: u8 = b'S'; #[derive(Debug)] #[allow(dead_code)] pub enum Close { Statement(Oid), // None selects the unnamed portal Portal(Option), } impl Encode<'_> for Close { fn encode_with(&self, buf: &mut Vec, _: ()) { // 15 bytes for 1-digit statement/portal IDs buf.reserve(20); buf.push(b'C'); buf.put_length_prefixed(|buf| match self { Close::Statement(id) => { buf.push(CLOSE_STATEMENT); buf.put_statement_name(*id); } Close::Portal(id) => { buf.push(CLOSE_PORTAL); buf.put_portal_name(*id); } }) } } sqlx-postgres-0.7.3/src/message/command_complete.rs000064400000000000000000000041000072674642500206000ustar 00000000000000use atoi::atoi; use memchr::memrchr; use sqlx_core::bytes::Bytes; use crate::error::Error; use crate::io::Decode; #[derive(Debug)] pub struct CommandComplete { /// The command tag. This is usually a single word that identifies which SQL command /// was completed. tag: Bytes, } impl Decode<'_> for CommandComplete { #[inline] fn decode_with(buf: Bytes, _: ()) -> Result { Ok(CommandComplete { tag: buf }) } } impl CommandComplete { /// Returns the number of rows affected. /// If the command does not return rows (e.g., "CREATE TABLE"), returns 0. pub fn rows_affected(&self) -> u64 { // Look backwards for the first SPACE memrchr(b' ', &self.tag) // This is either a word or the number of rows affected .and_then(|i| atoi(&self.tag[(i + 1)..])) .unwrap_or(0) } } #[test] fn test_decode_command_complete_for_insert() { const DATA: &[u8] = b"INSERT 0 1214\0"; let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); assert_eq!(cc.rows_affected(), 1214); } #[test] fn test_decode_command_complete_for_begin() { const DATA: &[u8] = b"BEGIN\0"; let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); assert_eq!(cc.rows_affected(), 0); } #[test] fn test_decode_command_complete_for_update() { const DATA: &[u8] = b"UPDATE 5\0"; let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); assert_eq!(cc.rows_affected(), 5); } #[cfg(all(test, not(debug_assertions)))] #[bench] fn bench_decode_command_complete(b: &mut test::Bencher) { const DATA: &[u8] = b"INSERT 0 1214\0"; b.iter(|| { let _ = CommandComplete::decode(test::black_box(Bytes::from_static(DATA))); }); } #[cfg(all(test, not(debug_assertions)))] #[bench] fn bench_decode_command_complete_rows_affected(b: &mut test::Bencher) { const DATA: &[u8] = b"INSERT 0 1214\0"; let data = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); b.iter(|| { let _rows = test::black_box(&data).rows_affected(); }); } sqlx-postgres-0.7.3/src/message/copy.rs000064400000000000000000000043000072674642500162460ustar 00000000000000use crate::error::Result; use crate::io::{BufExt, BufMutExt, Decode, Encode}; use sqlx_core::bytes::{Buf, BufMut, Bytes}; use std::ops::Deref; /// The same structure is sent for both `CopyInResponse` and `CopyOutResponse` pub struct CopyResponse { pub format: i8, pub num_columns: i16, pub format_codes: Vec, } pub struct CopyData(pub B); pub struct CopyFail { pub message: String, } pub struct CopyDone; impl Decode<'_> for CopyResponse { fn decode_with(mut buf: Bytes, _: ()) -> Result { let format = buf.get_i8(); let num_columns = buf.get_i16(); let format_codes = (0..num_columns).map(|_| buf.get_i16()).collect(); Ok(CopyResponse { format, num_columns, format_codes, }) } } impl Decode<'_> for CopyData { fn decode_with(buf: Bytes, _: ()) -> Result { // well.. that was easy Ok(CopyData(buf)) } } impl> Encode<'_> for CopyData { fn encode_with(&self, buf: &mut Vec, _context: ()) { buf.push(b'd'); buf.put_u32(self.0.len() as u32 + 4); buf.extend_from_slice(&self.0); } } impl Decode<'_> for CopyFail { fn decode_with(mut buf: Bytes, _: ()) -> Result { Ok(CopyFail { message: buf.get_str_nul()?, }) } } impl Encode<'_> for CopyFail { fn encode_with(&self, buf: &mut Vec, _: ()) { let len = 4 + self.message.len() + 1; buf.push(b'f'); // to pay respects buf.put_u32(len as u32); buf.put_str_nul(&self.message); } } impl CopyFail { pub fn new(msg: impl Into) -> CopyFail { CopyFail { message: msg.into(), } } } impl Decode<'_> for CopyDone { fn decode_with(buf: Bytes, _: ()) -> Result { if buf.is_empty() { Ok(CopyDone) } else { Err(err_protocol!( "expected no data for CopyDone, got: {:?}", buf )) } } } impl Encode<'_> for CopyDone { fn encode_with(&self, buf: &mut Vec, _: ()) { buf.reserve(4); buf.push(b'c'); buf.put_u32(4); } } sqlx-postgres-0.7.3/src/message/data_row.rs000064400000000000000000000062240072674642500171030ustar 00000000000000use std::ops::Range; use byteorder::{BigEndian, ByteOrder}; use sqlx_core::bytes::Bytes; use crate::error::Error; use crate::io::Decode; /// A row of data from the database. #[derive(Debug)] pub struct DataRow { pub(crate) storage: Bytes, /// Ranges into the stored row data. /// This uses `u32` instead of usize to reduce the size of this type. Values cannot be larger /// than `i32` in postgres. pub(crate) values: Vec>>, } impl DataRow { #[inline] pub(crate) fn get(&self, index: usize) -> Option<&'_ [u8]> { self.values[index] .as_ref() .map(|col| &self.storage[(col.start as usize)..(col.end as usize)]) } } impl Decode<'_> for DataRow { fn decode_with(buf: Bytes, _: ()) -> Result { let cnt = BigEndian::read_u16(&buf) as usize; let mut values = Vec::with_capacity(cnt); let mut offset = 2; for _ in 0..cnt { // Length of the column value, in bytes (this count does not include itself). // Can be zero. As a special case, -1 indicates a NULL column value. // No value bytes follow in the NULL case. let length = BigEndian::read_i32(&buf[(offset as usize)..]); offset += 4; if length < 0 { values.push(None); } else { values.push(Some(offset..(offset + length as u32))); offset += length as u32; } } Ok(Self { storage: buf, values, }) } } #[test] fn test_decode_data_row() { const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P"; let row = DataRow::decode(DATA.into()).unwrap(); assert_eq!(row.values.len(), 8); assert!(row.get(0).is_none()); assert_eq!(row.get(1).unwrap(), &[0_u8, 0, 0, 10][..]); assert!(row.get(2).is_none()); assert_eq!(row.get(3).unwrap(), &[0_u8, 0, 0, 20][..]); assert!(row.get(4).is_none()); assert_eq!(row.get(5).unwrap(), &[0_u8, 0, 0, 40][..]); assert!(row.get(6).is_none()); assert_eq!(row.get(7).unwrap(), &[0_u8, 0, 0, 80][..]); } #[cfg(all(test, not(debug_assertions)))] #[bench] fn bench_data_row_get(b: &mut test::Bencher) { const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P"; let row = DataRow::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); b.iter(|| { let _value = test::black_box(&row).get(3); }); } #[cfg(all(test, not(debug_assertions)))] #[bench] fn bench_decode_data_row(b: &mut test::Bencher) { const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P"; b.iter(|| { let _ = DataRow::decode(test::black_box(Bytes::from_static(DATA))); }); } sqlx-postgres-0.7.3/src/message/describe.rs000064400000000000000000000053660072674642500170710ustar 00000000000000use crate::io::Encode; use crate::io::PgBufMutExt; use crate::types::Oid; const DESCRIBE_PORTAL: u8 = b'P'; const DESCRIBE_STATEMENT: u8 = b'S'; // [Describe] will emit both a [RowDescription] and a [ParameterDescription] message #[derive(Debug)] #[allow(dead_code)] pub enum Describe { UnnamedStatement, Statement(Oid), UnnamedPortal, Portal(Oid), } impl Encode<'_> for Describe { fn encode_with(&self, buf: &mut Vec, _: ()) { // 15 bytes for 1-digit statement/portal IDs buf.reserve(20); buf.push(b'D'); buf.put_length_prefixed(|buf| { match self { // #[likely] Describe::Statement(id) => { buf.push(DESCRIBE_STATEMENT); buf.put_statement_name(*id); } Describe::UnnamedPortal => { buf.push(DESCRIBE_PORTAL); buf.push(0); } Describe::UnnamedStatement => { buf.push(DESCRIBE_STATEMENT); buf.push(0); } Describe::Portal(id) => { buf.push(DESCRIBE_PORTAL); buf.put_portal_name(Some(*id)); } } }); } } #[test] fn test_encode_describe_portal() { const EXPECTED: &[u8] = b"D\0\0\0\x0EPsqlx_p_5\0"; let mut buf = Vec::new(); let m = Describe::Portal(Oid(5)); m.encode(&mut buf); assert_eq!(buf, EXPECTED); } #[test] fn test_encode_describe_unnamed_portal() { const EXPECTED: &[u8] = b"D\0\0\0\x06P\0"; let mut buf = Vec::new(); let m = Describe::UnnamedPortal; m.encode(&mut buf); assert_eq!(buf, EXPECTED); } #[test] fn test_encode_describe_statement() { const EXPECTED: &[u8] = b"D\0\0\0\x0ESsqlx_s_5\0"; let mut buf = Vec::new(); let m = Describe::Statement(Oid(5)); m.encode(&mut buf); assert_eq!(buf, EXPECTED); } #[test] fn test_encode_describe_unnamed_statement() { const EXPECTED: &[u8] = b"D\0\0\0\x06S\0"; let mut buf = Vec::new(); let m = Describe::UnnamedStatement; m.encode(&mut buf); assert_eq!(buf, EXPECTED); } #[cfg(all(test, not(debug_assertions)))] #[bench] fn bench_encode_describe_portal(b: &mut test::Bencher) { use test::black_box; let mut buf = Vec::with_capacity(128); b.iter(|| { buf.clear(); black_box(Describe::Portal(5)).encode(&mut buf); }); } #[cfg(all(test, not(debug_assertions)))] #[bench] fn bench_encode_describe_unnamed_statement(b: &mut test::Bencher) { use test::black_box; let mut buf = Vec::with_capacity(128); b.iter(|| { buf.clear(); black_box(Describe::UnnamedStatement).encode(&mut buf); }); } sqlx-postgres-0.7.3/src/message/execute.rs000064400000000000000000000016410072674642500167430ustar 00000000000000use crate::io::Encode; use crate::io::PgBufMutExt; use crate::types::Oid; pub struct Execute { /// The id of the portal to execute (`None` selects the unnamed portal). pub portal: Option, /// Maximum number of rows to return, if portal contains a query /// that returns rows (ignored otherwise). Zero denotes “no limit”. pub limit: u32, } impl Encode<'_> for Execute { fn encode_with(&self, buf: &mut Vec, _: ()) { buf.reserve(20); buf.push(b'E'); buf.put_length_prefixed(|buf| { buf.put_portal_name(self.portal); buf.extend(&self.limit.to_be_bytes()); }); } } #[test] fn test_encode_execute() { const EXPECTED: &[u8] = b"E\0\0\0\x11sqlx_p_5\0\0\0\0\x02"; let mut buf = Vec::new(); let m = Execute { portal: Some(Oid(5)), limit: 2, }; m.encode(&mut buf); assert_eq!(buf, EXPECTED); } sqlx-postgres-0.7.3/src/message/flush.rs000064400000000000000000000010250072674642500164160ustar 00000000000000use crate::io::Encode; // The Flush message does not cause any specific output to be generated, // but forces the backend to deliver any data pending in its output buffers. // A Flush must be sent after any extended-query command except Sync, if the // frontend wishes to examine the results of that command before issuing more commands. #[derive(Debug)] pub struct Flush; impl Encode<'_> for Flush { fn encode_with(&self, buf: &mut Vec, _: ()) { buf.push(b'H'); buf.extend(&4_i32.to_be_bytes()); } } sqlx-postgres-0.7.3/src/message/mod.rs000064400000000000000000000065070072674642500160660ustar 00000000000000use sqlx_core::bytes::Bytes; use crate::error::Error; use crate::io::Decode; mod authentication; mod backend_key_data; mod bind; mod close; mod command_complete; mod copy; mod data_row; mod describe; mod execute; mod flush; mod notification; mod parameter_description; mod parameter_status; mod parse; mod password; mod query; mod ready_for_query; mod response; mod row_description; mod sasl; mod ssl_request; mod startup; mod sync; mod terminate; pub use authentication::{Authentication, AuthenticationSasl}; pub use backend_key_data::BackendKeyData; pub use bind::Bind; pub use close::Close; pub use command_complete::CommandComplete; pub use copy::{CopyData, CopyDone, CopyFail, CopyResponse}; pub use data_row::DataRow; pub use describe::Describe; pub use execute::Execute; pub use flush::Flush; pub use notification::Notification; pub use parameter_description::ParameterDescription; pub use parameter_status::ParameterStatus; pub use parse::Parse; pub use password::Password; pub use query::Query; pub use ready_for_query::{ReadyForQuery, TransactionStatus}; pub use response::{Notice, PgSeverity}; pub use row_description::RowDescription; pub use sasl::{SaslInitialResponse, SaslResponse}; pub use ssl_request::SslRequest; pub use startup::Startup; pub use sync::Sync; pub use terminate::Terminate; #[derive(Debug, PartialOrd, PartialEq)] #[repr(u8)] pub enum MessageFormat { Authentication, BackendKeyData, BindComplete, CloseComplete, CommandComplete, CopyData, CopyDone, CopyInResponse, CopyOutResponse, DataRow, EmptyQueryResponse, ErrorResponse, NoData, NoticeResponse, NotificationResponse, ParameterDescription, ParameterStatus, ParseComplete, PortalSuspended, ReadyForQuery, RowDescription, } #[derive(Debug)] pub struct Message { pub format: MessageFormat, pub contents: Bytes, } impl Message { #[inline] pub fn decode<'de, T>(self) -> Result where T: Decode<'de>, { T::decode(self.contents) } } impl MessageFormat { pub fn try_from_u8(v: u8) -> Result { // https://www.postgresql.org/docs/current/protocol-message-formats.html Ok(match v { b'1' => MessageFormat::ParseComplete, b'2' => MessageFormat::BindComplete, b'3' => MessageFormat::CloseComplete, b'C' => MessageFormat::CommandComplete, b'd' => MessageFormat::CopyData, b'c' => MessageFormat::CopyDone, b'G' => MessageFormat::CopyInResponse, b'H' => MessageFormat::CopyOutResponse, b'D' => MessageFormat::DataRow, b'E' => MessageFormat::ErrorResponse, b'I' => MessageFormat::EmptyQueryResponse, b'A' => MessageFormat::NotificationResponse, b'K' => MessageFormat::BackendKeyData, b'N' => MessageFormat::NoticeResponse, b'R' => MessageFormat::Authentication, b'S' => MessageFormat::ParameterStatus, b'T' => MessageFormat::RowDescription, b'Z' => MessageFormat::ReadyForQuery, b'n' => MessageFormat::NoData, b's' => MessageFormat::PortalSuspended, b't' => MessageFormat::ParameterDescription, _ => return Err(err_protocol!("unknown message type: {:?}", v as char)), }) } } sqlx-postgres-0.7.3/src/message/notification.rs000064400000000000000000000017230072674642500177700ustar 00000000000000use sqlx_core::bytes::{Buf, Bytes}; use crate::error::Error; use crate::io::{BufExt, Decode}; #[derive(Debug)] pub struct Notification { pub(crate) process_id: u32, pub(crate) channel: Bytes, pub(crate) payload: Bytes, } impl Decode<'_> for Notification { #[inline] fn decode_with(mut buf: Bytes, _: ()) -> Result { let process_id = buf.get_u32(); let channel = buf.get_bytes_nul()?; let payload = buf.get_bytes_nul()?; Ok(Self { process_id, channel, payload, }) } } #[test] fn test_decode_notification_response() { const NOTIFICATION_RESPONSE: &[u8] = b"\x34\x20\x10\x02TEST-CHANNEL\0THIS IS A TEST\0"; let message = Notification::decode(Bytes::from(NOTIFICATION_RESPONSE)).unwrap(); assert_eq!(message.process_id, 0x34201002); assert_eq!(&*message.channel, &b"TEST-CHANNEL"[..]); assert_eq!(&*message.payload, &b"THIS IS A TEST"[..]); } sqlx-postgres-0.7.3/src/message/parameter_description.rs000064400000000000000000000025100072674642500216600ustar 00000000000000use smallvec::SmallVec; use sqlx_core::bytes::{Buf, Bytes}; use crate::error::Error; use crate::io::Decode; use crate::types::Oid; #[derive(Debug)] pub struct ParameterDescription { pub types: SmallVec<[Oid; 6]>, } impl Decode<'_> for ParameterDescription { fn decode_with(mut buf: Bytes, _: ()) -> Result { let cnt = buf.get_u16(); let mut types = SmallVec::with_capacity(cnt as usize); for _ in 0..cnt { types.push(Oid(buf.get_u32())); } Ok(Self { types }) } } #[test] fn test_decode_parameter_description() { const DATA: &[u8] = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; let m = ParameterDescription::decode(DATA.into()).unwrap(); assert_eq!(m.types.len(), 2); assert_eq!(m.types[0], Oid(0x0000_0000)); assert_eq!(m.types[1], Oid(0x0000_0500)); } #[test] fn test_decode_empty_parameter_description() { const DATA: &[u8] = b"\x00\x00"; let m = ParameterDescription::decode(DATA.into()).unwrap(); assert!(m.types.is_empty()); } #[cfg(all(test, not(debug_assertions)))] #[bench] fn bench_decode_parameter_description(b: &mut test::Bencher) { const DATA: &[u8] = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; b.iter(|| { ParameterDescription::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); }); } sqlx-postgres-0.7.3/src/message/parameter_status.rs000064400000000000000000000031460072674642500206660ustar 00000000000000use sqlx_core::bytes::Bytes; use crate::error::Error; use crate::io::{BufExt, Decode}; #[derive(Debug)] pub struct ParameterStatus { pub name: String, pub value: String, } impl Decode<'_> for ParameterStatus { fn decode_with(mut buf: Bytes, _: ()) -> Result { let name = buf.get_str_nul()?; let value = buf.get_str_nul()?; Ok(Self { name, value }) } } #[test] fn test_decode_parameter_status() { const DATA: &[u8] = b"client_encoding\x00UTF8\x00"; let m = ParameterStatus::decode(DATA.into()).unwrap(); assert_eq!(&m.name, "client_encoding"); assert_eq!(&m.value, "UTF8") } #[test] fn test_decode_empty_parameter_status() { const DATA: &[u8] = b"\x00\x00"; let m = ParameterStatus::decode(DATA.into()).unwrap(); assert!(m.name.is_empty()); assert!(m.value.is_empty()); } #[cfg(all(test, not(debug_assertions)))] #[bench] fn bench_decode_parameter_status(b: &mut test::Bencher) { const DATA: &[u8] = b"client_encoding\x00UTF8\x00"; b.iter(|| { ParameterStatus::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); }); } #[test] fn test_decode_parameter_status_response() { const PARAMETER_STATUS_RESPONSE: &[u8] = b"crdb_version\0CockroachDB CCL v21.1.0 (x86_64-unknown-linux-gnu, built 2021/05/17 13:49:40, go1.15.11)\0"; let message = ParameterStatus::decode(Bytes::from(PARAMETER_STATUS_RESPONSE)).unwrap(); assert_eq!(message.name, "crdb_version"); assert_eq!( message.value, "CockroachDB CCL v21.1.0 (x86_64-unknown-linux-gnu, built 2021/05/17 13:49:40, go1.15.11)" ); } sqlx-postgres-0.7.3/src/message/parse.rs000064400000000000000000000026520072674642500164160ustar 00000000000000use std::i16; use crate::io::PgBufMutExt; use crate::io::{BufMutExt, Encode}; use crate::types::Oid; #[derive(Debug)] pub struct Parse<'a> { /// The ID of the destination prepared statement. pub statement: Oid, /// The query string to be parsed. pub query: &'a str, /// The parameter data types specified (could be zero). Note that this is not an /// indication of the number of parameters that might appear in the query string, /// only the number that the frontend wants to pre-specify types for. pub param_types: &'a [Oid], } impl Encode<'_> for Parse<'_> { fn encode_with(&self, buf: &mut Vec, _: ()) { buf.push(b'P'); buf.put_length_prefixed(|buf| { buf.put_statement_name(self.statement); buf.put_str_nul(self.query); // TODO: Return an error here instead assert!(self.param_types.len() <= (u16::MAX as usize)); buf.extend(&(self.param_types.len() as i16).to_be_bytes()); for &oid in self.param_types { buf.extend(&oid.0.to_be_bytes()); } }) } } #[test] fn test_encode_parse() { const EXPECTED: &[u8] = b"P\0\0\0\x1dsqlx_s_1\0SELECT $1\0\0\x01\0\0\0\x19"; let mut buf = Vec::new(); let m = Parse { statement: Oid(1), query: "SELECT $1", param_types: &[Oid(25)], }; m.encode(&mut buf); assert_eq!(buf, EXPECTED); } sqlx-postgres-0.7.3/src/message/password.rs000064400000000000000000000057600072674642500171510ustar 00000000000000use std::fmt::Write; use md5::{Digest, Md5}; use crate::io::PgBufMutExt; use crate::io::{BufMutExt, Encode}; #[derive(Debug)] pub enum Password<'a> { Cleartext(&'a str), Md5 { password: &'a str, username: &'a str, salt: [u8; 4], }, } impl Password<'_> { #[inline] fn len(&self) -> usize { match self { Password::Cleartext(s) => s.len() + 5, Password::Md5 { .. } => 35 + 5, } } } impl Encode<'_> for Password<'_> { fn encode_with(&self, buf: &mut Vec, _: ()) { buf.reserve(1 + 4 + self.len()); buf.push(b'p'); buf.put_length_prefixed(|buf| { match self { Password::Cleartext(password) => { buf.put_str_nul(password); } Password::Md5 { username, password, salt, } => { // The actual `PasswordMessage` can be computed in SQL as // `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`. // Keep in mind the md5() function returns its result as a hex string. let mut hasher = Md5::new(); hasher.update(password); hasher.update(username); let mut output = String::with_capacity(35); let _ = write!(output, "{:x}", hasher.finalize_reset()); hasher.update(&output); hasher.update(salt); output.clear(); let _ = write!(output, "md5{:x}", hasher.finalize()); buf.put_str_nul(&output); } } }); } } #[test] fn test_encode_clear_password() { const EXPECTED: &[u8] = b"p\0\0\0\rpassword\0"; let mut buf = Vec::new(); let m = Password::Cleartext("password"); m.encode(&mut buf); assert_eq!(buf, EXPECTED); } #[test] fn test_encode_md5_password() { const EXPECTED: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0"; let mut buf = Vec::new(); let m = Password::Md5 { password: "password", username: "root", salt: [147, 24, 57, 152], }; m.encode(&mut buf); assert_eq!(buf, EXPECTED); } #[cfg(all(test, not(debug_assertions)))] #[bench] fn bench_encode_clear_password(b: &mut test::Bencher) { use test::black_box; let mut buf = Vec::with_capacity(128); b.iter(|| { buf.clear(); black_box(Password::Cleartext("password")).encode(&mut buf); }); } #[cfg(all(test, not(debug_assertions)))] #[bench] fn bench_encode_md5_password(b: &mut test::Bencher) { use test::black_box; let mut buf = Vec::with_capacity(128); b.iter(|| { buf.clear(); black_box(Password::Md5 { password: "password", username: "root", salt: [147, 24, 57, 152], }) .encode(&mut buf); }); } sqlx-postgres-0.7.3/src/message/query.rs000064400000000000000000000010700072674642500164420ustar 00000000000000use crate::io::{BufMutExt, Encode}; #[derive(Debug)] pub struct Query<'a>(pub &'a str); impl Encode<'_> for Query<'_> { fn encode_with(&self, buf: &mut Vec, _: ()) { let len = 4 + self.0.len() + 1; buf.reserve(len + 1); buf.push(b'Q'); buf.extend(&(len as i32).to_be_bytes()); buf.put_str_nul(self.0); } } #[test] fn test_encode_query() { const EXPECTED: &[u8] = b"Q\0\0\0\rSELECT 1\0"; let mut buf = Vec::new(); let m = Query("SELECT 1"); m.encode(&mut buf); assert_eq!(buf, EXPECTED); } sqlx-postgres-0.7.3/src/message/ready_for_query.rs000064400000000000000000000023460072674642500205030ustar 00000000000000use sqlx_core::bytes::Bytes; use crate::error::Error; use crate::io::Decode; #[derive(Debug)] #[repr(u8)] pub enum TransactionStatus { /// Not in a transaction block. Idle = b'I', /// In a transaction block. Transaction = b'T', /// In a _failed_ transaction block. Queries will be rejected until block is ended. Error = b'E', } #[derive(Debug)] pub struct ReadyForQuery { pub transaction_status: TransactionStatus, } impl Decode<'_> for ReadyForQuery { fn decode_with(buf: Bytes, _: ()) -> Result { let status = match buf[0] { b'I' => TransactionStatus::Idle, b'T' => TransactionStatus::Transaction, b'E' => TransactionStatus::Error, status => { return Err(err_protocol!( "unknown transaction status: {:?}", status as char )); } }; Ok(Self { transaction_status: status, }) } } #[test] fn test_decode_ready_for_query() -> Result<(), Error> { const DATA: &[u8] = b"E"; let m = ReadyForQuery::decode(Bytes::from_static(DATA))?; assert!(matches!(m.transaction_status, TransactionStatus::Error)); Ok(()) } sqlx-postgres-0.7.3/src/message/response.rs000064400000000000000000000157660072674642500171540ustar 00000000000000use std::str::from_utf8; use memchr::memchr; use sqlx_core::bytes::Bytes; use crate::error::Error; use crate::io::Decode; #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[repr(u8)] pub enum PgSeverity { Panic, Fatal, Error, Warning, Notice, Debug, Info, Log, } impl PgSeverity { #[inline] pub fn is_error(self) -> bool { matches!(self, Self::Panic | Self::Fatal | Self::Error) } } impl TryFrom<&str> for PgSeverity { type Error = Error; fn try_from(s: &str) -> Result { let result = match s { "PANIC" => PgSeverity::Panic, "FATAL" => PgSeverity::Fatal, "ERROR" => PgSeverity::Error, "WARNING" => PgSeverity::Warning, "NOTICE" => PgSeverity::Notice, "DEBUG" => PgSeverity::Debug, "INFO" => PgSeverity::Info, "LOG" => PgSeverity::Log, severity => { return Err(err_protocol!("unknown severity: {:?}", severity)); } }; Ok(result) } } #[derive(Debug)] pub struct Notice { storage: Bytes, severity: PgSeverity, message: (u16, u16), code: (u16, u16), } impl Notice { #[inline] pub fn severity(&self) -> PgSeverity { self.severity } #[inline] pub fn code(&self) -> &str { self.get_cached_str(self.code) } #[inline] pub fn message(&self) -> &str { self.get_cached_str(self.message) } // Field descriptions available here: // https://www.postgresql.org/docs/current/protocol-error-fields.html #[inline] pub fn get(&self, ty: u8) -> Option<&str> { self.get_raw(ty).and_then(|v| from_utf8(v).ok()) } pub fn get_raw(&self, ty: u8) -> Option<&[u8]> { self.fields() .filter(|(field, _)| *field == ty) .map(|(_, (start, end))| &self.storage[start as usize..end as usize]) .next() } } impl Notice { #[inline] fn fields(&self) -> Fields<'_> { Fields { storage: &self.storage, offset: 0, } } #[inline] fn get_cached_str(&self, cache: (u16, u16)) -> &str { // unwrap: this cannot fail at this stage from_utf8(&self.storage[cache.0 as usize..cache.1 as usize]).unwrap() } } impl Decode<'_> for Notice { fn decode_with(buf: Bytes, _: ()) -> Result { // In order to support PostgreSQL 9.5 and older we need to parse the localized S field. // Newer versions additionally come with the V field that is guaranteed to be in English. // We thus read both versions and prefer the unlocalized one if available. const DEFAULT_SEVERITY: PgSeverity = PgSeverity::Log; let mut severity_v = None; let mut severity_s = None; let mut message = (0, 0); let mut code = (0, 0); // we cache the three always present fields // this enables to keep the access time down for the fields most likely accessed let fields = Fields { storage: &buf, offset: 0, }; for (field, v) in fields { if message.0 != 0 && code.0 != 0 { // stop iterating when we have the 3 fields we were looking for // we assume V (severity) was the first field as it should be break; } match field { b'S' => { severity_s = from_utf8(&buf[v.0 as usize..v.1 as usize]) // If the error string is not UTF-8, we have no hope of interpreting it, // localized or not. The `V` field would likely fail to parse as well. .map_err(|_| notice_protocol_err())? .try_into() // If we couldn't parse the severity here, it might just be localized. .ok(); } b'V' => { // Propagate errors here, because V is not localized and // thus we are missing a possible variant. severity_v = Some( from_utf8(&buf[v.0 as usize..v.1 as usize]) .map_err(|_| notice_protocol_err())? .try_into()?, ); } b'M' => { message = v; } b'C' => { code = v; } _ => {} } } Ok(Self { severity: severity_v.or(severity_s).unwrap_or(DEFAULT_SEVERITY), message, code, storage: buf, }) } } /// An iterator over each field in the Error (or Notice) response. struct Fields<'a> { storage: &'a [u8], offset: u16, } impl<'a> Iterator for Fields<'a> { type Item = (u8, (u16, u16)); fn next(&mut self) -> Option { // The fields in the response body are sequentially stored as [tag][string], // ending in a final, additional [nul] let ty = self.storage[self.offset as usize]; if ty == 0 { return None; } let nul = memchr(b'\0', &self.storage[(self.offset + 1) as usize..])? as u16; let offset = self.offset; self.offset += nul + 2; Some((ty, (offset + 1, offset + nul + 1))) } } fn notice_protocol_err() -> Error { // https://github.com/launchbadge/sqlx/issues/1144 Error::Protocol( "Postgres returned a non-UTF-8 string for its error message. \ This is most likely due to an error that occurred during authentication and \ the default lc_messages locale is not binary-compatible with UTF-8. \ See the server logs for the error details." .into(), ) } #[test] fn test_decode_error_response() { const DATA: &[u8] = b"SNOTICE\0VNOTICE\0C42710\0Mextension \"uuid-ossp\" already exists, skipping\0Fextension.c\0L1656\0RCreateExtension\0\0"; let m = Notice::decode(Bytes::from_static(DATA)).unwrap(); assert_eq!( m.message(), "extension \"uuid-ossp\" already exists, skipping" ); assert!(matches!(m.severity(), PgSeverity::Notice)); assert_eq!(m.code(), "42710"); } #[cfg(all(test, not(debug_assertions)))] #[bench] fn bench_error_response_get_message(b: &mut test::Bencher) { const DATA: &[u8] = b"SNOTICE\0VNOTICE\0C42710\0Mextension \"uuid-ossp\" already exists, skipping\0Fextension.c\0L1656\0RCreateExtension\0\0"; let res = Notice::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); b.iter(|| { let _ = test::black_box(&res).message(); }); } #[cfg(all(test, not(debug_assertions)))] #[bench] fn bench_decode_error_response(b: &mut test::Bencher) { const DATA: &[u8] = b"SNOTICE\0VNOTICE\0C42710\0Mextension \"uuid-ossp\" already exists, skipping\0Fextension.c\0L1656\0RCreateExtension\0\0"; b.iter(|| { let _ = Notice::decode(test::black_box(Bytes::from_static(DATA))); }); } sqlx-postgres-0.7.3/src/message/row_description.rs000064400000000000000000000043300072674642500205110ustar 00000000000000use sqlx_core::bytes::{Buf, Bytes}; use crate::error::Error; use crate::io::{BufExt, Decode}; use crate::types::Oid; #[derive(Debug)] pub struct RowDescription { pub fields: Vec, } #[derive(Debug)] pub struct Field { /// The name of the field. pub name: String, /// If the field can be identified as a column of a specific table, the /// object ID of the table; otherwise zero. pub relation_id: Option, /// If the field can be identified as a column of a specific table, the attribute number of /// the column; otherwise zero. pub relation_attribute_no: Option, /// The object ID of the field's data type. pub data_type_id: Oid, /// The data type size (see pg_type.typlen). Note that negative values denote /// variable-width types. pub data_type_size: i16, /// The type modifier (see pg_attribute.atttypmod). The meaning of the /// modifier is type-specific. pub type_modifier: i32, /// The format code being used for the field. pub format: i16, } impl Decode<'_> for RowDescription { fn decode_with(mut buf: Bytes, _: ()) -> Result { let cnt = buf.get_u16(); let mut fields = Vec::with_capacity(cnt as usize); for _ in 0..cnt { let name = buf.get_str_nul()?.to_owned(); let relation_id = buf.get_i32(); let relation_attribute_no = buf.get_i16(); let data_type_id = Oid(buf.get_u32()); let data_type_size = buf.get_i16(); let type_modifier = buf.get_i32(); let format = buf.get_i16(); fields.push(Field { name, relation_id: if relation_id == 0 { None } else { Some(relation_id) }, relation_attribute_no: if relation_attribute_no == 0 { None } else { Some(relation_attribute_no) }, data_type_id, data_type_size, type_modifier, format, }) } Ok(Self { fields }) } } // TODO: Unit Test RowDescription // TODO: Benchmark RowDescription sqlx-postgres-0.7.3/src/message/sasl.rs000064400000000000000000000017020072674642500162410ustar 00000000000000use crate::io::PgBufMutExt; use crate::io::{BufMutExt, Encode}; pub struct SaslInitialResponse<'a> { pub response: &'a str, pub plus: bool, } impl Encode<'_> for SaslInitialResponse<'_> { fn encode_with(&self, buf: &mut Vec, _: ()) { buf.push(b'p'); buf.put_length_prefixed(|buf| { // name of the SASL authentication mechanism that the client selected buf.put_str_nul(if self.plus { "SCRAM-SHA-256-PLUS" } else { "SCRAM-SHA-256" }); buf.extend(&(self.response.as_bytes().len() as i32).to_be_bytes()); buf.extend(self.response.as_bytes()); }); } } pub struct SaslResponse<'a>(pub &'a str); impl Encode<'_> for SaslResponse<'_> { fn encode_with(&self, buf: &mut Vec, _: ()) { buf.push(b'p'); buf.put_length_prefixed(|buf| { buf.extend(self.0.as_bytes()); }); } } sqlx-postgres-0.7.3/src/message/ssl_request.rs000064400000000000000000000007700072674642500176540ustar 00000000000000use crate::io::Encode; pub struct SslRequest; impl SslRequest { pub const BYTES: &'static [u8] = b"\x00\x00\x00\x08\x04\xd2\x16/"; } impl Encode<'_> for SslRequest { #[inline] fn encode_with(&self, buf: &mut Vec, _: ()) { buf.extend(&8_u32.to_be_bytes()); buf.extend(&(((1234 << 16) | 5679) as u32).to_be_bytes()); } } #[test] fn test_encode_ssl_request() { let mut buf = Vec::new(); SslRequest.encode(&mut buf); assert_eq!(buf, SslRequest::BYTES); } sqlx-postgres-0.7.3/src/message/startup.rs000064400000000000000000000054010072674642500170010ustar 00000000000000use crate::io::PgBufMutExt; use crate::io::{BufMutExt, Encode}; // To begin a session, a frontend opens a connection to the server and sends a startup message. // This message includes the names of the user and of the database the user wants to connect to; // it also identifies the particular protocol version to be used. // Optionally, the startup message can include additional settings for run-time parameters. pub struct Startup<'a> { /// The database user name to connect as. Required; there is no default. pub username: Option<&'a str>, /// The database to connect to. Defaults to the user name. pub database: Option<&'a str>, /// Additional start-up params. /// pub params: &'a [(&'a str, &'a str)], } impl Encode<'_> for Startup<'_> { fn encode_with(&self, buf: &mut Vec, _: ()) { buf.reserve(120); buf.put_length_prefixed(|buf| { // The protocol version number. The most significant 16 bits are the // major version number (3 for the protocol described here). The least // significant 16 bits are the minor version number (0 // for the protocol described here) buf.extend(&196_608_i32.to_be_bytes()); if let Some(username) = self.username { // The database user name to connect as. encode_startup_param(buf, "user", username); } if let Some(database) = self.database { // The database to connect to. Defaults to the user name. encode_startup_param(buf, "database", database); } for (name, value) in self.params { encode_startup_param(buf, name, value); } // A zero byte is required as a terminator // after the last name/value pair. buf.push(0); }); } } #[inline] fn encode_startup_param(buf: &mut Vec, name: &str, value: &str) { buf.put_str_nul(name); buf.put_str_nul(value); } #[test] fn test_encode_startup() { const EXPECTED: &[u8] = b"\0\0\0)\0\x03\0\0user\0postgres\0database\0postgres\0\0"; let mut buf = Vec::new(); let m = Startup { username: Some("postgres"), database: Some("postgres"), params: &[], }; m.encode(&mut buf); assert_eq!(buf, EXPECTED); } #[cfg(all(test, not(debug_assertions)))] #[bench] fn bench_encode_startup(b: &mut test::Bencher) { use test::black_box; let mut buf = Vec::with_capacity(128); b.iter(|| { buf.clear(); black_box(Startup { username: Some("postgres"), database: Some("postgres"), params: &[], }) .encode(&mut buf); }); } sqlx-postgres-0.7.3/src/message/sync.rs000064400000000000000000000003260072674642500162540ustar 00000000000000use crate::io::Encode; #[derive(Debug)] pub struct Sync; impl Encode<'_> for Sync { fn encode_with(&self, buf: &mut Vec, _: ()) { buf.push(b'S'); buf.extend(&4_i32.to_be_bytes()); } } sqlx-postgres-0.7.3/src/message/terminate.rs000064400000000000000000000003170072674642500172700ustar 00000000000000use crate::io::Encode; pub struct Terminate; impl Encode<'_> for Terminate { fn encode_with(&self, buf: &mut Vec, _: ()) { buf.push(b'X'); buf.extend(&4_u32.to_be_bytes()); } } sqlx-postgres-0.7.3/src/migrate.rs000064400000000000000000000230460072674642500153100ustar 00000000000000use std::str::FromStr; use std::time::Duration; use std::time::Instant; use futures_core::future::BoxFuture; pub(crate) use sqlx_core::migrate::MigrateError; pub(crate) use sqlx_core::migrate::{AppliedMigration, Migration}; pub(crate) use sqlx_core::migrate::{Migrate, MigrateDatabase}; use crate::connection::{ConnectOptions, Connection}; use crate::error::Error; use crate::executor::Executor; use crate::query::query; use crate::query_as::query_as; use crate::query_scalar::query_scalar; use crate::{PgConnectOptions, PgConnection, Postgres}; fn parse_for_maintenance(url: &str) -> Result<(PgConnectOptions, String), Error> { let mut options = PgConnectOptions::from_str(url)?; // pull out the name of the database to create let database = options .database .as_deref() .unwrap_or(&options.username) .to_owned(); // switch us to the maintenance database // use `postgres` _unless_ the database is postgres, in which case, use `template1` // this matches the behavior of the `createdb` util options.database = if database == "postgres" { Some("template1".into()) } else { Some("postgres".into()) }; Ok((options, database)) } impl MigrateDatabase for Postgres { fn create_database(url: &str) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { let (options, database) = parse_for_maintenance(url)?; let mut conn = options.connect().await?; let _ = conn .execute(&*format!( "CREATE DATABASE \"{}\"", database.replace('"', "\"\"") )) .await?; Ok(()) }) } fn database_exists(url: &str) -> BoxFuture<'_, Result> { Box::pin(async move { let (options, database) = parse_for_maintenance(url)?; let mut conn = options.connect().await?; let exists: bool = query_scalar("select exists(SELECT 1 from pg_database WHERE datname = $1)") .bind(database) .fetch_one(&mut conn) .await?; Ok(exists) }) } fn drop_database(url: &str) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { let (options, database) = parse_for_maintenance(url)?; let mut conn = options.connect().await?; let _ = conn .execute(&*format!( "DROP DATABASE IF EXISTS \"{}\"", database.replace('"', "\"\"") )) .await?; Ok(()) }) } fn force_drop_database(url: &str) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { let (options, database) = parse_for_maintenance(url)?; let mut conn = options.connect().await?; let row: (String,) = query_as("SELECT current_setting('server_version_num')") .fetch_one(&mut conn) .await?; let version = row.0.parse::().unwrap(); let pid_type = if version >= 90200 { "pid" } else { "procpid" }; conn.execute(&*format!( "SELECT pg_terminate_backend(pg_stat_activity.{pid_type}) FROM pg_stat_activity \ WHERE pg_stat_activity.datname = '{database}' AND {pid_type} <> pg_backend_pid()" )) .await?; Self::drop_database(url).await }) } } impl Migrate for PgConnection { fn ensure_migrations_table(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { Box::pin(async move { // language=SQL self.execute( r#" CREATE TABLE IF NOT EXISTS _sqlx_migrations ( version BIGINT PRIMARY KEY, description TEXT NOT NULL, installed_on TIMESTAMPTZ NOT NULL DEFAULT now(), success BOOLEAN NOT NULL, checksum BYTEA NOT NULL, execution_time BIGINT NOT NULL ); "#, ) .await?; Ok(()) }) } fn dirty_version(&mut self) -> BoxFuture<'_, Result, MigrateError>> { Box::pin(async move { // language=SQL let row: Option<(i64,)> = query_as( "SELECT version FROM _sqlx_migrations WHERE success = false ORDER BY version LIMIT 1", ) .fetch_optional(self) .await?; Ok(row.map(|r| r.0)) }) } fn list_applied_migrations( &mut self, ) -> BoxFuture<'_, Result, MigrateError>> { Box::pin(async move { // language=SQL let rows: Vec<(i64, Vec)> = query_as("SELECT version, checksum FROM _sqlx_migrations ORDER BY version") .fetch_all(self) .await?; let migrations = rows .into_iter() .map(|(version, checksum)| AppliedMigration { version, checksum: checksum.into(), }) .collect(); Ok(migrations) }) } fn lock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { Box::pin(async move { let database_name = current_database(self).await?; let lock_id = generate_lock_id(&database_name); // create an application lock over the database // this function will not return until the lock is acquired // https://www.postgresql.org/docs/current/explicit-locking.html#ADVISORY-LOCKS // https://www.postgresql.org/docs/current/functions-admin.html#FUNCTIONS-ADVISORY-LOCKS-TABLE // language=SQL let _ = query("SELECT pg_advisory_lock($1)") .bind(lock_id) .execute(self) .await?; Ok(()) }) } fn unlock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { Box::pin(async move { let database_name = current_database(self).await?; let lock_id = generate_lock_id(&database_name); // language=SQL let _ = query("SELECT pg_advisory_unlock($1)") .bind(lock_id) .execute(self) .await?; Ok(()) }) } fn apply<'e: 'm, 'm>( &'e mut self, migration: &'m Migration, ) -> BoxFuture<'m, Result> { Box::pin(async move { let mut tx = self.begin().await?; let start = Instant::now(); // Use a single transaction for the actual migration script and the essential bookeeping so we never // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. // The `execution_time` however can only be measured for the whole transaction. This value _only_ exists for // data lineage and debugging reasons, so it is not super important if it is lost. So we initialize it to -1 // and update it once the actual transaction completed. let _ = tx.execute(&*migration.sql).await?; // language=SQL let _ = query( r#" INSERT INTO _sqlx_migrations ( version, description, success, checksum, execution_time ) VALUES ( $1, $2, TRUE, $3, -1 ) "#, ) .bind(migration.version) .bind(&*migration.description) .bind(&*migration.checksum) .execute(&mut *tx) .await?; tx.commit().await?; // Update `elapsed_time`. // NOTE: The process may disconnect/die at this point, so the elapsed time value might be lost. We accept // this small risk since this value is not super important. let elapsed = start.elapsed(); // language=SQL let _ = query( r#" UPDATE _sqlx_migrations SET execution_time = $1 WHERE version = $2 "#, ) .bind(elapsed.as_nanos() as i64) .bind(migration.version) .execute(self) .await?; Ok(elapsed) }) } fn revert<'e: 'm, 'm>( &'e mut self, migration: &'m Migration, ) -> BoxFuture<'m, Result> { Box::pin(async move { // Use a single transaction for the actual migration script and the essential bookeeping so we never // execute migrations twice. See https://github.com/launchbadge/sqlx/issues/1966. let mut tx = self.begin().await?; let start = Instant::now(); let _ = tx.execute(&*migration.sql).await?; // language=SQL let _ = query(r#"DELETE FROM _sqlx_migrations WHERE version = $1"#) .bind(migration.version) .execute(&mut *tx) .await?; tx.commit().await?; let elapsed = start.elapsed(); Ok(elapsed) }) } } async fn current_database(conn: &mut PgConnection) -> Result { // language=SQL Ok(query_scalar("SELECT current_database()") .fetch_one(conn) .await?) } // inspired from rails: https://github.com/rails/rails/blob/6e49cc77ab3d16c06e12f93158eaf3e507d4120e/activerecord/lib/active_record/migration.rb#L1308 fn generate_lock_id(database_name: &str) -> i64 { const CRC_IEEE: crc::Crc = crc::Crc::::new(&crc::CRC_32_ISO_HDLC); // 0x3d32ad9e chosen by fair dice roll 0x3d32ad9e * (CRC_IEEE.checksum(database_name.as_bytes()) as i64) } sqlx-postgres-0.7.3/src/options/connect.rs000064400000000000000000000015420072674642500170010ustar 00000000000000use crate::connection::ConnectOptions; use crate::error::Error; use crate::{PgConnectOptions, PgConnection}; use futures_core::future::BoxFuture; use log::LevelFilter; use sqlx_core::Url; use std::time::Duration; impl ConnectOptions for PgConnectOptions { type Connection = PgConnection; fn from_url(url: &Url) -> Result { Self::parse_from_url(url) } fn connect(&self) -> BoxFuture<'_, Result> where Self::Connection: Sized, { Box::pin(PgConnection::establish(self)) } fn log_statements(mut self, level: LevelFilter) -> Self { self.log_settings.log_statements(level); self } fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self { self.log_settings.log_slow_statements(level, duration); self } } sqlx-postgres-0.7.3/src/options/mod.rs000064400000000000000000000520420072674642500161300ustar 00000000000000use std::borrow::Cow; use std::env::var; use std::fmt::{Display, Write}; use std::path::{Path, PathBuf}; pub use ssl_mode::PgSslMode; use crate::{connection::LogSettings, net::tls::CertificateInput}; mod connect; mod parse; mod pgpass; mod ssl_mode; /// Options and flags which can be used to configure a PostgreSQL connection. /// /// A value of `PgConnectOptions` can be parsed from a connection URL, /// as described by [libpq](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING). /// /// The general form for a connection URL is: /// /// ```text /// postgresql://[user[:password]@][host][:port][/dbname][?param1=value1&...] /// ``` /// /// This type also implements [`FromStr`][std::str::FromStr] so you can parse it from a string /// containing a connection URL and then further adjust options if necessary (see example below). /// /// ## Parameters /// /// |Parameter|Default|Description| /// |---------|-------|-----------| /// | `sslmode` | `prefer` | Determines whether or with what priority a secure SSL TCP/IP connection will be negotiated. See [`PgSslMode`]. | /// | `sslrootcert` | `None` | Sets the name of a file containing a list of trusted SSL Certificate Authorities. | /// | `statement-cache-capacity` | `100` | The maximum number of prepared statements stored in the cache. Set to `0` to disable. | /// | `host` | `None` | Path to the directory containing a PostgreSQL unix domain socket, which will be used instead of TCP if set. | /// | `hostaddr` | `None` | Same as `host`, but only accepts IP addresses. | /// | `application-name` | `None` | The name will be displayed in the pg_stat_activity view and included in CSV log entries. | /// | `user` | result of `whoami` | PostgreSQL user name to connect as. | /// | `password` | `None` | Password to be used if the server demands password authentication. | /// | `port` | `5432` | Port number to connect to at the server host, or socket file name extension for Unix-domain connections. | /// | `dbname` | `None` | The database name. | /// | `options` | `None` | The runtime parameters to send to the server at connection start. | /// /// The URL scheme designator can be either `postgresql://` or `postgres://`. /// Each of the URL parts is optional. /// /// ```text /// postgresql:// /// postgresql://localhost /// postgresql://localhost:5433 /// postgresql://localhost/mydb /// postgresql://user@localhost /// postgresql://user:secret@localhost /// postgresql://localhost?dbname=mydb&user=postgres&password=postgres /// ``` /// /// # Example /// /// ```rust,no_run /// use sqlx::{Connection, ConnectOptions}; /// use sqlx::postgres::{PgConnectOptions, PgConnection, PgPool, PgSslMode}; /// /// # async fn example() -> sqlx::Result<()> { /// // URL connection string /// let conn = PgConnection::connect("postgres://localhost/mydb").await?; /// /// // Manually-constructed options /// let conn = PgConnectOptions::new() /// .host("secret-host") /// .port(2525) /// .username("secret-user") /// .password("secret-password") /// .ssl_mode(PgSslMode::Require) /// .connect() /// .await?; /// /// // Modifying options parsed from a string /// let mut opts: PgConnectOptions = "postgres://localhost/mydb".parse()?; /// /// // Change the log verbosity level for queries. /// // Information about SQL queries is logged at `DEBUG` level by default. /// opts.log_statements(log::LevelFilter::Trace); /// /// let pool = PgPool::connect_with(&opts).await?; /// # } /// ``` #[derive(Debug, Clone)] pub struct PgConnectOptions { pub(crate) host: String, pub(crate) port: u16, pub(crate) socket: Option, pub(crate) username: String, pub(crate) password: Option, pub(crate) database: Option, pub(crate) ssl_mode: PgSslMode, pub(crate) ssl_root_cert: Option, pub(crate) ssl_client_cert: Option, pub(crate) ssl_client_key: Option, pub(crate) statement_cache_capacity: usize, pub(crate) application_name: Option, pub(crate) log_settings: LogSettings, pub(crate) extra_float_digits: Option>, pub(crate) options: Option, } impl Default for PgConnectOptions { fn default() -> Self { Self::new_without_pgpass().apply_pgpass() } } impl PgConnectOptions { /// Creates a new, default set of options ready for configuration. /// /// By default, this reads the following environment variables and sets their /// equivalent options. /// /// * `PGHOST` /// * `PGPORT` /// * `PGUSER` /// * `PGPASSWORD` /// * `PGDATABASE` /// * `PGSSLROOTCERT` /// * `PGSSLCERT` /// * `PGSSLKEY` /// * `PGSSLMODE` /// * `PGAPPNAME` /// /// # Example /// /// ```rust /// # use sqlx_core::postgres::PgConnectOptions; /// let options = PgConnectOptions::new(); /// ``` pub fn new() -> Self { Self::new_without_pgpass().apply_pgpass() } pub fn new_without_pgpass() -> Self { let port = var("PGPORT") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(5432); let host = var("PGHOST").ok().unwrap_or_else(|| default_host(port)); let username = var("PGUSER").ok().unwrap_or_else(whoami::username); let database = var("PGDATABASE").ok(); PgConnectOptions { port, host, socket: None, username, password: var("PGPASSWORD").ok(), database, ssl_root_cert: var("PGSSLROOTCERT").ok().map(CertificateInput::from), ssl_client_cert: var("PGSSLCERT").ok().map(CertificateInput::from), ssl_client_key: var("PGSSLKEY").ok().map(CertificateInput::from), ssl_mode: var("PGSSLMODE") .ok() .and_then(|v| v.parse().ok()) .unwrap_or_default(), statement_cache_capacity: 100, application_name: var("PGAPPNAME").ok(), extra_float_digits: Some("2".into()), log_settings: Default::default(), options: var("PGOPTIONS").ok(), } } pub(crate) fn apply_pgpass(mut self) -> Self { if self.password.is_none() { self.password = pgpass::load_password( &self.host, self.port, &self.username, self.database.as_deref(), ); } self } /// Sets the name of the host to connect to. /// /// If a host name begins with a slash, it specifies /// Unix-domain communication rather than TCP/IP communication; the value is the name of /// the directory in which the socket file is stored. /// /// The default behavior when host is not specified, or is empty, /// is to connect to a Unix-domain socket /// /// # Example /// /// ```rust /// # use sqlx_core::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .host("localhost"); /// ``` pub fn host(mut self, host: &str) -> Self { self.host = host.to_owned(); self } /// Get the current host. /// /// # Example /// /// ```rust /// # use sqlx_core::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .host("127.0.0.1"); /// assert_eq!(options.get_host(), "127.0.0.1"); /// ``` pub fn get_host(&self) -> &str { self.host.as_str() } /// Sets the port to connect to at the server host. /// /// The default port for PostgreSQL is `5432`. /// /// # Example /// /// ```rust /// # use sqlx_core::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .port(5432); /// ``` pub fn port(mut self, port: u16) -> Self { self.port = port; self } /// Sets a custom path to a directory containing a unix domain socket, /// switching the connection method from TCP to the corresponding socket. /// /// By default set to `None`. pub fn socket(mut self, path: impl AsRef) -> Self { self.socket = Some(path.as_ref().to_path_buf()); self } /// Sets the username to connect as. /// /// Defaults to be the same as the operating system name of /// the user running the application. /// /// # Example /// /// ```rust /// # use sqlx_core::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .username("postgres"); /// ``` pub fn username(mut self, username: &str) -> Self { self.username = username.to_owned(); self } /// Sets the password to use if the server demands password authentication. /// /// # Example /// /// ```rust /// # use sqlx_core::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .username("root") /// .password("safe-and-secure"); /// ``` pub fn password(mut self, password: &str) -> Self { self.password = Some(password.to_owned()); self } /// Sets the database name. Defaults to be the same as the user name. /// /// # Example /// /// ```rust /// # use sqlx_core::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .database("postgres"); /// ``` pub fn database(mut self, database: &str) -> Self { self.database = Some(database.to_owned()); self } /// Get the current database name. /// /// # Example /// /// ```rust /// # use sqlx_core::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .database("postgres"); /// assert!(options.get_database().is_some()); /// ``` pub fn get_database(&self) -> Option<&str> { self.database.as_deref() } /// Sets whether or with what priority a secure SSL TCP/IP connection will be negotiated /// with the server. /// /// By default, the SSL mode is [`Prefer`](PgSslMode::Prefer), and the client will /// first attempt an SSL connection but fallback to a non-SSL connection on failure. /// /// Ignored for Unix domain socket communication. /// /// # Example /// /// ```rust /// # use sqlx_core::postgres::{PgSslMode, PgConnectOptions}; /// let options = PgConnectOptions::new() /// .ssl_mode(PgSslMode::Require); /// ``` pub fn ssl_mode(mut self, mode: PgSslMode) -> Self { self.ssl_mode = mode; self } /// Sets the name of a file containing SSL certificate authority (CA) certificate(s). /// If the file exists, the server's certificate will be verified to be signed by /// one of these authorities. /// /// # Example /// /// ```rust /// # use sqlx_core::postgres::{PgSslMode, PgConnectOptions}; /// let options = PgConnectOptions::new() /// // Providing a CA certificate with less than VerifyCa is pointless /// .ssl_mode(PgSslMode::VerifyCa) /// .ssl_root_cert("./ca-certificate.crt"); /// ``` pub fn ssl_root_cert(mut self, cert: impl AsRef) -> Self { self.ssl_root_cert = Some(CertificateInput::File(cert.as_ref().to_path_buf())); self } /// Sets the name of a file containing SSL client certificate. /// /// # Example /// /// ```rust /// # use sqlx_core::postgres::{PgSslMode, PgConnectOptions}; /// let options = PgConnectOptions::new() /// // Providing a CA certificate with less than VerifyCa is pointless /// .ssl_mode(PgSslMode::VerifyCa) /// .ssl_client_cert("./client.crt"); /// ``` pub fn ssl_client_cert(mut self, cert: impl AsRef) -> Self { self.ssl_client_cert = Some(CertificateInput::File(cert.as_ref().to_path_buf())); self } /// Sets the SSL client certificate as a PEM-encoded byte slice. /// /// This should be an ASCII-encoded blob that starts with `-----BEGIN CERTIFICATE-----`. /// /// # Example /// Note: embedding SSL certificates and keys in the binary is not advised. /// This is for illustration purposes only. /// /// ```rust /// # use sqlx_core::postgres::{PgSslMode, PgConnectOptions}; /// /// const CERT: &[u8] = b"\ /// -----BEGIN CERTIFICATE----- /// /// -----END CERTIFICATE-----"; /// /// let options = PgConnectOptions::new() /// // Providing a CA certificate with less than VerifyCa is pointless /// .ssl_mode(PgSslMode::VerifyCa) /// .ssl_client_cert_from_pem(CERT); /// ``` pub fn ssl_client_cert_from_pem(mut self, cert: impl AsRef<[u8]>) -> Self { self.ssl_client_cert = Some(CertificateInput::Inline(cert.as_ref().to_vec())); self } /// Sets the name of a file containing SSL client key. /// /// # Example /// /// ```rust /// # use sqlx_core::postgres::{PgSslMode, PgConnectOptions}; /// let options = PgConnectOptions::new() /// // Providing a CA certificate with less than VerifyCa is pointless /// .ssl_mode(PgSslMode::VerifyCa) /// .ssl_client_key("./client.key"); /// ``` pub fn ssl_client_key(mut self, key: impl AsRef) -> Self { self.ssl_client_key = Some(CertificateInput::File(key.as_ref().to_path_buf())); self } /// Sets the SSL client key as a PEM-encoded byte slice. /// /// This should be an ASCII-encoded blob that starts with `-----BEGIN PRIVATE KEY-----`. /// /// # Example /// Note: embedding SSL certificates and keys in the binary is not advised. /// This is for illustration purposes only. /// /// ```rust /// # use sqlx_core::postgres::{PgSslMode, PgConnectOptions}; /// /// const KEY: &[u8] = b"\ /// -----BEGIN PRIVATE KEY----- /// /// -----END PRIVATE KEY-----"; /// /// let options = PgConnectOptions::new() /// // Providing a CA certificate with less than VerifyCa is pointless /// .ssl_mode(PgSslMode::VerifyCa) /// .ssl_client_key_from_pem(KEY); /// ``` pub fn ssl_client_key_from_pem(mut self, key: impl AsRef<[u8]>) -> Self { self.ssl_client_key = Some(CertificateInput::Inline(key.as_ref().to_vec())); self } /// Sets PEM encoded trusted SSL Certificate Authorities (CA). /// /// # Example /// /// ```rust /// # use sqlx_core::postgres::{PgSslMode, PgConnectOptions}; /// let options = PgConnectOptions::new() /// // Providing a CA certificate with less than VerifyCa is pointless /// .ssl_mode(PgSslMode::VerifyCa) /// .ssl_root_cert_from_pem(vec![]); /// ``` pub fn ssl_root_cert_from_pem(mut self, pem_certificate: Vec) -> Self { self.ssl_root_cert = Some(CertificateInput::Inline(pem_certificate)); self } /// Sets the capacity of the connection's statement cache in a number of stored /// distinct statements. Caching is handled using LRU, meaning when the /// amount of queries hits the defined limit, the oldest statement will get /// dropped. /// /// The default cache capacity is 100 statements. pub fn statement_cache_capacity(mut self, capacity: usize) -> Self { self.statement_cache_capacity = capacity; self } /// Sets the application name. Defaults to None /// /// # Example /// /// ```rust /// # use sqlx_core::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .application_name("my-app"); /// ``` pub fn application_name(mut self, application_name: &str) -> Self { self.application_name = Some(application_name.to_owned()); self } /// Sets or removes the `extra_float_digits` connection option. /// /// This changes the default precision of floating-point values returned in text mode (when /// not using prepared statements such as calling methods of [`Executor`] directly). /// /// Historically, Postgres would by default round floating-point values to 6 and 15 digits /// for `float4`/`REAL` (`f32`) and `float8`/`DOUBLE` (`f64`), respectively, which would mean /// that the returned value may not be exactly the same as its representation in Postgres. /// /// The nominal range for this value is `-15` to `3`, where negative values for this option /// cause floating-points to be rounded to that many fewer digits than normal (`-1` causes /// `float4` to be rounded to 5 digits instead of six, or 14 instead of 15 for `float8`), /// positive values cause Postgres to emit that many extra digits of precision over default /// (or simply use maximum precision in Postgres 12 and later), /// and 0 means keep the default behavior (or the "old" behavior described above /// as of Postgres 12). /// /// SQLx sets this value to 3 by default, which tells Postgres to return floating-point values /// at their maximum precision in the hope that the parsed value will be identical to its /// counterpart in Postgres. This is also the default in Postgres 12 and later anyway. /// /// However, older versions of Postgres and alternative implementations that talk the Postgres /// protocol may not support this option, or the full range of values. /// /// If you get an error like "unknown option `extra_float_digits`" when connecting, try /// setting this to `None` or consult the manual of your database for the allowed range /// of values. /// /// For more information, see: /// * [Postgres manual, 20.11.2: Client Connection Defaults; Locale and Formatting][20.11.2] /// * [Postgres manual, 8.1.3: Numeric Types; Floating-point Types][8.1.3] /// /// [`Executor`]: crate::executor::Executor /// [20.11.2]: https://www.postgresql.org/docs/current/runtime-config-client.html#RUNTIME-CONFIG-CLIENT-FORMAT /// [8.1.3]: https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-FLOAT /// /// ### Examples /// ```rust /// # use sqlx_core::postgres::PgConnectOptions; /// /// let mut options = PgConnectOptions::new() /// // for Redshift and Postgres 10 /// .extra_float_digits(2); /// /// let mut options = PgConnectOptions::new() /// // don't send the option at all (Postgres 9 and older) /// .extra_float_digits(None); /// ``` pub fn extra_float_digits(mut self, extra_float_digits: impl Into>) -> Self { self.extra_float_digits = extra_float_digits.into().map(|it| it.to_string().into()); self } /// Set additional startup options for the connection as a list of key-value pairs. /// /// # Example /// /// ```rust /// # use sqlx_core::postgres::PgConnectOptions; /// let options = PgConnectOptions::new() /// .options([("geqo", "off"), ("statement_timeout", "5min")]); /// ``` pub fn options(mut self, options: I) -> Self where K: Display, V: Display, I: IntoIterator, { // Do this in here so `options_str` is only set if we have an option to insert let options_str = self.options.get_or_insert_with(String::new); for (k, v) in options { if !options_str.is_empty() { options_str.push(' '); } write!(options_str, "-c {k}={v}").expect("failed to write an option to the string"); } self } /// We try using a socket if hostname starts with `/` or if socket parameter /// is specified. pub(crate) fn fetch_socket(&self) -> Option { match self.socket { Some(ref socket) => { let full_path = format!("{}/.s.PGSQL.{}", socket.display(), self.port); Some(full_path) } None if self.host.starts_with('/') => { let full_path = format!("{}/.s.PGSQL.{}", self.host, self.port); Some(full_path) } _ => None, } } } fn default_host(port: u16) -> String { // try to check for the existence of a unix socket and uses that let socket = format!(".s.PGSQL.{port}"); let candidates = [ "/var/run/postgresql", // Debian "/private/tmp", // OSX (homebrew) "/tmp", // Default ]; for candidate in &candidates { if Path::new(candidate).join(&socket).exists() { return candidate.to_string(); } } // fallback to localhost if no socket was found "localhost".to_owned() } #[test] fn test_options_formatting() { let options = PgConnectOptions::new().options([("geqo", "off")]); assert_eq!(options.options, Some("-c geqo=off".to_string())); let options = options.options([("search_path", "sqlx")]); assert_eq!( options.options, Some("-c geqo=off -c search_path=sqlx".to_string()) ); let options = PgConnectOptions::new().options([("geqo", "off"), ("statement_timeout", "5min")]); assert_eq!( options.options, Some("-c geqo=off -c statement_timeout=5min".to_string()) ); let options = PgConnectOptions::new(); assert_eq!(options.options, None); } sqlx-postgres-0.7.3/src/options/parse.rs000064400000000000000000000167540072674642500164750ustar 00000000000000use crate::error::Error; use crate::PgConnectOptions; use sqlx_core::percent_encoding::percent_decode_str; use sqlx_core::Url; use std::net::IpAddr; use std::str::FromStr; impl PgConnectOptions { pub(crate) fn parse_from_url(url: &Url) -> Result { let mut options = Self::new_without_pgpass(); if let Some(host) = url.host_str() { let host_decoded = percent_decode_str(host); options = match host_decoded.clone().next() { Some(b'/') => options.socket(&*host_decoded.decode_utf8().map_err(Error::config)?), _ => options.host(host), } } if let Some(port) = url.port() { options = options.port(port); } let username = url.username(); if !username.is_empty() { options = options.username( &*percent_decode_str(username) .decode_utf8() .map_err(Error::config)?, ); } if let Some(password) = url.password() { options = options.password( &*percent_decode_str(password) .decode_utf8() .map_err(Error::config)?, ); } let path = url.path().trim_start_matches('/'); if !path.is_empty() { options = options.database(path); } for (key, value) in url.query_pairs().into_iter() { match &*key { "sslmode" | "ssl-mode" => { options = options.ssl_mode(value.parse().map_err(Error::config)?); } "sslrootcert" | "ssl-root-cert" | "ssl-ca" => { options = options.ssl_root_cert(&*value); } "sslcert" | "ssl-cert" => options = options.ssl_client_cert(&*value), "sslkey" | "ssl-key" => options = options.ssl_client_key(&*value), "statement-cache-capacity" => { options = options.statement_cache_capacity(value.parse().map_err(Error::config)?); } "host" => { if value.starts_with("/") { options = options.socket(&*value); } else { options = options.host(&*value); } } "hostaddr" => { value.parse::().map_err(Error::config)?; options = options.host(&*value) } "port" => options = options.port(value.parse().map_err(Error::config)?), "dbname" => options = options.database(&*value), "user" => options = options.username(&*value), "password" => options = options.password(&*value), "application_name" => options = options.application_name(&*value), "options" => { if let Some(options) = options.options.as_mut() { options.push(' '); options.push_str(&*value); } else { options.options = Some(value.to_string()); } } k if k.starts_with("options[") => { if let Some(key) = k.strip_prefix("options[").unwrap().strip_suffix(']') { options = options.options([(key, &*value)]); } } _ => tracing::warn!(%key, %value, "ignoring unrecognized connect parameter"), } } let options = options.apply_pgpass(); Ok(options) } } impl FromStr for PgConnectOptions { type Err = Error; fn from_str(s: &str) -> Result { let url: Url = s.parse().map_err(Error::config)?; Self::parse_from_url(&url) } } #[test] fn it_parses_socket_correctly_from_parameter() { let url = "postgres:///?host=/var/run/postgres/"; let opts = PgConnectOptions::from_str(url).unwrap(); assert_eq!(Some("/var/run/postgres/".into()), opts.socket); } #[test] fn it_parses_host_correctly_from_parameter() { let url = "postgres:///?host=google.database.com"; let opts = PgConnectOptions::from_str(url).unwrap(); assert_eq!(None, opts.socket); assert_eq!("google.database.com", &opts.host); } #[test] fn it_parses_hostaddr_correctly_from_parameter() { let url = "postgres:///?hostaddr=8.8.8.8"; let opts = PgConnectOptions::from_str(url).unwrap(); assert_eq!(None, opts.socket); assert_eq!("8.8.8.8", &opts.host); } #[test] fn it_parses_port_correctly_from_parameter() { let url = "postgres:///?port=1234"; let opts = PgConnectOptions::from_str(url).unwrap(); assert_eq!(None, opts.socket); assert_eq!(1234, opts.port); } #[test] fn it_parses_dbname_correctly_from_parameter() { let url = "postgres:///?dbname=some_db"; let opts = PgConnectOptions::from_str(url).unwrap(); assert_eq!(None, opts.socket); assert_eq!(Some("some_db"), opts.database.as_deref()); } #[test] fn it_parses_user_correctly_from_parameter() { let url = "postgres:///?user=some_user"; let opts = PgConnectOptions::from_str(url).unwrap(); assert_eq!(None, opts.socket); assert_eq!("some_user", opts.username); } #[test] fn it_parses_password_correctly_from_parameter() { let url = "postgres:///?password=some_pass"; let opts = PgConnectOptions::from_str(url).unwrap(); assert_eq!(None, opts.socket); assert_eq!(Some("some_pass"), opts.password.as_deref()); } #[test] fn it_parses_application_name_correctly_from_parameter() { let url = "postgres:///?application_name=some_name"; let opts = PgConnectOptions::from_str(url).unwrap(); assert_eq!(Some("some_name"), opts.application_name.as_deref()); } #[test] fn it_parses_username_with_at_sign_correctly() { let url = "postgres://user@hostname:password@hostname:5432/database"; let opts = PgConnectOptions::from_str(url).unwrap(); assert_eq!("user@hostname", &opts.username); } #[test] fn it_parses_password_with_non_ascii_chars_correctly() { let url = "postgres://username:p@ssw0rd@hostname:5432/database"; let opts = PgConnectOptions::from_str(url).unwrap(); assert_eq!(Some("p@ssw0rd".into()), opts.password); } #[test] fn it_parses_socket_correctly_percent_encoded() { let url = "postgres://%2Fvar%2Flib%2Fpostgres/database"; let opts = PgConnectOptions::from_str(url).unwrap(); assert_eq!(Some("/var/lib/postgres/".into()), opts.socket); } #[test] fn it_parses_socket_correctly_with_username_percent_encoded() { let url = "postgres://some_user@%2Fvar%2Flib%2Fpostgres/database"; let opts = PgConnectOptions::from_str(url).unwrap(); assert_eq!("some_user", opts.username); assert_eq!(Some("/var/lib/postgres/".into()), opts.socket); assert_eq!(Some("database"), opts.database.as_deref()); } #[test] fn it_parses_libpq_options_correctly() { let url = "postgres:///?options=-c%20synchronous_commit%3Doff%20--search_path%3Dpostgres"; let opts = PgConnectOptions::from_str(url).unwrap(); assert_eq!( Some("-c synchronous_commit=off --search_path=postgres".into()), opts.options ); } #[test] fn it_parses_sqlx_options_correctly() { let url = "postgres:///?options[synchronous_commit]=off&options[search_path]=postgres"; let opts = PgConnectOptions::from_str(url).unwrap(); assert_eq!( Some("-c synchronous_commit=off -c search_path=postgres".into()), opts.options ); } sqlx-postgres-0.7.3/src/options/pgpass.rs000064400000000000000000000225400072674642500166460ustar 00000000000000use std::borrow::Cow; use std::env::var_os; use std::fs::File; use std::io::{BufRead, BufReader}; use std::path::PathBuf; /// try to load a password from the various pgpass file locations pub fn load_password( host: &str, port: u16, username: &str, database: Option<&str>, ) -> Option { let custom_file = var_os("PGPASSFILE"); if let Some(file) = custom_file { if let Some(password) = load_password_from_file(PathBuf::from(file), host, port, username, database) { return Some(password); } } #[cfg(not(target_os = "windows"))] let default_file = home::home_dir().map(|path| path.join(".pgpass")); #[cfg(target_os = "windows")] let default_file = { use etcetera::BaseStrategy; etcetera::base_strategy::Windows::new() .ok() .map(|basedirs| basedirs.data_dir().join("postgres").join("pgpass.conf")) }; load_password_from_file(default_file?, host, port, username, database) } /// try to extract a password from a pgpass file fn load_password_from_file( path: PathBuf, host: &str, port: u16, username: &str, database: Option<&str>, ) -> Option { let file = File::open(&path).ok()?; #[cfg(target_os = "linux")] { use std::os::unix::fs::PermissionsExt; // check file permissions on linux let metadata = file.metadata().ok()?; let permissions = metadata.permissions(); let mode = permissions.mode(); if mode & 0o77 != 0 { tracing::warn!( path = %path.to_string_lossy(), permissions = format!("{mode:o}"), "Ignoring path. Permissions are not strict enough", ); return None; } } let reader = BufReader::new(file); load_password_from_reader(reader, host, port, username, database) } fn load_password_from_reader( mut reader: impl BufRead, host: &str, port: u16, username: &str, database: Option<&str>, ) -> Option { let mut line = String::new(); // https://stackoverflow.com/a/55041833 fn trim_newline(s: &mut String) { if s.ends_with('\n') { s.pop(); if s.ends_with('\r') { s.pop(); } } } while let Ok(n) = reader.read_line(&mut line) { if n == 0 { break; } if line.starts_with('#') { // comment, do nothing } else { // try to load password from line trim_newline(&mut line); if let Some(password) = load_password_from_line(&line, host, port, username, database) { return Some(password); } } line.clear(); } None } /// try to check all fields & extract the password fn load_password_from_line( mut line: &str, host: &str, port: u16, username: &str, database: Option<&str>, ) -> Option { let whole_line = line; // Pgpass line ordering: hostname, port, database, username, password // See: https://www.postgresql.org/docs/9.3/libpq-pgpass.html match line.trim_start().chars().next() { None | Some('#') => None, _ => { matches_next_field(whole_line, &mut line, host)?; matches_next_field(whole_line, &mut line, &port.to_string())?; matches_next_field(whole_line, &mut line, database.unwrap_or_default())?; matches_next_field(whole_line, &mut line, username)?; Some(line.to_owned()) } } } /// check if the next field matches the provided value fn matches_next_field(whole_line: &str, line: &mut &str, value: &str) -> Option<()> { let field = find_next_field(line); match field { Some(field) => { if field == "*" || field == value { Some(()) } else { None } } None => { tracing::warn!(line = whole_line, "Malformed line in pgpass file"); None } } } /// extract the next value from a line in a pgpass file /// /// `line` will get updated to point behind the field and delimiter fn find_next_field<'a>(line: &mut &'a str) -> Option> { let mut escaping = false; let mut escaped_string = None; let mut last_added = 0; let char_indicies = line.char_indices(); for (idx, c) in char_indicies { if c == ':' && !escaping { let (field, rest) = line.split_at(idx); *line = &rest[1..]; if let Some(mut escaped_string) = escaped_string { escaped_string += &field[last_added..]; return Some(Cow::Owned(escaped_string)); } else { return Some(Cow::Borrowed(field)); } } else if c == '\\' { let s = escaped_string.get_or_insert_with(String::new); if escaping { s.push('\\'); } else { *s += &line[last_added..idx]; } escaping = !escaping; last_added = idx + 1; } else { escaping = false; } } return None; } #[cfg(test)] mod tests { use super::{find_next_field, load_password_from_line, load_password_from_reader}; use std::borrow::Cow; #[test] fn test_find_next_field() { fn test_case<'a>(mut input: &'a str, result: Option>, rest: &str) { assert_eq!(find_next_field(&mut input), result); assert_eq!(input, rest); } // normal field test_case("foo:bar:baz", Some(Cow::Borrowed("foo")), "bar:baz"); // \ escaped test_case( "foo\\\\:bar:baz", Some(Cow::Owned("foo\\".to_owned())), "bar:baz", ); // : escaped test_case( "foo\\::bar:baz", Some(Cow::Owned("foo:".to_owned())), "bar:baz", ); // unnecessary escape test_case( "foo\\a:bar:baz", Some(Cow::Owned("fooa".to_owned())), "bar:baz", ); // other text after escape test_case( "foo\\\\a:bar:baz", Some(Cow::Owned("foo\\a".to_owned())), "bar:baz", ); // double escape test_case( "foo\\\\\\\\a:bar:baz", Some(Cow::Owned("foo\\\\a".to_owned())), "bar:baz", ); // utf8 support test_case("🦀:bar:baz", Some(Cow::Borrowed("🦀")), "bar:baz"); // missing delimiter (eof) test_case("foo", None, "foo"); // missing delimiter after escape test_case("foo\\:", None, "foo\\:"); // missing delimiter after unused trailing escape test_case("foo\\", None, "foo\\"); } #[test] fn test_load_password_from_line() { // normal assert_eq!( load_password_from_line( "localhost:5432:bar:foo:baz", "localhost", 5432, "foo", Some("bar") ), Some("baz".to_owned()) ); // wildcard assert_eq!( load_password_from_line("*:5432:bar:foo:baz", "localhost", 5432, "foo", Some("bar")), Some("baz".to_owned()) ); // accept wildcard with missing db assert_eq!( load_password_from_line("localhost:5432:*:foo:baz", "localhost", 5432, "foo", None), Some("baz".to_owned()) ); // doesn't match assert_eq!( load_password_from_line( "thishost:5432:bar:foo:baz", "thathost", 5432, "foo", Some("bar") ), None ); // malformed entry assert_eq!( load_password_from_line( "localhost:5432:bar:foo", "localhost", 5432, "foo", Some("bar") ), None ); } #[test] fn test_load_password_from_reader() { let file = b"\ localhost:5432:bar:foo:baz\n\ # mixed line endings (also a comment!)\n\ *:5432:bar:foo:baz\r\n\ # trailing space, comment with CRLF! \r\n\ thishost:5432:bar:foo:baz \n\ # malformed line \n\ thathost:5432:foobar:foo\n\ # missing trailing newline\n\ localhost:5432:*:foo:baz "; // normal assert_eq!( load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", Some("bar")), Some("baz".to_owned()) ); // wildcard assert_eq!( load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", Some("foobar")), Some("baz".to_owned()) ); // accept wildcard with missing db assert_eq!( load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", None), Some("baz".to_owned()) ); // doesn't match assert_eq!( load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", Some("foobar")), None ); // malformed entry assert_eq!( load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", Some("foobar")), None ); } } sqlx-postgres-0.7.3/src/options/ssl_mode.rs000064400000000000000000000032700072674642500171550ustar 00000000000000use crate::error::Error; use std::str::FromStr; /// Options for controlling the level of protection provided for PostgreSQL SSL connections. /// /// It is used by the [`ssl_mode`](super::PgConnectOptions::ssl_mode) method. #[derive(Debug, Clone, Copy)] pub enum PgSslMode { /// Only try a non-SSL connection. Disable, /// First try a non-SSL connection; if that fails, try an SSL connection. Allow, /// First try an SSL connection; if that fails, try a non-SSL connection. Prefer, /// Only try an SSL connection. If a root CA file is present, verify the connection /// in the same way as if `VerifyCa` was specified. Require, /// Only try an SSL connection, and verify that the server certificate is issued by a /// trusted certificate authority (CA). VerifyCa, /// Only try an SSL connection; verify that the server certificate is issued by a trusted /// CA and that the requested server host name matches that in the certificate. VerifyFull, } impl Default for PgSslMode { fn default() -> Self { PgSslMode::Prefer } } impl FromStr for PgSslMode { type Err = Error; fn from_str(s: &str) -> Result { Ok(match &*s.to_ascii_lowercase() { "disable" => PgSslMode::Disable, "allow" => PgSslMode::Allow, "prefer" => PgSslMode::Prefer, "require" => PgSslMode::Require, "verify-ca" => PgSslMode::VerifyCa, "verify-full" => PgSslMode::VerifyFull, _ => { return Err(Error::Configuration( format!("unknown value {s:?} for `ssl_mode`").into(), )); } }) } } sqlx-postgres-0.7.3/src/query_result.rs000064400000000000000000000013100072674642500164110ustar 00000000000000use std::iter::{Extend, IntoIterator}; #[derive(Debug, Default)] pub struct PgQueryResult { pub(super) rows_affected: u64, } impl PgQueryResult { pub fn rows_affected(&self) -> u64 { self.rows_affected } } impl Extend for PgQueryResult { fn extend>(&mut self, iter: T) { for elem in iter { self.rows_affected += elem.rows_affected; } } } #[cfg(feature = "any")] impl From for crate::any::AnyQueryResult { fn from(done: PgQueryResult) -> Self { crate::any::AnyQueryResult { rows_affected: done.rows_affected, last_insert_id: None, } } } sqlx-postgres-0.7.3/src/row.rs000064400000000000000000000024110072674642500144600ustar 00000000000000use crate::column::ColumnIndex; use crate::error::Error; use crate::message::DataRow; use crate::statement::PgStatementMetadata; use crate::value::PgValueFormat; use crate::{PgColumn, PgValueRef, Postgres}; use std::sync::Arc; pub(crate) use sqlx_core::row::Row; /// Implementation of [`Row`] for PostgreSQL. pub struct PgRow { pub(crate) data: DataRow, pub(crate) format: PgValueFormat, pub(crate) metadata: Arc, } impl Row for PgRow { type Database = Postgres; fn columns(&self) -> &[PgColumn] { &self.metadata.columns } fn try_get_raw(&self, index: I) -> Result, Error> where I: ColumnIndex, { let index = index.index(self)?; let column = &self.metadata.columns[index]; let value = self.data.get(index); Ok(PgValueRef { format: self.format, row: Some(&self.data.storage), type_info: column.type_info.clone(), value, }) } } impl ColumnIndex for &'_ str { fn index(&self, row: &PgRow) -> Result { row.metadata .column_names .get(*self) .ok_or_else(|| Error::ColumnNotFound((*self).into())) .map(|v| *v) } } sqlx-postgres-0.7.3/src/statement.rs000064400000000000000000000046640072674642500156710ustar 00000000000000use super::{PgColumn, PgTypeInfo}; use crate::column::ColumnIndex; use crate::error::Error; use crate::ext::ustr::UStr; use crate::{PgArguments, Postgres}; use std::borrow::Cow; use std::sync::Arc; pub(crate) use sqlx_core::statement::Statement; use sqlx_core::{Either, HashMap}; #[derive(Debug, Clone)] pub struct PgStatement<'q> { pub(crate) sql: Cow<'q, str>, pub(crate) metadata: Arc, } #[derive(Debug, Default)] pub(crate) struct PgStatementMetadata { pub(crate) columns: Vec, // This `Arc` is not redundant; it's used to avoid deep-copying this map for the `Any` backend. // See `sqlx-postgres/src/any.rs` pub(crate) column_names: Arc>, pub(crate) parameters: Vec, } impl<'q> Statement<'q> for PgStatement<'q> { type Database = Postgres; fn to_owned(&self) -> PgStatement<'static> { PgStatement::<'static> { sql: Cow::Owned(self.sql.clone().into_owned()), metadata: self.metadata.clone(), } } fn sql(&self) -> &str { &self.sql } fn parameters(&self) -> Option> { Some(Either::Left(&self.metadata.parameters)) } fn columns(&self) -> &[PgColumn] { &self.metadata.columns } impl_statement_query!(PgArguments); } impl ColumnIndex> for &'_ str { fn index(&self, statement: &PgStatement<'_>) -> Result { statement .metadata .column_names .get(*self) .ok_or_else(|| Error::ColumnNotFound((*self).into())) .map(|v| *v) } } // #[cfg(feature = "any")] // impl<'q> From> for crate::any::AnyStatement<'q> { // #[inline] // fn from(statement: PgStatement<'q>) -> Self { // crate::any::AnyStatement::<'q> { // columns: statement // .metadata // .columns // .iter() // .map(|col| col.clone().into()) // .collect(), // column_names: statement.metadata.column_names.clone(), // parameters: Some(Either::Left( // statement // .metadata // .parameters // .iter() // .map(|ty| ty.clone().into()) // .collect(), // )), // sql: statement.sql, // } // } // } sqlx-postgres-0.7.3/src/testing/mod.rs000064400000000000000000000173240072674642500161160ustar 00000000000000use std::fmt::Write; use std::ops::Deref; use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::{Duration, SystemTime}; use futures_core::future::BoxFuture; use once_cell::sync::OnceCell; use crate::connection::Connection; use crate::error::Error; use crate::executor::Executor; use crate::pool::{Pool, PoolOptions}; use crate::query::query; use crate::query_scalar::query_scalar; use crate::{PgConnectOptions, PgConnection, Postgres}; pub(crate) use sqlx_core::testing::*; // Using a blocking `OnceCell` here because the critical sections are short. static MASTER_POOL: OnceCell> = OnceCell::new(); // Automatically delete any databases created before the start of the test binary. static DO_CLEANUP: AtomicBool = AtomicBool::new(true); impl TestSupport for Postgres { fn test_context(args: &TestArgs) -> BoxFuture<'_, Result, Error>> { Box::pin(async move { let res = test_context(args).await; res }) } fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { let mut conn = MASTER_POOL .get() .expect("cleanup_test() invoked outside `#[sqlx::test]") .acquire() .await?; conn.execute(&format!("drop database if exists {db_name:?};")[..]) .await?; query("delete from _sqlx_test.databases where db_name = $1") .bind(&db_name) .execute(&mut *conn) .await?; Ok(()) }) } fn cleanup_test_dbs() -> BoxFuture<'static, Result, Error>> { Box::pin(async move { let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); let mut conn = PgConnection::connect(&url).await?; let now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap(); let num_deleted = do_cleanup(&mut conn, now).await?; let _ = conn.close().await; Ok(Some(num_deleted)) }) } fn snapshot( _conn: &mut Self::Connection, ) -> BoxFuture<'_, Result, Error>> { // TODO: I want to get the testing feature out the door so this will have to wait, // but I'm keeping the code around for now because I plan to come back to it. todo!() } } async fn test_context(args: &TestArgs) -> Result, Error> { let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); let master_opts = PgConnectOptions::from_str(&url).expect("failed to parse DATABASE_URL"); let pool = PoolOptions::new() // Postgres' normal connection limit is 100 plus 3 superuser connections // We don't want to use the whole cap and there may be fuzziness here due to // concurrently running tests anyway. .max_connections(20) // Immediately close master connections. Tokio's I/O streams don't like hopping runtimes. .after_release(|_conn, _| Box::pin(async move { Ok(false) })) .connect_lazy_with(master_opts); let master_pool = match MASTER_POOL.try_insert(pool) { Ok(inserted) => inserted, Err((existing, pool)) => { // Sanity checks. assert_eq!( existing.connect_options().host, pool.connect_options().host, "DATABASE_URL changed at runtime, host differs" ); assert_eq!( existing.connect_options().database, pool.connect_options().database, "DATABASE_URL changed at runtime, database differs" ); existing } }; let mut conn = master_pool.acquire().await?; // language=PostgreSQL conn.execute( // Explicit lock avoids this latent bug: https://stackoverflow.com/a/29908840 // I couldn't find a bug on the mailing list for `CREATE SCHEMA` specifically, // but a clearly related bug with `CREATE TABLE` has been known since 2007: // https://www.postgresql.org/message-id/200710222037.l9MKbCJZ098744%40wwwmaster.postgresql.org r#" lock table pg_catalog.pg_namespace in share row exclusive mode; create schema if not exists _sqlx_test; create table if not exists _sqlx_test.databases ( db_name text primary key, test_path text not null, created_at timestamptz not null default now() ); create index if not exists databases_created_at on _sqlx_test.databases(created_at); create sequence if not exists _sqlx_test.database_ids; "#, ) .await?; // Record the current time _before_ we acquire the `DO_CLEANUP` permit. This // prevents the first test thread from accidentally deleting new test dbs // created by other test threads if we're a bit slow. let now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap(); // Only run cleanup if the test binary just started. if DO_CLEANUP.swap(false, Ordering::SeqCst) { do_cleanup(&mut conn, now).await?; } let new_db_name: String = query_scalar( r#" insert into _sqlx_test.databases(db_name, test_path) select '_sqlx_test_' || nextval('_sqlx_test.database_ids'), $1 returning db_name "#, ) .bind(&args.test_path) .fetch_one(&mut *conn) .await?; conn.execute(&format!("create database {new_db_name:?}")[..]) .await?; Ok(TestContext { pool_opts: PoolOptions::new() // Don't allow a single test to take all the connections. // Most tests shouldn't require more than 5 connections concurrently, // or else they're likely doing too much in one test. .max_connections(5) // Close connections ASAP if left in the idle queue. .idle_timeout(Some(Duration::from_secs(1))) .parent(master_pool.clone()), connect_opts: master_pool .connect_options() .deref() .clone() .database(&new_db_name), db_name: new_db_name, }) } async fn do_cleanup(conn: &mut PgConnection, created_before: Duration) -> Result { // since SystemTime is not monotonic we added a little margin here to avoid race conditions with other threads let created_before = i64::try_from(created_before.as_secs()).unwrap() - 2; let delete_db_names: Vec = query_scalar( "select db_name from _sqlx_test.databases \ where created_at < (to_timestamp($1) at time zone 'UTC')", ) .bind(&created_before) .fetch_all(&mut *conn) .await?; if delete_db_names.is_empty() { return Ok(0); } let mut deleted_db_names = Vec::with_capacity(delete_db_names.len()); let delete_db_names = delete_db_names.into_iter(); let mut command = String::new(); for db_name in delete_db_names { command.clear(); writeln!(command, "drop database if exists {db_name:?};").ok(); match conn.execute(&*command).await { Ok(_deleted) => { deleted_db_names.push(db_name); } // Assume a database error just means the DB is still in use. Err(Error::Database(dbe)) => { eprintln!("could not clean test database {db_name:?}: {dbe}") } // Bubble up other errors Err(e) => return Err(e), } } query("delete from _sqlx_test.databases where db_name = any($1::text[])") .bind(&deleted_db_names) .execute(&mut *conn) .await?; Ok(deleted_db_names.len()) } sqlx-postgres-0.7.3/src/transaction.rs000064400000000000000000000042720072674642500162050ustar 00000000000000use futures_core::future::BoxFuture; use crate::error::Error; use crate::executor::Executor; use crate::{PgConnection, Postgres}; pub(crate) use sqlx_core::transaction::*; /// Implementation of [`TransactionManager`] for PostgreSQL. pub struct PgTransactionManager; impl TransactionManager for PgTransactionManager { type Database = Postgres; fn begin(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { let rollback = Rollback::new(conn); let query = begin_ansi_transaction_sql(rollback.conn.transaction_depth); rollback.conn.queue_simple_query(&query); rollback.conn.transaction_depth += 1; rollback.conn.wait_until_ready().await?; rollback.defuse(); Ok(()) }) } fn commit(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { if conn.transaction_depth > 0 { conn.execute(&*commit_ansi_transaction_sql(conn.transaction_depth)) .await?; conn.transaction_depth -= 1; } Ok(()) }) } fn rollback(conn: &mut PgConnection) -> BoxFuture<'_, Result<(), Error>> { Box::pin(async move { if conn.transaction_depth > 0 { conn.execute(&*rollback_ansi_transaction_sql(conn.transaction_depth)) .await?; conn.transaction_depth -= 1; } Ok(()) }) } fn start_rollback(conn: &mut PgConnection) { if conn.transaction_depth > 0 { conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.transaction_depth)); conn.transaction_depth -= 1; } } } struct Rollback<'c> { conn: &'c mut PgConnection, defuse: bool, } impl Drop for Rollback<'_> { fn drop(&mut self) { if !self.defuse { PgTransactionManager::start_rollback(self.conn) } } } impl<'c> Rollback<'c> { fn new(conn: &'c mut PgConnection) -> Self { Self { conn, defuse: false, } } fn defuse(mut self) { self.defuse = true; } } sqlx-postgres-0.7.3/src/type_info.rs000064400000000000000000001307260072674642500156600ustar 00000000000000#![allow(dead_code)] use std::borrow::Cow; use std::fmt::{self, Display, Formatter}; use std::ops::Deref; use std::sync::Arc; use crate::ext::ustr::UStr; use crate::types::Oid; pub(crate) use sqlx_core::type_info::TypeInfo; /// Type information for a PostgreSQL type. #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct PgTypeInfo(pub(crate) PgType); impl Deref for PgTypeInfo { type Target = PgType; fn deref(&self) -> &Self::Target { &self.0 } } #[derive(Debug, Clone)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] #[repr(u32)] pub enum PgType { Bool, Bytea, Char, Name, Int8, Int2, Int4, Text, Oid, Json, JsonArray, Point, Lseg, Path, Box, Polygon, Line, LineArray, Cidr, CidrArray, Float4, Float8, Unknown, Circle, CircleArray, Macaddr8, Macaddr8Array, Macaddr, Inet, BoolArray, ByteaArray, CharArray, NameArray, Int2Array, Int4Array, TextArray, BpcharArray, VarcharArray, Int8Array, PointArray, LsegArray, PathArray, BoxArray, Float4Array, Float8Array, PolygonArray, OidArray, MacaddrArray, InetArray, Bpchar, Varchar, Date, Time, Timestamp, TimestampArray, DateArray, TimeArray, Timestamptz, TimestamptzArray, Interval, IntervalArray, NumericArray, Timetz, TimetzArray, Bit, BitArray, Varbit, VarbitArray, Numeric, Record, RecordArray, Uuid, UuidArray, Jsonb, JsonbArray, Int4Range, Int4RangeArray, NumRange, NumRangeArray, TsRange, TsRangeArray, TstzRange, TstzRangeArray, DateRange, DateRangeArray, Int8Range, Int8RangeArray, Jsonpath, JsonpathArray, Money, MoneyArray, // https://www.postgresql.org/docs/9.3/datatype-pseudo.html Void, // A realized user-defined type. When a connection sees a DeclareXX variant it resolves // into this one before passing it along to `accepts` or inside of `Value` objects. Custom(Arc), // From [`PgTypeInfo::with_name`] DeclareWithName(UStr), // NOTE: Do we want to bring back type declaration by ID? It's notoriously fragile but // someone may have a user for it DeclareWithOid(Oid), } #[derive(Debug, Clone)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct PgCustomType { #[cfg_attr(feature = "offline", serde(skip))] pub(crate) oid: Oid, pub(crate) name: UStr, pub(crate) kind: PgTypeKind, } #[derive(Debug, Clone)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub enum PgTypeKind { Simple, Pseudo, Domain(PgTypeInfo), Composite(Arc<[(String, PgTypeInfo)]>), Array(PgTypeInfo), Enum(Arc<[String]>), Range(PgTypeInfo), } impl PgTypeInfo { /// Returns the corresponding `PgTypeInfo` if the OID is a built-in type and recognized by SQLx. pub(crate) fn try_from_oid(oid: Oid) -> Option { PgType::try_from_oid(oid).map(Self) } /// Returns the _kind_ (simple, array, enum, etc.) for this type. pub fn kind(&self) -> &PgTypeKind { self.0.kind() } /// Returns the OID for this type, if available. /// /// The OID may not be available if SQLx only knows the type by name. /// It will have to be resolved by a `PgConnection` at runtime which /// will yield a new and semantically distinct `TypeInfo` instance. /// /// This method does not perform any such lookup. /// /// ### Note /// With the exception of [the default `pg_type` catalog][pg_type], type OIDs are *not* stable in PostgreSQL. /// If a type is added by an extension, its OID will be assigned when the `CREATE EXTENSION` statement is executed, /// and so can change depending on what extensions are installed and in what order, as well as the exact /// version of PostgreSQL. /// /// [pg_type]: https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_type.dat pub fn oid(&self) -> Option { self.0.try_oid() } #[doc(hidden)] pub fn __type_feature_gate(&self) -> Option<&'static str> { if [ PgTypeInfo::DATE, PgTypeInfo::TIME, PgTypeInfo::TIMESTAMP, PgTypeInfo::TIMESTAMPTZ, PgTypeInfo::DATE_ARRAY, PgTypeInfo::TIME_ARRAY, PgTypeInfo::TIMESTAMP_ARRAY, PgTypeInfo::TIMESTAMPTZ_ARRAY, ] .contains(self) { Some("time") } else if [PgTypeInfo::UUID, PgTypeInfo::UUID_ARRAY].contains(self) { Some("uuid") } else if [ PgTypeInfo::JSON, PgTypeInfo::JSONB, PgTypeInfo::JSON_ARRAY, PgTypeInfo::JSONB_ARRAY, ] .contains(self) { Some("json") } else if [ PgTypeInfo::CIDR, PgTypeInfo::INET, PgTypeInfo::CIDR_ARRAY, PgTypeInfo::INET_ARRAY, ] .contains(self) { Some("ipnetwork") } else if [PgTypeInfo::MACADDR].contains(self) { Some("mac_address") } else if [PgTypeInfo::NUMERIC, PgTypeInfo::NUMERIC_ARRAY].contains(self) { Some("bigdecimal") } else { None } } /// Create a `PgTypeInfo` from a type name. /// /// The OID for the type will be fetched from Postgres on use of /// a value of this type. The fetched OID will be cached per-connection. pub const fn with_name(name: &'static str) -> Self { Self(PgType::DeclareWithName(UStr::Static(name))) } /// Create a `PgTypeInfo` from an OID. /// /// Note that the OID for a type is very dependent on the environment. If you only ever use /// one database or if this is an unhandled built-in type, you should be fine. Otherwise, /// you will be better served using [`with_name`](Self::with_name). pub const fn with_oid(oid: Oid) -> Self { Self(PgType::DeclareWithOid(oid)) } } // DEVELOPER PRO TIP: find builtin type OIDs easily by grepping this file // https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_type.dat // // If you have Postgres running locally you can also try // SELECT oid, typarray FROM pg_type where typname = '' impl PgType { /// Returns the corresponding `PgType` if the OID is a built-in type and recognized by SQLx. pub(crate) fn try_from_oid(oid: Oid) -> Option { Some(match oid.0 { 16 => PgType::Bool, 17 => PgType::Bytea, 18 => PgType::Char, 19 => PgType::Name, 20 => PgType::Int8, 21 => PgType::Int2, 23 => PgType::Int4, 25 => PgType::Text, 26 => PgType::Oid, 114 => PgType::Json, 199 => PgType::JsonArray, 600 => PgType::Point, 601 => PgType::Lseg, 602 => PgType::Path, 603 => PgType::Box, 604 => PgType::Polygon, 628 => PgType::Line, 629 => PgType::LineArray, 650 => PgType::Cidr, 651 => PgType::CidrArray, 700 => PgType::Float4, 701 => PgType::Float8, 705 => PgType::Unknown, 718 => PgType::Circle, 719 => PgType::CircleArray, 774 => PgType::Macaddr8, 775 => PgType::Macaddr8Array, 790 => PgType::Money, 791 => PgType::MoneyArray, 829 => PgType::Macaddr, 869 => PgType::Inet, 1000 => PgType::BoolArray, 1001 => PgType::ByteaArray, 1002 => PgType::CharArray, 1003 => PgType::NameArray, 1005 => PgType::Int2Array, 1007 => PgType::Int4Array, 1009 => PgType::TextArray, 1014 => PgType::BpcharArray, 1015 => PgType::VarcharArray, 1016 => PgType::Int8Array, 1017 => PgType::PointArray, 1018 => PgType::LsegArray, 1019 => PgType::PathArray, 1020 => PgType::BoxArray, 1021 => PgType::Float4Array, 1022 => PgType::Float8Array, 1027 => PgType::PolygonArray, 1028 => PgType::OidArray, 1040 => PgType::MacaddrArray, 1041 => PgType::InetArray, 1042 => PgType::Bpchar, 1043 => PgType::Varchar, 1082 => PgType::Date, 1083 => PgType::Time, 1114 => PgType::Timestamp, 1115 => PgType::TimestampArray, 1182 => PgType::DateArray, 1183 => PgType::TimeArray, 1184 => PgType::Timestamptz, 1185 => PgType::TimestamptzArray, 1186 => PgType::Interval, 1187 => PgType::IntervalArray, 1231 => PgType::NumericArray, 1266 => PgType::Timetz, 1270 => PgType::TimetzArray, 1560 => PgType::Bit, 1561 => PgType::BitArray, 1562 => PgType::Varbit, 1563 => PgType::VarbitArray, 1700 => PgType::Numeric, 2278 => PgType::Void, 2249 => PgType::Record, 2287 => PgType::RecordArray, 2950 => PgType::Uuid, 2951 => PgType::UuidArray, 3802 => PgType::Jsonb, 3807 => PgType::JsonbArray, 3904 => PgType::Int4Range, 3905 => PgType::Int4RangeArray, 3906 => PgType::NumRange, 3907 => PgType::NumRangeArray, 3908 => PgType::TsRange, 3909 => PgType::TsRangeArray, 3910 => PgType::TstzRange, 3911 => PgType::TstzRangeArray, 3912 => PgType::DateRange, 3913 => PgType::DateRangeArray, 3926 => PgType::Int8Range, 3927 => PgType::Int8RangeArray, 4072 => PgType::Jsonpath, 4073 => PgType::JsonpathArray, _ => { return None; } }) } pub(crate) fn oid(&self) -> Oid { match self.try_oid() { Some(oid) => oid, None => unreachable!("(bug) use of unresolved type declaration [oid]"), } } pub(crate) fn try_oid(&self) -> Option { Some(match self { PgType::Bool => Oid(16), PgType::Bytea => Oid(17), PgType::Char => Oid(18), PgType::Name => Oid(19), PgType::Int8 => Oid(20), PgType::Int2 => Oid(21), PgType::Int4 => Oid(23), PgType::Text => Oid(25), PgType::Oid => Oid(26), PgType::Json => Oid(114), PgType::JsonArray => Oid(199), PgType::Point => Oid(600), PgType::Lseg => Oid(601), PgType::Path => Oid(602), PgType::Box => Oid(603), PgType::Polygon => Oid(604), PgType::Line => Oid(628), PgType::LineArray => Oid(629), PgType::Cidr => Oid(650), PgType::CidrArray => Oid(651), PgType::Float4 => Oid(700), PgType::Float8 => Oid(701), PgType::Unknown => Oid(705), PgType::Circle => Oid(718), PgType::CircleArray => Oid(719), PgType::Macaddr8 => Oid(774), PgType::Macaddr8Array => Oid(775), PgType::Money => Oid(790), PgType::MoneyArray => Oid(791), PgType::Macaddr => Oid(829), PgType::Inet => Oid(869), PgType::BoolArray => Oid(1000), PgType::ByteaArray => Oid(1001), PgType::CharArray => Oid(1002), PgType::NameArray => Oid(1003), PgType::Int2Array => Oid(1005), PgType::Int4Array => Oid(1007), PgType::TextArray => Oid(1009), PgType::BpcharArray => Oid(1014), PgType::VarcharArray => Oid(1015), PgType::Int8Array => Oid(1016), PgType::PointArray => Oid(1017), PgType::LsegArray => Oid(1018), PgType::PathArray => Oid(1019), PgType::BoxArray => Oid(1020), PgType::Float4Array => Oid(1021), PgType::Float8Array => Oid(1022), PgType::PolygonArray => Oid(1027), PgType::OidArray => Oid(1028), PgType::MacaddrArray => Oid(1040), PgType::InetArray => Oid(1041), PgType::Bpchar => Oid(1042), PgType::Varchar => Oid(1043), PgType::Date => Oid(1082), PgType::Time => Oid(1083), PgType::Timestamp => Oid(1114), PgType::TimestampArray => Oid(1115), PgType::DateArray => Oid(1182), PgType::TimeArray => Oid(1183), PgType::Timestamptz => Oid(1184), PgType::TimestamptzArray => Oid(1185), PgType::Interval => Oid(1186), PgType::IntervalArray => Oid(1187), PgType::NumericArray => Oid(1231), PgType::Timetz => Oid(1266), PgType::TimetzArray => Oid(1270), PgType::Bit => Oid(1560), PgType::BitArray => Oid(1561), PgType::Varbit => Oid(1562), PgType::VarbitArray => Oid(1563), PgType::Numeric => Oid(1700), PgType::Void => Oid(2278), PgType::Record => Oid(2249), PgType::RecordArray => Oid(2287), PgType::Uuid => Oid(2950), PgType::UuidArray => Oid(2951), PgType::Jsonb => Oid(3802), PgType::JsonbArray => Oid(3807), PgType::Int4Range => Oid(3904), PgType::Int4RangeArray => Oid(3905), PgType::NumRange => Oid(3906), PgType::NumRangeArray => Oid(3907), PgType::TsRange => Oid(3908), PgType::TsRangeArray => Oid(3909), PgType::TstzRange => Oid(3910), PgType::TstzRangeArray => Oid(3911), PgType::DateRange => Oid(3912), PgType::DateRangeArray => Oid(3913), PgType::Int8Range => Oid(3926), PgType::Int8RangeArray => Oid(3927), PgType::Jsonpath => Oid(4072), PgType::JsonpathArray => Oid(4073), PgType::Custom(ty) => ty.oid, PgType::DeclareWithOid(oid) => *oid, PgType::DeclareWithName(_) => { return None; } }) } pub(crate) fn display_name(&self) -> &str { match self { PgType::Bool => "BOOL", PgType::Bytea => "BYTEA", PgType::Char => "\"CHAR\"", PgType::Name => "NAME", PgType::Int8 => "INT8", PgType::Int2 => "INT2", PgType::Int4 => "INT4", PgType::Text => "TEXT", PgType::Oid => "OID", PgType::Json => "JSON", PgType::JsonArray => "JSON[]", PgType::Point => "POINT", PgType::Lseg => "LSEG", PgType::Path => "PATH", PgType::Box => "BOX", PgType::Polygon => "POLYGON", PgType::Line => "LINE", PgType::LineArray => "LINE[]", PgType::Cidr => "CIDR", PgType::CidrArray => "CIDR[]", PgType::Float4 => "FLOAT4", PgType::Float8 => "FLOAT8", PgType::Unknown => "UNKNOWN", PgType::Circle => "CIRCLE", PgType::CircleArray => "CIRCLE[]", PgType::Macaddr8 => "MACADDR8", PgType::Macaddr8Array => "MACADDR8[]", PgType::Macaddr => "MACADDR", PgType::Inet => "INET", PgType::BoolArray => "BOOL[]", PgType::ByteaArray => "BYTEA[]", PgType::CharArray => "\"CHAR\"[]", PgType::NameArray => "NAME[]", PgType::Int2Array => "INT2[]", PgType::Int4Array => "INT4[]", PgType::TextArray => "TEXT[]", PgType::BpcharArray => "CHAR[]", PgType::VarcharArray => "VARCHAR[]", PgType::Int8Array => "INT8[]", PgType::PointArray => "POINT[]", PgType::LsegArray => "LSEG[]", PgType::PathArray => "PATH[]", PgType::BoxArray => "BOX[]", PgType::Float4Array => "FLOAT4[]", PgType::Float8Array => "FLOAT8[]", PgType::PolygonArray => "POLYGON[]", PgType::OidArray => "OID[]", PgType::MacaddrArray => "MACADDR[]", PgType::InetArray => "INET[]", PgType::Bpchar => "CHAR", PgType::Varchar => "VARCHAR", PgType::Date => "DATE", PgType::Time => "TIME", PgType::Timestamp => "TIMESTAMP", PgType::TimestampArray => "TIMESTAMP[]", PgType::DateArray => "DATE[]", PgType::TimeArray => "TIME[]", PgType::Timestamptz => "TIMESTAMPTZ", PgType::TimestamptzArray => "TIMESTAMPTZ[]", PgType::Interval => "INTERVAL", PgType::IntervalArray => "INTERVAL[]", PgType::NumericArray => "NUMERIC[]", PgType::Timetz => "TIMETZ", PgType::TimetzArray => "TIMETZ[]", PgType::Bit => "BIT", PgType::BitArray => "BIT[]", PgType::Varbit => "VARBIT", PgType::VarbitArray => "VARBIT[]", PgType::Numeric => "NUMERIC", PgType::Record => "RECORD", PgType::RecordArray => "RECORD[]", PgType::Uuid => "UUID", PgType::UuidArray => "UUID[]", PgType::Jsonb => "JSONB", PgType::JsonbArray => "JSONB[]", PgType::Int4Range => "INT4RANGE", PgType::Int4RangeArray => "INT4RANGE[]", PgType::NumRange => "NUMRANGE", PgType::NumRangeArray => "NUMRANGE[]", PgType::TsRange => "TSRANGE", PgType::TsRangeArray => "TSRANGE[]", PgType::TstzRange => "TSTZRANGE", PgType::TstzRangeArray => "TSTZRANGE[]", PgType::DateRange => "DATERANGE", PgType::DateRangeArray => "DATERANGE[]", PgType::Int8Range => "INT8RANGE", PgType::Int8RangeArray => "INT8RANGE[]", PgType::Jsonpath => "JSONPATH", PgType::JsonpathArray => "JSONPATH[]", PgType::Money => "MONEY", PgType::MoneyArray => "MONEY[]", PgType::Void => "VOID", PgType::Custom(ty) => &*ty.name, PgType::DeclareWithOid(_) => "?", PgType::DeclareWithName(name) => name, } } pub(crate) fn name(&self) -> &str { match self { PgType::Bool => "bool", PgType::Bytea => "bytea", PgType::Char => "char", PgType::Name => "name", PgType::Int8 => "int8", PgType::Int2 => "int2", PgType::Int4 => "int4", PgType::Text => "text", PgType::Oid => "oid", PgType::Json => "json", PgType::JsonArray => "_json", PgType::Point => "point", PgType::Lseg => "lseg", PgType::Path => "path", PgType::Box => "box", PgType::Polygon => "polygon", PgType::Line => "line", PgType::LineArray => "_line", PgType::Cidr => "cidr", PgType::CidrArray => "_cidr", PgType::Float4 => "float4", PgType::Float8 => "float8", PgType::Unknown => "unknown", PgType::Circle => "circle", PgType::CircleArray => "_circle", PgType::Macaddr8 => "macaddr8", PgType::Macaddr8Array => "_macaddr8", PgType::Macaddr => "macaddr", PgType::Inet => "inet", PgType::BoolArray => "_bool", PgType::ByteaArray => "_bytea", PgType::CharArray => "_char", PgType::NameArray => "_name", PgType::Int2Array => "_int2", PgType::Int4Array => "_int4", PgType::TextArray => "_text", PgType::BpcharArray => "_bpchar", PgType::VarcharArray => "_varchar", PgType::Int8Array => "_int8", PgType::PointArray => "_point", PgType::LsegArray => "_lseg", PgType::PathArray => "_path", PgType::BoxArray => "_box", PgType::Float4Array => "_float4", PgType::Float8Array => "_float8", PgType::PolygonArray => "_polygon", PgType::OidArray => "_oid", PgType::MacaddrArray => "_macaddr", PgType::InetArray => "_inet", PgType::Bpchar => "bpchar", PgType::Varchar => "varchar", PgType::Date => "date", PgType::Time => "time", PgType::Timestamp => "timestamp", PgType::TimestampArray => "_timestamp", PgType::DateArray => "_date", PgType::TimeArray => "_time", PgType::Timestamptz => "timestamptz", PgType::TimestamptzArray => "_timestamptz", PgType::Interval => "interval", PgType::IntervalArray => "_interval", PgType::NumericArray => "_numeric", PgType::Timetz => "timetz", PgType::TimetzArray => "_timetz", PgType::Bit => "bit", PgType::BitArray => "_bit", PgType::Varbit => "varbit", PgType::VarbitArray => "_varbit", PgType::Numeric => "numeric", PgType::Record => "record", PgType::RecordArray => "_record", PgType::Uuid => "uuid", PgType::UuidArray => "_uuid", PgType::Jsonb => "jsonb", PgType::JsonbArray => "_jsonb", PgType::Int4Range => "int4range", PgType::Int4RangeArray => "_int4range", PgType::NumRange => "numrange", PgType::NumRangeArray => "_numrange", PgType::TsRange => "tsrange", PgType::TsRangeArray => "_tsrange", PgType::TstzRange => "tstzrange", PgType::TstzRangeArray => "_tstzrange", PgType::DateRange => "daterange", PgType::DateRangeArray => "_daterange", PgType::Int8Range => "int8range", PgType::Int8RangeArray => "_int8range", PgType::Jsonpath => "jsonpath", PgType::JsonpathArray => "_jsonpath", PgType::Money => "money", PgType::MoneyArray => "_money", PgType::Void => "void", PgType::Custom(ty) => &*ty.name, PgType::DeclareWithOid(_) => "?", PgType::DeclareWithName(name) => name, } } pub(crate) fn kind(&self) -> &PgTypeKind { match self { PgType::Bool => &PgTypeKind::Simple, PgType::Bytea => &PgTypeKind::Simple, PgType::Char => &PgTypeKind::Simple, PgType::Name => &PgTypeKind::Simple, PgType::Int8 => &PgTypeKind::Simple, PgType::Int2 => &PgTypeKind::Simple, PgType::Int4 => &PgTypeKind::Simple, PgType::Text => &PgTypeKind::Simple, PgType::Oid => &PgTypeKind::Simple, PgType::Json => &PgTypeKind::Simple, PgType::JsonArray => &PgTypeKind::Array(PgTypeInfo(PgType::Json)), PgType::Point => &PgTypeKind::Simple, PgType::Lseg => &PgTypeKind::Simple, PgType::Path => &PgTypeKind::Simple, PgType::Box => &PgTypeKind::Simple, PgType::Polygon => &PgTypeKind::Simple, PgType::Line => &PgTypeKind::Simple, PgType::LineArray => &PgTypeKind::Array(PgTypeInfo(PgType::Line)), PgType::Cidr => &PgTypeKind::Simple, PgType::CidrArray => &PgTypeKind::Array(PgTypeInfo(PgType::Cidr)), PgType::Float4 => &PgTypeKind::Simple, PgType::Float8 => &PgTypeKind::Simple, PgType::Unknown => &PgTypeKind::Simple, PgType::Circle => &PgTypeKind::Simple, PgType::CircleArray => &PgTypeKind::Array(PgTypeInfo(PgType::Circle)), PgType::Macaddr8 => &PgTypeKind::Simple, PgType::Macaddr8Array => &PgTypeKind::Array(PgTypeInfo(PgType::Macaddr8)), PgType::Macaddr => &PgTypeKind::Simple, PgType::Inet => &PgTypeKind::Simple, PgType::BoolArray => &PgTypeKind::Array(PgTypeInfo(PgType::Bool)), PgType::ByteaArray => &PgTypeKind::Array(PgTypeInfo(PgType::Bytea)), PgType::CharArray => &PgTypeKind::Array(PgTypeInfo(PgType::Char)), PgType::NameArray => &PgTypeKind::Array(PgTypeInfo(PgType::Name)), PgType::Int2Array => &PgTypeKind::Array(PgTypeInfo(PgType::Int2)), PgType::Int4Array => &PgTypeKind::Array(PgTypeInfo(PgType::Int4)), PgType::TextArray => &PgTypeKind::Array(PgTypeInfo(PgType::Text)), PgType::BpcharArray => &PgTypeKind::Array(PgTypeInfo(PgType::Bpchar)), PgType::VarcharArray => &PgTypeKind::Array(PgTypeInfo(PgType::Varchar)), PgType::Int8Array => &PgTypeKind::Array(PgTypeInfo(PgType::Int8)), PgType::PointArray => &PgTypeKind::Array(PgTypeInfo(PgType::Point)), PgType::LsegArray => &PgTypeKind::Array(PgTypeInfo(PgType::Lseg)), PgType::PathArray => &PgTypeKind::Array(PgTypeInfo(PgType::Path)), PgType::BoxArray => &PgTypeKind::Array(PgTypeInfo(PgType::Box)), PgType::Float4Array => &PgTypeKind::Array(PgTypeInfo(PgType::Float4)), PgType::Float8Array => &PgTypeKind::Array(PgTypeInfo(PgType::Float8)), PgType::PolygonArray => &PgTypeKind::Array(PgTypeInfo(PgType::Polygon)), PgType::OidArray => &PgTypeKind::Array(PgTypeInfo(PgType::Oid)), PgType::MacaddrArray => &PgTypeKind::Array(PgTypeInfo(PgType::Macaddr)), PgType::InetArray => &PgTypeKind::Array(PgTypeInfo(PgType::Inet)), PgType::Bpchar => &PgTypeKind::Simple, PgType::Varchar => &PgTypeKind::Simple, PgType::Date => &PgTypeKind::Simple, PgType::Time => &PgTypeKind::Simple, PgType::Timestamp => &PgTypeKind::Simple, PgType::TimestampArray => &PgTypeKind::Array(PgTypeInfo(PgType::Timestamp)), PgType::DateArray => &PgTypeKind::Array(PgTypeInfo(PgType::Date)), PgType::TimeArray => &PgTypeKind::Array(PgTypeInfo(PgType::Time)), PgType::Timestamptz => &PgTypeKind::Simple, PgType::TimestamptzArray => &PgTypeKind::Array(PgTypeInfo(PgType::Timestamptz)), PgType::Interval => &PgTypeKind::Simple, PgType::IntervalArray => &PgTypeKind::Array(PgTypeInfo(PgType::Interval)), PgType::NumericArray => &PgTypeKind::Array(PgTypeInfo(PgType::Numeric)), PgType::Timetz => &PgTypeKind::Simple, PgType::TimetzArray => &PgTypeKind::Array(PgTypeInfo(PgType::Timetz)), PgType::Bit => &PgTypeKind::Simple, PgType::BitArray => &PgTypeKind::Array(PgTypeInfo(PgType::Bit)), PgType::Varbit => &PgTypeKind::Simple, PgType::VarbitArray => &PgTypeKind::Array(PgTypeInfo(PgType::Varbit)), PgType::Numeric => &PgTypeKind::Simple, PgType::Record => &PgTypeKind::Simple, PgType::RecordArray => &PgTypeKind::Array(PgTypeInfo(PgType::Record)), PgType::Uuid => &PgTypeKind::Simple, PgType::UuidArray => &PgTypeKind::Array(PgTypeInfo(PgType::Uuid)), PgType::Jsonb => &PgTypeKind::Simple, PgType::JsonbArray => &PgTypeKind::Array(PgTypeInfo(PgType::Jsonb)), PgType::Int4Range => &PgTypeKind::Range(PgTypeInfo::INT4), PgType::Int4RangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::Int4Range)), PgType::NumRange => &PgTypeKind::Range(PgTypeInfo::NUMERIC), PgType::NumRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::NumRange)), PgType::TsRange => &PgTypeKind::Range(PgTypeInfo::TIMESTAMP), PgType::TsRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::TsRange)), PgType::TstzRange => &PgTypeKind::Range(PgTypeInfo::TIMESTAMPTZ), PgType::TstzRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::TstzRange)), PgType::DateRange => &PgTypeKind::Range(PgTypeInfo::DATE), PgType::DateRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::DateRange)), PgType::Int8Range => &PgTypeKind::Range(PgTypeInfo::INT8), PgType::Int8RangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::Int8Range)), PgType::Jsonpath => &PgTypeKind::Simple, PgType::JsonpathArray => &PgTypeKind::Array(PgTypeInfo(PgType::Jsonpath)), PgType::Money => &PgTypeKind::Simple, PgType::MoneyArray => &PgTypeKind::Array(PgTypeInfo(PgType::Money)), PgType::Void => &PgTypeKind::Pseudo, PgType::Custom(ty) => &ty.kind, PgType::DeclareWithOid(oid) => { unreachable!("(bug) use of unresolved type declaration [oid={}]", oid.0); } PgType::DeclareWithName(name) => { unreachable!("(bug) use of unresolved type declaration [name={name}]"); } } } /// If `self` is an array type, return the type info for its element. /// /// This method should only be called on resolved types: calling it on /// a type that is merely declared (DeclareWithOid/Name) is a bug. pub(crate) fn try_array_element(&self) -> Option> { // We explicitly match on all the `None` cases to ensure an exhaustive match. match self { PgType::Bool => None, PgType::BoolArray => Some(Cow::Owned(PgTypeInfo(PgType::Bool))), PgType::Bytea => None, PgType::ByteaArray => Some(Cow::Owned(PgTypeInfo(PgType::Bytea))), PgType::Char => None, PgType::CharArray => Some(Cow::Owned(PgTypeInfo(PgType::Char))), PgType::Name => None, PgType::NameArray => Some(Cow::Owned(PgTypeInfo(PgType::Name))), PgType::Int8 => None, PgType::Int8Array => Some(Cow::Owned(PgTypeInfo(PgType::Int8))), PgType::Int2 => None, PgType::Int2Array => Some(Cow::Owned(PgTypeInfo(PgType::Int2))), PgType::Int4 => None, PgType::Int4Array => Some(Cow::Owned(PgTypeInfo(PgType::Int4))), PgType::Text => None, PgType::TextArray => Some(Cow::Owned(PgTypeInfo(PgType::Text))), PgType::Oid => None, PgType::OidArray => Some(Cow::Owned(PgTypeInfo(PgType::Oid))), PgType::Json => None, PgType::JsonArray => Some(Cow::Owned(PgTypeInfo(PgType::Json))), PgType::Point => None, PgType::PointArray => Some(Cow::Owned(PgTypeInfo(PgType::Point))), PgType::Lseg => None, PgType::LsegArray => Some(Cow::Owned(PgTypeInfo(PgType::Lseg))), PgType::Path => None, PgType::PathArray => Some(Cow::Owned(PgTypeInfo(PgType::Path))), PgType::Box => None, PgType::BoxArray => Some(Cow::Owned(PgTypeInfo(PgType::Box))), PgType::Polygon => None, PgType::PolygonArray => Some(Cow::Owned(PgTypeInfo(PgType::Polygon))), PgType::Line => None, PgType::LineArray => Some(Cow::Owned(PgTypeInfo(PgType::Line))), PgType::Cidr => None, PgType::CidrArray => Some(Cow::Owned(PgTypeInfo(PgType::Cidr))), PgType::Float4 => None, PgType::Float4Array => Some(Cow::Owned(PgTypeInfo(PgType::Float4))), PgType::Float8 => None, PgType::Float8Array => Some(Cow::Owned(PgTypeInfo(PgType::Float8))), PgType::Circle => None, PgType::CircleArray => Some(Cow::Owned(PgTypeInfo(PgType::Circle))), PgType::Macaddr8 => None, PgType::Macaddr8Array => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr8))), PgType::Money => None, PgType::MoneyArray => Some(Cow::Owned(PgTypeInfo(PgType::Money))), PgType::Macaddr => None, PgType::MacaddrArray => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr))), PgType::Inet => None, PgType::InetArray => Some(Cow::Owned(PgTypeInfo(PgType::Inet))), PgType::Bpchar => None, PgType::BpcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Bpchar))), PgType::Varchar => None, PgType::VarcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Varchar))), PgType::Date => None, PgType::DateArray => Some(Cow::Owned(PgTypeInfo(PgType::Date))), PgType::Time => None, PgType::TimeArray => Some(Cow::Owned(PgTypeInfo(PgType::Time))), PgType::Timestamp => None, PgType::TimestampArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamp))), PgType::Timestamptz => None, PgType::TimestamptzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamptz))), PgType::Interval => None, PgType::IntervalArray => Some(Cow::Owned(PgTypeInfo(PgType::Interval))), PgType::Timetz => None, PgType::TimetzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timetz))), PgType::Bit => None, PgType::BitArray => Some(Cow::Owned(PgTypeInfo(PgType::Bit))), PgType::Varbit => None, PgType::VarbitArray => Some(Cow::Owned(PgTypeInfo(PgType::Varbit))), PgType::Numeric => None, PgType::NumericArray => Some(Cow::Owned(PgTypeInfo(PgType::Numeric))), PgType::Record => None, PgType::RecordArray => Some(Cow::Owned(PgTypeInfo(PgType::Record))), PgType::Uuid => None, PgType::UuidArray => Some(Cow::Owned(PgTypeInfo(PgType::Uuid))), PgType::Jsonb => None, PgType::JsonbArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonb))), PgType::Int4Range => None, PgType::Int4RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int4Range))), PgType::NumRange => None, PgType::NumRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::NumRange))), PgType::TsRange => None, PgType::TsRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TsRange))), PgType::TstzRange => None, PgType::TstzRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TstzRange))), PgType::DateRange => None, PgType::DateRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::DateRange))), PgType::Int8Range => None, PgType::Int8RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int8Range))), PgType::Jsonpath => None, PgType::JsonpathArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonpath))), // There is no `UnknownArray` PgType::Unknown => None, // There is no `VoidArray` PgType::Void => None, PgType::Custom(ty) => match &ty.kind { PgTypeKind::Simple => None, PgTypeKind::Pseudo => None, PgTypeKind::Domain(_) => None, PgTypeKind::Composite(_) => None, PgTypeKind::Array(ref elem_type_info) => Some(Cow::Borrowed(elem_type_info)), PgTypeKind::Enum(_) => None, PgTypeKind::Range(_) => None, }, PgType::DeclareWithOid(oid) => { unreachable!("(bug) use of unresolved type declaration [oid={}]", oid.0); } PgType::DeclareWithName(name) => { unreachable!("(bug) use of unresolved type declaration [name={name}]"); } } } } impl TypeInfo for PgTypeInfo { fn name(&self) -> &str { self.0.display_name() } fn is_null(&self) -> bool { false } fn is_void(&self) -> bool { matches!(self.0, PgType::Void) } } impl PartialEq for PgCustomType { fn eq(&self, other: &PgCustomType) -> bool { other.oid == self.oid } } impl PgTypeInfo { // boolean, state of true or false pub(crate) const BOOL: Self = Self(PgType::Bool); pub(crate) const BOOL_ARRAY: Self = Self(PgType::BoolArray); // binary data types, variable-length binary string pub(crate) const BYTEA: Self = Self(PgType::Bytea); pub(crate) const BYTEA_ARRAY: Self = Self(PgType::ByteaArray); // uuid pub(crate) const UUID: Self = Self(PgType::Uuid); pub(crate) const UUID_ARRAY: Self = Self(PgType::UuidArray); // record pub(crate) const RECORD: Self = Self(PgType::Record); pub(crate) const RECORD_ARRAY: Self = Self(PgType::RecordArray); // // JSON types // https://www.postgresql.org/docs/current/datatype-json.html // pub(crate) const JSON: Self = Self(PgType::Json); pub(crate) const JSON_ARRAY: Self = Self(PgType::JsonArray); pub(crate) const JSONB: Self = Self(PgType::Jsonb); pub(crate) const JSONB_ARRAY: Self = Self(PgType::JsonbArray); pub(crate) const JSONPATH: Self = Self(PgType::Jsonpath); pub(crate) const JSONPATH_ARRAY: Self = Self(PgType::JsonpathArray); // // network address types // https://www.postgresql.org/docs/current/datatype-net-types.html // pub(crate) const CIDR: Self = Self(PgType::Cidr); pub(crate) const CIDR_ARRAY: Self = Self(PgType::CidrArray); pub(crate) const INET: Self = Self(PgType::Inet); pub(crate) const INET_ARRAY: Self = Self(PgType::InetArray); pub(crate) const MACADDR: Self = Self(PgType::Macaddr); pub(crate) const MACADDR_ARRAY: Self = Self(PgType::MacaddrArray); pub(crate) const MACADDR8: Self = Self(PgType::Macaddr8); pub(crate) const MACADDR8_ARRAY: Self = Self(PgType::Macaddr8Array); // // character types // https://www.postgresql.org/docs/current/datatype-character.html // // internal type for object names pub(crate) const NAME: Self = Self(PgType::Name); pub(crate) const NAME_ARRAY: Self = Self(PgType::NameArray); // character type, fixed-length, blank-padded pub(crate) const BPCHAR: Self = Self(PgType::Bpchar); pub(crate) const BPCHAR_ARRAY: Self = Self(PgType::BpcharArray); // character type, variable-length with limit pub(crate) const VARCHAR: Self = Self(PgType::Varchar); pub(crate) const VARCHAR_ARRAY: Self = Self(PgType::VarcharArray); // character type, variable-length pub(crate) const TEXT: Self = Self(PgType::Text); pub(crate) const TEXT_ARRAY: Self = Self(PgType::TextArray); // unknown type, transmitted as text pub(crate) const UNKNOWN: Self = Self(PgType::Unknown); // // numeric types // https://www.postgresql.org/docs/current/datatype-numeric.html // // single-byte internal type pub(crate) const CHAR: Self = Self(PgType::Char); pub(crate) const CHAR_ARRAY: Self = Self(PgType::CharArray); // internal type for type ids pub(crate) const OID: Self = Self(PgType::Oid); pub(crate) const OID_ARRAY: Self = Self(PgType::OidArray); // small-range integer; -32768 to +32767 pub(crate) const INT2: Self = Self(PgType::Int2); pub(crate) const INT2_ARRAY: Self = Self(PgType::Int2Array); // typical choice for integer; -2147483648 to +2147483647 pub(crate) const INT4: Self = Self(PgType::Int4); pub(crate) const INT4_ARRAY: Self = Self(PgType::Int4Array); // large-range integer; -9223372036854775808 to +9223372036854775807 pub(crate) const INT8: Self = Self(PgType::Int8); pub(crate) const INT8_ARRAY: Self = Self(PgType::Int8Array); // variable-precision, inexact, 6 decimal digits precision pub(crate) const FLOAT4: Self = Self(PgType::Float4); pub(crate) const FLOAT4_ARRAY: Self = Self(PgType::Float4Array); // variable-precision, inexact, 15 decimal digits precision pub(crate) const FLOAT8: Self = Self(PgType::Float8); pub(crate) const FLOAT8_ARRAY: Self = Self(PgType::Float8Array); // user-specified precision, exact pub(crate) const NUMERIC: Self = Self(PgType::Numeric); pub(crate) const NUMERIC_ARRAY: Self = Self(PgType::NumericArray); // user-specified precision, exact pub(crate) const MONEY: Self = Self(PgType::Money); pub(crate) const MONEY_ARRAY: Self = Self(PgType::MoneyArray); // // date/time types // https://www.postgresql.org/docs/current/datatype-datetime.html // // both date and time (no time zone) pub(crate) const TIMESTAMP: Self = Self(PgType::Timestamp); pub(crate) const TIMESTAMP_ARRAY: Self = Self(PgType::TimestampArray); // both date and time (with time zone) pub(crate) const TIMESTAMPTZ: Self = Self(PgType::Timestamptz); pub(crate) const TIMESTAMPTZ_ARRAY: Self = Self(PgType::TimestamptzArray); // date (no time of day) pub(crate) const DATE: Self = Self(PgType::Date); pub(crate) const DATE_ARRAY: Self = Self(PgType::DateArray); // time of day (no date) pub(crate) const TIME: Self = Self(PgType::Time); pub(crate) const TIME_ARRAY: Self = Self(PgType::TimeArray); // time of day (no date), with time zone pub(crate) const TIMETZ: Self = Self(PgType::Timetz); pub(crate) const TIMETZ_ARRAY: Self = Self(PgType::TimetzArray); // time interval pub(crate) const INTERVAL: Self = Self(PgType::Interval); pub(crate) const INTERVAL_ARRAY: Self = Self(PgType::IntervalArray); // // geometric types // https://www.postgresql.org/docs/current/datatype-geometric.html // // point on a plane pub(crate) const POINT: Self = Self(PgType::Point); pub(crate) const POINT_ARRAY: Self = Self(PgType::PointArray); // infinite line pub(crate) const LINE: Self = Self(PgType::Line); pub(crate) const LINE_ARRAY: Self = Self(PgType::LineArray); // finite line segment pub(crate) const LSEG: Self = Self(PgType::Lseg); pub(crate) const LSEG_ARRAY: Self = Self(PgType::LsegArray); // rectangular box pub(crate) const BOX: Self = Self(PgType::Box); pub(crate) const BOX_ARRAY: Self = Self(PgType::BoxArray); // open or closed path pub(crate) const PATH: Self = Self(PgType::Path); pub(crate) const PATH_ARRAY: Self = Self(PgType::PathArray); // polygon pub(crate) const POLYGON: Self = Self(PgType::Polygon); pub(crate) const POLYGON_ARRAY: Self = Self(PgType::PolygonArray); // circle pub(crate) const CIRCLE: Self = Self(PgType::Circle); pub(crate) const CIRCLE_ARRAY: Self = Self(PgType::CircleArray); // // bit string types // https://www.postgresql.org/docs/current/datatype-bit.html // pub(crate) const BIT: Self = Self(PgType::Bit); pub(crate) const BIT_ARRAY: Self = Self(PgType::BitArray); pub(crate) const VARBIT: Self = Self(PgType::Varbit); pub(crate) const VARBIT_ARRAY: Self = Self(PgType::VarbitArray); // // range types // https://www.postgresql.org/docs/current/rangetypes.html // pub(crate) const INT4_RANGE: Self = Self(PgType::Int4Range); pub(crate) const INT4_RANGE_ARRAY: Self = Self(PgType::Int4RangeArray); pub(crate) const NUM_RANGE: Self = Self(PgType::NumRange); pub(crate) const NUM_RANGE_ARRAY: Self = Self(PgType::NumRangeArray); pub(crate) const TS_RANGE: Self = Self(PgType::TsRange); pub(crate) const TS_RANGE_ARRAY: Self = Self(PgType::TsRangeArray); pub(crate) const TSTZ_RANGE: Self = Self(PgType::TstzRange); pub(crate) const TSTZ_RANGE_ARRAY: Self = Self(PgType::TstzRangeArray); pub(crate) const DATE_RANGE: Self = Self(PgType::DateRange); pub(crate) const DATE_RANGE_ARRAY: Self = Self(PgType::DateRangeArray); pub(crate) const INT8_RANGE: Self = Self(PgType::Int8Range); pub(crate) const INT8_RANGE_ARRAY: Self = Self(PgType::Int8RangeArray); // // pseudo types // https://www.postgresql.org/docs/9.3/datatype-pseudo.html // pub(crate) const VOID: Self = Self(PgType::Void); } impl Display for PgTypeInfo { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.pad(self.name()) } } impl PartialEq for PgType { fn eq(&self, other: &PgType) -> bool { if let (Some(a), Some(b)) = (self.try_oid(), other.try_oid()) { // If there are OIDs available, use OIDs to perform a direct match a == b } else if matches!( (self, other), (PgType::DeclareWithName(_), PgType::DeclareWithOid(_)) | (PgType::DeclareWithOid(_), PgType::DeclareWithName(_)) ) { // One is a declare-with-name and the other is a declare-with-id // This only occurs in the TEXT protocol with custom types // Just opt-out of type checking here true } else { // Otherwise, perform a match on the name self.name().eq_ignore_ascii_case(other.name()) } } } sqlx-postgres-0.7.3/src/types/array.rs000064400000000000000000000242240072674642500161410ustar 00000000000000use sqlx_core::bytes::Buf; use sqlx_core::types::Text; use std::borrow::Cow; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::type_info::PgType; use crate::types::Oid; use crate::types::Type; use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; /// Provides information necessary to encode and decode Postgres arrays as compatible Rust types. /// /// Implementing this trait for some type `T` enables relevant `Type`,`Encode` and `Decode` impls /// for `Vec`, `&[T]` (slices), `[T; N]` (arrays), etc. /// /// ### Note: `#[derive(sqlx::Type)]` /// If you have the `postgres` feature enabled, `#[derive(sqlx::Type)]` will also generate /// an impl of this trait for your type if your wrapper is marked `#[sqlx(transparent)]`: /// /// ```rust,ignore /// #[derive(sqlx::Type)] /// #[sqlx(transparent)] /// struct UserId(i64); /// /// let user_ids: Vec = sqlx::query_scalar("select '{ 123, 456 }'::int8[]") /// .fetch(&mut pg_connection) /// .await?; /// ``` /// /// However, this may cause an error if the type being wrapped does not implement `PgHasArrayType`, /// e.g. `Vec` itself, because we don't currently support multidimensional arrays: /// /// ```rust,ignore /// #[derive(sqlx::Type)] // ERROR: `Vec` does not implement `PgHasArrayType` /// #[sqlx(transparent)] /// struct UserIds(Vec); /// ``` /// /// To remedy this, add `#[sqlx(no_pg_array)]`, which disables the generation /// of the `PgHasArrayType` impl: /// /// ```rust,ignore /// #[derive(sqlx::Type)] /// #[sqlx(transparent, no_pg_array)] /// struct UserIds(Vec); /// ``` /// /// See [the documentation of `Type`][Type] for more details. pub trait PgHasArrayType { fn array_type_info() -> PgTypeInfo; fn array_compatible(ty: &PgTypeInfo) -> bool { *ty == Self::array_type_info() } } impl PgHasArrayType for Option where T: PgHasArrayType, { fn array_type_info() -> PgTypeInfo { T::array_type_info() } fn array_compatible(ty: &PgTypeInfo) -> bool { T::array_compatible(ty) } } impl PgHasArrayType for Text { fn array_type_info() -> PgTypeInfo { String::array_type_info() } fn array_compatible(ty: &PgTypeInfo) -> bool { String::array_compatible(ty) } } impl Type for [T] where T: PgHasArrayType, { fn type_info() -> PgTypeInfo { T::array_type_info() } fn compatible(ty: &PgTypeInfo) -> bool { T::array_compatible(ty) } } impl Type for Vec where T: PgHasArrayType, { fn type_info() -> PgTypeInfo { T::array_type_info() } fn compatible(ty: &PgTypeInfo) -> bool { T::array_compatible(ty) } } impl Type for [T; N] where T: PgHasArrayType, { fn type_info() -> PgTypeInfo { T::array_type_info() } fn compatible(ty: &PgTypeInfo) -> bool { T::array_compatible(ty) } } impl<'q, T> Encode<'q, Postgres> for Vec where for<'a> &'a [T]: Encode<'q, Postgres>, T: Encode<'q, Postgres>, { #[inline] fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { self.as_slice().encode_by_ref(buf) } } impl<'q, T, const N: usize> Encode<'q, Postgres> for [T; N] where for<'a> &'a [T]: Encode<'q, Postgres>, T: Encode<'q, Postgres>, { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { self.as_slice().encode_by_ref(buf) } } impl<'q, T> Encode<'q, Postgres> for &'_ [T] where T: Encode<'q, Postgres> + Type, { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { let type_info = if self.len() < 1 { T::type_info() } else { self[0].produces().unwrap_or_else(T::type_info) }; buf.extend(&1_i32.to_be_bytes()); // number of dimensions buf.extend(&0_i32.to_be_bytes()); // flags // element type match type_info.0 { PgType::DeclareWithName(name) => buf.patch_type_by_name(&name), ty => { buf.extend(&ty.oid().0.to_be_bytes()); } } buf.extend(&(self.len() as i32).to_be_bytes()); // len buf.extend(&1_i32.to_be_bytes()); // lower bound for element in self.iter() { buf.encode(element); } IsNull::No } } impl<'r, T, const N: usize> Decode<'r, Postgres> for [T; N] where T: for<'a> Decode<'a, Postgres> + Type, { fn decode(value: PgValueRef<'r>) -> Result { // This could be done more efficiently by refactoring the Vec decoding below so that it can // be used for arrays and Vec. let vec: Vec = Decode::decode(value)?; let array: [T; N] = vec.try_into().map_err(|_| "wrong number of elements")?; Ok(array) } } impl<'r, T> Decode<'r, Postgres> for Vec where T: for<'a> Decode<'a, Postgres> + Type, { fn decode(value: PgValueRef<'r>) -> Result { let format = value.format(); match format { PgValueFormat::Binary => { // https://github.com/postgres/postgres/blob/a995b371ae29de2d38c4b7881cf414b1560e9746/src/backend/utils/adt/arrayfuncs.c#L1548 let mut buf = value.as_bytes()?; // number of dimensions in the array let ndim = buf.get_i32(); if ndim == 0 { // zero dimensions is an empty array return Ok(Vec::new()); } if ndim != 1 { return Err(format!("encountered an array of {ndim} dimensions; only one-dimensional arrays are supported").into()); } // appears to have been used in the past to communicate potential NULLS // but reading source code back through our supported postgres versions (9.5+) // this is never used for anything let _flags = buf.get_i32(); // the OID of the element let element_type_oid = Oid(buf.get_u32()); let element_type_info: PgTypeInfo = PgTypeInfo::try_from_oid(element_type_oid) .or_else(|| value.type_info.try_array_element().map(Cow::into_owned)) .ok_or_else(|| { BoxDynError::from(format!( "failed to resolve array element type for oid {}", element_type_oid.0 )) })?; // length of the array axis let len = buf.get_i32(); // the lower bound, we only support arrays starting from "1" let lower = buf.get_i32(); if lower != 1 { return Err(format!("encountered an array with a lower bound of {lower} in the first dimension; only arrays starting at one are supported").into()); } let mut elements = Vec::with_capacity(len as usize); for _ in 0..len { elements.push(T::decode(PgValueRef::get( &mut buf, format, element_type_info.clone(), ))?) } Ok(elements) } PgValueFormat::Text => { // no type is provided from the database for the element let element_type_info = T::type_info(); let s = value.as_str()?; // https://github.com/postgres/postgres/blob/a995b371ae29de2d38c4b7881cf414b1560e9746/src/backend/utils/adt/arrayfuncs.c#L718 // trim the wrapping braces let s = &s[1..(s.len() - 1)]; if s.is_empty() { // short-circuit empty arrays up here return Ok(Vec::new()); } // NOTE: Nearly *all* types use ',' as the sequence delimiter. Yes, there is one // that does not. The BOX (not PostGIS) type uses ';' as a delimiter. // TODO: When we add support for BOX we need to figure out some way to make the // delimiter selection let delimiter = ','; let mut done = false; let mut in_quotes = false; let mut in_escape = false; let mut value = String::with_capacity(10); let mut chars = s.chars(); let mut elements = Vec::with_capacity(4); while !done { loop { match chars.next() { Some(ch) => match ch { _ if in_escape => { value.push(ch); in_escape = false; } '"' => { in_quotes = !in_quotes; } '\\' => { in_escape = true; } _ if ch == delimiter && !in_quotes => { break; } _ => { value.push(ch); } }, None => { done = true; break; } } } let value_opt = if value == "NULL" { None } else { Some(value.as_bytes()) }; elements.push(T::decode(PgValueRef { value: value_opt, row: None, type_info: element_type_info.clone(), format, })?); value.clear(); } Ok(elements) } } } } sqlx-postgres-0.7.3/src/types/bigdecimal.rs000064400000000000000000000306100072674642500170770ustar 00000000000000use std::cmp; use bigdecimal::BigDecimal; use num_bigint::{BigInt, Sign}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::numeric::{PgNumeric, PgNumericSign}; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; impl Type for BigDecimal { fn type_info() -> PgTypeInfo { PgTypeInfo::NUMERIC } } impl PgHasArrayType for BigDecimal { fn array_type_info() -> PgTypeInfo { PgTypeInfo::NUMERIC_ARRAY } } impl TryFrom for BigDecimal { type Error = BoxDynError; fn try_from(numeric: PgNumeric) -> Result { let (digits, sign, weight) = match numeric { PgNumeric::Number { digits, sign, weight, .. } => (digits, sign, weight), PgNumeric::NotANumber => { return Err("BigDecimal does not support NaN values".into()); } }; if digits.is_empty() { // Postgres returns an empty digit array for 0 but BigInt expects at least one zero return Ok(0u64.into()); } let sign = match sign { PgNumericSign::Positive => Sign::Plus, PgNumericSign::Negative => Sign::Minus, }; // weight is 0 if the decimal point falls after the first base-10000 digit let scale = (digits.len() as i64 - weight as i64 - 1) * 4; // no optimized algorithm for base-10 so use base-100 for faster processing let mut cents = Vec::with_capacity(digits.len() * 2); for digit in &digits { cents.push((digit / 100) as u8); cents.push((digit % 100) as u8); } let bigint = BigInt::from_radix_be(sign, ¢s, 100) .ok_or("PgNumeric contained an out-of-range digit")?; Ok(BigDecimal::new(bigint, scale)) } } impl TryFrom<&'_ BigDecimal> for PgNumeric { type Error = BoxDynError; fn try_from(decimal: &BigDecimal) -> Result { let base_10_to_10000 = |chunk: &[u8]| chunk.iter().fold(0i16, |a, &d| a * 10 + d as i16); // NOTE: this unfortunately copies the BigInt internally let (integer, exp) = decimal.as_bigint_and_exponent(); // this routine is specifically optimized for base-10 // FIXME: is there a way to iterate over the digits to avoid the Vec allocation let (sign, base_10) = integer.to_radix_be(10); // weight is positive power of 10000 // exp is the negative power of 10 let weight_10 = base_10.len() as i64 - exp; // scale is only nonzero when we have fractional digits // since `exp` is the _negative_ decimal exponent, it tells us // exactly what our scale should be let scale: i16 = cmp::max(0, exp).try_into()?; // there's an implicit +1 offset in the interpretation let weight: i16 = if weight_10 <= 0 { weight_10 / 4 - 1 } else { // the `-1` is a fix for an off by 1 error (4 digits should still be 0 weight) (weight_10 - 1) / 4 } .try_into()?; let digits_len = if base_10.len() % 4 != 0 { base_10.len() / 4 + 1 } else { base_10.len() / 4 }; let offset = weight_10.rem_euclid(4) as usize; let mut digits = Vec::with_capacity(digits_len); if let Some(first) = base_10.get(..offset) { if !first.is_empty() { digits.push(base_10_to_10000(first)); } } else if offset != 0 { digits.push(base_10_to_10000(&base_10) * 10i16.pow((offset - base_10.len()) as u32)); } if let Some(rest) = base_10.get(offset..) { digits.extend( rest.chunks(4) .map(|chunk| base_10_to_10000(chunk) * 10i16.pow(4 - chunk.len() as u32)), ); } while let Some(&0) = digits.last() { digits.pop(); } Ok(PgNumeric::Number { sign: match sign { Sign::Plus | Sign::NoSign => PgNumericSign::Positive, Sign::Minus => PgNumericSign::Negative, }, scale, weight, digits, }) } } /// ### Panics /// If this `BigDecimal` cannot be represented by `PgNumeric`. impl Encode<'_, Postgres> for BigDecimal { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { PgNumeric::try_from(self) .expect("BigDecimal magnitude too great for Postgres NUMERIC type") .encode(buf); IsNull::No } fn size_hint(&self) -> usize { // BigDecimal::digits() gives us base-10 digits, so we divide by 4 to get base-10000 digits // and since this is just a hint we just always round up 8 + (self.digits() / 4 + 1) as usize * 2 } } impl Decode<'_, Postgres> for BigDecimal { fn decode(value: PgValueRef<'_>) -> Result { match value.format() { PgValueFormat::Binary => PgNumeric::decode(value.as_bytes()?)?.try_into(), PgValueFormat::Text => Ok(value.as_str()?.parse::()?), } } } #[cfg(test)] mod bigdecimal_to_pgnumeric { use super::{BigDecimal, PgNumeric, PgNumericSign}; use std::convert::TryFrom; #[test] fn zero() { let zero: BigDecimal = "0".parse().unwrap(); assert_eq!( PgNumeric::try_from(&zero).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 0, digits: vec![] } ); } #[test] fn one() { let one: BigDecimal = "1".parse().unwrap(); assert_eq!( PgNumeric::try_from(&one).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 0, digits: vec![1] } ); } #[test] fn ten() { let ten: BigDecimal = "10".parse().unwrap(); assert_eq!( PgNumeric::try_from(&ten).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 0, digits: vec![10] } ); } #[test] fn one_hundred() { let one_hundred: BigDecimal = "100".parse().unwrap(); assert_eq!( PgNumeric::try_from(&one_hundred).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 0, digits: vec![100] } ); } #[test] fn ten_thousand() { // BigDecimal doesn't normalize here let ten_thousand: BigDecimal = "10000".parse().unwrap(); assert_eq!( PgNumeric::try_from(&ten_thousand).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 1, digits: vec![1] } ); } #[test] fn two_digits() { let two_digits: BigDecimal = "12345".parse().unwrap(); assert_eq!( PgNumeric::try_from(&two_digits).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 1, digits: vec![1, 2345] } ); } #[test] fn one_tenth() { let one_tenth: BigDecimal = "0.1".parse().unwrap(); assert_eq!( PgNumeric::try_from(&one_tenth).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 1, weight: -1, digits: vec![1000] } ); } #[test] fn one_hundredth() { let one_hundredth: BigDecimal = "0.01".parse().unwrap(); assert_eq!( PgNumeric::try_from(&one_hundredth).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 2, weight: -1, digits: vec![100] } ); } #[test] fn twelve_thousandths() { let twelve_thousandths: BigDecimal = "0.012".parse().unwrap(); assert_eq!( PgNumeric::try_from(&twelve_thousandths).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 3, weight: -1, digits: vec![120] } ); } #[test] fn decimal_1() { let decimal: BigDecimal = "1.2345".parse().unwrap(); assert_eq!( PgNumeric::try_from(&decimal).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 4, weight: 0, digits: vec![1, 2345] } ); } #[test] fn decimal_2() { let decimal: BigDecimal = "0.12345".parse().unwrap(); assert_eq!( PgNumeric::try_from(&decimal).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 5, weight: -1, digits: vec![1234, 5000] } ); } #[test] fn decimal_3() { let decimal: BigDecimal = "0.01234".parse().unwrap(); assert_eq!( PgNumeric::try_from(&decimal).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 5, weight: -1, digits: vec![0123, 4000] } ); } #[test] fn decimal_4() { let decimal: BigDecimal = "12345.67890".parse().unwrap(); assert_eq!( PgNumeric::try_from(&decimal).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 5, weight: 1, digits: vec![1, 2345, 6789] } ); } #[test] fn one_digit_decimal() { let one_digit_decimal: BigDecimal = "0.00001234".parse().unwrap(); assert_eq!( PgNumeric::try_from(&one_digit_decimal).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 8, weight: -2, digits: vec![1234] } ); } #[test] fn issue_423_four_digit() { // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 let four_digit: BigDecimal = "1234".parse().unwrap(); assert_eq!( PgNumeric::try_from(&four_digit).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 0, digits: vec![1234] } ); } #[test] fn issue_423_negative_four_digit() { // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 let negative_four_digit: BigDecimal = "-1234".parse().unwrap(); assert_eq!( PgNumeric::try_from(&negative_four_digit).unwrap(), PgNumeric::Number { sign: PgNumericSign::Negative, scale: 0, weight: 0, digits: vec![1234] } ); } #[test] fn issue_423_eight_digit() { // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 let eight_digit: BigDecimal = "12345678".parse().unwrap(); assert_eq!( PgNumeric::try_from(&eight_digit).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 1, digits: vec![1234, 5678] } ); } #[test] fn issue_423_negative_eight_digit() { // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 let negative_eight_digit: BigDecimal = "-12345678".parse().unwrap(); assert_eq!( PgNumeric::try_from(&negative_eight_digit).unwrap(), PgNumeric::Number { sign: PgNumericSign::Negative, scale: 0, weight: 1, digits: vec![1234, 5678] } ); } } sqlx-postgres-0.7.3/src/types/bit_vec.rs000064400000000000000000000056510072674642500164410ustar 00000000000000use crate::{ decode::Decode, encode::{Encode, IsNull}, error::BoxDynError, types::Type, PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres, }; use bit_vec::BitVec; use sqlx_core::bytes::Buf; use std::{io, mem}; impl Type for BitVec { fn type_info() -> PgTypeInfo { PgTypeInfo::VARBIT } fn compatible(ty: &PgTypeInfo) -> bool { *ty == PgTypeInfo::BIT || *ty == PgTypeInfo::VARBIT } } impl PgHasArrayType for BitVec { fn array_type_info() -> PgTypeInfo { PgTypeInfo::VARBIT_ARRAY } fn array_compatible(ty: &PgTypeInfo) -> bool { *ty == PgTypeInfo::BIT_ARRAY || *ty == PgTypeInfo::VARBIT_ARRAY } } impl Encode<'_, Postgres> for BitVec { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.extend(&(self.len() as i32).to_be_bytes()); buf.extend(self.to_bytes()); IsNull::No } fn size_hint(&self) -> usize { mem::size_of::() + self.len() } } impl Decode<'_, Postgres> for BitVec { fn decode(value: PgValueRef<'_>) -> Result { match value.format() { PgValueFormat::Binary => { let mut bytes = value.as_bytes()?; let len = bytes.get_i32(); if len < 0 { Err(io::Error::new( io::ErrorKind::InvalidData, "Negative VARBIT length.", ))? } // The smallest amount of data we can read is one byte let bytes_len = (len as usize + 7) / 8; if bytes.remaining() != bytes_len { Err(io::Error::new( io::ErrorKind::InvalidData, "VARBIT length mismatch.", ))?; } let mut bitvec = BitVec::from_bytes(&bytes); // Chop off zeroes from the back. We get bits in bytes, so if // our bitvec is not in full bytes, extra zeroes are added to // the end. while bitvec.len() > len as usize { bitvec.pop(); } Ok(bitvec) } PgValueFormat::Text => { let s = value.as_str()?; let mut bit_vec = BitVec::with_capacity(s.len()); for c in s.chars() { match c { '0' => bit_vec.push(false), '1' => bit_vec.push(true), _ => { Err(io::Error::new( io::ErrorKind::InvalidData, "VARBIT data contains other characters than 1 or 0.", ))?; } } } Ok(bit_vec) } } } } sqlx-postgres-0.7.3/src/types/bool.rs000064400000000000000000000020650072674642500157550ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; impl Type for bool { fn type_info() -> PgTypeInfo { PgTypeInfo::BOOL } } impl PgHasArrayType for bool { fn array_type_info() -> PgTypeInfo { PgTypeInfo::BOOL_ARRAY } } impl Encode<'_, Postgres> for bool { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.push(*self as u8); IsNull::No } } impl Decode<'_, Postgres> for bool { fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => value.as_bytes()?[0] != 0, PgValueFormat::Text => match value.as_str()? { "t" => true, "f" => false, s => { return Err(format!("unexpected value {s:?} for boolean").into()); } }, }) } } sqlx-postgres-0.7.3/src/types/bytes.rs000064400000000000000000000066120072674642500161520ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; impl PgHasArrayType for u8 { fn array_type_info() -> PgTypeInfo { PgTypeInfo::BYTEA } } impl PgHasArrayType for &'_ [u8] { fn array_type_info() -> PgTypeInfo { PgTypeInfo::BYTEA_ARRAY } } impl PgHasArrayType for &'_ [u8; N] { fn array_type_info() -> PgTypeInfo { PgTypeInfo::BYTEA_ARRAY } } impl PgHasArrayType for Box<[u8]> { fn array_type_info() -> PgTypeInfo { <[&[u8]] as Type>::type_info() } } impl PgHasArrayType for Vec { fn array_type_info() -> PgTypeInfo { <[&[u8]] as Type>::type_info() } } impl PgHasArrayType for [u8; N] { fn array_type_info() -> PgTypeInfo { <[&[u8]] as Type>::type_info() } } impl Encode<'_, Postgres> for &'_ [u8] { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.extend_from_slice(self); IsNull::No } } impl Encode<'_, Postgres> for Box<[u8]> { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { <&[u8] as Encode>::encode(self.as_ref(), buf) } } impl Encode<'_, Postgres> for Vec { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { <&[u8] as Encode>::encode(self, buf) } } impl Encode<'_, Postgres> for [u8; N] { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { <&[u8] as Encode>::encode(self.as_slice(), buf) } } impl<'r> Decode<'r, Postgres> for &'r [u8] { fn decode(value: PgValueRef<'r>) -> Result { match value.format() { PgValueFormat::Binary => value.as_bytes(), PgValueFormat::Text => { Err("unsupported decode to `&[u8]` of BYTEA in a simple query; use a prepared query or decode to `Vec`".into()) } } } } fn text_hex_decode_input(value: PgValueRef<'_>) -> Result<&[u8], BoxDynError> { // BYTEA is formatted as \x followed by hex characters value .as_bytes()? .strip_prefix(b"\\x") .ok_or("text does not start with \\x") .map_err(Into::into) } impl Decode<'_, Postgres> for Box<[u8]> { fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => Box::from(value.as_bytes()?), PgValueFormat::Text => Box::from(hex::decode(text_hex_decode_input(value)?)?), }) } } impl Decode<'_, Postgres> for Vec { fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => value.as_bytes()?.to_owned(), PgValueFormat::Text => hex::decode(text_hex_decode_input(value)?)?, }) } } impl Decode<'_, Postgres> for [u8; N] { fn decode(value: PgValueRef<'_>) -> Result { let mut bytes = [0u8; N]; match value.format() { PgValueFormat::Binary => { bytes = value.as_bytes()?.try_into()?; } PgValueFormat::Text => hex::decode_to_slice(text_hex_decode_input(value)?, &mut bytes)?, }; Ok(bytes) } } sqlx-postgres-0.7.3/src/types/chrono/date.rs000064400000000000000000000027350072674642500172330ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; use chrono::{Duration, NaiveDate}; use std::mem; impl Type for NaiveDate { fn type_info() -> PgTypeInfo { PgTypeInfo::DATE } } impl PgHasArrayType for NaiveDate { fn array_type_info() -> PgTypeInfo { PgTypeInfo::DATE_ARRAY } } impl Encode<'_, Postgres> for NaiveDate { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { // DATE is encoded as the days since epoch let days = (*self - postgres_epoch_date()).num_days() as i32; Encode::::encode(&days, buf) } fn size_hint(&self) -> usize { mem::size_of::() } } impl<'r> Decode<'r, Postgres> for NaiveDate { fn decode(value: PgValueRef<'r>) -> Result { Ok(match value.format() { PgValueFormat::Binary => { // DATE is encoded as the days since epoch let days: i32 = Decode::::decode(value)?; postgres_epoch_date() + Duration::days(days.into()) } PgValueFormat::Text => NaiveDate::parse_from_str(value.as_str()?, "%Y-%m-%d")?, }) } } #[inline] fn postgres_epoch_date() -> NaiveDate { NaiveDate::from_ymd_opt(2000, 1, 1).expect("expected 2000-01-01 to be a valid NaiveDate") } sqlx-postgres-0.7.3/src/types/chrono/datetime.rs000064400000000000000000000072050072674642500201070ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; use chrono::{ DateTime, Duration, FixedOffset, Local, NaiveDate, NaiveDateTime, Offset, TimeZone, Utc, }; use std::mem; impl Type for NaiveDateTime { fn type_info() -> PgTypeInfo { PgTypeInfo::TIMESTAMP } } impl Type for DateTime { fn type_info() -> PgTypeInfo { PgTypeInfo::TIMESTAMPTZ } } impl PgHasArrayType for NaiveDateTime { fn array_type_info() -> PgTypeInfo { PgTypeInfo::TIMESTAMP_ARRAY } } impl PgHasArrayType for DateTime { fn array_type_info() -> PgTypeInfo { PgTypeInfo::TIMESTAMPTZ_ARRAY } } impl Encode<'_, Postgres> for NaiveDateTime { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { // FIXME: We should *really* be returning an error, Encode needs to be fallible // TIMESTAMP is encoded as the microseconds since the epoch let us = (*self - postgres_epoch_datetime()) .num_microseconds() .unwrap_or_else(|| panic!("NaiveDateTime out of range for Postgres: {self:?}")); Encode::::encode(&us, buf) } fn size_hint(&self) -> usize { mem::size_of::() } } impl<'r> Decode<'r, Postgres> for NaiveDateTime { fn decode(value: PgValueRef<'r>) -> Result { Ok(match value.format() { PgValueFormat::Binary => { // TIMESTAMP is encoded as the microseconds since the epoch let us = Decode::::decode(value)?; postgres_epoch_datetime() + Duration::microseconds(us) } PgValueFormat::Text => { let s = value.as_str()?; NaiveDateTime::parse_from_str( s, if s.contains('+') { // Contains a time-zone specifier // This is given for timestamptz for some reason // Postgres already guarantees this to always be UTC "%Y-%m-%d %H:%M:%S%.f%#z" } else { "%Y-%m-%d %H:%M:%S%.f" }, )? } }) } } impl Encode<'_, Postgres> for DateTime { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { Encode::::encode(self.naive_utc(), buf) } fn size_hint(&self) -> usize { mem::size_of::() } } impl<'r> Decode<'r, Postgres> for DateTime { fn decode(value: PgValueRef<'r>) -> Result { let naive = >::decode(value)?; Ok(Local.from_utc_datetime(&naive)) } } impl<'r> Decode<'r, Postgres> for DateTime { fn decode(value: PgValueRef<'r>) -> Result { let naive = >::decode(value)?; Ok(Utc.from_utc_datetime(&naive)) } } impl<'r> Decode<'r, Postgres> for DateTime { fn decode(value: PgValueRef<'r>) -> Result { let naive = >::decode(value)?; Ok(Utc.fix().from_utc_datetime(&naive)) } } #[inline] fn postgres_epoch_datetime() -> NaiveDateTime { NaiveDate::from_ymd_opt(2000, 1, 1) .expect("expected 2000-01-01 to be a valid NaiveDate") .and_hms_opt(0, 0, 0) .expect("expected 2000-01-01T00:00:00 to be a valid NaiveDateTime") } sqlx-postgres-0.7.3/src/types/chrono/mod.rs000064400000000000000000000000420072674642500170620ustar 00000000000000mod date; mod datetime; mod time; sqlx-postgres-0.7.3/src/types/chrono/time.rs000064400000000000000000000033150072674642500172470ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; use chrono::{Duration, NaiveTime}; use std::mem; impl Type for NaiveTime { fn type_info() -> PgTypeInfo { PgTypeInfo::TIME } } impl PgHasArrayType for NaiveTime { fn array_type_info() -> PgTypeInfo { PgTypeInfo::TIME_ARRAY } } impl Encode<'_, Postgres> for NaiveTime { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { // TIME is encoded as the microseconds since midnight // NOTE: panic! is on overflow and 1 day does not have enough micros to overflow let us = (*self - NaiveTime::default()).num_microseconds().unwrap(); Encode::::encode(&us, buf) } fn size_hint(&self) -> usize { mem::size_of::() } } impl<'r> Decode<'r, Postgres> for NaiveTime { fn decode(value: PgValueRef<'r>) -> Result { Ok(match value.format() { PgValueFormat::Binary => { // TIME is encoded as the microseconds since midnight let us: i64 = Decode::::decode(value)?; NaiveTime::default() + Duration::microseconds(us) } PgValueFormat::Text => NaiveTime::parse_from_str(value.as_str()?, "%H:%M:%S%.f")?, }) } } #[test] fn check_naive_time_default_is_midnight() { // Just a canary in case this changes. assert_eq!( NaiveTime::from_hms_opt(0, 0, 0), Some(NaiveTime::default()), "implementation assumes `NaiveTime::default()` equals midnight" ); } sqlx-postgres-0.7.3/src/types/citext.rs000064400000000000000000000060750072674642500163270ustar 00000000000000use crate::types::array_compatible; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres}; use sqlx_core::decode::Decode; use sqlx_core::encode::{Encode, IsNull}; use sqlx_core::error::BoxDynError; use sqlx_core::types::Type; use std::fmt; use std::fmt::{Debug, Display, Formatter}; use std::ops::Deref; use std::str::FromStr; /// Case-insensitive text (`citext`) support for Postgres. /// /// Note that SQLx considers the `citext` type to be compatible with `String` /// and its various derivatives, so direct usage of this type is generally unnecessary. /// /// However, it may be needed, for example, when binding a `citext[]` array, /// as Postgres will generally not accept a `text[]` array (mapped from `Vec`) in its place. /// /// See [the Postgres manual, Appendix F, Section 10][PG.F.10] for details on using `citext`. /// /// [PG.F.10]: https://www.postgresql.org/docs/current/citext.html /// /// ### Note: Extension Required /// The `citext` extension is not enabled by default in Postgres. You will need to do so explicitly: /// /// ```ignore /// CREATE EXTENSION IF NOT EXISTS "citext"; /// ``` /// /// ### Note: `PartialEq` is Case-Sensitive /// This type derives `PartialEq` which forwards to the implementation on `String`, which /// is case-sensitive. This impl exists mainly for testing. /// /// To properly emulate the case-insensitivity of `citext` would require use of locale-aware /// functions in `libc`, and even then would require querying the locale of the database server /// and setting it locally, which is unsafe. #[derive(Clone, Debug, Default, PartialEq)] pub struct PgCiText(pub String); impl Type for PgCiText { fn type_info() -> PgTypeInfo { // Since `citext` is enabled by an extension, it does not have a stable OID. PgTypeInfo::with_name("citext") } fn compatible(ty: &PgTypeInfo) -> bool { <&str as Type>::compatible(ty) } } impl Deref for PgCiText { type Target = str; fn deref(&self) -> &Self::Target { self.0.as_str() } } impl From for PgCiText { fn from(value: String) -> Self { Self(value) } } impl From for String { fn from(value: PgCiText) -> Self { value.0 } } impl FromStr for PgCiText { type Err = core::convert::Infallible; fn from_str(s: &str) -> Result { Ok(PgCiText(s.parse()?)) } } impl Display for PgCiText { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.write_str(&self.0) } } impl PgHasArrayType for PgCiText { fn array_type_info() -> PgTypeInfo { PgTypeInfo::with_name("_citext") } fn array_compatible(ty: &PgTypeInfo) -> bool { array_compatible::<&str>(ty) } } impl Encode<'_, Postgres> for PgCiText { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { <&str as Encode>::encode(&**self, buf) } } impl Decode<'_, Postgres> for PgCiText { fn decode(value: PgValueRef<'_>) -> Result { Ok(PgCiText(value.as_str()?.to_owned())) } } sqlx-postgres-0.7.3/src/types/float.rs000064400000000000000000000031270072674642500161270ustar 00000000000000use byteorder::{BigEndian, ByteOrder}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; impl Type for f32 { fn type_info() -> PgTypeInfo { PgTypeInfo::FLOAT4 } } impl PgHasArrayType for f32 { fn array_type_info() -> PgTypeInfo { PgTypeInfo::FLOAT4_ARRAY } } impl Encode<'_, Postgres> for f32 { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.extend(&self.to_be_bytes()); IsNull::No } } impl Decode<'_, Postgres> for f32 { fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => BigEndian::read_f32(value.as_bytes()?), PgValueFormat::Text => value.as_str()?.parse()?, }) } } impl Type for f64 { fn type_info() -> PgTypeInfo { PgTypeInfo::FLOAT8 } } impl PgHasArrayType for f64 { fn array_type_info() -> PgTypeInfo { PgTypeInfo::FLOAT8_ARRAY } } impl Encode<'_, Postgres> for f64 { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.extend(&self.to_be_bytes()); IsNull::No } } impl Decode<'_, Postgres> for f64 { fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => BigEndian::read_f64(value.as_bytes()?), PgValueFormat::Text => value.as_str()?.parse()?, }) } } sqlx-postgres-0.7.3/src/types/int.rs000064400000000000000000000056050072674642500156170ustar 00000000000000use byteorder::{BigEndian, ByteOrder}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; impl Type for i8 { fn type_info() -> PgTypeInfo { PgTypeInfo::CHAR } } impl PgHasArrayType for i8 { fn array_type_info() -> PgTypeInfo { PgTypeInfo::CHAR_ARRAY } } impl Encode<'_, Postgres> for i8 { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.extend(&self.to_be_bytes()); IsNull::No } } impl Decode<'_, Postgres> for i8 { fn decode(value: PgValueRef<'_>) -> Result { // note: in the TEXT encoding, a value of "0" here is encoded as an empty string Ok(value.as_bytes()?.get(0).copied().unwrap_or_default() as i8) } } impl Type for i16 { fn type_info() -> PgTypeInfo { PgTypeInfo::INT2 } } impl PgHasArrayType for i16 { fn array_type_info() -> PgTypeInfo { PgTypeInfo::INT2_ARRAY } } impl Encode<'_, Postgres> for i16 { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.extend(&self.to_be_bytes()); IsNull::No } } impl Decode<'_, Postgres> for i16 { fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => BigEndian::read_i16(value.as_bytes()?), PgValueFormat::Text => value.as_str()?.parse()?, }) } } impl Type for i32 { fn type_info() -> PgTypeInfo { PgTypeInfo::INT4 } } impl PgHasArrayType for i32 { fn array_type_info() -> PgTypeInfo { PgTypeInfo::INT4_ARRAY } } impl Encode<'_, Postgres> for i32 { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.extend(&self.to_be_bytes()); IsNull::No } } impl Decode<'_, Postgres> for i32 { fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => BigEndian::read_i32(value.as_bytes()?), PgValueFormat::Text => value.as_str()?.parse()?, }) } } impl Type for i64 { fn type_info() -> PgTypeInfo { PgTypeInfo::INT8 } } impl PgHasArrayType for i64 { fn array_type_info() -> PgTypeInfo { PgTypeInfo::INT8_ARRAY } } impl Encode<'_, Postgres> for i64 { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.extend(&self.to_be_bytes()); IsNull::No } } impl Decode<'_, Postgres> for i64 { fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { PgValueFormat::Binary => BigEndian::read_i64(value.as_bytes()?), PgValueFormat::Text => value.as_str()?.parse()?, }) } } sqlx-postgres-0.7.3/src/types/interval.rs000064400000000000000000000253570072674642500166570ustar 00000000000000use std::mem; use byteorder::{NetworkEndian, ReadBytesExt}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; // `PgInterval` is available for direct access to the INTERVAL type #[derive(Debug, Eq, PartialEq, Clone, Hash)] pub struct PgInterval { pub months: i32, pub days: i32, pub microseconds: i64, } impl Type for PgInterval { fn type_info() -> PgTypeInfo { PgTypeInfo::INTERVAL } } impl PgHasArrayType for PgInterval { fn array_type_info() -> PgTypeInfo { PgTypeInfo::INTERVAL_ARRAY } } impl<'de> Decode<'de, Postgres> for PgInterval { fn decode(value: PgValueRef<'de>) -> Result { match value.format() { PgValueFormat::Binary => { let mut buf = value.as_bytes()?; let microseconds = buf.read_i64::()?; let days = buf.read_i32::()?; let months = buf.read_i32::()?; Ok(PgInterval { months, days, microseconds, }) } // TODO: Implement parsing of text mode PgValueFormat::Text => { Err("not implemented: decode `INTERVAL` in text mode (unprepared queries)".into()) } } } } impl Encode<'_, Postgres> for PgInterval { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.extend(&self.microseconds.to_be_bytes()); buf.extend(&self.days.to_be_bytes()); buf.extend(&self.months.to_be_bytes()); IsNull::No } fn size_hint(&self) -> usize { 2 * mem::size_of::() } } // We then implement Encode + Type for std Duration, chrono Duration, and time Duration // This is to enable ease-of-use for encoding when its simple impl Type for std::time::Duration { fn type_info() -> PgTypeInfo { PgTypeInfo::INTERVAL } } impl PgHasArrayType for std::time::Duration { fn array_type_info() -> PgTypeInfo { PgTypeInfo::INTERVAL_ARRAY } } impl Encode<'_, Postgres> for std::time::Duration { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { PgInterval::try_from(*self) .expect("failed to encode `std::time::Duration`") .encode_by_ref(buf) } fn size_hint(&self) -> usize { 2 * mem::size_of::() } } impl TryFrom for PgInterval { type Error = BoxDynError; /// Convert a `std::time::Duration` to a `PgInterval` /// /// This returns an error if there is a loss of precision using nanoseconds or if there is a /// microsecond overflow. fn try_from(value: std::time::Duration) -> Result { if value.as_nanos() % 1000 != 0 { return Err("PostgreSQL `INTERVAL` does not support nanoseconds precision".into()); } Ok(Self { months: 0, days: 0, microseconds: value.as_micros().try_into()?, }) } } #[cfg(feature = "chrono")] impl Type for chrono::Duration { fn type_info() -> PgTypeInfo { PgTypeInfo::INTERVAL } } #[cfg(feature = "chrono")] impl PgHasArrayType for chrono::Duration { fn array_type_info() -> PgTypeInfo { PgTypeInfo::INTERVAL_ARRAY } } #[cfg(feature = "chrono")] impl Encode<'_, Postgres> for chrono::Duration { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { let pg_interval = PgInterval::try_from(*self).expect("Failed to encode chrono::Duration"); pg_interval.encode_by_ref(buf) } fn size_hint(&self) -> usize { 2 * mem::size_of::() } } #[cfg(feature = "chrono")] impl TryFrom for PgInterval { type Error = BoxDynError; /// Convert a `chrono::Duration` to a `PgInterval`. /// /// This returns an error if there is a loss of precision using nanoseconds or if there is a /// nanosecond overflow. fn try_from(value: chrono::Duration) -> Result { value .num_nanoseconds() .map_or::, _>( Err("Overflow has occurred for PostgreSQL `INTERVAL`".into()), |nanoseconds| { if nanoseconds % 1000 != 0 { return Err( "PostgreSQL `INTERVAL` does not support nanoseconds precision".into(), ); } Ok(()) }, )?; value.num_microseconds().map_or( Err("Overflow has occurred for PostgreSQL `INTERVAL`".into()), |microseconds| { Ok(Self { months: 0, days: 0, microseconds: microseconds, }) }, ) } } #[cfg(feature = "time")] impl Type for time::Duration { fn type_info() -> PgTypeInfo { PgTypeInfo::INTERVAL } } #[cfg(feature = "time")] impl PgHasArrayType for time::Duration { fn array_type_info() -> PgTypeInfo { PgTypeInfo::INTERVAL_ARRAY } } #[cfg(feature = "time")] impl Encode<'_, Postgres> for time::Duration { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { let pg_interval = PgInterval::try_from(*self).expect("Failed to encode time::Duration"); pg_interval.encode_by_ref(buf) } fn size_hint(&self) -> usize { 2 * mem::size_of::() } } #[cfg(feature = "time")] impl TryFrom for PgInterval { type Error = BoxDynError; /// Convert a `time::Duration` to a `PgInterval`. /// /// This returns an error if there is a loss of precision using nanoseconds or if there is a /// microsecond overflow. fn try_from(value: time::Duration) -> Result { if value.whole_nanoseconds() % 1000 != 0 { return Err("PostgreSQL `INTERVAL` does not support nanoseconds precision".into()); } Ok(Self { months: 0, days: 0, microseconds: value.whole_microseconds().try_into()?, }) } } #[test] fn test_encode_interval() { let mut buf = PgArgumentBuffer::default(); let interval = PgInterval { months: 0, days: 0, microseconds: 0, }; assert!(matches!( Encode::::encode(&interval, &mut buf), IsNull::No )); assert_eq!(&**buf, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); buf.clear(); let interval = PgInterval { months: 0, days: 0, microseconds: 1_000, }; assert!(matches!( Encode::::encode(&interval, &mut buf), IsNull::No )); assert_eq!(&**buf, [0, 0, 0, 0, 0, 0, 3, 232, 0, 0, 0, 0, 0, 0, 0, 0]); buf.clear(); let interval = PgInterval { months: 0, days: 0, microseconds: 1_000_000, }; assert!(matches!( Encode::::encode(&interval, &mut buf), IsNull::No )); assert_eq!(&**buf, [0, 0, 0, 0, 0, 15, 66, 64, 0, 0, 0, 0, 0, 0, 0, 0]); buf.clear(); let interval = PgInterval { months: 0, days: 0, microseconds: 3_600_000_000, }; assert!(matches!( Encode::::encode(&interval, &mut buf), IsNull::No )); assert_eq!( &**buf, [0, 0, 0, 0, 214, 147, 164, 0, 0, 0, 0, 0, 0, 0, 0, 0] ); buf.clear(); let interval = PgInterval { months: 0, days: 1, microseconds: 0, }; assert!(matches!( Encode::::encode(&interval, &mut buf), IsNull::No )); assert_eq!(&**buf, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]); buf.clear(); let interval = PgInterval { months: 1, days: 0, microseconds: 0, }; assert!(matches!( Encode::::encode(&interval, &mut buf), IsNull::No )); assert_eq!(&**buf, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); buf.clear(); } #[test] fn test_pginterval_std() { // Case for positive duration let interval = PgInterval { days: 0, months: 0, microseconds: 27_000, }; assert_eq!( &PgInterval::try_from(std::time::Duration::from_micros(27_000)).unwrap(), &interval ); // Case when precision loss occurs assert!(PgInterval::try_from(std::time::Duration::from_nanos(27_000_001)).is_err()); // Case when microsecond overflow occurs assert!(PgInterval::try_from(std::time::Duration::from_secs(20_000_000_000_000)).is_err()); } #[test] #[cfg(feature = "chrono")] fn test_pginterval_chrono() { // Case for positive duration let interval = PgInterval { days: 0, months: 0, microseconds: 27_000, }; assert_eq!( &PgInterval::try_from(chrono::Duration::microseconds(27_000)).unwrap(), &interval ); // Case for negative duration let interval = PgInterval { days: 0, months: 0, microseconds: -27_000, }; assert_eq!( &PgInterval::try_from(chrono::Duration::microseconds(-27_000)).unwrap(), &interval ); // Case when precision loss occurs assert!(PgInterval::try_from(chrono::Duration::nanoseconds(27_000_001)).is_err()); assert!(PgInterval::try_from(chrono::Duration::nanoseconds(-27_000_001)).is_err()); // Case when nanosecond overflow occurs assert!(PgInterval::try_from(chrono::Duration::seconds(10_000_000_000)).is_err()); assert!(PgInterval::try_from(chrono::Duration::seconds(-10_000_000_000)).is_err()); } #[test] #[cfg(feature = "time")] fn test_pginterval_time() { // Case for positive duration let interval = PgInterval { days: 0, months: 0, microseconds: 27_000, }; assert_eq!( &PgInterval::try_from(time::Duration::microseconds(27_000)).unwrap(), &interval ); // Case for negative duration let interval = PgInterval { days: 0, months: 0, microseconds: -27_000, }; assert_eq!( &PgInterval::try_from(time::Duration::microseconds(-27_000)).unwrap(), &interval ); // Case when precision loss occurs assert!(PgInterval::try_from(time::Duration::nanoseconds(27_000_001)).is_err()); assert!(PgInterval::try_from(time::Duration::nanoseconds(-27_000_001)).is_err()); // Case when microsecond overflow occurs assert!(PgInterval::try_from(time::Duration::seconds(10_000_000_000_000)).is_err()); assert!(PgInterval::try_from(time::Duration::seconds(-10_000_000_000_000)).is_err()); } sqlx-postgres-0.7.3/src/types/ipaddr.rs000064400000000000000000000027060072674642500162670ustar 00000000000000use std::net::IpAddr; use ipnetwork::IpNetwork; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres}; impl Type for IpAddr where IpNetwork: Type, { fn type_info() -> PgTypeInfo { IpNetwork::type_info() } fn compatible(ty: &PgTypeInfo) -> bool { IpNetwork::compatible(ty) } } impl PgHasArrayType for IpAddr { fn array_type_info() -> PgTypeInfo { ::array_type_info() } fn array_compatible(ty: &PgTypeInfo) -> bool { ::array_compatible(ty) } } impl<'db> Encode<'db, Postgres> for IpAddr where IpNetwork: Encode<'db, Postgres>, { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { IpNetwork::from(*self).encode_by_ref(buf) } fn size_hint(&self) -> usize { IpNetwork::from(*self).size_hint() } } impl<'db> Decode<'db, Postgres> for IpAddr where IpNetwork: Decode<'db, Postgres>, { fn decode(value: PgValueRef<'db>) -> Result { let ipnetwork = IpNetwork::decode(value)?; if ipnetwork.is_ipv4() && ipnetwork.prefix() != 32 || ipnetwork.is_ipv6() && ipnetwork.prefix() != 128 { Err("lossy decode from inet/cidr")? } Ok(ipnetwork.ip()) } } sqlx-postgres-0.7.3/src/types/ipnetwork.rs000064400000000000000000000100010072674642500170310ustar 00000000000000use std::net::{Ipv4Addr, Ipv6Addr}; use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; // https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/include/utils/inet.h#L39 // Technically this is a magic number here but it doesn't make sense to drag in the whole of `libc` // just for one constant. const PGSQL_AF_INET: u8 = 2; // AF_INET const PGSQL_AF_INET6: u8 = PGSQL_AF_INET + 1; impl Type for IpNetwork { fn type_info() -> PgTypeInfo { PgTypeInfo::INET } fn compatible(ty: &PgTypeInfo) -> bool { *ty == PgTypeInfo::CIDR || *ty == PgTypeInfo::INET } } impl PgHasArrayType for IpNetwork { fn array_type_info() -> PgTypeInfo { PgTypeInfo::INET_ARRAY } fn array_compatible(ty: &PgTypeInfo) -> bool { *ty == PgTypeInfo::CIDR_ARRAY || *ty == PgTypeInfo::INET_ARRAY } } impl Encode<'_, Postgres> for IpNetwork { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { // https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L293 // https://github.com/postgres/postgres/blob/574925bfd0a8175f6e161936ea11d9695677ba09/src/backend/utils/adt/network.c#L271 match self { IpNetwork::V4(net) => { buf.push(PGSQL_AF_INET); // ip_family buf.push(net.prefix()); // ip_bits buf.push(0); // is_cidr buf.push(4); // nb (number of bytes) buf.extend_from_slice(&net.ip().octets()) // address } IpNetwork::V6(net) => { buf.push(PGSQL_AF_INET6); // ip_family buf.push(net.prefix()); // ip_bits buf.push(0); // is_cidr buf.push(16); // nb (number of bytes) buf.extend_from_slice(&net.ip().octets()); // address } } IsNull::No } fn size_hint(&self) -> usize { match self { IpNetwork::V4(_) => 8, IpNetwork::V6(_) => 20, } } } impl Decode<'_, Postgres> for IpNetwork { fn decode(value: PgValueRef<'_>) -> Result { let bytes = match value.format() { PgValueFormat::Binary => value.as_bytes()?, PgValueFormat::Text => { return Ok(value.as_str()?.parse()?); } }; if bytes.len() >= 8 { let family = bytes[0]; let prefix = bytes[1]; let _is_cidr = bytes[2] != 0; let len = bytes[3]; match family { PGSQL_AF_INET => { if bytes.len() == 8 && len == 4 { let inet = Ipv4Network::new( Ipv4Addr::new(bytes[4], bytes[5], bytes[6], bytes[7]), prefix, )?; return Ok(IpNetwork::V4(inet)); } } PGSQL_AF_INET6 => { if bytes.len() == 20 && len == 16 { let inet = Ipv6Network::new( Ipv6Addr::from([ bytes[4], bytes[5], bytes[6], bytes[7], bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15], bytes[16], bytes[17], bytes[18], bytes[19], ]), prefix, )?; return Ok(IpNetwork::V6(inet)); } } _ => { return Err(format!("unknown ip family {family}").into()); } } } Err("invalid data received when expecting an INET".into()) } } sqlx-postgres-0.7.3/src/types/json.rs000064400000000000000000000054410072674642500157740ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::array_compatible; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; use serde::{Deserialize, Serialize}; use serde_json::value::RawValue as JsonRawValue; use serde_json::Value as JsonValue; pub(crate) use sqlx_core::types::{Json, Type}; // // In general, most applications should prefer to store JSON data as jsonb, // unless there are quite specialized needs, such as legacy assumptions // about ordering of object keys. impl Type for Json { fn type_info() -> PgTypeInfo { PgTypeInfo::JSONB } fn compatible(ty: &PgTypeInfo) -> bool { *ty == PgTypeInfo::JSON || *ty == PgTypeInfo::JSONB } } impl PgHasArrayType for Json { fn array_type_info() -> PgTypeInfo { PgTypeInfo::JSONB_ARRAY } fn array_compatible(ty: &PgTypeInfo) -> bool { array_compatible::>(ty) } } impl PgHasArrayType for JsonValue { fn array_type_info() -> PgTypeInfo { PgTypeInfo::JSONB_ARRAY } fn array_compatible(ty: &PgTypeInfo) -> bool { array_compatible::(ty) } } impl PgHasArrayType for JsonRawValue { fn array_type_info() -> PgTypeInfo { PgTypeInfo::JSONB_ARRAY } fn array_compatible(ty: &PgTypeInfo) -> bool { array_compatible::(ty) } } impl<'q, T> Encode<'q, Postgres> for Json where T: Serialize, { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { // we have a tiny amount of dynamic behavior depending if we are resolved to be JSON // instead of JSONB buf.patch(|buf, ty: &PgTypeInfo| { if *ty == PgTypeInfo::JSON || *ty == PgTypeInfo::JSON_ARRAY { buf[0] = b' '; } }); // JSONB version (as of 2020-03-20) buf.push(1); // the JSON data written to the buffer is the same regardless of parameter type serde_json::to_writer(&mut **buf, &self.0) .expect("failed to serialize to JSON for encoding on transmission to the database"); IsNull::No } } impl<'r, T: 'r> Decode<'r, Postgres> for Json where T: Deserialize<'r>, { fn decode(value: PgValueRef<'r>) -> Result { let mut buf = value.as_bytes()?; if value.format() == PgValueFormat::Binary && value.type_info == PgTypeInfo::JSONB { assert_eq!( buf[0], 1, "unsupported JSONB format version {}; please open an issue", buf[0] ); buf = &buf[1..]; } serde_json::from_slice(buf).map(Json).map_err(Into::into) } } sqlx-postgres-0.7.3/src/types/lquery.rs000064400000000000000000000233450072674642500163470ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; use bitflags::bitflags; use std::fmt::{self, Display, Formatter}; use std::io::Write; use std::ops::Deref; use std::str::FromStr; use crate::types::ltree::{PgLTreeLabel, PgLTreeParseError}; /// Represents lquery specific errors #[derive(Debug, thiserror::Error)] #[non_exhaustive] pub enum PgLQueryParseError { #[error("lquery cannot be empty")] EmptyString, #[error("unexpected character in lquery")] UnexpectedCharacter, #[error("error parsing integer: {0}")] ParseIntError(#[from] std::num::ParseIntError), #[error("error parsing integer: {0}")] LTreeParrseError(#[from] PgLTreeParseError), /// LQuery version not supported #[error("lquery version not supported")] InvalidLqueryVersion, } /// Container for a Label Tree Query (`lquery`) in Postgres. /// /// See https://www.postgresql.org/docs/current/ltree.html /// /// ### Note: Requires Postgres 13+ /// /// This integration requires that the `lquery` type support the binary format in the Postgres /// wire protocol, which only became available in Postgres 13. /// ([Postgres 13.0 Release Notes, Additional Modules][https://www.postgresql.org/docs/13/release-13.html#id-1.11.6.11.5.14]) /// /// Ideally, SQLx's Postgres driver should support falling back to text format for types /// which don't have `typsend` and `typrecv` entries in `pg_type`, but that work still needs /// to be done. /// /// ### Note: Extension Required /// The `ltree` extension is not enabled by default in Postgres. You will need to do so explicitly: /// /// ```ignore /// CREATE EXTENSION IF NOT EXISTS "ltree"; /// ``` #[derive(Clone, Debug, Default, PartialEq)] pub struct PgLQuery { levels: Vec, } // TODO: maybe a QueryBuilder pattern would be nice here impl PgLQuery { /// creates default/empty lquery pub fn new() -> Self { Self::default() } pub fn from(levels: Vec) -> Self { Self { levels } } /// push a query level pub fn push(&mut self, level: PgLQueryLevel) { self.levels.push(level); } /// pop a query level pub fn pop(&mut self) -> Option { self.levels.pop() } /// creates lquery from an iterator with checking labels pub fn from_iter(levels: I) -> Result where S: Into, I: IntoIterator, { let mut lquery = Self::default(); for level in levels { lquery.push(PgLQueryLevel::from_str(&level.into())?); } Ok(lquery) } } impl IntoIterator for PgLQuery { type Item = PgLQueryLevel; type IntoIter = std::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { self.levels.into_iter() } } impl FromStr for PgLQuery { type Err = PgLQueryParseError; fn from_str(s: &str) -> Result { Ok(Self { levels: s .split('.') .map(|s| PgLQueryLevel::from_str(s)) .collect::>()?, }) } } impl Display for PgLQuery { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let mut iter = self.levels.iter(); if let Some(label) = iter.next() { write!(f, "{label}")?; for label in iter { write!(f, ".{label}")?; } } Ok(()) } } impl Deref for PgLQuery { type Target = [PgLQueryLevel]; fn deref(&self) -> &Self::Target { &self.levels } } impl Type for PgLQuery { fn type_info() -> PgTypeInfo { // Since `ltree` is enabled by an extension, it does not have a stable OID. PgTypeInfo::with_name("lquery") } } impl Encode<'_, Postgres> for PgLQuery { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.extend(1i8.to_le_bytes()); write!(buf, "{self}") .expect("Display implementation panicked while writing to PgArgumentBuffer"); IsNull::No } } impl<'r> Decode<'r, Postgres> for PgLQuery { fn decode(value: PgValueRef<'r>) -> Result { match value.format() { PgValueFormat::Binary => { let bytes = value.as_bytes()?; let version = i8::from_le_bytes([bytes[0]; 1]); if version != 1 { return Err(Box::new(PgLQueryParseError::InvalidLqueryVersion)); } Ok(Self::from_str(std::str::from_utf8(&bytes[1..])?)?) } PgValueFormat::Text => Ok(Self::from_str(value.as_str()?)?), } } } bitflags! { /// Modifiers that can be set to non-star labels #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct PgLQueryVariantFlag: u16 { /// * - Match any label with this prefix, for example foo* matches foobar const ANY_END = 0x01; /// @ - Match case-insensitively, for example a@ matches A const IN_CASE = 0x02; /// % - Match initial underscore-separated words const SUBLEXEME = 0x04; } } impl Display for PgLQueryVariantFlag { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { if self.contains(PgLQueryVariantFlag::ANY_END) { write!(f, "*")?; } if self.contains(PgLQueryVariantFlag::IN_CASE) { write!(f, "@")?; } if self.contains(PgLQueryVariantFlag::SUBLEXEME) { write!(f, "%")?; } Ok(()) } } #[derive(Clone, Debug, PartialEq)] pub struct PgLQueryVariant { label: PgLTreeLabel, modifiers: PgLQueryVariantFlag, } impl Display for PgLQueryVariant { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{}{}", self.label, self.modifiers) } } #[derive(Clone, Debug, PartialEq)] pub enum PgLQueryLevel { /// match any label (*) with optional at least / at most numbers Star(Option, Option), /// match any of specified labels with optional flags NonStar(Vec), /// match none of specified labels with optional flags NotNonStar(Vec), } impl FromStr for PgLQueryLevel { type Err = PgLQueryParseError; fn from_str(s: &str) -> Result { let bytes = s.as_bytes(); if bytes.is_empty() { Err(PgLQueryParseError::EmptyString) } else { match bytes[0] { b'*' => { if bytes.len() > 1 { let parts = s[2..s.len() - 1].split(',').collect::>(); match parts.len() { 1 => { let number = parts[0].parse()?; Ok(PgLQueryLevel::Star(Some(number), Some(number))) } 2 => Ok(PgLQueryLevel::Star( Some(parts[0].parse()?), Some(parts[1].parse()?), )), _ => Err(PgLQueryParseError::UnexpectedCharacter), } } else { Ok(PgLQueryLevel::Star(None, None)) } } b'!' => Ok(PgLQueryLevel::NotNonStar( s[1..] .split('|') .map(|s| PgLQueryVariant::from_str(s)) .collect::, PgLQueryParseError>>()?, )), _ => Ok(PgLQueryLevel::NonStar( s.split('|') .map(|s| PgLQueryVariant::from_str(s)) .collect::, PgLQueryParseError>>()?, )), } } } } impl FromStr for PgLQueryVariant { type Err = PgLQueryParseError; fn from_str(s: &str) -> Result { let mut label_length = s.len(); let mut rev_iter = s.bytes().rev(); let mut modifiers = PgLQueryVariantFlag::empty(); while let Some(b) = rev_iter.next() { match b { b'@' => modifiers.insert(PgLQueryVariantFlag::IN_CASE), b'*' => modifiers.insert(PgLQueryVariantFlag::ANY_END), b'%' => modifiers.insert(PgLQueryVariantFlag::SUBLEXEME), _ => break, } label_length -= 1; } Ok(PgLQueryVariant { label: PgLTreeLabel::new(&s[0..label_length])?, modifiers, }) } } fn write_variants(f: &mut Formatter<'_>, variants: &[PgLQueryVariant], not: bool) -> fmt::Result { let mut iter = variants.iter(); if let Some(variant) = iter.next() { write!(f, "{}{}", if not { "!" } else { "" }, variant)?; for variant in iter { write!(f, ".{variant}")?; } } Ok(()) } impl Display for PgLQueryLevel { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { PgLQueryLevel::Star(Some(at_least), Some(at_most)) => { if at_least == at_most { write!(f, "*{{{at_least}}}") } else { write!(f, "*{{{at_least},{at_most}}}") } } PgLQueryLevel::Star(Some(at_least), _) => write!(f, "*{{{at_least},}}"), PgLQueryLevel::Star(_, Some(at_most)) => write!(f, "*{{,{at_most}}}"), PgLQueryLevel::Star(_, _) => write!(f, "*"), PgLQueryLevel::NonStar(variants) => write_variants(f, &variants, false), PgLQueryLevel::NotNonStar(variants) => write_variants(f, &variants, true), } } } sqlx-postgres-0.7.3/src/types/ltree.rs000064400000000000000000000127550072674642500161440ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; use std::fmt::{self, Display, Formatter}; use std::io::Write; use std::ops::Deref; use std::str::FromStr; /// Represents ltree specific errors #[derive(Debug, thiserror::Error)] #[non_exhaustive] pub enum PgLTreeParseError { /// LTree labels can only contain [A-Za-z0-9_] #[error("ltree label contains invalid characters")] InvalidLtreeLabel, /// LTree version not supported #[error("ltree version not supported")] InvalidLtreeVersion, } #[derive(Clone, Debug, Default, PartialEq)] pub struct PgLTreeLabel(String); impl PgLTreeLabel { pub fn new(label: S) -> Result where String: From, { let label = String::from(label); if label.len() <= 256 && label .bytes() .all(|c| c.is_ascii_alphabetic() || c.is_ascii_digit() || c == b'_') { Ok(Self(label)) } else { Err(PgLTreeParseError::InvalidLtreeLabel) } } } impl Deref for PgLTreeLabel { type Target = str; fn deref(&self) -> &Self::Target { self.0.as_str() } } impl FromStr for PgLTreeLabel { type Err = PgLTreeParseError; fn from_str(s: &str) -> Result { PgLTreeLabel::new(s) } } impl Display for PgLTreeLabel { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{}", self.0) } } /// Container for a Label Tree (`ltree`) in Postgres. /// /// See https://www.postgresql.org/docs/current/ltree.html /// /// ### Note: Requires Postgres 13+ /// /// This integration requires that the `ltree` type support the binary format in the Postgres /// wire protocol, which only became available in Postgres 13. /// ([Postgres 13.0 Release Notes, Additional Modules][https://www.postgresql.org/docs/13/release-13.html#id-1.11.6.11.5.14]) /// /// Ideally, SQLx's Postgres driver should support falling back to text format for types /// which don't have `typsend` and `typrecv` entries in `pg_type`, but that work still needs /// to be done. /// /// ### Note: Extension Required /// The `ltree` extension is not enabled by default in Postgres. You will need to do so explicitly: /// /// ```ignore /// CREATE EXTENSION IF NOT EXISTS "ltree"; /// ``` #[derive(Clone, Debug, Default, PartialEq)] pub struct PgLTree { labels: Vec, } impl PgLTree { /// creates default/empty ltree pub fn new() -> Self { Self::default() } /// creates ltree from a [Vec] pub fn from(labels: Vec) -> Self { Self { labels } } /// creates ltree from an iterator with checking labels pub fn from_iter(labels: I) -> Result where String: From, I: IntoIterator, { let mut ltree = Self::default(); for label in labels { ltree.push(PgLTreeLabel::new(label)?); } Ok(ltree) } /// push a label to ltree pub fn push(&mut self, label: PgLTreeLabel) { self.labels.push(label); } /// pop a label from ltree pub fn pop(&mut self) -> Option { self.labels.pop() } } impl IntoIterator for PgLTree { type Item = PgLTreeLabel; type IntoIter = std::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { self.labels.into_iter() } } impl FromStr for PgLTree { type Err = PgLTreeParseError; fn from_str(s: &str) -> Result { Ok(Self { labels: s .split('.') .map(|s| PgLTreeLabel::new(s)) .collect::, Self::Err>>()?, }) } } impl Display for PgLTree { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let mut iter = self.labels.iter(); if let Some(label) = iter.next() { write!(f, "{label}")?; for label in iter { write!(f, ".{label}")?; } } Ok(()) } } impl Deref for PgLTree { type Target = [PgLTreeLabel]; fn deref(&self) -> &Self::Target { &self.labels } } impl Type for PgLTree { fn type_info() -> PgTypeInfo { // Since `ltree` is enabled by an extension, it does not have a stable OID. PgTypeInfo::with_name("ltree") } } impl PgHasArrayType for PgLTree { fn array_type_info() -> PgTypeInfo { PgTypeInfo::with_name("_ltree") } } impl Encode<'_, Postgres> for PgLTree { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.extend(1i8.to_le_bytes()); write!(buf, "{self}") .expect("Display implementation panicked while writing to PgArgumentBuffer"); IsNull::No } } impl<'r> Decode<'r, Postgres> for PgLTree { fn decode(value: PgValueRef<'r>) -> Result { match value.format() { PgValueFormat::Binary => { let bytes = value.as_bytes()?; let version = i8::from_le_bytes([bytes[0]; 1]); if version != 1 { return Err(Box::new(PgLTreeParseError::InvalidLtreeVersion)); } Ok(Self::from_str(std::str::from_utf8(&bytes[1..])?)?) } PgValueFormat::Text => Ok(Self::from_str(value.as_str()?)?), } } } sqlx-postgres-0.7.3/src/types/mac_address.rs000064400000000000000000000024720072674642500172710ustar 00000000000000use mac_address::MacAddress; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; impl Type for MacAddress { fn type_info() -> PgTypeInfo { PgTypeInfo::MACADDR } fn compatible(ty: &PgTypeInfo) -> bool { *ty == PgTypeInfo::MACADDR } } impl PgHasArrayType for MacAddress { fn array_type_info() -> PgTypeInfo { PgTypeInfo::MACADDR_ARRAY } } impl Encode<'_, Postgres> for MacAddress { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.extend_from_slice(&self.bytes()); // write just the address IsNull::No } fn size_hint(&self) -> usize { 6 } } impl Decode<'_, Postgres> for MacAddress { fn decode(value: PgValueRef<'_>) -> Result { let bytes = match value.format() { PgValueFormat::Binary => value.as_bytes()?, PgValueFormat::Text => { return Ok(value.as_str()?.parse()?); } }; if bytes.len() == 6 { return Ok(MacAddress::new(bytes.try_into().unwrap())); } Err("invalid data received when expecting an MACADDR".into()) } } sqlx-postgres-0.7.3/src/types/mod.rs000064400000000000000000000254620072674642500156070ustar 00000000000000//! Conversions between Rust and **Postgres** types. //! //! # Types //! //! | Rust type | Postgres type(s) | //! |---------------------------------------|------------------------------------------------------| //! | `bool` | BOOL | //! | `i8` | "CHAR" | //! | `i16` | SMALLINT, SMALLSERIAL, INT2 | //! | `i32` | INT, SERIAL, INT4 | //! | `i64` | BIGINT, BIGSERIAL, INT8 | //! | `f32` | REAL, FLOAT4 | //! | `f64` | DOUBLE PRECISION, FLOAT8 | //! | `&str`, [`String`] | VARCHAR, CHAR(N), TEXT, NAME, CITEXT | //! | `&[u8]`, `Vec` | BYTEA | //! | `()` | VOID | //! | [`PgInterval`] | INTERVAL | //! | [`PgRange`](PgRange) | INT8RANGE, INT4RANGE, TSRANGE, TSTZRANGE, DATERANGE, NUMRANGE | //! | [`PgMoney`] | MONEY | //! | [`PgLTree`] | LTREE | //! | [`PgLQuery`] | LQUERY | //! | [`PgCiText`] | CITEXT1 | //! //! 1 SQLx generally considers `CITEXT` to be compatible with `String`, `&str`, etc., //! but this wrapper type is available for edge cases, such as `CITEXT[]` which Postgres //! does not consider to be compatible with `TEXT[]`. //! //! ### [`bigdecimal`](https://crates.io/crates/bigdecimal) //! Requires the `bigdecimal` Cargo feature flag. //! //! | Rust type | Postgres type(s) | //! |---------------------------------------|------------------------------------------------------| //! | `bigdecimal::BigDecimal` | NUMERIC | //! //! ### [`rust_decimal`](https://crates.io/crates/rust_decimal) //! Requires the `rust_decimal` Cargo feature flag. //! //! | Rust type | Postgres type(s) | //! |---------------------------------------|------------------------------------------------------| //! | `rust_decimal::Decimal` | NUMERIC | //! //! ### [`chrono`](https://crates.io/crates/chrono) //! //! Requires the `chrono` Cargo feature flag. //! //! | Rust type | Postgres type(s) | //! |---------------------------------------|------------------------------------------------------| //! | `chrono::DateTime` | TIMESTAMPTZ | //! | `chrono::DateTime` | TIMESTAMPTZ | //! | `chrono::NaiveDateTime` | TIMESTAMP | //! | `chrono::NaiveDate` | DATE | //! | `chrono::NaiveTime` | TIME | //! | [`PgTimeTz`] | TIMETZ | //! //! ### [`time`](https://crates.io/crates/time) //! //! Requires the `time` Cargo feature flag. //! //! | Rust type | Postgres type(s) | //! |---------------------------------------|------------------------------------------------------| //! | `time::PrimitiveDateTime` | TIMESTAMP | //! | `time::OffsetDateTime` | TIMESTAMPTZ | //! | `time::Date` | DATE | //! | `time::Time` | TIME | //! | [`PgTimeTz`] | TIMETZ | //! //! ### [`uuid`](https://crates.io/crates/uuid) //! //! Requires the `uuid` Cargo feature flag. //! //! | Rust type | Postgres type(s) | //! |---------------------------------------|------------------------------------------------------| //! | `uuid::Uuid` | UUID | //! //! ### [`ipnetwork`](https://crates.io/crates/ipnetwork) //! //! Requires the `ipnetwork` Cargo feature flag. //! //! | Rust type | Postgres type(s) | //! |---------------------------------------|------------------------------------------------------| //! | `ipnetwork::IpNetwork` | INET, CIDR | //! | `std::net::IpAddr` | INET, CIDR | //! //! Note that because `IpAddr` does not support network prefixes, it is an error to attempt to decode //! an `IpAddr` from a `INET` or `CIDR` value with a network prefix smaller than the address' full width: //! `/32` for IPv4 addresses and `/128` for IPv6 addresses. //! //! `IpNetwork` does not have this limitation. //! //! ### [`mac_address`](https://crates.io/crates/mac_address) //! //! Requires the `mac_address` Cargo feature flag. //! //! | Rust type | Postgres type(s) | //! |---------------------------------------|------------------------------------------------------| //! | `mac_address::MacAddress` | MACADDR | //! //! ### [`bit-vec`](https://crates.io/crates/bit-vec) //! //! Requires the `bit-vec` Cargo feature flag. //! //! | Rust type | Postgres type(s) | //! |---------------------------------------|------------------------------------------------------| //! | `bit_vec::BitVec` | BIT, VARBIT | //! //! ### [`json`](https://crates.io/crates/serde_json) //! //! Requires the `json` Cargo feature flag. //! //! | Rust type | Postgres type(s) | //! |---------------------------------------|------------------------------------------------------| //! | [`Json`] | JSON, JSONB | //! | `serde_json::Value` | JSON, JSONB | //! | `&serde_json::value::RawValue` | JSON, JSONB | //! //! `Value` and `RawValue` from `serde_json` can be used for unstructured JSON data with //! Postgres. //! //! [`Json`](crate::types::Json) can be used for structured JSON data with Postgres. //! //! # [Composite types](https://www.postgresql.org/docs/current/rowtypes.html) //! //! User-defined composite types are supported through a derive for `Type`. //! //! ```text //! CREATE TYPE inventory_item AS ( //! name text, //! supplier_id integer, //! price numeric //! ); //! ``` //! //! ```rust,ignore //! #[derive(sqlx::Type)] //! #[sqlx(type_name = "inventory_item")] //! struct InventoryItem { //! name: String, //! supplier_id: i32, //! price: BigDecimal, //! } //! ``` //! //! Anonymous composite types are represented as tuples. Note that anonymous composites may only //! be returned and not sent to Postgres (this is a limitation of postgres). //! //! # Arrays //! //! One-dimensional arrays are supported as `Vec` or `&[T]` where `T` implements `Type`. //! //! # [Enumerations](https://www.postgresql.org/docs/current/datatype-enum.html) //! //! User-defined enumerations are supported through a derive for `Type`. //! //! ```text //! CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy'); //! ``` //! //! ```rust,ignore //! #[derive(sqlx::Type)] //! #[sqlx(type_name = "mood", rename_all = "lowercase")] //! enum Mood { Sad, Ok, Happy } //! ``` //! //! Rust enumerations may also be defined to be represented as an integer using `repr`. //! The following type expects a SQL type of `INTEGER` or `INT4` and will convert to/from the //! Rust enumeration. //! //! ```rust,ignore //! #[derive(sqlx::Type)] //! #[repr(i32)] //! enum Mood { Sad = 0, Ok = 1, Happy = 2 } //! ``` //! use crate::type_info::PgTypeKind; use crate::{PgTypeInfo, Postgres}; pub(crate) use sqlx_core::types::{Json, Type}; mod array; mod bool; mod bytes; mod citext; mod float; mod int; mod interval; mod lquery; mod ltree; // Not behind a Cargo feature because we require JSON in the driver implementation. mod json; mod money; mod oid; mod range; mod record; mod str; mod text; mod tuple; mod void; #[cfg(any(feature = "chrono", feature = "time"))] mod time_tz; #[cfg(feature = "bigdecimal")] mod bigdecimal; #[cfg(any(feature = "bigdecimal", feature = "rust_decimal"))] mod numeric; #[cfg(feature = "rust_decimal")] mod rust_decimal; #[cfg(feature = "chrono")] mod chrono; #[cfg(feature = "time")] mod time; #[cfg(feature = "uuid")] mod uuid; #[cfg(feature = "ipnetwork")] mod ipnetwork; #[cfg(feature = "ipnetwork")] mod ipaddr; #[cfg(feature = "mac_address")] mod mac_address; #[cfg(feature = "bit-vec")] mod bit_vec; pub use array::PgHasArrayType; pub use citext::PgCiText; pub use interval::PgInterval; pub use lquery::PgLQuery; pub use lquery::PgLQueryLevel; pub use lquery::PgLQueryVariant; pub use lquery::PgLQueryVariantFlag; pub use ltree::PgLTree; pub use ltree::PgLTreeLabel; pub use ltree::PgLTreeParseError; pub use money::PgMoney; pub use oid::Oid; pub use range::PgRange; #[cfg(any(feature = "chrono", feature = "time"))] pub use time_tz::PgTimeTz; // used in derive(Type) for `struct` // but the interface is not considered part of the public API #[doc(hidden)] pub use record::{PgRecordDecoder, PgRecordEncoder}; // Type::compatible impl appropriate for arrays fn array_compatible + ?Sized>(ty: &PgTypeInfo) -> bool { // we require the declared type to be an _array_ with an // element type that is acceptable if let PgTypeKind::Array(element) = &ty.kind() { return E::compatible(&element); } false } sqlx-postgres-0.7.3/src/types/money.rs000064400000000000000000000246410072674642500161550ustar 00000000000000use crate::{ decode::Decode, encode::{Encode, IsNull}, error::BoxDynError, types::Type, {PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}, }; use byteorder::{BigEndian, ByteOrder}; use std::{ io, ops::{Add, AddAssign, Sub, SubAssign}, }; /// The PostgreSQL [`MONEY`] type stores a currency amount with a fixed fractional /// precision. The fractional precision is determined by the database's /// `lc_monetary` setting. /// /// Data is read and written as 64-bit signed integers, and conversion into a /// decimal should be done using the right precision. /// /// Reading `MONEY` value in text format is not supported and will cause an error. /// /// ### `locale_frac_digits` /// This parameter corresponds to the number of digits after the decimal separator. /// /// This value must match what Postgres is expecting for the locale set in the database /// or else the decimal value you see on the client side will not match the `money` value /// on the server side. /// /// **For _most_ locales, this value is `2`.** /// /// If you're not sure what locale your database is set to or how many decimal digits it specifies, /// you can execute `SHOW lc_monetary;` to get the locale name, and then look it up in this list /// (you can ignore the `.utf8` prefix): /// https://lh.2xlibre.net/values/frac_digits/ /// /// If that link is dead and you're on a POSIX-compliant system (Unix, FreeBSD) you can also execute: /// /// ```sh /// $ LC_MONETARY= locale -k frac_digits /// ``` /// /// And the value you want is `N` in `frac_digits=N`. If you have shell access to the database /// server you should execute it there as available locales may differ between machines. /// /// Note that if `frac_digits` for the locale is outside the range `[0, 10]`, Postgres assumes /// it's a sentinel value and defaults to 2: /// https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/cash.c#L114-L123 /// /// [`MONEY`]: https://www.postgresql.org/docs/current/datatype-money.html #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct PgMoney( /// The raw integer value sent over the wire; for locales with `frac_digits=2` (i.e. most /// of them), this will be the value in whole cents. /// /// E.g. for `select '$123.45'::money` with a locale of `en_US` (`frac_digits=2`), /// this will be `12345`. /// /// If the currency of your locale does not have fractional units, e.g. Yen, then this will /// just be the units of the currency. /// /// See the type-level docs for an explanation of `locale_frac_units`. pub i64, ); impl PgMoney { /// Convert the money value into a [`BigDecimal`] using `locale_frac_digits`. /// /// See the type-level docs for an explanation of `locale_frac_digits`. /// /// [`BigDecimal`]: crate::types::BigDecimal #[cfg(feature = "bigdecimal")] pub fn to_bigdecimal(self, locale_frac_digits: i64) -> bigdecimal::BigDecimal { let digits = num_bigint::BigInt::from(self.0); bigdecimal::BigDecimal::new(digits, locale_frac_digits) } /// Convert the money value into a [`Decimal`] using `locale_frac_digits`. /// /// See the type-level docs for an explanation of `locale_frac_digits`. /// /// [`Decimal`]: crate::types::Decimal #[cfg(feature = "rust_decimal")] pub fn to_decimal(self, locale_frac_digits: u32) -> rust_decimal::Decimal { rust_decimal::Decimal::new(self.0, locale_frac_digits) } /// Convert a [`Decimal`] value into money using `locale_frac_digits`. /// /// See the type-level docs for an explanation of `locale_frac_digits`. /// /// Note that `Decimal` has 96 bits of precision, but `PgMoney` only has 63 plus the sign bit. /// If the value is larger than 63 bits it will be truncated. /// /// [`Decimal`]: crate::types::Decimal #[cfg(feature = "rust_decimal")] pub fn from_decimal(mut decimal: rust_decimal::Decimal, locale_frac_digits: u32) -> Self { // this is all we need to convert to our expected locale's `frac_digits` decimal.rescale(locale_frac_digits); /// a mask to bitwise-AND with an `i64` to zero the sign bit const SIGN_MASK: i64 = i64::MAX; let is_negative = decimal.is_sign_negative(); let serialized = decimal.serialize(); // interpret bytes `4..12` as an i64, ignoring the sign bit // this is where truncation occurs let value = i64::from_le_bytes( *<&[u8; 8]>::try_from(&serialized[4..12]) .expect("BUG: slice of serialized should be 8 bytes"), ) & SIGN_MASK; // zero out the sign bit // negate if necessary Self(if is_negative { -value } else { value }) } /// Convert a [`BigDecimal`](crate::types::BigDecimal) value into money using the correct precision /// defined in the PostgreSQL settings. The default precision is two. #[cfg(feature = "bigdecimal")] pub fn from_bigdecimal( decimal: bigdecimal::BigDecimal, locale_frac_digits: u32, ) -> Result { use bigdecimal::ToPrimitive; let multiplier = bigdecimal::BigDecimal::new( num_bigint::BigInt::from(10i128.pow(locale_frac_digits)), 0, ); let cents = decimal * multiplier; let money = cents.to_i64().ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidData, "Provided BigDecimal could not convert to i64: overflow.", ) })?; Ok(Self(money)) } } impl Type for PgMoney { fn type_info() -> PgTypeInfo { PgTypeInfo::MONEY } } impl PgHasArrayType for PgMoney { fn array_type_info() -> PgTypeInfo { PgTypeInfo::MONEY_ARRAY } } impl From for PgMoney where T: Into, { fn from(num: T) -> Self { Self(num.into()) } } impl Encode<'_, Postgres> for PgMoney { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.extend(&self.0.to_be_bytes()); IsNull::No } } impl Decode<'_, Postgres> for PgMoney { fn decode(value: PgValueRef<'_>) -> Result { match value.format() { PgValueFormat::Binary => { let cents = BigEndian::read_i64(value.as_bytes()?); Ok(PgMoney(cents)) } PgValueFormat::Text => { let error = io::Error::new( io::ErrorKind::InvalidData, "Reading a `MONEY` value in text format is not supported.", ); Err(Box::new(error)) } } } } impl Add for PgMoney { type Output = PgMoney; /// Adds two monetary values. /// /// # Panics /// Panics if overflowing the `i64::MAX`. fn add(self, rhs: PgMoney) -> Self::Output { self.0 .checked_add(rhs.0) .map(PgMoney) .expect("overflow adding money amounts") } } impl AddAssign for PgMoney { /// An assigning add for two monetary values. /// /// # Panics /// Panics if overflowing the `i64::MAX`. fn add_assign(&mut self, rhs: PgMoney) { self.0 = self .0 .checked_add(rhs.0) .expect("overflow adding money amounts") } } impl Sub for PgMoney { type Output = PgMoney; /// Subtracts two monetary values. /// /// # Panics /// Panics if underflowing the `i64::MIN`. fn sub(self, rhs: PgMoney) -> Self::Output { self.0 .checked_sub(rhs.0) .map(PgMoney) .expect("overflow subtracting money amounts") } } impl SubAssign for PgMoney { /// An assigning subtract for two monetary values. /// /// # Panics /// Panics if underflowing the `i64::MIN`. fn sub_assign(&mut self, rhs: PgMoney) { self.0 = self .0 .checked_sub(rhs.0) .expect("overflow subtracting money amounts") } } #[cfg(test)] mod tests { use super::PgMoney; #[test] fn adding_works() { assert_eq!(PgMoney(3), PgMoney(1) + PgMoney(2)) } #[test] fn add_assign_works() { let mut money = PgMoney(1); money += PgMoney(2); assert_eq!(PgMoney(3), money); } #[test] fn subtracting_works() { assert_eq!(PgMoney(4), PgMoney(5) - PgMoney(1)) } #[test] fn sub_assign_works() { let mut money = PgMoney(1); money -= PgMoney(2); assert_eq!(PgMoney(-1), money); } #[test] #[should_panic] fn add_overflow_panics() { let _ = PgMoney(i64::MAX) + PgMoney(1); } #[test] #[should_panic] fn add_assign_overflow_panics() { let mut money = PgMoney(i64::MAX); money += PgMoney(1); } #[test] #[should_panic] fn sub_overflow_panics() { let _ = PgMoney(i64::MIN) - PgMoney(1); } #[test] #[should_panic] fn sub_assign_overflow_panics() { let mut money = PgMoney(i64::MIN); money -= PgMoney(1); } #[test] #[cfg(feature = "bigdecimal")] fn conversion_to_bigdecimal_works() { let money = PgMoney(12345); assert_eq!( bigdecimal::BigDecimal::new(num_bigint::BigInt::from(12345), 2), money.to_bigdecimal(2) ); } #[test] #[cfg(feature = "rust_decimal")] fn conversion_to_decimal_works() { assert_eq!( rust_decimal::Decimal::new(12345, 2), PgMoney(12345).to_decimal(2) ); } #[test] #[cfg(feature = "rust_decimal")] fn conversion_from_decimal_works() { assert_eq!( PgMoney(12345), PgMoney::from_decimal(rust_decimal::Decimal::new(12345, 2), 2) ); assert_eq!( PgMoney(12345), PgMoney::from_decimal(rust_decimal::Decimal::new(123450, 3), 2) ); assert_eq!( PgMoney(-12345), PgMoney::from_decimal(rust_decimal::Decimal::new(-123450, 3), 2) ); assert_eq!( PgMoney(-12300), PgMoney::from_decimal(rust_decimal::Decimal::new(-123, 0), 2) ); } #[test] #[cfg(feature = "bigdecimal")] fn conversion_from_bigdecimal_works() { let dec = bigdecimal::BigDecimal::new(num_bigint::BigInt::from(12345), 2); assert_eq!(PgMoney(12345), PgMoney::from_bigdecimal(dec, 2).unwrap()); } } sqlx-postgres-0.7.3/src/types/numeric.rs000064400000000000000000000113720072674642500164650ustar 00000000000000use sqlx_core::bytes::Buf; use crate::error::BoxDynError; use crate::PgArgumentBuffer; /// Represents a `NUMERIC` value in the **Postgres** wire protocol. #[derive(Debug, PartialEq, Eq)] pub(crate) enum PgNumeric { /// Equivalent to the `'NaN'` value in Postgres. The result of, e.g. `1 / 0`. NotANumber, /// A populated `NUMERIC` value. /// /// A description of these fields can be found here (although the type being described is the /// version for in-memory calculations, the field names are the same): /// https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L224-L269 Number { /// The sign of the value: positive (also set for 0 and -0), or negative. sign: PgNumericSign, /// The digits of the number in base-10000 with the most significant digit first /// (big-endian). /// /// The length of this vector must not overflow `i16` for the binary protocol. /// /// *Note*: the `Encode` implementation will panic if any digit is `>= 10000`. digits: Vec, /// The scaling factor of the number, such that the value will be interpreted as /// /// ```text /// digits[0] * 10,000 ^ weight /// + digits[1] * 10,000 ^ (weight - 1) /// ... /// + digits[N] * 10,000 ^ (weight - N) where N = digits.len() - 1 /// ``` /// May be negative. weight: i16, /// How many _decimal_ (base-10) digits following the decimal point to consider in /// arithmetic regardless of how many actually follow the decimal point as determined by /// `weight`--the comment in the Postgres code linked above recommends using this only for /// ignoring unnecessary trailing zeroes (as trimming nonzero digits means reducing the /// precision of the value). /// /// Must be `>= 0`. scale: i16, }, } // https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L167-L170 const SIGN_POS: u16 = 0x0000; const SIGN_NEG: u16 = 0x4000; const SIGN_NAN: u16 = 0xC000; // overflows i16 (C equivalent truncates from integer literal) /// Possible sign values for [PgNumeric]. #[derive(Copy, Clone, Debug, PartialEq, Eq)] #[repr(u16)] pub(crate) enum PgNumericSign { Positive = SIGN_POS, Negative = SIGN_NEG, } impl PgNumericSign { fn try_from_u16(val: u16) -> Result { match val { SIGN_POS => Ok(PgNumericSign::Positive), SIGN_NEG => Ok(PgNumericSign::Negative), SIGN_NAN => unreachable!("sign value for NaN passed to PgNumericSign"), _ => Err(format!("invalid value for PgNumericSign: {val:#04X}").into()), } } } impl PgNumeric { pub(crate) fn decode(mut buf: &[u8]) -> Result { // https://github.com/postgres/postgres/blob/bcd1c3630095e48bc3b1eb0fc8e8c8a7c851eba1/src/backend/utils/adt/numeric.c#L874 let num_digits = buf.get_u16(); let weight = buf.get_i16(); let sign = buf.get_u16(); let scale = buf.get_i16(); if sign == SIGN_NAN { Ok(PgNumeric::NotANumber) } else { let digits: Vec<_> = (0..num_digits).map(|_| buf.get_i16()).collect::<_>(); Ok(PgNumeric::Number { sign: PgNumericSign::try_from_u16(sign)?, scale, weight, digits, }) } } /// ### Panics /// /// * If `digits.len()` overflows `i16` /// * If any element in `digits` is greater than or equal to 10000 pub(crate) fn encode(&self, buf: &mut PgArgumentBuffer) { match *self { PgNumeric::Number { ref digits, sign, scale, weight, } => { let digits_len: i16 = digits .len() .try_into() .expect("PgNumeric.digits.len() should not overflow i16"); buf.extend(&digits_len.to_be_bytes()); buf.extend(&weight.to_be_bytes()); buf.extend(&(sign as i16).to_be_bytes()); buf.extend(&scale.to_be_bytes()); for digit in digits { debug_assert!(*digit < 10000, "PgNumeric digits must be in base-10000"); buf.extend(&digit.to_be_bytes()); } } PgNumeric::NotANumber => { buf.extend(&0_i16.to_be_bytes()); buf.extend(&0_i16.to_be_bytes()); buf.extend(&SIGN_NAN.to_be_bytes()); buf.extend(&0_i16.to_be_bytes()); } } } } sqlx-postgres-0.7.3/src/types/oid.rs000064400000000000000000000035070072674642500155770ustar 00000000000000use byteorder::{BigEndian, ByteOrder}; use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; /// The PostgreSQL [`OID`] type stores an object identifier, /// used internally by PostgreSQL as primary keys for various system tables. /// /// [`OID`]: https://www.postgresql.org/docs/current/datatype-oid.html #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, Default)] pub struct Oid( /// The raw unsigned integer value sent over the wire pub u32, ); impl Oid { pub(crate) fn incr_one(&mut self) { self.0 = self.0.wrapping_add(1); } } impl Type for Oid { fn type_info() -> PgTypeInfo { PgTypeInfo::OID } } impl PgHasArrayType for Oid { fn array_type_info() -> PgTypeInfo { PgTypeInfo::OID_ARRAY } } impl Encode<'_, Postgres> for Oid { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.extend(&self.0.to_be_bytes()); IsNull::No } } impl Decode<'_, Postgres> for Oid { fn decode(value: PgValueRef<'_>) -> Result { Ok(Self(match value.format() { PgValueFormat::Binary => BigEndian::read_u32(value.as_bytes()?), PgValueFormat::Text => value.as_str()?.parse()?, })) } } impl Serialize for Oid { fn serialize(&self, serializer: S) -> Result where S: Serializer, { self.0.serialize(serializer) } } impl<'de> Deserialize<'de> for Oid { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { u32::deserialize(deserializer).map(Self) } } sqlx-postgres-0.7.3/src/types/range.rs000064400000000000000000000346530072674642500161260ustar 00000000000000use std::fmt::{self, Debug, Display, Formatter}; use std::ops::{Bound, Range, RangeBounds, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive}; use bitflags::bitflags; use sqlx_core::bytes::Buf; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::type_info::PgTypeKind; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; // https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/include/utils/rangetypes.h#L35-L44 bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] struct RangeFlags: u8 { const EMPTY = 0x01; const LB_INC = 0x02; const UB_INC = 0x04; const LB_INF = 0x08; const UB_INF = 0x10; const LB_NULL = 0x20; // not used const UB_NULL = 0x40; // not used const CONTAIN_EMPTY = 0x80; // internal } } #[derive(Debug, PartialEq, Eq, Clone)] pub struct PgRange { pub start: Bound, pub end: Bound, } impl From<[Bound; 2]> for PgRange { fn from(v: [Bound; 2]) -> Self { let [start, end] = v; Self { start, end } } } impl From<(Bound, Bound)> for PgRange { fn from(v: (Bound, Bound)) -> Self { Self { start: v.0, end: v.1, } } } impl From> for PgRange { fn from(v: Range) -> Self { Self { start: Bound::Included(v.start), end: Bound::Excluded(v.end), } } } impl From> for PgRange { fn from(v: RangeFrom) -> Self { Self { start: Bound::Included(v.start), end: Bound::Unbounded, } } } impl From> for PgRange { fn from(v: RangeInclusive) -> Self { let (start, end) = v.into_inner(); Self { start: Bound::Included(start), end: Bound::Included(end), } } } impl From> for PgRange { fn from(v: RangeTo) -> Self { Self { start: Bound::Unbounded, end: Bound::Excluded(v.end), } } } impl From> for PgRange { fn from(v: RangeToInclusive) -> Self { Self { start: Bound::Unbounded, end: Bound::Included(v.end), } } } impl RangeBounds for PgRange { fn start_bound(&self) -> Bound<&T> { match self.start { Bound::Included(ref start) => Bound::Included(start), Bound::Excluded(ref start) => Bound::Excluded(start), Bound::Unbounded => Bound::Unbounded, } } fn end_bound(&self) -> Bound<&T> { match self.end { Bound::Included(ref end) => Bound::Included(end), Bound::Excluded(ref end) => Bound::Excluded(end), Bound::Unbounded => Bound::Unbounded, } } } impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::INT4_RANGE } fn compatible(ty: &PgTypeInfo) -> bool { range_compatible::(ty) } } impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::INT8_RANGE } fn compatible(ty: &PgTypeInfo) -> bool { range_compatible::(ty) } } #[cfg(feature = "bigdecimal")] impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::NUM_RANGE } fn compatible(ty: &PgTypeInfo) -> bool { range_compatible::(ty) } } #[cfg(feature = "rust_decimal")] impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::NUM_RANGE } fn compatible(ty: &PgTypeInfo) -> bool { range_compatible::(ty) } } #[cfg(feature = "chrono")] impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::DATE_RANGE } fn compatible(ty: &PgTypeInfo) -> bool { range_compatible::(ty) } } #[cfg(feature = "chrono")] impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::TS_RANGE } fn compatible(ty: &PgTypeInfo) -> bool { range_compatible::(ty) } } #[cfg(feature = "chrono")] impl Type for PgRange> { fn type_info() -> PgTypeInfo { PgTypeInfo::TSTZ_RANGE } fn compatible(ty: &PgTypeInfo) -> bool { range_compatible::>(ty) } } #[cfg(feature = "time")] impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::DATE_RANGE } fn compatible(ty: &PgTypeInfo) -> bool { range_compatible::(ty) } } #[cfg(feature = "time")] impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::TS_RANGE } fn compatible(ty: &PgTypeInfo) -> bool { range_compatible::(ty) } } #[cfg(feature = "time")] impl Type for PgRange { fn type_info() -> PgTypeInfo { PgTypeInfo::TSTZ_RANGE } fn compatible(ty: &PgTypeInfo) -> bool { range_compatible::(ty) } } impl PgHasArrayType for PgRange { fn array_type_info() -> PgTypeInfo { PgTypeInfo::INT4_RANGE_ARRAY } } impl PgHasArrayType for PgRange { fn array_type_info() -> PgTypeInfo { PgTypeInfo::INT8_RANGE_ARRAY } } #[cfg(feature = "bigdecimal")] impl PgHasArrayType for PgRange { fn array_type_info() -> PgTypeInfo { PgTypeInfo::NUM_RANGE_ARRAY } } #[cfg(feature = "rust_decimal")] impl PgHasArrayType for PgRange { fn array_type_info() -> PgTypeInfo { PgTypeInfo::NUM_RANGE_ARRAY } } #[cfg(feature = "chrono")] impl PgHasArrayType for PgRange { fn array_type_info() -> PgTypeInfo { PgTypeInfo::DATE_RANGE_ARRAY } } #[cfg(feature = "chrono")] impl PgHasArrayType for PgRange { fn array_type_info() -> PgTypeInfo { PgTypeInfo::TS_RANGE_ARRAY } } #[cfg(feature = "chrono")] impl PgHasArrayType for PgRange> { fn array_type_info() -> PgTypeInfo { PgTypeInfo::TSTZ_RANGE_ARRAY } } #[cfg(feature = "time")] impl PgHasArrayType for PgRange { fn array_type_info() -> PgTypeInfo { PgTypeInfo::DATE_RANGE_ARRAY } } #[cfg(feature = "time")] impl PgHasArrayType for PgRange { fn array_type_info() -> PgTypeInfo { PgTypeInfo::TS_RANGE_ARRAY } } #[cfg(feature = "time")] impl PgHasArrayType for PgRange { fn array_type_info() -> PgTypeInfo { PgTypeInfo::TSTZ_RANGE_ARRAY } } impl<'q, T> Encode<'q, Postgres> for PgRange where T: Encode<'q, Postgres>, { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { // https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/backend/utils/adt/rangetypes.c#L245 let mut flags = RangeFlags::empty(); flags |= match self.start { Bound::Included(_) => RangeFlags::LB_INC, Bound::Unbounded => RangeFlags::LB_INF, Bound::Excluded(_) => RangeFlags::empty(), }; flags |= match self.end { Bound::Included(_) => RangeFlags::UB_INC, Bound::Unbounded => RangeFlags::UB_INF, Bound::Excluded(_) => RangeFlags::empty(), }; buf.push(flags.bits()); if let Bound::Included(v) | Bound::Excluded(v) = &self.start { buf.encode(v); } if let Bound::Included(v) | Bound::Excluded(v) = &self.end { buf.encode(v); } // ranges are themselves never null IsNull::No } } impl<'r, T> Decode<'r, Postgres> for PgRange where T: Type + for<'a> Decode<'a, Postgres>, { fn decode(value: PgValueRef<'r>) -> Result { match value.format { PgValueFormat::Binary => { let element_ty = if let PgTypeKind::Range(element) = &value.type_info.0.kind() { element } else { return Err(format!("unexpected non-range type {}", value.type_info).into()); }; let mut buf = value.as_bytes()?; let mut start = Bound::Unbounded; let mut end = Bound::Unbounded; let flags = RangeFlags::from_bits_truncate(buf.get_u8()); if flags.contains(RangeFlags::EMPTY) { return Ok(PgRange { start, end }); } if !flags.contains(RangeFlags::LB_INF) { let value = T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone()))?; start = if flags.contains(RangeFlags::LB_INC) { Bound::Included(value) } else { Bound::Excluded(value) }; } if !flags.contains(RangeFlags::UB_INF) { let value = T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone()))?; end = if flags.contains(RangeFlags::UB_INC) { Bound::Included(value) } else { Bound::Excluded(value) }; } Ok(PgRange { start, end }) } PgValueFormat::Text => { // https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/backend/utils/adt/rangetypes.c#L2046 let mut start = None; let mut end = None; let s = value.as_str()?; // remember the bounds let sb = s.as_bytes(); let lower = sb[0] as char; let upper = sb[sb.len() - 1] as char; // trim the wrapping braces/brackets let s = &s[1..(s.len() - 1)]; let mut chars = s.chars(); let mut element = String::new(); let mut done = false; let mut quoted = false; let mut in_quotes = false; let mut in_escape = false; let mut prev_ch = '\0'; let mut count = 0; while !done { element.clear(); loop { match chars.next() { Some(ch) => { match ch { _ if in_escape => { element.push(ch); in_escape = false; } '"' if in_quotes => { in_quotes = false; } '"' => { in_quotes = true; quoted = true; if prev_ch == '"' { element.push('"') } } '\\' if !in_escape => { in_escape = true; } ',' if !in_quotes => break, _ => { element.push(ch); } } prev_ch = ch; } None => { done = true; break; } } } count += 1; if !(element.is_empty() && !quoted) { let value = Some(T::decode(PgValueRef { type_info: T::type_info(), format: PgValueFormat::Text, value: Some(element.as_bytes()), row: None, })?); if count == 1 { start = value; } else if count == 2 { end = value; } else { return Err("more than 2 elements found in a range".into()); } } } let start = parse_bound(lower, start)?; let end = parse_bound(upper, end)?; Ok(PgRange { start, end }) } } } } fn parse_bound(ch: char, value: Option) -> Result, BoxDynError> { Ok(if let Some(value) = value { match ch { '(' | ')' => Bound::Excluded(value), '[' | ']' => Bound::Included(value), _ => { return Err(format!( "expected `(`, ')', '[', or `]` but found `{ch}` for range literal" ) .into()); } } } else { Bound::Unbounded }) } impl Display for PgRange where T: Display, { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match &self.start { Bound::Unbounded => f.write_str("(,")?, Bound::Excluded(v) => write!(f, "({v},")?, Bound::Included(v) => write!(f, "[{v},")?, } match &self.end { Bound::Unbounded => f.write_str(")")?, Bound::Excluded(v) => write!(f, "{v})")?, Bound::Included(v) => write!(f, "{v}]")?, } Ok(()) } } fn range_compatible>(ty: &PgTypeInfo) -> bool { // we require the declared type to be a _range_ with an // element type that is acceptable if let PgTypeKind::Range(element) = &ty.kind() { return E::compatible(&element); } false } sqlx-postgres-0.7.3/src/types/record.rs000064400000000000000000000141440072674642500163010ustar 00000000000000use sqlx_core::bytes::Buf; use crate::decode::Decode; use crate::encode::Encode; use crate::error::{mismatched_types, BoxDynError}; use crate::type_info::TypeInfo; use crate::type_info::{PgType, PgTypeKind}; use crate::types::Oid; use crate::types::Type; use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; #[doc(hidden)] pub struct PgRecordEncoder<'a> { buf: &'a mut PgArgumentBuffer, off: usize, num: u32, } impl<'a> PgRecordEncoder<'a> { #[doc(hidden)] pub fn new(buf: &'a mut PgArgumentBuffer) -> Self { let off = buf.len(); // reserve space for a field count buf.extend(&(0_u32).to_be_bytes()); Self { buf, off, num: 0 } } #[doc(hidden)] pub fn finish(&mut self) { // fill in the record length self.buf[self.off..(self.off + 4)].copy_from_slice(&self.num.to_be_bytes()); } #[doc(hidden)] pub fn encode<'q, T>(&mut self, value: T) -> &mut Self where 'a: 'q, T: Encode<'q, Postgres> + Type, { let ty = value.produces().unwrap_or_else(T::type_info); if let PgType::DeclareWithName(name) = ty.0 { // push a hole for this type ID // to be filled in on query execution self.buf.patch_type_by_name(&name); } else { // write type id self.buf.extend(&ty.0.oid().0.to_be_bytes()); } self.buf.encode(value); self.num += 1; self } } #[doc(hidden)] pub struct PgRecordDecoder<'r> { buf: &'r [u8], typ: PgTypeInfo, fmt: PgValueFormat, ind: usize, } impl<'r> PgRecordDecoder<'r> { #[doc(hidden)] pub fn new(value: PgValueRef<'r>) -> Result { let fmt = value.format(); let mut buf = value.as_bytes()?; let typ = value.type_info; match fmt { PgValueFormat::Binary => { let _len = buf.get_u32(); } PgValueFormat::Text => { // remove the enclosing `(` .. `)` buf = &buf[1..(buf.len() - 1)]; } } Ok(Self { buf, fmt, typ, ind: 0, }) } #[doc(hidden)] pub fn try_decode(&mut self) -> Result where T: for<'a> Decode<'a, Postgres> + Type, { if self.buf.is_empty() { return Err(format!("no field `{0}` found on record", self.ind).into()); } match self.fmt { PgValueFormat::Binary => { let element_type_oid = Oid(self.buf.get_u32()); let element_type_opt = match self.typ.0.kind() { PgTypeKind::Simple if self.typ.0 == PgType::Record => { PgTypeInfo::try_from_oid(element_type_oid) } PgTypeKind::Composite(fields) => { let ty = fields[self.ind].1.clone(); if ty.0.oid() != element_type_oid { return Err("unexpected mismatch of composite type information".into()); } Some(ty) } _ => { return Err( "unexpected non-composite type being decoded as a composite type" .into(), ); } }; if let Some(ty) = &element_type_opt { if !ty.is_null() && !T::compatible(ty) { return Err(mismatched_types::(ty)); } } let element_type = element_type_opt .ok_or_else(|| BoxDynError::from(format!("custom types in records are not fully supported yet: failed to retrieve type info for field {} with type oid {}", self.ind, element_type_oid.0)))?; self.ind += 1; T::decode(PgValueRef::get(&mut self.buf, self.fmt, element_type)) } PgValueFormat::Text => { let mut element = String::new(); let mut quoted = false; let mut in_quotes = false; let mut in_escape = false; let mut prev_ch = '\0'; while !self.buf.is_empty() { let ch = self.buf.get_u8() as char; match ch { _ if in_escape => { element.push(ch); in_escape = false; } '"' if in_quotes => { in_quotes = false; } '"' => { in_quotes = true; quoted = true; if prev_ch == '"' { element.push('"') } } '\\' if !in_escape => { in_escape = true; } ',' if !in_quotes => break, _ => { element.push(ch); } } prev_ch = ch; } let buf = if element.is_empty() && !quoted { // completely empty input means NULL None } else { Some(element.as_bytes()) }; // NOTE: we do not call [`accepts`] or give a chance to from a user as // TEXT sequences are not strongly typed T::decode(PgValueRef { // NOTE: We pass `0` as the type ID because we don't have a reasonable value // we could use. type_info: PgTypeInfo::with_oid(Oid(0)), format: self.fmt, value: buf, row: None, }) } } } } sqlx-postgres-0.7.3/src/types/rust_decimal.rs000064400000000000000000000306110072674642500174730ustar 00000000000000use rust_decimal::{prelude::Zero, Decimal}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::numeric::{PgNumeric, PgNumericSign}; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; use rust_decimal::MathematicalOps; impl Type for Decimal { fn type_info() -> PgTypeInfo { PgTypeInfo::NUMERIC } } impl PgHasArrayType for Decimal { fn array_type_info() -> PgTypeInfo { PgTypeInfo::NUMERIC_ARRAY } } impl TryFrom for Decimal { type Error = BoxDynError; fn try_from(numeric: PgNumeric) -> Result { let (digits, sign, mut weight, scale) = match numeric { PgNumeric::Number { digits, sign, weight, scale, } => (digits, sign, weight, scale), PgNumeric::NotANumber => { return Err("Decimal does not support NaN values".into()); } }; if digits.is_empty() { // Postgres returns an empty digit array for 0 return Ok(0u64.into()); } let mut value = Decimal::ZERO; // Sum over `digits`, multiply each by its weight and add it to `value`. for digit in digits { let mul = Decimal::from(10_000i16) .checked_powi(weight as i64) .ok_or("value not representable as rust_decimal::Decimal")?; let part = Decimal::from(digit) * mul; value = value .checked_add(part) .ok_or("value not representable as rust_decimal::Decimal")?; weight = weight.checked_sub(1).ok_or("weight underflowed")?; } match sign { PgNumericSign::Positive => value.set_sign_positive(true), PgNumericSign::Negative => value.set_sign_negative(true), } value.rescale(scale as u32); Ok(value) } } impl TryFrom<&'_ Decimal> for PgNumeric { type Error = BoxDynError; fn try_from(decimal: &Decimal) -> Result { // `Decimal` added `is_zero()` as an inherent method in a more recent version if Zero::is_zero(decimal) { return Ok(PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 0, digits: vec![], }); } let scale = decimal.scale() as u16; // A serialized version of the decimal number. The resulting byte array // will have the following representation: // // Bytes 1-4: flags // Bytes 5-8: lo portion of m // Bytes 9-12: mid portion of m // Bytes 13-16: high portion of m let mut mantissa = u128::from_le_bytes(decimal.serialize()); // chop off the flags mantissa >>= 32; // If our scale is not a multiple of 4, we need to go to the next // multiple. let groups_diff = scale % 4; if groups_diff > 0 { let remainder = 4 - groups_diff as u32; let power = 10u32.pow(remainder as u32) as u128; mantissa = mantissa * power; } // Array to store max mantissa of Decimal in Postgres decimal format. let mut digits = Vec::with_capacity(8); // Convert to base-10000. while mantissa != 0 { digits.push((mantissa % 10_000) as i16); mantissa /= 10_000; } // Change the endianness. digits.reverse(); // Weight is number of digits on the left side of the decimal. let digits_after_decimal = (scale + 3) as u16 / 4; let weight = digits.len() as i16 - digits_after_decimal as i16 - 1; // Remove non-significant zeroes. while let Some(&0) = digits.last() { digits.pop(); } Ok(PgNumeric::Number { sign: match decimal.is_sign_negative() { false => PgNumericSign::Positive, true => PgNumericSign::Negative, }, scale: scale as i16, weight, digits, }) } } /// ### Panics /// If this `Decimal` cannot be represented by `PgNumeric`. impl Encode<'_, Postgres> for Decimal { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { PgNumeric::try_from(self) .expect("Decimal magnitude too great for Postgres NUMERIC type") .encode(buf); IsNull::No } } impl Decode<'_, Postgres> for Decimal { fn decode(value: PgValueRef<'_>) -> Result { match value.format() { PgValueFormat::Binary => PgNumeric::decode(value.as_bytes()?)?.try_into(), PgValueFormat::Text => Ok(value.as_str()?.parse::()?), } } } #[cfg(test)] mod decimal_to_pgnumeric { use super::{Decimal, PgNumeric, PgNumericSign}; use std::convert::TryFrom; #[test] fn zero() { let zero: Decimal = "0".parse().unwrap(); assert_eq!( PgNumeric::try_from(&zero).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 0, digits: vec![] } ); } #[test] fn one() { let one: Decimal = "1".parse().unwrap(); assert_eq!( PgNumeric::try_from(&one).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 0, digits: vec![1] } ); } #[test] fn ten() { let ten: Decimal = "10".parse().unwrap(); assert_eq!( PgNumeric::try_from(&ten).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 0, digits: vec![10] } ); } #[test] fn one_hundred() { let one_hundred: Decimal = "100".parse().unwrap(); assert_eq!( PgNumeric::try_from(&one_hundred).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 0, digits: vec![100] } ); } #[test] fn ten_thousand() { // Decimal doesn't normalize here let ten_thousand: Decimal = "10000".parse().unwrap(); assert_eq!( PgNumeric::try_from(&ten_thousand).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 1, digits: vec![1] } ); } #[test] fn two_digits() { let two_digits: Decimal = "12345".parse().unwrap(); assert_eq!( PgNumeric::try_from(&two_digits).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 1, digits: vec![1, 2345] } ); } #[test] fn one_tenth() { let one_tenth: Decimal = "0.1".parse().unwrap(); assert_eq!( PgNumeric::try_from(&one_tenth).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 1, weight: -1, digits: vec![1000] } ); } #[test] fn decimal_1() { let decimal: Decimal = "1.2345".parse().unwrap(); assert_eq!( PgNumeric::try_from(&decimal).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 4, weight: 0, digits: vec![1, 2345] } ); } #[test] fn decimal_2() { let decimal: Decimal = "0.12345".parse().unwrap(); assert_eq!( PgNumeric::try_from(&decimal).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 5, weight: -1, digits: vec![1234, 5000] } ); } #[test] fn decimal_3() { let decimal: Decimal = "0.01234".parse().unwrap(); assert_eq!( PgNumeric::try_from(&decimal).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 5, weight: -1, digits: vec![0123, 4000] } ); } #[test] fn decimal_4() { let decimal: Decimal = "12345.67890".parse().unwrap(); let expected_numeric = PgNumeric::Number { sign: PgNumericSign::Positive, scale: 5, weight: 1, digits: vec![1, 2345, 6789], }; assert_eq!(PgNumeric::try_from(&decimal).unwrap(), expected_numeric); let actual_decimal = Decimal::try_from(expected_numeric).unwrap(); assert_eq!(actual_decimal, decimal); assert_eq!(actual_decimal.mantissa(), 1234567890); assert_eq!(actual_decimal.scale(), 5); } #[test] fn one_digit_decimal() { let one_digit_decimal: Decimal = "0.00001234".parse().unwrap(); let expected_numeric = PgNumeric::Number { sign: PgNumericSign::Positive, scale: 8, weight: -2, digits: vec![1234], }; assert_eq!( PgNumeric::try_from(&one_digit_decimal).unwrap(), expected_numeric ); let actual_decimal = Decimal::try_from(expected_numeric).unwrap(); assert_eq!(actual_decimal, one_digit_decimal); assert_eq!(actual_decimal.mantissa(), 1234); assert_eq!(actual_decimal.scale(), 8); } #[test] fn issue_423_four_digit() { // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 let four_digit: Decimal = "1234".parse().unwrap(); assert_eq!( PgNumeric::try_from(&four_digit).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 0, digits: vec![1234] } ); } #[test] fn issue_423_negative_four_digit() { // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 let negative_four_digit: Decimal = "-1234".parse().unwrap(); assert_eq!( PgNumeric::try_from(&negative_four_digit).unwrap(), PgNumeric::Number { sign: PgNumericSign::Negative, scale: 0, weight: 0, digits: vec![1234] } ); } #[test] fn issue_423_eight_digit() { // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 let eight_digit: Decimal = "12345678".parse().unwrap(); assert_eq!( PgNumeric::try_from(&eight_digit).unwrap(), PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, weight: 1, digits: vec![1234, 5678] } ); } #[test] fn issue_423_negative_eight_digit() { // This is a regression test for https://github.com/launchbadge/sqlx/issues/423 let negative_eight_digit: Decimal = "-12345678".parse().unwrap(); assert_eq!( PgNumeric::try_from(&negative_eight_digit).unwrap(), PgNumeric::Number { sign: PgNumericSign::Negative, scale: 0, weight: 1, digits: vec![1234, 5678] } ); } #[test] fn issue_2247_trailing_zeros() { // This is a regression test for https://github.com/launchbadge/sqlx/issues/2247 let one_hundred: Decimal = "100.00".parse().unwrap(); let expected_numeric = PgNumeric::Number { sign: PgNumericSign::Positive, scale: 2, weight: 0, digits: vec![100], }; assert_eq!(PgNumeric::try_from(&one_hundred).unwrap(), expected_numeric); let actual_decimal = Decimal::try_from(expected_numeric).unwrap(); assert_eq!(actual_decimal, one_hundred); assert_eq!(actual_decimal.mantissa(), 10000); assert_eq!(actual_decimal.scale(), 2); } #[test] fn issue_666_trailing_zeroes_at_max_precision() {} } sqlx-postgres-0.7.3/src/types/str.rs000064400000000000000000000072660072674642500156420ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::array_compatible; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef, Postgres}; use std::borrow::Cow; impl Type for str { fn type_info() -> PgTypeInfo { PgTypeInfo::TEXT } fn compatible(ty: &PgTypeInfo) -> bool { [ PgTypeInfo::TEXT, PgTypeInfo::NAME, PgTypeInfo::BPCHAR, PgTypeInfo::VARCHAR, PgTypeInfo::UNKNOWN, PgTypeInfo::with_name("citext"), ] .contains(ty) } } impl Type for Cow<'_, str> { fn type_info() -> PgTypeInfo { <&str as Type>::type_info() } fn compatible(ty: &PgTypeInfo) -> bool { <&str as Type>::compatible(ty) } } impl Type for Box { fn type_info() -> PgTypeInfo { <&str as Type>::type_info() } fn compatible(ty: &PgTypeInfo) -> bool { <&str as Type>::compatible(ty) } } impl Type for String { fn type_info() -> PgTypeInfo { <&str as Type>::type_info() } fn compatible(ty: &PgTypeInfo) -> bool { <&str as Type>::compatible(ty) } } impl PgHasArrayType for &'_ str { fn array_type_info() -> PgTypeInfo { PgTypeInfo::TEXT_ARRAY } fn array_compatible(ty: &PgTypeInfo) -> bool { array_compatible::<&str>(ty) } } impl PgHasArrayType for Cow<'_, str> { fn array_type_info() -> PgTypeInfo { <&str as PgHasArrayType>::array_type_info() } fn array_compatible(ty: &PgTypeInfo) -> bool { <&str as PgHasArrayType>::array_compatible(ty) } } impl PgHasArrayType for Box { fn array_type_info() -> PgTypeInfo { <&str as PgHasArrayType>::array_type_info() } fn array_compatible(ty: &PgTypeInfo) -> bool { <&str as PgHasArrayType>::array_compatible(ty) } } impl PgHasArrayType for String { fn array_type_info() -> PgTypeInfo { <&str as PgHasArrayType>::array_type_info() } fn array_compatible(ty: &PgTypeInfo) -> bool { <&str as PgHasArrayType>::array_compatible(ty) } } impl Encode<'_, Postgres> for &'_ str { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { buf.extend(self.as_bytes()); IsNull::No } } impl Encode<'_, Postgres> for Cow<'_, str> { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { match self { Cow::Borrowed(str) => <&str as Encode>::encode(*str, buf), Cow::Owned(str) => <&str as Encode>::encode(&**str, buf), } } } impl Encode<'_, Postgres> for Box { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { <&str as Encode>::encode(&**self, buf) } } impl Encode<'_, Postgres> for String { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { <&str as Encode>::encode(&**self, buf) } } impl<'r> Decode<'r, Postgres> for &'r str { fn decode(value: PgValueRef<'r>) -> Result { Ok(value.as_str()?) } } impl<'r> Decode<'r, Postgres> for Cow<'r, str> { fn decode(value: PgValueRef<'r>) -> Result { Ok(Cow::Borrowed(value.as_str()?)) } } impl<'r> Decode<'r, Postgres> for Box { fn decode(value: PgValueRef<'r>) -> Result { Ok(Box::from(value.as_str()?)) } } impl Decode<'_, Postgres> for String { fn decode(value: PgValueRef<'_>) -> Result { Ok(value.as_str()?.to_owned()) } } sqlx-postgres-0.7.3/src/types/text.rs000064400000000000000000000030040072674642500160000ustar 00000000000000use crate::{PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres}; use sqlx_core::decode::Decode; use sqlx_core::encode::{Encode, IsNull}; use sqlx_core::error::BoxDynError; use sqlx_core::types::{Text, Type}; use std::fmt::Display; use std::str::FromStr; use std::io::Write; impl Type for Text { fn type_info() -> PgTypeInfo { >::type_info() } fn compatible(ty: &PgTypeInfo) -> bool { >::compatible(ty) } } impl<'q, T> Encode<'q, Postgres> for Text where T: Display, { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { // Unfortunately, our API design doesn't give us a way to bubble up the error here. // // Fortunately, writing to `Vec` is infallible so the only possible source of // errors is from the implementation of `Display::fmt()` itself, // where the onus is on the user. // // The blanket impl of `ToString` also panics if there's an error, so this is not // unprecedented. // // However, the panic should be documented anyway. write!(**buf, "{}", self.0).expect("unexpected error from `Display::fmt()`"); IsNull::No } } impl<'r, T> Decode<'r, Postgres> for Text where T: FromStr, BoxDynError: From<::Err>, { fn decode(value: PgValueRef<'r>) -> Result { let s: &str = Decode::::decode(value)?; Ok(Self(s.parse()?)) } } sqlx-postgres-0.7.3/src/types/time/date.rs000064400000000000000000000026420072674642500166760ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::time::PG_EPOCH; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; use std::mem; use time::macros::format_description; use time::{Date, Duration}; impl Type for Date { fn type_info() -> PgTypeInfo { PgTypeInfo::DATE } } impl PgHasArrayType for Date { fn array_type_info() -> PgTypeInfo { PgTypeInfo::DATE_ARRAY } } impl Encode<'_, Postgres> for Date { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { // DATE is encoded as the days since epoch let days = (*self - PG_EPOCH).whole_days() as i32; Encode::::encode(&days, buf) } fn size_hint(&self) -> usize { mem::size_of::() } } impl<'r> Decode<'r, Postgres> for Date { fn decode(value: PgValueRef<'r>) -> Result { Ok(match value.format() { PgValueFormat::Binary => { // DATE is encoded as the days since epoch let days: i32 = Decode::::decode(value)?; PG_EPOCH + Duration::days(days.into()) } PgValueFormat::Text => Date::parse( value.as_str()?, &format_description!("[year]-[month]-[day]"), )?, }) } } sqlx-postgres-0.7.3/src/types/time/datetime.rs000064400000000000000000000063510072674642500175560ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::time::PG_EPOCH; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; use std::borrow::Cow; use std::mem; use time::macros::format_description; use time::macros::offset; use time::{Duration, OffsetDateTime, PrimitiveDateTime}; impl Type for PrimitiveDateTime { fn type_info() -> PgTypeInfo { PgTypeInfo::TIMESTAMP } } impl Type for OffsetDateTime { fn type_info() -> PgTypeInfo { PgTypeInfo::TIMESTAMPTZ } } impl PgHasArrayType for PrimitiveDateTime { fn array_type_info() -> PgTypeInfo { PgTypeInfo::TIMESTAMP_ARRAY } } impl PgHasArrayType for OffsetDateTime { fn array_type_info() -> PgTypeInfo { PgTypeInfo::TIMESTAMPTZ_ARRAY } } impl Encode<'_, Postgres> for PrimitiveDateTime { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { // TIMESTAMP is encoded as the microseconds since the epoch let us = (*self - PG_EPOCH.midnight()).whole_microseconds() as i64; Encode::::encode(&us, buf) } fn size_hint(&self) -> usize { mem::size_of::() } } impl<'r> Decode<'r, Postgres> for PrimitiveDateTime { fn decode(value: PgValueRef<'r>) -> Result { Ok(match value.format() { PgValueFormat::Binary => { // TIMESTAMP is encoded as the microseconds since the epoch let us = Decode::::decode(value)?; PG_EPOCH.midnight() + Duration::microseconds(us) } PgValueFormat::Text => { let s = value.as_str()?; // If there is no decimal point we need to add one. let s = if s.contains('.') { Cow::Borrowed(s) } else { Cow::Owned(format!("{s}.0")) }; // Contains a time-zone specifier // This is given for timestamptz for some reason // Postgres already guarantees this to always be UTC if s.contains('+') { PrimitiveDateTime::parse(&*s, &format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond][offset_hour]"))? } else { PrimitiveDateTime::parse( &*s, &format_description!( "[year]-[month]-[day] [hour]:[minute]:[second].[subsecond]" ), )? } } }) } } impl Encode<'_, Postgres> for OffsetDateTime { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { let utc = self.to_offset(offset!(UTC)); let primitive = PrimitiveDateTime::new(utc.date(), utc.time()); Encode::::encode(&primitive, buf) } fn size_hint(&self) -> usize { mem::size_of::() } } impl<'r> Decode<'r, Postgres> for OffsetDateTime { fn decode(value: PgValueRef<'r>) -> Result { Ok(>::decode(value)?.assume_utc()) } } sqlx-postgres-0.7.3/src/types/time/mod.rs000064400000000000000000000001640072674642500165350ustar 00000000000000mod date; mod datetime; mod time; #[rustfmt::skip] const PG_EPOCH: ::time::Date = ::time::macros::date!(2000-1-1); sqlx-postgres-0.7.3/src/types/time/time.rs000064400000000000000000000030110072674642500167060ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; use std::mem; use time::macros::format_description; use time::{Duration, Time}; impl Type for Time { fn type_info() -> PgTypeInfo { PgTypeInfo::TIME } } impl PgHasArrayType for Time { fn array_type_info() -> PgTypeInfo { PgTypeInfo::TIME_ARRAY } } impl Encode<'_, Postgres> for Time { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { // TIME is encoded as the microseconds since midnight let us = (*self - Time::MIDNIGHT).whole_microseconds() as i64; Encode::::encode(&us, buf) } fn size_hint(&self) -> usize { mem::size_of::() } } impl<'r> Decode<'r, Postgres> for Time { fn decode(value: PgValueRef<'r>) -> Result { Ok(match value.format() { PgValueFormat::Binary => { // TIME is encoded as the microseconds since midnight let us = Decode::::decode(value)?; Time::MIDNIGHT + Duration::microseconds(us) } PgValueFormat::Text => Time::parse( value.as_str()?, // Postgres will not include the subsecond part if it's zero. &format_description!("[hour]:[minute]:[second][optional [.[subsecond]]]"), )?, }) } } sqlx-postgres-0.7.3/src/types/time_tz.rs000064400000000000000000000135660072674642500165050ustar 00000000000000use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; use crate::types::Type; use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; use byteorder::{BigEndian, ReadBytesExt}; use std::io::Cursor; use std::mem; #[cfg(feature = "time")] type DefaultTime = ::time::Time; #[cfg(all(not(feature = "time"), feature = "chrono"))] type DefaultTime = ::chrono::NaiveTime; #[cfg(feature = "time")] type DefaultOffset = ::time::UtcOffset; #[cfg(all(not(feature = "time"), feature = "chrono"))] type DefaultOffset = ::chrono::FixedOffset; /// Represents a moment of time, in a specified timezone. /// /// # Warning /// /// `PgTimeTz` provides `TIMETZ` and is supported only for reading from legacy databases. /// [PostgreSQL recommends] to use `TIMESTAMPTZ` instead. /// /// [PostgreSQL recommends]: https://wiki.postgresql.org/wiki/Don't_Do_This#Don.27t_use_timetz #[derive(Debug, PartialEq, Clone, Copy)] pub struct PgTimeTz