statrs-0.18.0/.cargo_vcs_info.json0000644000000001360000000000100124420ustar { "git": { "sha1": "2f402503593972578e83ff07959ee9fa4a31a488" }, "path_in_vcs": "" }statrs-0.18.0/CHANGELOG.md000064400000000000000000000337671046102023000130630ustar 00000000000000# Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [0.18.0] - 2024-12-02 ### ✨ Added - Added more inverse cumulative distribution functions. - Introduced feature flags: `rand` and `nalgebra`. - Added the `std_dev` method to the `Distribution` trait explicitly. - Supported sampling integers from discrete distributions. - Added support for the Gumbel distribution. ### ⚠️ Breaking Changes - Migrated multivariate distributions to generic dimensions. - Replaced `StatsError` with module-level error types in `distribution` and its children. - Changed `checked_logit`, `checked_multinomial`, and similar methods to return `Option` to handle invalid inputs. - Changed `Chi` distribution to use `u64` for degrees of freedom. ### 🛠️ Changed - Upgraded `nalgebra` to version 0.33. - Upgrades MSRV to 1.65+ - Improved documentation and added examples (e.g., for Hypergeometric distribution). - Added MSRV (Minimum Supported Rust Version) metadata to `Cargo.toml` and documentation. - Introduced coverage reporting with `llvm-cov`. - Updated CI to check all feature combinations and ensure MSRV compliance. - Added an MSRV badge to `crates.io`. ### ✅ Fixed - Corrected formatting issues in documentation. - Fixed several `rustdoc` warnings. - Expanded test coverage for Dirichlet and Multinomial distributions. - Improved ergonomics at cli for tests and ensured compatibility with updated NIST data. ### ❌ Removed - Replaced `StatsError` with module-level error types. - Deprecated the `error` module and preformatted NIST data. - Removed `rustfmt.toml` as part of CI clean-up. ### 🎉 New Contributors - @SabrinaJewson and @alimf17 made their first contributions! ## [0.17.1] - 2024-06-08 ### Details #### Changed - Release statrs version 0.17.1 by @YeungOnion #### Fixed - Code in benches still needs criterion by @YeungOnion [unreleased]: https://github.com/YeungOnion/statrs/compare/v0.17.1..HEAD [0.17.1]: https://github.com/YeungOnion/statrs/compare/v0.17.0..v0.17.1 ## [0.17.0](https://github.com/statrs-dev/statrs/compare/v0.16.0...v0.17.0) - 2024-05-30 ### Added - specializes `inverse_cdf()` for Uniform (#166) - Add way to get standard normal distribution easily. (#228) - reject constructing Uniform of infinite support (#218) - extend `StatsError` for finiteness (#218) - default implementation of survival function with generics (#179) - update `MultivariateNormal` API - construct from nalgebra with `MultivariateNormal::new_from_nalgebra` (#177) - support `std::vec` vector input in addition to `nalgebra` vectors (#199) ### Fixed - Update nalgebra to 0.32 (#187) - for Gamma with shape<1 there is no mode, returns `None` instead of some negative number (#212) - fix precision of ::inverse_cdf with some newton raphson steps (#227) - adds test case from #200 - fix integer bisection for default implementation of `::inverse_cdf` (#220) - also add tests from (#185) ### Other - Remove "nightly" feature and drop testing requirement for `nightly` (#234) - Allow some imprecision in specific test case (#215) - Update CI (#215) - Check formatting in CI via rustfmt - Expand CI test job - Add clippy job to CI - update README with formatting and adding to "Contributing" (#213) - Add test asserting that `StatsError` is Sync & Send (#226) - Rename private struct NonNAN to NonNan (#222) - Remove `lazy-static` dependency and make FCACHE a proper const (#211) - crate examples shall be in docstrings instead of README (#213) - alias `inverse_cdf` as "quantile function" in docs (#213) - docstrings with math shall be `text` instead of `ignore` (#213) ## [0.16.0] - Adds an `sf` method to the `ContinuousCDF` and `DiscreteCDF` traits - Calculates the survival function (CDF complement) for the distribution. - Survival function implemented for all distributions implementing `ContinuousCDF` and `DiscreteCDF` - See [PR description](https://github.com/statrs-dev/statrs/pull/172) for in-depth changes - update `nalgebra` to `0.29` ## [v0.15.0](https://www.github.com/statrs-dev/statrs/compare/v0.15.0...v0.16.0) - upgrade `nalgebra` to `0.27.1` to avoid RUSTSEC-2021-0070 ## [v0.14.0](https://www.github.com/statrs-dev/statrs/compare/v0.14.0...v0.15.0) - upgrade `rand` dependency to `0.8` - fix inaccurate sampling of `Gamma` - Implemented Empirical distribution - Implemented Laplace distribution - Removed Checked\* traits - Almost clippy-clean - Almost fully enabled rustfmt - Begin applying consistent numeric relative-accuracy targets with the approx crate - Introduce macro to generate testing boilerplate, yet not all tests use this yet - Moved to dynamic vectors in the MultivariateNormal distribution - Reduced a number of distribution-specific traits into the Distribution and DiscreteDistribution traits ## [v0.13.0](https://www.github.com/statrs-dev/statrs/compare/v0.12.0...v0.13.0) - Implemented `MultivariateNormal` distribution (depends on `nalgebra 0.19`) - Implemented `Dirac` distribution - Implemented `Negative Binomial` distribution ## [v0.12.0](https://www.github.com/statrs-dev/statrs/compare/v0.11.0...v0.12.0) - upgrade `rand` dependency to `0.7` ## [v0.11.0](https://www.github.com/statrs-dev/statrs/compare/v0.10.0...v0.11.0) - upgrade `rand` dependency to `0.6` - Implement `CheckedInverseCDF` and `InverseCDF` for `Normal` distribution ## [v0.10.0](https://www.github.com/statrs-dev/statrs/compare/v0.9.0...v0.10.0) - upgrade `rand` dependency to `0.5` - Removes the `Distribution` trait in favor of the `rand::distributions::Distribution` trait - Removed functions deprecated in `0.8.0` (`periodic`, `periodic_custom`, `sinusoidal`, `sinusoidal_custom`) ## [v0.9.0](https://www.github.com/statrs-dev/statrs/compare/v0.8.0...v0.9.0) - implemented infinite sequence generator for periodic sequence - implemented infinite sequence generator for sinusoidal sequence - implemented infinite sequence generator for square sequence - implemented infinite sequence generator for triangle sequence - implemented infinite sequence generator for sawtooth sequence - deprecate old non-infinite iterators in favor of new infinite iterators with `take` - Implemented `Pareto` distribution - Implemented `Entropy` trait for the `Categorical` distribution - Add a `checked_` interface to all distribution methods and functions that may panic ## [v0.8.0](https://www.github.com/statrs-dev/statrs/compare/v0.7.0...v0.8.0) - `cdf(x)`, `pdf(x)` and `pmf(x)` now return the correct value instead of panicking when `x` is outside the range of values that the distribution can attain. - Fixed a bug in the `Uniform` distribution implementation where samples were drawn from range `[min, max + 1)` instead of `[min, max]`. The samples are now drawn correctly from the range `[min, max]`. - Implement `generate::log_spaced` function - Implement `generate::Periodic` iterator - Implement `generate::Sinusoidal` iterator - Implement `generate::Square` iterator - Implement `generate::Triangle` iterator - Implement `generate::Sawtooth` iterator - Deprecate `generate::periodic` and `generate::periodic_custom` - Deprecate `generate::sinusoidal` and `generate::sinusoidal_custom` Note: A recent commit to the Rust nightly build causes compile errors when using empty slices with the `Statistics` trait, specifically the `Statistics::min` and `Statistics::max` methods. This only affects the case where the compiler must infer the type of the empty slice: ``` use statrs::statistics::Statistics; // compile error! Assumes the use of Ord::min rather than // Statistcs::min let x = []; assert!(x.min().is_nan()); ``` The fix is to pin the type of the empty slice: ``` // no compile error let x: [f64; 0] = []; assert!(x.min().is_nan()); ``` Since the regression affects a very slim edge-case and the fix is very simple, no breaking changes to the `Statistics` API was deemed necessary ## [v0.7.0](https://www.github.com/statrs-dev/statrs/compare/v0.6.0...v0.7.0) - Implemented `Categorical` distribution - Implemented `Erlang` distribution - Implemented `Multinomial` distribution - New `InverseCDF` trait for distributions that implement the inverse cdf function ## [v0.6.0](https://www.github.com/statrs-dev/statrs/compare/v0.5.1...v0.6.0) - `gamma::gamma_ur`, `gamma::gamma_ui`, `gamma::gamma_lr`, and `gamma::gamma_li` now follow strict gamma function domain, panicking if `a` or `x` are not in `(0, +inf)` - `beta::beta_reg` no longer allows `0.0` for `a` or `b` arguments - `InverseGamma` distribution no longer accepts `f64::INFINITY` as valid arguments for `shape` or `rate` as the value is nonsense - `Binomial::cdf` no longer accepts arguments outside the domain of `[0, n]` - `Bernoulli::cdf` no longer accepts arguments outside the domain of `[0, 1]` - `DiscreteUniform::cdf` no longer accepts arguments outside the domain of `[min, max]` - `Uniform::cdf` no longer accepts arguments outside the domain of `[min, max]` - `Triangular::cdf` no longer accepts arguments outside the domain of `[min, max]` - `FisherSnedecor` no longer accepts `f64::INFINITY` as a valid argument for `freedom_1` or `freedom_2` - `FisherSnedecor::cdf` no longer accepts arguments outside the domain of `[0, +inf)` - `Geometric::cdf` no longer accepts non-positive arguments - `Normal` now uses the Ziggurat method to generate random samples. This also affects all distributions depending on `Normal` for sampling including `Chi`, `LogNormal`, `Gamma`, and `StudentsT` - `Exponential` now uses the Ziggurat methd to generate random samples. - `Binomial` now implements `Univariate` rather than `Univariate`, meaning `Binomial::min` and `Binomial::max` now return `u64` - `Bernoulli` now implements `Univariate` rather than `Univariate`, meaning `Bernoulli::min` and `Bernoulli::min` now return `u64` - `Geometric` now implements `Univariate` rather than `Univariate`, meaning `Geometric::min` and `Geometric::min` now return `u64` - `Poisson` now implements `Univariate` rather than `Univariate`, meaning `Poisson::min` and `Poisson::min` now return `u64` - `Binomial` now implements `Mode` instead of `Mode` - `Bernoulli` now implements `Mode` instead of `Mode` - `Poisson` now implements `Mode` instead of `Mode` - `Geometric` now implements `Mode` instead of `Mode` - `Hypergeometric` now implements `Mode` instead of `Mode` - `Binomial` now implements `Discrete` rather than `Discrete` - `Bernoulli` now implements `Discrete` rather than `Discrete` - `Geometric` now implements `Discrete` rather than `Discrete` - `Hypergeometric` now implements `Discrete` rather than `Discrete` - `Poisson` now implements `Discrete` rather than `Discrete` ## [v0.5.1](https://www.github.com/statrs-dev/statrs/compare/v0.5.0...v0.5.1) - Fixed critical bug in `normal::sample_unchecked` where it was returning `NaN` ## [v0.5.0](https://www.github.com/statrs-dev/statrs/compare/v0.4.0...v0.5.0) - Implemented the `logistic::logistic` special function - Implemented the `logistic::logit` special function - Implemented the `factorial::multinomial` special function - Implemented the `harmonic::harmonic` special function - Implemented the `harmonic::gen_harmonic` special function - Implemented the `InverseGamma` distribution - Implemented the `Geometric` distribution - Implemented the `Hypergeometric` distribution - `gamma::gamma_ur` now panics when `x > 0` or `a == f64::NEG_INFINITY`. In addition, it also returns `f64::NAN` when `a == f64::INFINITY` and `0.0` when `x == f64::INFINITY` - `Gamma::pdf` and `Gamma::ln_pdf` now return `f64::NAN` if any of `shape`, `rate`, or `x` are `f64::INFINITY` - `Binomial::pdf` and `Binomial::ln_pdf` now panic if `x > n` or `x < 0` - `Bernoulli::pdf` and `Bernoulli::ln_pdf` now panic if `x > 1` or `x < 0` ## [v0.4.0] - Implemented the `exponential::integral` special function - Implemented the `Cauchy` (otherwise known as the `Lorenz`) distribution - Implemented the `Dirichlet` distribution - `Continuous` and `Discrete` traits no longer dependent on `Distribution` trait ## [v0.3.2] - Implemented the `FisherSnedecor` (F) distribution ## [v0.3.1] - Removed print statements from `ln_pdf` method in `Beta` distribution ## [v0.3.0] - Moved methods `min` and `max` out of trait `Univariate` into their own respective traits `Min` and `Max` - Traits `Min`, `Max`, `Mean`, `Variance`, `Entropy`, `Skewness`, `Median`, and `Mode` moved from `distribution` module to `statistics` module - `Mean`, `Variance`, `Entropy`, `Skewness`, `Median`, and `Mode` no longer depend on `Distribution` trait - `Mean`, `Variance`, `Skewness`, and `Mode` are now generic over only one type, the return type, due to not depending on `Distribution` anymore - `order_statistic`, `median`, `quantile`, `percentile`, `lower_quartile`, `upper_quartile`, `interquartile_range`, and `ranks` methods removed from `Statistics` trait. - `min`, `max`, `mean`, `variance`, and `std_dev` methods added to `Statistics` trait - `Statistics` trait now implemented for all types implementing `IntoIterator` where `Item` implements `Borrow`. Slice now implicitly implements `Statistics` through this new implementation. - Slice still implements `Min`, `Max`, `Mean`, and `Variance` but now through the `Statistics` implementation rather than its own implementation - `InplaceStatistics` renamed to `OrderStatistics`, all methods in `InplaceStatistics` have `_inplace` trimmed from method name. - Inverse DiGamma function implemented with signature `gamma::inv_digamma(x: f64) -> f64` ## [v0.2.0] - Created `statistics` module and `Statistics` trait - `Statistics` trait implementation for `[f64]` - Implemented `Beta` distribution - Added `Modulus` trait and implementations for `f32`, `f64`, `i32`, `i64`, `u32`, and `u64` in `euclid` module - Added periodic and sinusoidal vector generation functions in `generate` module statrs-0.18.0/Cargo.toml0000644000000035650000000000100104510ustar # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO # # When uploading crates to the registry Cargo will automatically # "normalize" Cargo.toml files for maximal compatibility # with all versions of Cargo and also rewrite `path` dependencies # to registry (e.g., crates.io) dependencies. # # If you are reading this file be aware that the original Cargo.toml # will likely look very different (and much more reasonable). # See Cargo.toml.orig for the original contents. [package] edition = "2021" rust-version = "1.65.0" name = "statrs" version = "0.18.0" authors = ["Michael Ma"] build = false include = [ "CHANGELOG.md", "LICENSE.md", "src/", "tests/", ] autobins = false autoexamples = false autotests = false autobenches = false description = "Statistical computing library for Rust" homepage = "https://github.com/statrs-dev/statrs" readme = "README.md" keywords = [ "probability", "statistics", "stats", "distribution", "math", ] categories = ["science"] license = "MIT" repository = "https://github.com/statrs-dev/statrs" [package.metadata.docs.rs] all-features = true rustdoc-args = [ "--cfg", "docsrs", ] [lib] name = "statrs" path = "src/lib.rs" [[test]] name = "nist_tests" path = "tests/nist_tests.rs" [dependencies.approx] version = "0.5.0" [dependencies.nalgebra] version = "0.33" features = ["std"] optional = true default-features = false [dependencies.num-traits] version = "0.2.14" [dependencies.rand] version = "0.8" optional = true [dev-dependencies.anyhow] version = "1.0" [dev-dependencies.criterion] version = "0.5" [dev-dependencies.nalgebra] version = "0.33" features = ["macros"] default-features = false [features] default = [ "nalgebra", "rand", ] nalgebra = ["dep:nalgebra"] rand = [ "dep:rand", "nalgebra?/rand", ] [lints.rust.unexpected_cfgs] level = "warn" priority = 0 check-cfg = ["cfg(coverage_nightly)"] statrs-0.18.0/Cargo.toml.orig000064400000000000000000000024151046102023000141230ustar 00000000000000[package] name = "statrs" version = "0.18.0" authors = ["Michael Ma"] description = "Statistical computing library for Rust" license = "MIT" keywords = ["probability", "statistics", "stats", "distribution", "math"] categories = ["science"] homepage = "https://github.com/statrs-dev/statrs" repository = "https://github.com/statrs-dev/statrs" edition = "2021" include = ["CHANGELOG.md", "LICENSE.md", "src/", "tests/"] # When changing MSRV: Also update the README rust-version = "1.65.0" [lib] name = "statrs" path = "src/lib.rs" [[bench]] name = "order_statistics" harness = false required-features = ["rand"] [features] default = ["nalgebra", "rand"] nalgebra = ["dep:nalgebra"] rand = ["dep:rand", "nalgebra?/rand"] [dependencies] approx = "0.5.0" num-traits = "0.2.14" [dependencies.rand] version = "0.8" optional = true [dependencies.nalgebra] version = "0.33" optional = true default-features = false features = ["std"] [dev-dependencies] criterion = "0.5" anyhow = "1.0" [dev-dependencies.nalgebra] version = "0.33" default-features = false features = ["macros"] [lints.rust.unexpected_cfgs] level = "warn" # Set by cargo-llvm-cov when running on nightly check-cfg = ['cfg(coverage_nightly)'] [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] statrs-0.18.0/LICENSE.md000064400000000000000000000020531046102023000126360ustar 00000000000000MIT License Copyright (c) 2016 Michael Ma Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. statrs-0.18.0/README.md000064400000000000000000000101651046102023000125140ustar 00000000000000# statrs ![tests][actions-test-badge] [![MIT licensed][license-badge]](./LICENSE.md) [![Crate][crates-badge]][crates-url] [![docs.rs][docsrs-badge]][docs-url] [![codecov-statrs][codecov-badge]][codecov-url] ![Crates.io MSRV][crates-msrv-badge] [actions-test-badge]: https://github.com/statrs-dev/statrs/actions/workflows/test.yml/badge.svg [crates-badge]: https://img.shields.io/crates/v/statrs.svg [crates-url]: https://crates.io/crates/statrs [license-badge]: https://img.shields.io/badge/license-MIT-blue.svg [docsrs-badge]: https://img.shields.io/docsrs/statrs [docs-url]: https://docs.rs/statrs/*/statrs [codecov-badge]: https://codecov.io/gh/statrs-dev/statrs/graph/badge.svg?token=XtMSMYXvIf [codecov-url]: https://codecov.io/gh/statrs-dev/statrs [crates-msrv-badge]: https://img.shields.io/crates/msrv/statrs Statrs provides a host of statistical utilities for Rust scientific computing. Included are a number of common distributions that can be sampled (i.e. Normal, Exponential, Student's T, Gamma, Uniform, etc.) plus common statistical functions like the gamma function, beta function, and error function. This library began as port of the statistical capabilities in the C# Math.NET library. All unit tests in the library borrowed from Math.NET when possible and filled-in when not. Planned for future releases are continued implementations of distributions as well as porting over more statistical utilities. Please check out the documentation [here][docs-url]. ## Usage Add the most recent release to your `Cargo.toml` ```toml [dependencies] statrs = "*" # replace * by the latest version of the crate. ``` For examples, view [the docs](https://docs.rs/statrs/*/statrs/). ### Running tests If you'd like to run all suggested tests, you'll need to download some data from NIST, we have a script for this and formatting the data in the `tests/` folder. ```sh cargo test ./tests/gather_nist_data.sh && cargo test -- --include-ignored nist_ ``` If you'd like to modify where the data is downloaded, you can use the environment variable, `STATRS_NIST_DATA_DIR` for running the script and the tests. ## Minimum supported Rust version (MSRV) This crate requires a Rust version of 1.65.0 or higher. Increases in MSRV will be considered a semver non-breaking API change and require a version increase (PATCH until 1.0.0, MINOR after 1.0.0). ## Contributing Thanks for your help to improve the project! **No contribution is too small and all contributions are valued.** Suggestions if you don't know where to start, - [documentation][docs-url] is a great place to start, as you'll be able to identify the value of existing documentation better than its authors. - tests are valuable in demonstrating correct behavior, you can review test coverage on the [CodeCov Report][codecov-url] - check out some of the issues marked [help wanted](https://github.com/statrs-dev/statrs/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). - look at features in other tools you'd like to see in statrs - Math.NET's - [Distributions](https://github.com/mathnet/mathnet-numerics/tree/master/src/Numerics/Distributions) - [Statistics](https://github.com/mathnet/mathnet-numerics/tree/master/src/Numerics/Statistics) - scipy.stats ### How to contribute Clone the repo: ``` git clone https://github.com/statrs-dev/statrs ``` Create a feature branch: ``` git checkout -b master ``` Write your code and docs, then ensure it is formatted: ``` cargo fmt ``` Add `--check` to view the diff without making file changes. Our CI will check format without making changes. After commiting your code: ```shell git push -u # with `git` gh pr create --head # with GitHub's cli ``` Then submit a PR, preferably referencing the relevant issue, if it exists. ### Commit messages Please be explicit and and purposeful with commit messages. [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/#summary) encouraged. #### Bad ``` Modify test code ``` #### Good ``` test: Update statrs::distribution::Normal test_cdf ``` ### Communication Expectations Please allow at least one week before pinging issues/pr's. statrs-0.18.0/src/consts.rs000064400000000000000000000021751046102023000137050ustar 00000000000000//! Defines mathematical expressions commonly used when computing distribution //! values as constants /// Constant value for `sqrt(2 * pi)` pub const SQRT_2PI: f64 = 2.5066282746310005024157652848110452530069867406099; /// Constant value for `ln(pi)` pub const LN_PI: f64 = 1.1447298858494001741434273513530587116472948129153; /// Constant value for `ln(sqrt(2 * pi))` pub const LN_SQRT_2PI: f64 = 0.91893853320467274178032973640561763986139747363778; /// Constant value for `ln(sqrt(2 * pi * e))` pub const LN_SQRT_2PIE: f64 = 1.4189385332046727417803297364056176398613974736378; /// Constant value for `ln(2 * sqrt(e / pi))` pub const LN_2_SQRT_E_OVER_PI: f64 = 0.6207822376352452223455184457816472122518527279025978; /// Constant value for `2 * sqrt(e / pi)` pub const TWO_SQRT_E_OVER_PI: f64 = 1.8603827342052657173362492472666631120594218414085755; /// Constant value for Euler-Masheroni constant `lim(n -> inf) { sum(k=1 -> n) /// { 1/k - ln(n) } }` pub const EULER_MASCHERONI: f64 = 0.5772156649015328606065120900824024310421593359399235988057672348849; /// Targeted accuracy instantiated over `f64` pub const ACC: f64 = 10e-11; statrs-0.18.0/src/distribution/bernoulli.rs000064400000000000000000000161611046102023000171060ustar 00000000000000use crate::distribution::{Binomial, BinomialError, Discrete, DiscreteCDF}; use crate::statistics::*; /// Implements the /// [Bernoulli](https://en.wikipedia.org/wiki/Bernoulli_distribution) /// distribution which is a special case of the /// [Binomial](https://en.wikipedia.org/wiki/Binomial_distribution) /// distribution where `n = 1` (referenced [Here](./struct.Binomial.html)) /// /// # Examples /// /// ``` /// use statrs::distribution::{Bernoulli, Discrete}; /// use statrs::statistics::Distribution; /// /// let n = Bernoulli::new(0.5).unwrap(); /// assert_eq!(n.mean().unwrap(), 0.5); /// assert_eq!(n.pmf(0), 0.5); /// assert_eq!(n.pmf(1), 0.5); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct Bernoulli { b: Binomial, } impl Bernoulli { /// Constructs a new bernoulli distribution with /// the given `p` probability of success. /// /// # Errors /// /// Returns an error if `p` is `NaN`, less than `0.0` /// or greater than `1.0` /// /// # Examples /// /// ``` /// use statrs::distribution::Bernoulli; /// /// let mut result = Bernoulli::new(0.5); /// assert!(result.is_ok()); /// /// result = Bernoulli::new(-0.5); /// assert!(result.is_err()); /// ``` pub fn new(p: f64) -> Result { Binomial::new(p, 1).map(|b| Bernoulli { b }) } /// Returns the probability of success `p` of the /// bernoulli distribution. /// /// # Examples /// /// ``` /// use statrs::distribution::Bernoulli; /// /// let n = Bernoulli::new(0.5).unwrap(); /// assert_eq!(n.p(), 0.5); /// ``` pub fn p(&self) -> f64 { self.b.p() } /// Returns the number of trials `n` of the /// bernoulli distribution. Will always be `1.0`. /// /// # Examples /// /// ``` /// use statrs::distribution::Bernoulli; /// /// let n = Bernoulli::new(0.5).unwrap(); /// assert_eq!(n.n(), 1); /// ``` pub fn n(&self) -> u64 { 1 } } impl std::fmt::Display for Bernoulli { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Bernoulli({})", self.p()) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Bernoulli { fn sample(&self, rng: &mut R) -> bool { rng.gen_bool(self.p()) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Bernoulli { fn sample(&self, rng: &mut R) -> f64 { rng.sample::(self) as u8 as f64 } } impl DiscreteCDF for Bernoulli { /// Calculates the cumulative distribution /// function for the bernoulli distribution at `x`. /// /// # Formula /// /// ```text /// if x < 0 { 0 } /// else if x >= 1 { 1 } /// else { 1 - p } /// ``` fn cdf(&self, x: u64) -> f64 { self.b.cdf(x) } /// Calculates the survival function for the /// bernoulli distribution at `x`. /// /// # Formula /// /// ```text /// if x < 0 { 1 } /// else if x >= 1 { 0 } /// else { p } /// ``` fn sf(&self, x: u64) -> f64 { self.b.sf(x) } } impl Min for Bernoulli { /// Returns the minimum value in the domain of the /// bernoulli distribution representable by a 64- /// bit integer /// /// # Formula /// /// ```text /// 0 /// ``` fn min(&self) -> u64 { 0 } } impl Max for Bernoulli { /// Returns the maximum value in the domain of the /// bernoulli distribution representable by a 64- /// bit integer /// /// # Formula /// /// ```text /// 1 /// ``` fn max(&self) -> u64 { 1 } } impl Distribution for Bernoulli { /// Returns the mean of the bernoulli /// distribution /// /// # Formula /// /// ```text /// p /// ``` fn mean(&self) -> Option { self.b.mean() } /// Returns the variance of the bernoulli /// distribution /// /// # Formula /// /// ```text /// p * (1 - p) /// ``` fn variance(&self) -> Option { self.b.variance() } /// Returns the entropy of the bernoulli /// distribution /// /// # Formula /// /// ```text /// q = (1 - p) /// -q * ln(q) - p * ln(p) /// ``` fn entropy(&self) -> Option { self.b.entropy() } /// Returns the skewness of the bernoulli /// distribution /// /// # Formula /// /// ```text /// q = (1 - p) /// (1 - 2p) / sqrt(p * q) /// ``` fn skewness(&self) -> Option { self.b.skewness() } } impl Median for Bernoulli { /// Returns the median of the bernoulli /// distribution /// /// # Formula /// /// ```text /// if p < 0.5 { 0 } /// else if p > 0.5 { 1 } /// else { 0.5 } /// ``` fn median(&self) -> f64 { self.b.median() } } impl Mode> for Bernoulli { /// Returns the mode of the bernoulli distribution /// /// # Formula /// /// ```text /// if p < 0.5 { 0 } /// else { 1 } /// ``` fn mode(&self) -> Option { self.b.mode() } } impl Discrete for Bernoulli { /// Calculates the probability mass function for the /// bernoulli distribution at `x`. /// /// # Formula /// /// ```text /// if x == 0 { 1 - p } /// else { p } /// ``` fn pmf(&self, x: u64) -> f64 { self.b.pmf(x) } /// Calculates the log probability mass function for the /// bernoulli distribution at `x`. /// /// # Formula /// /// ```text /// else if x == 0 { ln(1 - p) } /// else { ln(p) } /// ``` fn ln_pmf(&self, x: u64) -> f64 { self.b.ln_pmf(x) } } #[rustfmt::skip] #[cfg(test)] mod testing { use super::*; use crate::testing_boiler; testing_boiler!(p: f64; Bernoulli; BinomialError); #[test] fn test_create() { create_ok(0.0); create_ok(0.3); create_ok(1.0); } #[test] fn test_bad_create() { create_err(f64::NAN); create_err(-1.0); create_err(2.0); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: u64| move |x: Bernoulli| x.cdf(arg); test_relative(0.3, 1., cdf(1)); } #[test] fn test_sf_upper_bound() { let sf = |arg: u64| move |x: Bernoulli| x.sf(arg); test_relative(0.3, 0., sf(1)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Bernoulli| x.cdf(arg); test_relative(0.0, 1.0, cdf(0)); test_relative(0.0, 1.0, cdf(1)); test_absolute(0.3, 0.7, 1e-15, cdf(0)); test_absolute(0.7, 0.3, 1e-15, cdf(0)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Bernoulli| x.sf(arg); test_relative(0.0, 0.0, sf(0)); test_relative(0.0, 0.0, sf(1)); test_absolute(0.3, 0.3, 1e-15, sf(0)); test_absolute(0.7, 0.7, 1e-15, sf(0)); } } statrs-0.18.0/src/distribution/beta.rs000064400000000000000000000436031046102023000160270ustar 00000000000000use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::{beta, gamma}; use crate::statistics::*; /// Implements the [Beta](https://en.wikipedia.org/wiki/Beta_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{Beta, Continuous}; /// use statrs::statistics::*; /// use statrs::prec; /// /// let n = Beta::new(2.0, 2.0).unwrap(); /// assert_eq!(n.mean().unwrap(), 0.5); /// assert!(prec::almost_eq(n.pdf(0.5), 1.5, 1e-14)); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct Beta { shape_a: f64, shape_b: f64, } /// Represents the errors that can occur when creating a [`Beta`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum BetaError { /// Shape A is NaN, infinite, zero or negative. ShapeAInvalid, /// Shape B is NaN, infinite, zero or negative. ShapeBInvalid, } impl std::fmt::Display for BetaError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { BetaError::ShapeAInvalid => write!(f, "Shape A is NaN, infinite, zero or negative"), BetaError::ShapeBInvalid => write!(f, "Shape B is NaN, infinite, zero or negative"), } } } impl std::error::Error for BetaError {} impl Beta { /// Constructs a new beta distribution with shapeA (α) of `shape_a` /// and shapeB (β) of `shape_b` /// /// # Errors /// /// Returns an error if `shape_a` or `shape_b` are `NaN` or infinite. /// Also returns an error if `shape_a <= 0.0` or `shape_b <= 0.0` /// /// # Examples /// /// ``` /// use statrs::distribution::Beta; /// /// let mut result = Beta::new(2.0, 2.0); /// assert!(result.is_ok()); /// /// result = Beta::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` pub fn new(shape_a: f64, shape_b: f64) -> Result { if shape_a.is_nan() || shape_a.is_infinite() || shape_a <= 0.0 { return Err(BetaError::ShapeAInvalid); } if shape_b.is_nan() || shape_b.is_infinite() || shape_b <= 0.0 { return Err(BetaError::ShapeBInvalid); } Ok(Beta { shape_a, shape_b }) } /// Returns the shapeA (α) of the beta distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Beta; /// /// let n = Beta::new(1.0, 2.0).unwrap(); /// assert_eq!(n.shape_a(), 1.0); /// ``` pub fn shape_a(&self) -> f64 { self.shape_a } /// Returns the shapeB (β) of the beta distributionβ /// /// # Examples /// /// ``` /// use statrs::distribution::Beta; /// /// let n = Beta::new(1.0, 2.0).unwrap(); /// assert_eq!(n.shape_b(), 2.0); /// ``` pub fn shape_b(&self) -> f64 { self.shape_b } } impl std::fmt::Display for Beta { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Beta(a={}, b={})", self.shape_a, self.shape_b) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Beta { fn sample(&self, rng: &mut R) -> f64 { // Generated by sampling two gamma distributions and normalizing. let x = super::gamma::sample_unchecked(rng, self.shape_a, 1.0); let y = super::gamma::sample_unchecked(rng, self.shape_b, 1.0); x / (x + y) } } impl ContinuousCDF for Beta { /// Calculates the cumulative distribution function for the beta /// distribution at `x`. /// /// # Formula /// /// ```text /// I_x(α, β) /// ``` /// /// where `α` is shapeA, `β` is shapeB, and `I_x` is the regularized /// lower incomplete beta function. fn cdf(&self, x: f64) -> f64 { if x < 0.0 { 0.0 } else if x >= 1.0 { 1.0 } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) { x } else { beta::beta_reg(self.shape_a, self.shape_b, x) } } /// Calculates the survival function for the beta distribution at `x`. /// /// # Formula /// /// ```text /// I_(1-x)(β, α) /// ``` /// /// where `α` is shapeA, `β` is shapeB, and `I_x` is the regularized /// lower incomplete beta function. fn sf(&self, x: f64) -> f64 { if x < 0.0 { 1.0 } else if x >= 1.0 { 0.0 } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) { 1. - x } else { beta::beta_reg(self.shape_b, self.shape_a, 1.0 - x) } } /// Calculates the inverse cumulative distribution function for the beta /// distribution at `x`. /// /// # Panics /// /// If x is not in `[0, 1]`. /// /// # Formula /// /// ```text /// I^{-1}_x(α, β) /// ``` /// /// where `α` is shapeA, `β` is shapeB, and `I_x` is the inverse of the /// regularized lower incomplete beta function. fn inverse_cdf(&self, x: f64) -> f64 { if !(0.0..=1.0).contains(&x) { panic!("x must be in [0, 1]"); } else { beta::inv_beta_reg(self.shape_a, self.shape_b, x) } } } impl Min for Beta { /// Returns the minimum value in the domain of the beta distribution /// representable by a double precision float. /// /// # Formula /// /// ```text /// 0 /// ``` fn min(&self) -> f64 { 0.0 } } impl Max for Beta { /// Returns the maximum value in the domain of the beta distribution /// representable by a double precision float. /// /// # Formula /// /// ```text /// 1 /// ``` fn max(&self) -> f64 { 1.0 } } impl Distribution for Beta { /// Returns the mean of the beta distribution. /// /// # Formula /// /// ```text /// α / (α + β) /// ``` /// /// where `α` is shapeA and `β` is shapeB. fn mean(&self) -> Option { Some(self.shape_a / (self.shape_a + self.shape_b)) } /// Returns the variance of the beta distribution. /// /// # Formula /// /// ```text /// (α * β) / ((α + β)^2 * (α + β + 1)) /// ``` /// /// where `α` is shapeA and `β` is shapeB. fn variance(&self) -> Option { Some( self.shape_a * self.shape_b / ((self.shape_a + self.shape_b) * (self.shape_a + self.shape_b) * (self.shape_a + self.shape_b + 1.0)), ) } /// Returns the entropy of the beta distribution. /// /// # Formula /// /// ```text /// ln(B(α, β)) - (α - 1)ψ(α) - (β - 1)ψ(β) + (α + β - 2)ψ(α + β) /// ``` /// /// where `α` is shapeA, `β` is shapeB and `ψ` is the digamma function. fn entropy(&self) -> Option { Some( beta::ln_beta(self.shape_a, self.shape_b) - (self.shape_a - 1.0) * gamma::digamma(self.shape_a) - (self.shape_b - 1.0) * gamma::digamma(self.shape_b) + (self.shape_a + self.shape_b - 2.0) * gamma::digamma(self.shape_a + self.shape_b), ) } /// Returns the skewness of the Beta distribution. /// /// # Formula /// /// ```text /// 2(β - α) * sqrt(α + β + 1) / ((α + β + 2) * sqrt(αβ)) /// ``` /// /// where `α` is shapeA and `β` is shapeB. fn skewness(&self) -> Option { Some( 2.0 * (self.shape_b - self.shape_a) * (self.shape_a + self.shape_b + 1.0).sqrt() / ((self.shape_a + self.shape_b + 2.0) * (self.shape_a * self.shape_b).sqrt()), ) } } impl Mode> for Beta { /// Returns the mode of the Beta distribution. Returns `None` if `α <= 1` /// or `β <= 1`. /// /// # Remarks /// /// Since the mode is technically only calculated for `α > 1, β > 1`, those /// are the only values we allow. We may consider relaxing this constraint /// in the future. /// /// # Formula /// /// ```text /// (α - 1) / (α + β - 2) /// ``` /// /// where `α` is shapeA and `β` is shapeB fn mode(&self) -> Option { // TODO: perhaps relax constraint in order to allow calculation // of 'anti-mode; if self.shape_a <= 1.0 || self.shape_b <= 1.0 { None } else { Some((self.shape_a - 1.0) / (self.shape_a + self.shape_b - 2.0)) } } } impl Continuous for Beta { /// Calculates the probability density function for the beta distribution /// at `x`. /// /// # Formula /// /// ```text /// let B(α, β) = Γ(α)Γ(β)/Γ(α + β) /// /// x^(α - 1) * (1 - x)^(β - 1) / B(α, β) /// ``` /// /// where `α` is shapeA, `β` is shapeB, and `Γ` is the gamma function fn pdf(&self, x: f64) -> f64 { if !(0.0..=1.0).contains(&x) { 0.0 } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) { 1.0 } else if self.shape_a > 80.0 || self.shape_b > 80.0 { self.ln_pdf(x).exp() } else { let bb = gamma::gamma(self.shape_a + self.shape_b) / (gamma::gamma(self.shape_a) * gamma::gamma(self.shape_b)); bb * x.powf(self.shape_a - 1.0) * (1.0 - x).powf(self.shape_b - 1.0) } } /// Calculates the log probability density function for the beta /// distribution at `x`. /// /// # Formula /// /// ```text /// let B(α, β) = Γ(α)Γ(β)/Γ(α + β) /// /// ln(x^(α - 1) * (1 - x)^(β - 1) / B(α, β)) /// ``` /// /// where `α` is shapeA, `β` is shapeB, and `Γ` is the gamma function. fn ln_pdf(&self, x: f64) -> f64 { if !(0.0..=1.0).contains(&x) { f64::NEG_INFINITY } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) { 0.0 } else { let aa = gamma::ln_gamma(self.shape_a + self.shape_b) - gamma::ln_gamma(self.shape_a) - gamma::ln_gamma(self.shape_b); let bb = if ulps_eq!(self.shape_a, 1.0) && x == 0.0 { 0.0 } else if x == 0.0 { f64::NEG_INFINITY } else { (self.shape_a - 1.0) * x.ln() }; let cc = if ulps_eq!(self.shape_b, 1.0) && ulps_eq!(x, 1.0) { 0.0 } else if ulps_eq!(x, 1.0) { f64::NEG_INFINITY } else { (self.shape_b - 1.0) * (1.0 - x).ln() }; aa + bb + cc } } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use super::super::internal::*; use crate::testing_boiler; testing_boiler!(a: f64, b: f64; Beta; BetaError); #[test] fn test_create() { let valid = [(1.0, 1.0), (9.0, 1.0), (5.0, 100.0)]; for (a, b) in valid { create_ok(a, b); } } #[test] fn test_bad_create() { let invalid = [ (0.0, 0.0), (0.0, 0.1), (1.0, 0.0), (0.5, f64::INFINITY), (f64::INFINITY, 0.5), (f64::NAN, 1.0), (1.0, f64::NAN), (f64::NAN, f64::NAN), (1.0, -1.0), (-1.0, 1.0), (-1.0, -1.0), (f64::INFINITY, f64::INFINITY), ]; for (a, b) in invalid { create_err(a, b); } } #[test] fn test_mean() { let f = |x: Beta| x.mean().unwrap(); let test = [ ((1.0, 1.0), 0.5), ((9.0, 1.0), 0.9), ((5.0, 100.0), 0.047619047619047619047616), ]; for ((a, b), res) in test { test_relative(a, b, res, f); } } #[test] fn test_variance() { let f = |x: Beta| x.variance().unwrap(); let test = [ ((1.0, 1.0), 1.0 / 12.0), ((9.0, 1.0), 9.0 / 1100.0), ((5.0, 100.0), 500.0 / 1168650.0), ]; for ((a, b), res) in test { test_relative(a, b, res, f); } } #[test] fn test_entropy() { let f = |x: Beta| x.entropy().unwrap(); let test = [ ((9.0, 1.0), -1.3083356884473304939016015), ((5.0, 100.0), -2.52016231876027436794592), ]; for ((a, b), res) in test { test_relative(a, b, res, f); } test_absolute(1.0, 1.0, 0.0, 1e-14, f); } #[test] fn test_skewness() { let skewness = |x: Beta| x.skewness().unwrap(); test_relative(1.0, 1.0, 0.0, skewness); test_relative(9.0, 1.0, -1.4740554623801777107177478829, skewness); test_relative(5.0, 100.0, 0.817594109275534303545831591, skewness); } #[test] fn test_mode() { let mode = |x: Beta| x.mode().unwrap(); test_relative(5.0, 100.0, 0.038834951456310676243255386, mode); } #[test] fn test_mode_shape_a_lte_1() { test_none(1.0, 5.0, |dist| dist.mode()); } #[test] fn test_mode_shape_b_lte_1() { test_none(5.0, 1.0, |dist| dist.mode()); } #[test] fn test_min_max() { let min = |x: Beta| x.min(); let max = |x: Beta| x.max(); test_relative(1.0, 1.0, 0.0, min); test_relative(1.0, 1.0, 1.0, max); } #[test] fn test_pdf() { let f = |arg: f64| move |x: Beta| x.pdf(arg); let test = [ ((1.0, 1.0), 0.0, 1.0), ((1.0, 1.0), 0.5, 1.0), ((1.0, 1.0), 1.0, 1.0), ((9.0, 1.0), 0.0, 0.0), ((9.0, 1.0), 0.5, 0.03515625), ((9.0, 1.0), 1.0, 9.0), ((5.0, 100.0), 0.0, 0.0), ((5.0, 100.0), 0.5, 4.534102298350337661e-23), ((5.0, 100.0), 1.0, 0.0), ((5.0, 100.0), 1.0, 0.0) ]; for ((a, b), x, expect) in test { test_relative(a, b, expect, f(x)); } } #[test] fn test_pdf_input_lt_0() { let pdf = |arg: f64| move |x: Beta| x.pdf(arg); test_relative(1.0, 1.0, 0.0, pdf(-1.0)); } #[test] fn test_pdf_input_gt_0() { let pdf = |arg: f64| move |x: Beta| x.pdf(arg); test_relative(1.0, 1.0, 0.0, pdf(2.0)); } #[test] fn test_ln_pdf() { let f = |arg: f64| move |x: Beta| x.ln_pdf(arg); let test = [ ((1.0, 1.0), 0.0, 0.0), ((1.0, 1.0), 0.5, 0.0), ((1.0, 1.0), 1.0, 0.0), ((9.0, 1.0), 0.0, f64::NEG_INFINITY), ((9.0, 1.0), 0.5, -3.347952867143343092547366497), ((9.0, 1.0), 1.0, 2.1972245773362193827904904738), ((5.0, 100.0), 0.0, f64::NEG_INFINITY), ((5.0, 100.0), 0.5, -51.447830024537682154565870), ((5.0, 100.0), 1.0, f64::NEG_INFINITY), ]; for ((a, b), x, expect) in test { test_relative(a, b, expect, f(x)); } } #[test] fn test_ln_pdf_input_lt_0() { let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg); test_relative(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(-1.0)); } #[test] fn test_ln_pdf_input_gt_1() { let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg); test_relative(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(2.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Beta| x.cdf(arg); let test = [ ((1.0, 1.0), 0.0, 0.0), ((1.0, 1.0), 0.5, 0.5), ((1.0, 1.0), 1.0, 1.0), ((9.0, 1.0), 0.0, 0.0), ((9.0, 1.0), 0.5, 0.001953125), ((9.0, 1.0), 1.0, 1.0), ((5.0, 100.0), 0.0, 0.0), ((5.0, 100.0), 0.5, 1.0), ((5.0, 100.0), 1.0, 1.0), ]; for ((a, b), x, expect) in test { test_relative(a, b, expect, cdf(x)); } } #[test] fn test_sf() { let sf = |arg: f64| move |x: Beta| x.sf(arg); let test = [ ((1.0, 1.0), 0.0, 1.0), ((1.0, 1.0), 0.5, 0.5), ((1.0, 1.0), 1.0, 0.0), ((9.0, 1.0), 0.0, 1.0), ((9.0, 1.0), 0.5, 0.998046875), ((9.0, 1.0), 1.0, 0.0), ((5.0, 100.0), 0.0, 1.0), ((5.0, 100.0), 0.5, 0.0), ((5.0, 100.0), 1.0, 0.0), ]; for ((a, b), x, expect) in test { test_relative(a, b, expect, sf(x)); } } #[test] fn test_inverse_cdf() { // let inverse_cdf = |arg: f64| move |x: Beta| x.inverse_cdf(arg); let func = |arg: f64| move |x: Beta| x.inverse_cdf(x.cdf(arg)); let test = [ ((1.0, 1.0), 0.0, 0.0), ((1.0, 1.0), 0.5, 0.5), ((1.0, 1.0), 1.0, 1.0), ((9.0, 1.0), 0.0, 0.0), ((9.0, 1.0), 0.001953125, 0.001953125), ((9.0, 1.0), 0.5, 0.5), ((9.0, 1.0), 1.0, 1.0), ((5.0, 100.0), 0.0, 0.0), ((5.0, 100.0), 0.01, 0.01), ((5.0, 100.0), 1.0, 1.0), ]; for ((a, b), x, expect) in test { test_relative(a, b, expect, func(x)); }; } #[test] fn test_cdf_input_lt_0() { let cdf = |arg: f64| move |x: Beta| x.cdf(arg); test_relative(1.0, 1.0, 0.0, cdf(-1.0)); } #[test] fn test_cdf_input_gt_1() { let cdf = |arg: f64| move |x: Beta| x.cdf(arg); test_relative(1.0, 1.0, 1.0, cdf(2.0)); } #[test] fn test_sf_input_lt_0() { let sf = |arg: f64| move |x: Beta| x.sf(arg); test_relative(1.0, 1.0, 1.0, sf(-1.0)); } #[test] fn test_sf_input_gt_1() { let sf = |arg: f64| move |x: Beta| x.sf(arg); test_relative(1.0, 1.0, 0.0, sf(2.0)); } #[test] fn test_continuous() { test::check_continuous_distribution(&create_ok(1.2, 3.4), 0.0, 1.0); test::check_continuous_distribution(&create_ok(4.5, 6.7), 0.0, 1.0); } } statrs-0.18.0/src/distribution/binomial.rs000064400000000000000000000416221046102023000167050ustar 00000000000000use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::{beta, factorial}; use crate::statistics::*; use std::f64; /// Implements the /// [Binomial](https://en.wikipedia.org/wiki/Binomial_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{Binomial, Discrete}; /// use statrs::statistics::Distribution; /// /// let n = Binomial::new(0.5, 5).unwrap(); /// assert_eq!(n.mean().unwrap(), 2.5); /// assert_eq!(n.pmf(0), 0.03125); /// assert_eq!(n.pmf(3), 0.3125); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct Binomial { p: f64, n: u64, } /// Represents the errors that can occur when creating a [`Binomial`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum BinomialError { /// The probability is NaN or not in `[0, 1]`. ProbabilityInvalid, } impl std::fmt::Display for BinomialError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { BinomialError::ProbabilityInvalid => write!(f, "Probability is NaN or not in [0, 1]"), } } } impl std::error::Error for BinomialError {} impl Binomial { /// Constructs a new binomial distribution /// with a given `p` probability of success of `n` /// trials. /// /// # Errors /// /// Returns an error if `p` is `NaN`, less than `0.0`, /// greater than `1.0`, or if `n` is less than `0` /// /// # Examples /// /// ``` /// use statrs::distribution::Binomial; /// /// let mut result = Binomial::new(0.5, 5); /// assert!(result.is_ok()); /// /// result = Binomial::new(-0.5, 5); /// assert!(result.is_err()); /// ``` pub fn new(p: f64, n: u64) -> Result { if p.is_nan() || !(0.0..=1.0).contains(&p) { Err(BinomialError::ProbabilityInvalid) } else { Ok(Binomial { p, n }) } } /// Returns the probability of success `p` of /// the binomial distribution. /// /// # Examples /// /// ``` /// use statrs::distribution::Binomial; /// /// let n = Binomial::new(0.5, 5).unwrap(); /// assert_eq!(n.p(), 0.5); /// ``` pub fn p(&self) -> f64 { self.p } /// Returns the number of trials `n` of the /// binomial distribution. /// /// # Examples /// /// ``` /// use statrs::distribution::Binomial; /// /// let n = Binomial::new(0.5, 5).unwrap(); /// assert_eq!(n.n(), 5); /// ``` pub fn n(&self) -> u64 { self.n } } impl std::fmt::Display for Binomial { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Bin({},{})", self.p, self.n) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Binomial { fn sample(&self, rng: &mut R) -> u64 { (0..self.n).fold(0, |acc, _| { let n: f64 = rng.gen(); if n < self.p { acc + 1 } else { acc } }) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Binomial { fn sample(&self, rng: &mut R) -> f64 { rng.sample::(self) as f64 } } impl DiscreteCDF for Binomial { /// Calculates the cumulative distribution function for the /// binomial distribution at `x` /// /// # Formula /// /// ```text /// I_(1 - p)(n - x, 1 + x) /// ``` /// /// where `I_(x)(a, b)` is the regularized incomplete beta function fn cdf(&self, x: u64) -> f64 { if x >= self.n { 1.0 } else { let k = x; beta::beta_reg((self.n - k) as f64, k as f64 + 1.0, 1.0 - self.p) } } /// Calculates the survival function for the /// binomial distribution at `x` /// /// # Formula /// /// ```text /// I_(p)(x + 1, n - x) /// ``` /// /// where `I_(x)(a, b)` is the regularized incomplete beta function fn sf(&self, x: u64) -> f64 { if x >= self.n { 0.0 } else { let k = x; beta::beta_reg(k as f64 + 1.0, (self.n - k) as f64, self.p) } } } impl Min for Binomial { /// Returns the minimum value in the domain of the /// binomial distribution representable by a 64-bit /// integer /// /// # Formula /// /// ```text /// 0 /// ``` fn min(&self) -> u64 { 0 } } impl Max for Binomial { /// Returns the maximum value in the domain of the /// binomial distribution representable by a 64-bit /// integer /// /// # Formula /// /// ```text /// n /// ``` fn max(&self) -> u64 { self.n } } impl Distribution for Binomial { /// Returns the mean of the binomial distribution /// /// # Formula /// /// ```text /// p * n /// ``` fn mean(&self) -> Option { Some(self.p * self.n as f64) } /// Returns the variance of the binomial distribution /// /// # Formula /// /// ```text /// n * p * (1 - p) /// ``` fn variance(&self) -> Option { Some(self.p * (1.0 - self.p) * self.n as f64) } /// Returns the entropy of the binomial distribution /// /// # Formula /// /// ```text /// (1 / 2) * ln (2 * π * e * n * p * (1 - p)) /// ``` fn entropy(&self) -> Option { let entr = if self.p == 0.0 || ulps_eq!(self.p, 1.0) { 0.0 } else { (0..self.n + 1).fold(0.0, |acc, x| { let p = self.pmf(x); acc - p * p.ln() }) }; Some(entr) } /// Returns the skewness of the binomial distribution /// /// # Formula /// /// ```text /// (1 - 2p) / sqrt(n * p * (1 - p))) /// ``` fn skewness(&self) -> Option { Some((1.0 - 2.0 * self.p) / (self.n as f64 * self.p * (1.0 - self.p)).sqrt()) } } impl Median for Binomial { /// Returns the median of the binomial distribution /// /// # Formula /// /// ```text /// floor(n * p) /// ``` fn median(&self) -> f64 { (self.p * self.n as f64).floor() } } impl Mode> for Binomial { /// Returns the mode for the binomial distribution /// /// # Formula /// /// ```text /// floor((n + 1) * p) /// ``` fn mode(&self) -> Option { let mode = if self.p == 0.0 { 0 } else if ulps_eq!(self.p, 1.0) { self.n } else { ((self.n as f64 + 1.0) * self.p).floor() as u64 }; Some(mode) } } impl Discrete for Binomial { /// Calculates the probability mass function for the binomial /// distribution at `x` /// /// # Formula /// /// ```text /// (n choose k) * p^k * (1 - p)^(n - k) /// ``` fn pmf(&self, x: u64) -> f64 { if x > self.n { 0.0 } else if self.p == 0.0 { if x == 0 { 1.0 } else { 0.0 } } else if ulps_eq!(self.p, 1.0) { if x == self.n { 1.0 } else { 0.0 } } else { (factorial::ln_binomial(self.n, x) + x as f64 * self.p.ln() + (self.n - x) as f64 * (1.0 - self.p).ln()) .exp() } } /// Calculates the log probability mass function for the binomial /// distribution at `x` /// /// # Formula /// /// ```text /// ln((n choose k) * p^k * (1 - p)^(n - k)) /// ``` fn ln_pmf(&self, x: u64) -> f64 { if x > self.n { f64::NEG_INFINITY } else if self.p == 0.0 { if x == 0 { 0.0 } else { f64::NEG_INFINITY } } else if ulps_eq!(self.p, 1.0) { if x == self.n { 0.0 } else { f64::NEG_INFINITY } } else { factorial::ln_binomial(self.n, x) + x as f64 * self.p.ln() + (self.n - x) as f64 * (1.0 - self.p).ln() } } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(p: f64, n: u64; Binomial; BinomialError); #[test] fn test_create() { create_ok(0.0, 4); create_ok(0.3, 3); create_ok(1.0, 2); } #[test] fn test_bad_create() { create_err(f64::NAN, 1); create_err(-1.0, 1); create_err(2.0, 1); } #[test] fn test_mean() { let mean = |x: Binomial| x.mean().unwrap(); test_exact(0.0, 4, 0.0, mean); test_absolute(0.3, 3, 0.9, 1e-15, mean); test_exact(1.0, 2, 2.0, mean); } #[test] fn test_variance() { let variance = |x: Binomial| x.variance().unwrap(); test_exact(0.0, 4, 0.0, variance); test_exact(0.3, 3, 0.63, variance); test_exact(1.0, 2, 0.0, variance); } #[test] fn test_entropy() { let entropy = |x: Binomial| x.entropy().unwrap(); test_exact(0.0, 4, 0.0, entropy); test_absolute(0.3, 3, 1.1404671643037712668976423399228972051669206536461, 1e-15, entropy); test_exact(1.0, 2, 0.0, entropy); } #[test] fn test_skewness() { let skewness = |x: Binomial| x.skewness().unwrap(); test_exact(0.0, 4, f64::INFINITY, skewness); test_exact(0.3, 3, 0.503952630678969636286, skewness); test_exact(1.0, 2, f64::NEG_INFINITY, skewness); } #[test] fn test_median() { let median = |x: Binomial| x.median(); test_exact(0.0, 4, 0.0, median); test_exact(0.3, 3, 0.0, median); test_exact(1.0, 2, 2.0, median); } #[test] fn test_mode() { let mode = |x: Binomial| x.mode().unwrap(); test_exact(0.0, 4, 0, mode); test_exact(0.3, 3, 1, mode); test_exact(1.0, 2, 2, mode); } #[test] fn test_min_max() { let min = |x: Binomial| x.min(); let max = |x: Binomial| x.max(); test_exact(0.3, 10, 0, min); test_exact(0.3, 10, 10, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: Binomial| x.pmf(arg); test_exact(0.0, 1, 1.0, pmf(0)); test_exact(0.0, 1, 0.0, pmf(1)); test_exact(0.0, 3, 1.0, pmf(0)); test_exact(0.0, 3, 0.0, pmf(1)); test_exact(0.0, 3, 0.0, pmf(3)); test_exact(0.0, 10, 1.0, pmf(0)); test_exact(0.0, 10, 0.0, pmf(1)); test_exact(0.0, 10, 0.0, pmf(10)); test_exact(0.3, 1, 0.69999999999999995559107901499373838305473327636719, pmf(0)); test_exact(0.3, 1, 0.2999999999999999888977697537484345957636833190918, pmf(1)); test_exact(0.3, 3, 0.34299999999999993471888615204079956461021032657166, pmf(0)); test_absolute(0.3, 3, 0.44099999999999992772448109690231306411849135972008, 1e-15, pmf(1)); test_absolute(0.3, 3, 0.026999999999999997002397833512077451789759292859569, 1e-16, pmf(3)); test_absolute(0.3, 10, 0.02824752489999998207939855277004937778546385011091, 1e-17, pmf(0)); test_absolute(0.3, 10, 0.12106082099999992639752977030555903089040470780077, 1e-15, pmf(1)); test_absolute(0.3, 10, 0.0000059048999999999978147480206303047454017251032868501, 1e-20, pmf(10)); test_exact(1.0, 1, 0.0, pmf(0)); test_exact(1.0, 1, 1.0, pmf(1)); test_exact(1.0, 3, 0.0, pmf(0)); test_exact(1.0, 3, 0.0, pmf(1)); test_exact(1.0, 3, 1.0, pmf(3)); test_exact(1.0, 10, 0.0, pmf(0)); test_exact(1.0, 10, 0.0, pmf(1)); test_exact(1.0, 10, 1.0, pmf(10)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: Binomial| x.ln_pmf(arg); test_exact(0.0, 1, 0.0, ln_pmf(0)); test_exact(0.0, 1, f64::NEG_INFINITY, ln_pmf(1)); test_exact(0.0, 3, 0.0, ln_pmf(0)); test_exact(0.0, 3, f64::NEG_INFINITY, ln_pmf(1)); test_exact(0.0, 3, f64::NEG_INFINITY, ln_pmf(3)); test_exact(0.0, 10, 0.0, ln_pmf(0)); test_exact(0.0, 10, f64::NEG_INFINITY, ln_pmf(1)); test_exact(0.0, 10, f64::NEG_INFINITY, ln_pmf(10)); test_exact(0.3, 1, -0.3566749439387324423539544041072745145718090708995, ln_pmf(0)); test_exact(0.3, 1, -1.2039728043259360296301803719337238685164245381839, ln_pmf(1)); test_exact(0.3, 3, -1.0700248318161973270618632123218235437154272126985, ln_pmf(0)); test_absolute(0.3, 3, -0.81871040353529122294284394322574719301255212216016, 1e-15, ln_pmf(1)); test_absolute(0.3, 3, -3.6119184129778080888905411158011716055492736145517, 1e-15, ln_pmf(3)); test_exact(0.3, 10, -3.566749439387324423539544041072745145718090708995, ln_pmf(0)); test_absolute(0.3, 10, -2.1114622067804823267977785542148302920616046876506, 1e-14, ln_pmf(1)); test_exact(0.3, 10, -12.039728043259360296301803719337238685164245381839, ln_pmf(10)); test_exact(1.0, 1, f64::NEG_INFINITY, ln_pmf(0)); test_exact(1.0, 1, 0.0, ln_pmf(1)); test_exact(1.0, 3, f64::NEG_INFINITY, ln_pmf(0)); test_exact(1.0, 3, f64::NEG_INFINITY, ln_pmf(1)); test_exact(1.0, 3, 0.0, ln_pmf(3)); test_exact(1.0, 10, f64::NEG_INFINITY, ln_pmf(0)); test_exact(1.0, 10, f64::NEG_INFINITY, ln_pmf(1)); test_exact(1.0, 10, 0.0, ln_pmf(10)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Binomial| x.cdf(arg); test_exact(0.0, 1, 1.0, cdf(0)); test_exact(0.0, 1, 1.0, cdf(1)); test_exact(0.0, 3, 1.0, cdf(0)); test_exact(0.0, 3, 1.0, cdf(1)); test_exact(0.0, 3, 1.0, cdf(3)); test_exact(0.0, 10, 1.0, cdf(0)); test_exact(0.0, 10, 1.0, cdf(1)); test_exact(0.0, 10, 1.0, cdf(10)); test_absolute(0.3, 1, 0.7, 1e-15, cdf(0)); test_exact(0.3, 1, 1.0, cdf(1)); test_absolute(0.3, 3, 0.343, 1e-14, cdf(0)); test_absolute(0.3, 3, 0.784, 1e-15, cdf(1)); test_exact(0.3, 3, 1.0, cdf(3)); test_absolute(0.3, 10, 0.0282475249, 1e-16, cdf(0)); test_absolute(0.3, 10, 0.1493083459, 1e-14, cdf(1)); test_exact(0.3, 10, 1.0, cdf(10)); test_exact(1.0, 1, 0.0, cdf(0)); test_exact(1.0, 1, 1.0, cdf(1)); test_exact(1.0, 3, 0.0, cdf(0)); test_exact(1.0, 3, 0.0, cdf(1)); test_exact(1.0, 3, 1.0, cdf(3)); test_exact(1.0, 10, 0.0, cdf(0)); test_exact(1.0, 10, 0.0, cdf(1)); test_exact(1.0, 10, 1.0, cdf(10)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Binomial| x.sf(arg); test_exact(0.0, 1, 0.0, sf(0)); test_exact(0.0, 1, 0.0, sf(1)); test_exact(0.0, 3, 0.0, sf(0)); test_exact(0.0, 3, 0.0, sf(1)); test_exact(0.0, 3, 0.0, sf(3)); test_exact(0.0, 10, 0.0, sf(0)); test_exact(0.0, 10, 0.0, sf(1)); test_exact(0.0, 10, 0.0, sf(10)); test_absolute(0.3, 1, 0.3, 1e-15, sf(0)); test_exact(0.3, 1, 0.0, sf(1)); test_absolute(0.3, 3, 0.657, 1e-14, sf(0)); test_absolute(0.3, 3, 0.216, 1e-15, sf(1)); test_exact(0.3, 3, 0.0, sf(3)); test_absolute(0.3, 10, 0.9717524751000001, 1e-16, sf(0)); test_absolute(0.3, 10, 0.850691654100002, 1e-14, sf(1)); test_exact(0.3, 10, 0.0, sf(10)); test_exact(1.0, 1, 1.0, sf(0)); test_exact(1.0, 1, 0.0, sf(1)); test_exact(1.0, 3, 1.0, sf(0)); test_exact(1.0, 3, 1.0, sf(1)); test_exact(1.0, 3, 0.0, sf(3)); test_exact(1.0, 10, 1.0, sf(0)); test_exact(1.0, 10, 1.0, sf(1)); test_exact(1.0, 10, 0.0, sf(10)); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: u64| move |x: Binomial| x.cdf(arg); test_exact(0.5, 3, 1.0, cdf(5)); } #[test] fn test_sf_upper_bound() { let sf = |arg: u64| move |x: Binomial| x.sf(arg); test_exact(0.5, 3, 0.0, sf(5)); } #[test] fn test_inverse_cdf() { let invcdf = |arg: f64| move |x: Binomial| x.inverse_cdf(arg); test_exact(0.4, 5, 2, invcdf(0.3456)); // cases in issue #185 test_exact(0.018, 465, 1, invcdf(3.472e-4)); test_exact(0.5, 6, 4, invcdf(0.75)); } #[test] fn test_cdf_inverse_cdf() { let cdf_invcdf = |arg: u64| move |x: Binomial| x.inverse_cdf(x.cdf(arg)); test_exact(0.3, 10, 3, cdf_invcdf(3)); test_exact(0.3, 10, 4, cdf_invcdf(4)); test_exact(0.5, 6, 4, cdf_invcdf(4)); } #[test] fn test_discrete() { test::check_discrete_distribution(&create_ok(0.3, 5), 5); test::check_discrete_distribution(&create_ok(0.7, 10), 10); } } statrs-0.18.0/src/distribution/categorical.rs000064400000000000000000000412621046102023000173700ustar 00000000000000use crate::distribution::{Discrete, DiscreteCDF}; use crate::statistics::*; use std::f64; /// Implements the /// [Categorical](https://en.wikipedia.org/wiki/Categorical_distribution) /// distribution, also known as the generalized Bernoulli or discrete /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{Categorical, Discrete}; /// use statrs::statistics::Distribution; /// use statrs::prec; /// /// let n = Categorical::new(&[0.0, 1.0, 2.0]).unwrap(); /// assert!(prec::almost_eq(n.mean().unwrap(), 5.0 / 3.0, 1e-15)); /// assert_eq!(n.pmf(1), 1.0 / 3.0); /// ``` #[derive(Clone, PartialEq, Debug)] pub struct Categorical { norm_pmf: Vec, cdf: Vec, sf: Vec, } /// Represents the errors that can occur when creating a [`Categorical`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum CategoricalError { /// The probability mass is empty. ProbMassEmpty, /// The probabilities sums up to zero. ProbMassSumZero, /// The probability mass contains at least one element which is NaN or less than zero. ProbMassHasInvalidElements, } impl std::fmt::Display for CategoricalError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { CategoricalError::ProbMassEmpty => write!(f, "Probability mass is empty"), CategoricalError::ProbMassSumZero => write!(f, "Probabilities sum up to zero"), CategoricalError::ProbMassHasInvalidElements => write!( f, "Probability mass contains at least one element which is NaN or less than zero" ), } } } impl std::error::Error for CategoricalError {} impl Categorical { /// Constructs a new categorical distribution /// with the probabilities masses defined by `prob_mass` /// /// # Errors /// /// Returns an error if `prob_mass` is empty, the sum of /// the elements in `prob_mass` is 0, or any element is less than /// 0 or is `f64::NAN` /// /// # Note /// /// The elements in `prob_mass` do not need to be normalized /// /// # Examples /// /// ``` /// use statrs::distribution::Categorical; /// /// let mut result = Categorical::new(&[0.0, 1.0, 2.0]); /// assert!(result.is_ok()); /// /// result = Categorical::new(&[0.0, -1.0, 2.0]); /// assert!(result.is_err()); /// ``` pub fn new(prob_mass: &[f64]) -> Result { if prob_mass.is_empty() { return Err(CategoricalError::ProbMassEmpty); } let mut prob_sum = 0.0; for &p in prob_mass { if p.is_nan() || p < 0.0 { return Err(CategoricalError::ProbMassHasInvalidElements); } prob_sum += p; } if prob_sum == 0.0 { return Err(CategoricalError::ProbMassSumZero); } // extract un-normalized cdf let cdf = prob_mass_to_cdf(prob_mass); // extract un-normalized sf let sf = cdf_to_sf(&cdf); // extract normalized probability mass let sum = cdf[cdf.len() - 1]; let mut norm_pmf = vec![0.0; prob_mass.len()]; norm_pmf .iter_mut() .zip(prob_mass.iter()) .for_each(|(np, pm)| *np = *pm / sum); Ok(Categorical { norm_pmf, cdf, sf }) } fn cdf_max(&self) -> f64 { *self.cdf.last().unwrap() } } impl std::fmt::Display for Categorical { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Cat({:#?})", self.norm_pmf) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Categorical { fn sample(&self, rng: &mut R) -> usize { sample_unchecked(rng, &self.cdf) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Categorical { fn sample(&self, rng: &mut R) -> u64 { sample_unchecked(rng, &self.cdf) as u64 } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Categorical { fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, &self.cdf) as f64 } } impl DiscreteCDF for Categorical { /// Calculates the cumulative distribution function for the categorical /// distribution at `x` /// /// # Formula /// /// ```text /// sum(p_j) from 0..x /// ``` /// /// where `p_j` is the probability mass for the `j`th category fn cdf(&self, x: u64) -> f64 { if x >= self.cdf.len() as u64 { 1.0 } else { self.cdf.get(x as usize).unwrap() / self.cdf_max() } } /// Calculates the survival function for the categorical distribution /// at `x` /// /// # Formula /// /// ```text /// [ sum(p_j) from x..end ] /// ``` fn sf(&self, x: u64) -> f64 { if x >= self.sf.len() as u64 { 0.0 } else { self.sf.get(x as usize).unwrap() / self.cdf_max() } } /// Calculates the inverse cumulative distribution function for the /// categorical /// distribution at `x` /// /// # Panics /// /// If `x <= 0.0` or `x >= 1.0` /// /// # Formula /// /// ```text /// i /// ``` /// /// where `i` is the first index such that `x < f(i)` /// and `f(x)` is defined as `p_x + f(x - 1)` and `f(0) = p_0` where /// `p_x` is the `x`th probability mass fn inverse_cdf(&self, x: f64) -> u64 { if x >= 1.0 || x <= 0.0 { panic!("x must be in [0, 1]") } let denorm_prob = x * self.cdf_max(); binary_index(&self.cdf, denorm_prob) as u64 } } impl Min for Categorical { /// Returns the minimum value in the domain of the /// categorical distribution representable by a 64-bit /// integer /// /// # Formula /// /// ```text /// 0 /// ``` fn min(&self) -> u64 { 0 } } impl Max for Categorical { /// Returns the maximum value in the domain of the /// categorical distribution representable by a 64-bit /// integer /// /// # Formula /// /// ```text /// n /// ``` fn max(&self) -> u64 { self.cdf.len() as u64 - 1 } } impl Distribution for Categorical { /// Returns the mean of the categorical distribution /// /// # Formula /// /// ```text /// Σ(j * p_j) /// ``` /// /// where `p_j` is the `j`th probability mass, /// `Σ` is the sum from `0` to `k - 1`, /// and `k` is the number of categories fn mean(&self) -> Option { Some( self.norm_pmf .iter() .enumerate() .fold(0.0, |acc, (idx, &val)| acc + idx as f64 * val), ) } /// Returns the variance of the categorical distribution /// /// # Formula /// /// ```text /// Σ(p_j * (j - μ)^2) /// ``` /// /// where `p_j` is the `j`th probability mass, `μ` is the mean, /// `Σ` is the sum from `0` to `k - 1`, /// and `k` is the number of categories fn variance(&self) -> Option { let mu = self.mean()?; let var = self .norm_pmf .iter() .enumerate() .fold(0.0, |acc, (idx, &val)| { let r = idx as f64 - mu; acc + r * r * val }); Some(var) } /// Returns the entropy of the categorical distribution /// /// # Formula /// /// ```text /// -Σ(p_j * ln(p_j)) /// ``` /// /// where `p_j` is the `j`th probability mass, /// `Σ` is the sum from `0` to `k - 1`, /// and `k` is the number of categories fn entropy(&self) -> Option { let entr = -self .norm_pmf .iter() .filter(|&&p| p > 0.0) .map(|p| p * p.ln()) .sum::(); Some(entr) } } impl Median for Categorical { /// Returns the median of the categorical distribution /// /// # Formula /// /// ```text /// CDF^-1(0.5) /// ``` fn median(&self) -> f64 { self.inverse_cdf(0.5) as f64 } } impl Discrete for Categorical { /// Calculates the probability mass function for the categorical /// distribution at `x` /// /// # Formula /// /// ```text /// p_x /// ``` fn pmf(&self, x: u64) -> f64 { *self.norm_pmf.get(x as usize).unwrap_or(&0.0) } /// Calculates the log probability mass function for the categorical /// distribution at `x` fn ln_pmf(&self, x: u64) -> f64 { self.pmf(x).ln() } } /// Draws a sample from the categorical distribution described by `cdf` /// without doing any bounds checking #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] pub fn sample_unchecked(rng: &mut R, cdf: &[f64]) -> usize { let draw = rng.gen::() * cdf.last().unwrap(); cdf.iter().position(|val| *val >= draw).unwrap() } /// Computes the cdf from the given probability masses. Performs /// no parameter or bounds checking. pub fn prob_mass_to_cdf(prob_mass: &[f64]) -> Vec { let mut cdf = Vec::with_capacity(prob_mass.len()); prob_mass.iter().fold(0.0, |s, p| { let sum = s + p; cdf.push(sum); sum }); cdf } /// Computes the sf from the given cumulative densities. /// Performs no parameter or bounds checking. pub fn cdf_to_sf(cdf: &[f64]) -> Vec { let max = *cdf.last().unwrap(); cdf.iter().map(|x| max - x).collect() } // Returns the index of val if placed into the sorted search array. // If val is greater than all elements, it therefore would return // the length of the array (N). If val is less than all elements, it would // return 0. Otherwise val returns the index of the first element larger than // it within the search array. fn binary_index(search: &[f64], val: f64) -> usize { use std::cmp; let mut low = 0_isize; let mut high = search.len() as isize - 1; while low <= high { let mid = low + ((high - low) / 2); let el = *search.get(mid as usize).unwrap(); if el > val { high = mid - 1; } else if el < val { low = mid.saturating_add(1); } else { return mid as usize; } } cmp::min(search.len(), cmp::max(low, 0) as usize) } #[test] fn test_prob_mass_to_cdf() { let arr = [0.0, 0.5, 0.5, 3.0, 1.1]; let res = prob_mass_to_cdf(&arr); assert_eq!(res, [0.0, 0.5, 1.0, 4.0, 5.1]); } #[test] fn test_binary_index() { let arr = [0.0, 3.0, 5.0, 9.0, 10.0]; assert_eq!(0, binary_index(&arr, -1.0)); assert_eq!(2, binary_index(&arr, 5.0)); assert_eq!(3, binary_index(&arr, 5.2)); assert_eq!(5, binary_index(&arr, 10.1)); } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(prob_mass: &[f64]; Categorical; CategoricalError); #[test] fn test_create() { create_ok(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); } #[test] fn test_bad_create() { let invalid: &[(&[f64], CategoricalError)] = &[ (&[], CategoricalError::ProbMassEmpty), (&[-1.0, 1.0], CategoricalError::ProbMassHasInvalidElements), (&[0.0, 0.0, 0.0], CategoricalError::ProbMassSumZero), ]; for &(prob_mass, err) in invalid { test_create_err(prob_mass, err); } } #[test] fn test_mean() { let mean = |x: Categorical| x.mean().unwrap(); test_exact(&[0.0, 0.25, 0.5, 0.25], 2.0, mean); test_exact(&[0.0, 1.0, 2.0, 1.0], 2.0, mean); test_exact(&[0.0, 0.5, 0.5], 1.5, mean); test_exact(&[0.75, 0.25], 0.25, mean); test_exact(&[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 5.0, mean); } #[test] fn test_variance() { let variance = |x: Categorical| x.variance().unwrap(); test_exact(&[0.0, 0.25, 0.5, 0.25], 0.5, variance); test_exact(&[0.0, 1.0, 2.0, 1.0], 0.5, variance); test_exact(&[0.0, 0.5, 0.5], 0.25, variance); test_exact(&[0.75, 0.25], 0.1875, variance); test_exact(&[1.0, 0.0, 1.0], 1.0, variance); } #[test] fn test_entropy() { let entropy = |x: Categorical| x.entropy().unwrap(); test_exact(&[0.0, 1.0], 0.0, entropy); test_absolute(&[0.0, 1.0, 1.0], 2f64.ln(), 1e-15, entropy); test_absolute(&[1.0, 1.0, 1.0], 3f64.ln(), 1e-15, entropy); test_absolute(&vec![1.0; 100], 100f64.ln(), 1e-14, entropy); test_absolute(&[0.0, 0.25, 0.5, 0.25], 1.0397207708399179, 1e-15, entropy); } #[test] fn test_median() { let median = |x: Categorical| x.median(); test_exact(&[0.0, 3.0, 1.0, 1.0], 1.0, median); test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, median); } #[test] fn test_min_max() { let min = |x: Categorical| x.min(); let max = |x: Categorical| x.max(); test_exact(&[4.0, 2.5, 2.5, 1.0], 0, min); test_exact(&[4.0, 2.5, 2.5, 1.0], 3, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: Categorical| x.pmf(arg); test_exact(&[0.0, 0.25, 0.5, 0.25], 0.0, pmf(0)); test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(1)); test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(3)); } #[test] fn test_pmf_x_too_high() { let pmf = |arg: u64| move |x: Categorical| x.pmf(arg); test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, pmf(4)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: Categorical| x.ln_pmf(arg); test_exact(&[0.0, 0.25, 0.5, 0.25], 0f64.ln(), ln_pmf(0)); test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(1)); test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(3)); } #[test] fn test_ln_pmf_x_too_high() { let ln_pmf = |arg: u64| move |x: Categorical| x.ln_pmf(arg); test_exact(&[4.0, 2.5, 2.5, 1.0], f64::NEG_INFINITY, ln_pmf(4)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Categorical| x.cdf(arg); test_exact(&[0.0, 3.0, 1.0, 1.0], 3.0 / 5.0, cdf(1)); test_exact(&[1.0, 1.0, 1.0, 1.0], 0.25, cdf(0)); test_exact(&[4.0, 2.5, 2.5, 1.0], 0.4, cdf(0)); test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(3)); test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Categorical| x.sf(arg); test_exact(&[0.0, 3.0, 1.0, 1.0], 2.0 / 5.0, sf(1)); test_exact(&[1.0, 1.0, 1.0, 1.0], 0.75, sf(0)); test_exact(&[4.0, 2.5, 2.5, 1.0], 0.6, sf(0)); test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(3)); test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4)); } #[test] fn test_cdf_input_high() { let cdf = |arg: u64| move |x: Categorical| x.cdf(arg); test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4)); } #[test] fn test_sf_input_high() { let sf = |arg: u64| move |x: Categorical| x.sf(arg); test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4)); } #[test] fn test_cdf_sf_mirror() { let mass = [4.0, 2.5, 2.5, 1.0]; let cat = Categorical::new(&mass).unwrap(); assert_eq!(cat.cdf(0), 1.-cat.sf(0)); assert_eq!(cat.cdf(1), 1.-cat.sf(1)); assert_eq!(cat.cdf(2), 1.-cat.sf(2)); assert_eq!(cat.cdf(3), 1.-cat.sf(3)); } #[test] fn test_inverse_cdf() { let inverse_cdf = |arg: f64| move |x: Categorical| x.inverse_cdf(arg); test_exact(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.2)); test_exact(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.5)); test_exact(&[0.0, 3.0, 1.0, 1.0], 3, inverse_cdf(0.95)); test_exact(&[4.0, 2.5, 2.5, 1.0], 0, inverse_cdf(0.2)); test_exact(&[4.0, 2.5, 2.5, 1.0], 1, inverse_cdf(0.5)); test_exact(&[4.0, 2.5, 2.5, 1.0], 3, inverse_cdf(0.95)); } #[test] #[should_panic] fn test_inverse_cdf_input_low() { let dist = create_ok(&[4.0, 2.5, 2.5, 1.0]); dist.inverse_cdf(0.0); } #[test] #[should_panic] fn test_inverse_cdf_input_high() { let dist = create_ok(&[4.0, 2.5, 2.5, 1.0]); dist.inverse_cdf(1.0); } #[test] fn test_discrete() { test::check_discrete_distribution(&create_ok(&[1.0, 2.0, 3.0, 4.0]), 4); test::check_discrete_distribution(&create_ok(&[0.0, 1.0, 2.0, 3.0, 4.0]), 5); } } statrs-0.18.0/src/distribution/cauchy.rs000064400000000000000000000443551046102023000163750ustar 00000000000000use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; use std::f64; /// Implements the [Cauchy](https://en.wikipedia.org/wiki/Cauchy_distribution) /// distribution, also known as the Lorentz distribution. /// /// # Examples /// /// ``` /// use statrs::distribution::{Cauchy, Continuous}; /// use statrs::statistics::Mode; /// /// let n = Cauchy::new(0.0, 1.0).unwrap(); /// assert_eq!(n.mode().unwrap(), 0.0); /// assert_eq!(n.pdf(1.0), 0.1591549430918953357689); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct Cauchy { location: f64, scale: f64, } /// Represents the errors that can occur when creating a [`Cauchy`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum CauchyError { /// The location is NaN. LocationInvalid, /// The scale is NaN, zero or less than zero. ScaleInvalid, } impl std::fmt::Display for CauchyError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { CauchyError::LocationInvalid => write!(f, "Location is NaN"), CauchyError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"), } } } impl std::error::Error for CauchyError {} impl Cauchy { /// Constructs a new cauchy distribution with the given /// location and scale. /// /// # Errors /// /// Returns an error if location or scale are `NaN` or `scale <= 0.0` /// /// # Examples /// /// ``` /// use statrs::distribution::Cauchy; /// /// let mut result = Cauchy::new(0.0, 1.0); /// assert!(result.is_ok()); /// /// result = Cauchy::new(0.0, -1.0); /// assert!(result.is_err()); /// ``` pub fn new(location: f64, scale: f64) -> Result { if location.is_nan() { return Err(CauchyError::LocationInvalid); } if scale.is_nan() || scale <= 0.0 { return Err(CauchyError::ScaleInvalid); } Ok(Cauchy { location, scale }) } /// Returns the location of the cauchy distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Cauchy; /// /// let n = Cauchy::new(0.0, 1.0).unwrap(); /// assert_eq!(n.location(), 0.0); /// ``` pub fn location(&self) -> f64 { self.location } /// Returns the scale of the cauchy distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Cauchy; /// /// let n = Cauchy::new(0.0, 1.0).unwrap(); /// assert_eq!(n.scale(), 1.0); /// ``` pub fn scale(&self) -> f64 { self.scale } } impl std::fmt::Display for Cauchy { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Cauchy({}, {})", self.location, self.scale) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Cauchy { fn sample(&self, r: &mut R) -> f64 { self.location + self.scale * (f64::consts::PI * (r.gen::() - 0.5)).tan() } } impl ContinuousCDF for Cauchy { /// Calculates the cumulative distribution function for the /// cauchy distribution at `x` /// /// # Formula /// /// ```text /// (1 / π) * arctan((x - x_0) / γ) + 0.5 /// ``` /// /// where `x_0` is the location and `γ` is the scale fn cdf(&self, x: f64) -> f64 { (1.0 / f64::consts::PI) * ((x - self.location) / self.scale).atan() + 0.5 } /// Calculates the survival function for the /// cauchy distribution at `x` /// /// # Formula /// /// ```text /// (1 / π) * arctan(-(x - x_0) / γ) + 0.5 /// ``` /// /// where `x_0` is the location and `γ` is the scale. /// note that this is identical to the cdf except for /// the negative argument to the arctan function fn sf(&self, x: f64) -> f64 { (1.0 / f64::consts::PI) * ((self.location - x) / self.scale).atan() + 0.5 } /// Calculates the inverse cumulative distribution function for the /// cauchy distribution at `x` /// /// # Formula /// /// ```text /// x_0 + γ tan((x - 0.5) π) /// ``` /// /// where `x_0` is the location and `γ` is the scale fn inverse_cdf(&self, x: f64) -> f64 { if !(0.0..=1.0).contains(&x) { panic!("x must be in [0, 1]"); } else { self.location + self.scale * (f64::consts::PI * (x - 0.5)).tan() } } } impl Min for Cauchy { /// Returns the minimum value in the domain of the cauchy /// distribution representable by a double precision float /// /// # Formula /// /// ```text /// NEG_INF /// ``` fn min(&self) -> f64 { f64::NEG_INFINITY } } impl Max for Cauchy { /// Returns the maximum value in the domain of the cauchy /// distribution representable by a double precision float /// /// # Formula /// /// ```text /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY } } impl Distribution for Cauchy { /// Returns the entropy of the cauchy distribution /// /// # Formula /// /// ```text /// ln(γ) + ln(4π) /// ``` /// /// where `γ` is the scale fn entropy(&self) -> Option { Some((4.0 * f64::consts::PI * self.scale).ln()) } } impl Median for Cauchy { /// Returns the median of the cauchy distribution /// /// # Formula /// /// ```text /// x_0 /// ``` /// /// where `x_0` is the location fn median(&self) -> f64 { self.location } } impl Mode> for Cauchy { /// Returns the mode of the cauchy distribution /// /// # Formula /// /// ```text /// x_0 /// ``` /// /// where `x_0` is the location fn mode(&self) -> Option { Some(self.location) } } impl Continuous for Cauchy { /// Calculates the probability density function for the cauchy /// distribution at `x` /// /// # Formula /// /// ```text /// 1 / (πγ * (1 + ((x - x_0) / γ)^2)) /// ``` /// /// where `x_0` is the location and `γ` is the scale fn pdf(&self, x: f64) -> f64 { 1.0 / (f64::consts::PI * self.scale * (1.0 + ((x - self.location) / self.scale) * ((x - self.location) / self.scale))) } /// Calculates the log probability density function for the cauchy /// distribution at `x` /// /// # Formula /// /// ```text /// ln(1 / (πγ * (1 + ((x - x_0) / γ)^2))) /// ``` /// /// where `x_0` is the location and `γ` is the scale fn ln_pdf(&self, x: f64) -> f64 { -(f64::consts::PI * self.scale * (1.0 + ((x - self.location) / self.scale) * ((x - self.location) / self.scale))) .ln() } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(location: f64, scale: f64; Cauchy; CauchyError); #[test] fn test_create() { create_ok(0.0, 0.1); create_ok(0.0, 1.0); create_ok(0.0, 10.0); create_ok(10.0, 11.0); create_ok(-5.0, 100.0); create_ok(0.0, f64::INFINITY); } #[test] fn test_bad_create() { let invalid = [ (f64::NAN, 1.0, CauchyError::LocationInvalid), (1.0, f64::NAN, CauchyError::ScaleInvalid), (f64::NAN, f64::NAN, CauchyError::LocationInvalid), (1.0, 0.0, CauchyError::ScaleInvalid), ]; for (location, scale, err) in invalid { test_create_err(location, scale, err); } } #[test] fn test_entropy() { let entropy = |x: Cauchy| x.entropy().unwrap(); test_exact(0.0, 2.0, 3.224171427529236102395, entropy); test_exact(0.1, 4.0, 3.917318608089181411812, entropy); test_exact(1.0, 10.0, 4.833609339963336476996, entropy); test_exact(10.0, 11.0, 4.92891951976766133704, entropy); } #[test] fn test_mode() { let mode = |x: Cauchy| x.mode().unwrap(); test_exact(0.0, 2.0, 0.0, mode); test_exact(0.1, 4.0, 0.1, mode); test_exact(1.0, 10.0, 1.0, mode); test_exact(10.0, 11.0, 10.0, mode); test_exact(0.0, f64::INFINITY, 0.0, mode); } #[test] fn test_median() { let median = |x: Cauchy| x.median(); test_exact(0.0, 2.0, 0.0, median); test_exact(0.1, 4.0, 0.1, median); test_exact(1.0, 10.0, 1.0, median); test_exact(10.0, 11.0, 10.0, median); test_exact(0.0, f64::INFINITY, 0.0, median); } #[test] fn test_min_max() { let min = |x: Cauchy| x.min(); let max = |x: Cauchy| x.max(); test_exact(0.0, 1.0, f64::NEG_INFINITY, min); test_exact(0.0, 1.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Cauchy| x.pdf(arg); test_exact(0.0, 0.1, 0.001272730452554141029739, pdf(-5.0)); test_exact(0.0, 0.1, 0.03151583031522679916216, pdf(-1.0)); test_absolute(0.0, 0.1, 3.183098861837906715378, 1e-14, pdf(0.0)); test_exact(0.0, 0.1, 0.03151583031522679916216, pdf(1.0)); test_exact(0.0, 0.1, 0.001272730452554141029739, pdf(5.0)); test_absolute(0.0, 1.0, 0.01224268793014579505914, 1e-17, pdf(-5.0)); test_exact(0.0, 1.0, 0.1591549430918953357689, pdf(-1.0)); test_exact(0.0, 1.0, 0.3183098861837906715378, pdf(0.0)); test_exact(0.0, 1.0, 0.1591549430918953357689, pdf(1.0)); test_absolute(0.0, 1.0, 0.01224268793014579505914, 1e-17, pdf(5.0)); test_exact(0.0, 10.0, 0.02546479089470325372302, pdf(-5.0)); test_exact(0.0, 10.0, 0.03151583031522679916216, pdf(-1.0)); test_exact(0.0, 10.0, 0.03183098861837906715378, pdf(0.0)); test_exact(0.0, 10.0, 0.03151583031522679916216, pdf(1.0)); test_exact(0.0, 10.0, 0.02546479089470325372302, pdf(5.0)); test_exact(-5.0, 100.0, 0.003183098861837906715378, pdf(-5.0)); test_absolute(-5.0, 100.0, 0.003178014039374906864395, 1e-17, pdf(-1.0)); test_exact(-5.0, 100.0, 0.003175160959439308444267, pdf(0.0)); test_exact(-5.0, 100.0, 0.003171680810918599756255, pdf(1.0)); test_absolute(-5.0, 100.0, 0.003151583031522679916216, 1e-17, pdf(5.0)); test_exact(0.0, f64::INFINITY, 0.0, pdf(-5.0)); test_exact(0.0, f64::INFINITY, 0.0, pdf(-1.0)); test_exact(0.0, f64::INFINITY, 0.0, pdf(0.0)); test_exact(0.0, f64::INFINITY, 0.0, pdf(1.0)); test_exact(0.0, f64::INFINITY, 0.0, pdf(5.0)); test_exact(f64::INFINITY, 1.0, 0.0, pdf(-5.0)); test_exact(f64::INFINITY, 1.0, 0.0, pdf(-1.0)); test_exact(f64::INFINITY, 1.0, 0.0, pdf(0.0)); test_exact(f64::INFINITY, 1.0, 0.0, pdf(1.0)); test_exact(f64::INFINITY, 1.0, 0.0, pdf(5.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Cauchy| x.ln_pdf(arg); test_exact(0.0, 0.1, -6.666590723732973542744, ln_pdf(-5.0)); test_absolute(0.0, 0.1, -3.457265309696613941009, 1e-14, ln_pdf(-1.0)); test_exact(0.0, 0.1, 1.157855207144645509875, ln_pdf(0.0)); test_absolute(0.0, 0.1, -3.457265309696613941009, 1e-14, ln_pdf(1.0)); test_exact(0.0, 0.1, -6.666590723732973542744, ln_pdf(5.0)); test_exact(0.0, 1.0, -4.402826423870882219615, ln_pdf(-5.0)); test_absolute(0.0, 1.0, -1.837877066409345483561, 1e-15, ln_pdf(-1.0)); test_exact(0.0, 1.0, -1.144729885849400174143, ln_pdf(0.0)); test_absolute(0.0, 1.0, -1.837877066409345483561, 1e-15, ln_pdf(1.0)); test_exact(0.0, 1.0, -4.402826423870882219615, ln_pdf(5.0)); test_exact(0.0, 10.0, -3.670458530157655613928, ln_pdf(-5.0)); test_absolute(0.0, 10.0, -3.457265309696613941009, 1e-14, ln_pdf(-1.0)); test_exact(0.0, 10.0, -3.447314978843445858161, ln_pdf(0.0)); test_absolute(0.0, 10.0, -3.457265309696613941009, 1e-14, ln_pdf(1.0)); test_exact(0.0, 10.0, -3.670458530157655613928, ln_pdf(5.0)); test_exact(-5.0, 100.0, -5.749900071837491542179, ln_pdf(-5.0)); test_exact(-5.0, 100.0, -5.751498793201188569872, ln_pdf(-1.0)); test_exact(-5.0, 100.0, -5.75239695203607874116, ln_pdf(0.0)); test_exact(-5.0, 100.0, -5.75349360734762171285, ln_pdf(1.0)); test_exact(-5.0, 100.0, -5.759850402690659625027, ln_pdf(5.0)); test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0)); test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-1.0)); test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0)); test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(1.0)); test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(5.0)); test_exact(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(-5.0)); test_exact(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(-1.0)); test_exact(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(0.0)); test_exact(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(1.0)); test_exact(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(5.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Cauchy| x.cdf(arg); test_absolute(0.0, 0.1, 0.006365349100972796679298, 1e-16, cdf(-5.0)); test_absolute(0.0, 0.1, 0.03172551743055356951498, 1e-16, cdf(-1.0)); test_exact(0.0, 0.1, 0.5, cdf(0.0)); test_exact(0.0, 0.1, 0.968274482569446430485, cdf(1.0)); test_exact(0.0, 0.1, 0.9936346508990272033207, cdf(5.0)); test_absolute(0.0, 1.0, 0.06283295818900118381375, 1e-16, cdf(-5.0)); test_exact(0.0, 1.0, 0.25, cdf(-1.0)); test_exact(0.0, 1.0, 0.5, cdf(0.0)); test_exact(0.0, 1.0, 0.75, cdf(1.0)); test_exact(0.0, 1.0, 0.9371670418109988161863, cdf(5.0)); test_exact(0.0, 10.0, 0.3524163823495667258246, cdf(-5.0)); test_exact(0.0, 10.0, 0.468274482569446430485, cdf(-1.0)); test_exact(0.0, 10.0, 0.5, cdf(0.0)); test_exact(0.0, 10.0, 0.531725517430553569515, cdf(1.0)); test_exact(0.0, 10.0, 0.6475836176504332741754, cdf(5.0)); test_exact(-5.0, 100.0, 0.5, cdf(-5.0)); test_exact(-5.0, 100.0, 0.5127256113479918307809, cdf(-1.0)); test_exact(-5.0, 100.0, 0.5159022512561763751816, cdf(0.0)); test_exact(-5.0, 100.0, 0.5190757242358362337495, cdf(1.0)); test_exact(-5.0, 100.0, 0.531725517430553569515, cdf(5.0)); test_exact(0.0, f64::INFINITY, 0.5, cdf(-5.0)); test_exact(0.0, f64::INFINITY, 0.5, cdf(-1.0)); test_exact(0.0, f64::INFINITY, 0.5, cdf(0.0)); test_exact(0.0, f64::INFINITY, 0.5, cdf(1.0)); test_exact(0.0, f64::INFINITY, 0.5, cdf(5.0)); test_exact(f64::INFINITY, 1.0, 0.0, cdf(-5.0)); test_exact(f64::INFINITY, 1.0, 0.0, cdf(-1.0)); test_exact(f64::INFINITY, 1.0, 0.0, cdf(0.0)); test_exact(f64::INFINITY, 1.0, 0.0, cdf(1.0)); test_exact(f64::INFINITY, 1.0, 0.0, cdf(5.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Cauchy| x.sf(arg); test_absolute(0.0, 0.1, 0.9936346508990272, 1e-16, sf(-5.0)); test_absolute(0.0, 0.1, 0.9682744825694465, 1e-16, sf(-1.0)); test_exact(0.0, 0.1, 0.5, sf(0.0)); test_absolute(0.0, 0.1, 0.03172551743055352, 1e-16, sf(1.0)); test_exact(0.0, 0.1, 0.006365349100972806, sf(5.0)); test_absolute(0.0, 1.0, 0.9371670418109989, 1e-16, sf(-5.0)); test_exact(0.0, 1.0, 0.75, sf(-1.0)); test_exact(0.0, 1.0, 0.5, sf(0.0)); test_exact(0.0, 1.0, 0.25, sf(1.0)); test_exact(0.0, 1.0, 0.06283295818900114, sf(5.0)); test_exact(0.0, 10.0, 0.6475836176504333, sf(-5.0)); test_exact(0.0, 10.0, 0.5317255174305535, sf(-1.0)); test_exact(0.0, 10.0, 0.5, sf(0.0)); test_exact(0.0, 10.0, 0.4682744825694464, sf(1.0)); test_exact(0.0, 10.0, 0.35241638234956674, sf(5.0)); test_exact(-5.0, 100.0, 0.5, sf(-5.0)); test_exact(-5.0, 100.0, 0.4872743886520082, sf(-1.0)); test_exact(-5.0, 100.0, 0.4840977487438236, sf(0.0)); test_exact(-5.0, 100.0, 0.48092427576416374, sf(1.0)); test_exact(-5.0, 100.0, 0.4682744825694464, sf(5.0)); test_exact(0.0, f64::INFINITY, 0.5, sf(-5.0)); test_exact(0.0, f64::INFINITY, 0.5, sf(-1.0)); test_exact(0.0, f64::INFINITY, 0.5, sf(0.0)); test_exact(0.0, f64::INFINITY, 0.5, sf(1.0)); test_exact(0.0, f64::INFINITY, 0.5, sf(5.0)); test_exact(f64::INFINITY, 1.0, 1.0, sf(-5.0)); test_exact(f64::INFINITY, 1.0, 1.0, sf(-1.0)); test_exact(f64::INFINITY, 1.0, 1.0, sf(0.0)); test_exact(f64::INFINITY, 1.0, 1.0, sf(1.0)); test_exact(f64::INFINITY, 1.0, 1.0, sf(5.0)); } #[test] fn test_inverse_cdf() { let func = |arg: f64| move |x: Cauchy| x.inverse_cdf(x.cdf(arg)); test_absolute(0.0, 0.1, -5.0, 1e-10, func(-5.0)); test_absolute(0.0, 0.1, -1.0, 1e-14, func(-1.0)); test_exact(0.0, 0.1, 0.0, func(0.0)); test_absolute(0.0, 0.1, 1.0, 1e-14, func(1.0)); test_absolute(0.0, 0.1, 5.0, 1e-10, func(5.0)); test_absolute(0.0, 1.0, -5.0, 1e-14, func(-5.0)); test_absolute(0.0, 1.0, -1.0, 1e-15, func(-1.0)); test_exact(0.0, 1.0, 0.0, func(0.0)); test_absolute(0.0, 1.0, 1.0, 1e-15, func(1.0)); test_absolute(0.0, 1.0, 5.0, 1e-14, func(5.0)); test_absolute(0.0, 10.0, -5.0, 1e-14, func(-5.0)); test_absolute(0.0, 10.0, -1.0, 1e-14, func(-1.0)); test_exact(0.0, 10.0, 0.0, func(0.0)); test_absolute(0.0, 10.0, 1.0, 1e-14, func(1.0)); test_absolute(0.0, 10.0, 5.0, 1e-14, func(5.0)); test_exact(-5.0, 100.0, -5.0, func(-5.0)); test_absolute(-5.0, 100.0, -1.0, 1e-10, func(-1.0)); test_absolute(-5.0, 100.0, 0.0, 1e-14, func(0.0)); test_absolute(-5.0, 100.0, 1.0, 1e-14, func(1.0)); test_absolute(-5.0, 100.0, 5.0, 1e-10, func(5.0)); } #[test] fn test_continuous() { test::check_continuous_distribution(&create_ok(-1.2, 3.4), -1500.0, 1500.0); test::check_continuous_distribution(&create_ok(-4.5, 6.7), -5000.0, 5000.0); } } statrs-0.18.0/src/distribution/chi.rs000064400000000000000000000356151046102023000156630ustar 00000000000000use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::statistics::*; use std::f64; use std::num::NonZeroU64; /// Implements the [Chi](https://en.wikipedia.org/wiki/Chi_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{Chi, Continuous}; /// use statrs::statistics::Distribution; /// use statrs::prec; /// /// let n = Chi::new(2).unwrap(); /// assert!(prec::almost_eq(n.mean().unwrap(), 1.25331413731550025121, 1e-14)); /// assert!(prec::almost_eq(n.pdf(1.0), 0.60653065971263342360, 1e-15)); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct Chi { freedom: NonZeroU64, } /// Represents the errors that can occur when creating a [`Chi`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum ChiError { /// The degrees of freedom are zero. FreedomInvalid, } impl std::fmt::Display for ChiError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { ChiError::FreedomInvalid => { write!(f, "Degrees of freedom are zero") } } } } impl std::error::Error for ChiError {} impl Chi { /// Constructs a new chi distribution /// with `freedom` degrees of freedom /// /// # Errors /// /// Returns an error if `freedom` is equal to `0`. /// /// # Examples /// /// ``` /// use statrs::distribution::Chi; /// /// let mut result = Chi::new(2); /// assert!(result.is_ok()); /// /// result = Chi::new(0); /// assert!(result.is_err()); /// ``` pub fn new(freedom: u64) -> Result { match NonZeroU64::new(freedom) { Some(freedom) => Ok(Self { freedom }), None => Err(ChiError::FreedomInvalid), } } /// Returns the degrees of freedom of the chi distribution. /// Guaranteed to be non-zero. /// /// # Examples /// /// ``` /// use statrs::distribution::Chi; /// /// let n = Chi::new(2).unwrap(); /// assert_eq!(n.freedom(), 2); /// ``` pub fn freedom(&self) -> u64 { self.freedom.get() } } impl std::fmt::Display for Chi { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "χ_{}", self.freedom) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Chi { fn sample(&self, rng: &mut R) -> f64 { (0..self.freedom()) .fold(0.0, |acc, _| { acc + super::normal::sample_unchecked(rng, 0.0, 1.0).powf(2.0) }) .sqrt() } } impl ContinuousCDF for Chi { /// Calculates the cumulative distribution function for the chi /// distribution at `x`. /// /// # Formula /// /// ```text /// P(k / 2, x^2 / 2) /// ``` /// /// where `k` is the degrees of freedom and `P` is /// the regularized lower incomplete Gamma function fn cdf(&self, x: f64) -> f64 { if x == f64::INFINITY { 1.0 } else if x <= 0.0 { 0.0 } else { gamma::gamma_lr(self.freedom() as f64 / 2.0, x * x / 2.0) } } /// Calculates the survival function for the chi /// distribution at `x`. /// /// # Formula /// /// ```text /// P(k / 2, x^2 / 2) /// ``` /// /// where `k` is the degrees of freedom and `P` is /// the regularized upper incomplete Gamma function fn sf(&self, x: f64) -> f64 { if x == f64::INFINITY { 0.0 } else if x <= 0.0 { 1.0 } else { gamma::gamma_ur(self.freedom() as f64 / 2.0, x * x / 2.0) } } } impl Min for Chi { /// Returns the minimum value in the domain of the chi distribution /// representable by a double precision float /// /// # Formula /// /// ```text /// 0 /// ``` fn min(&self) -> f64 { 0.0 } } impl Max for Chi { /// Returns the maximum value in the domain of the chi distribution /// representable by a double precision float /// /// # Formula /// /// ```text /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY } } impl Distribution for Chi { /// Returns the mean of the chi distribution /// /// # Remarks /// /// Returns `NaN` if `freedom` is `INF` /// /// # Formula /// /// ```text /// sqrt2 * Γ((k + 1) / 2) / Γ(k / 2) /// ``` /// /// where `k` is degrees of freedom and `Γ` is the gamma function fn mean(&self) -> Option { let freedom = self.freedom() as f64; if self.freedom() > 300 { // Large n approximation based on the Stirling series approximation to the Gamma function // This avoids call the Gamma function with large arguments and returning NaN // // Relative accuracy follows O(1/n^4) and at 300 d.o.f. is better than 1e-12 // For a f32 impl the threshold should be changed to 150 Some( (freedom.sqrt()) / ((1.0 + 0.25 / freedom) * (1.0 + 0.03125 / (freedom * freedom)) * (1.0 - 0.046875 / (freedom * freedom * freedom))), ) } else { let mean = f64::consts::SQRT_2 * gamma::gamma((freedom + 1.0) / 2.0) / gamma::gamma(freedom / 2.0); Some(mean) } } /// Returns the variance of the chi distribution /// /// # Remarks /// /// Returns `NaN` if `freedom` is `INF` /// /// # Formula /// /// ```text /// k - μ^2 /// ``` /// /// where `k` is degrees of freedom and `μ` is the mean /// of the distribution fn variance(&self) -> Option { let mean = self.mean()?; Some(self.freedom() as f64 - mean * mean) } /// Returns the entropy of the chi distribution /// /// # Remarks /// /// Returns `None` if `freedom` is `INF` /// /// # Formula /// /// ```text /// ln(Γ(k / 2)) + 0.5 * (k - ln2 - (k - 1) * ψ(k / 2)) /// ``` /// /// where `k` is degrees of freedom, `Γ` is the gamma function, /// and `ψ` is the digamma function fn entropy(&self) -> Option { let freedom = self.freedom() as f64; let entr = gamma::ln_gamma(freedom / 2.0) + (freedom - (2.0f64).ln() - (freedom - 1.0) * gamma::digamma(freedom / 2.0)) / 2.0; Some(entr) } /// Returns the skewness of the chi distribution /// /// # Remarks /// /// Returns `NaN` if `freedom` is `INF` /// /// # Formula /// /// ```text /// (μ / σ^3) * (1 - 2σ^2) /// ``` /// where `μ` is the mean and `σ` the standard deviation /// of the distribution fn skewness(&self) -> Option { let sigma = self.std_dev()?; let skew = self.mean()? * (1.0 - 2.0 * sigma * sigma) / (sigma * sigma * sigma); Some(skew) } } impl Mode> for Chi { /// Returns the mode for the chi distribution /// /// # Panics /// /// If `freedom < 1.0` /// /// # Formula /// /// ```text /// sqrt(k - 1) /// ``` /// /// where `k` is the degrees of freedom fn mode(&self) -> Option { Some(((self.freedom() - 1) as f64).sqrt()) } } impl Continuous for Chi { /// Calculates the probability density function for the chi /// distribution at `x` /// /// # Formula /// /// ```text /// (2^(1 - (k / 2)) * x^(k - 1) * e^(-x^2 / 2)) / Γ(k / 2) /// ``` /// /// where `k` is the degrees of freedom and `Γ` is the gamma function fn pdf(&self, x: f64) -> f64 { if x == f64::INFINITY || x <= 0.0 { 0.0 } else if self.freedom() > 160 { self.ln_pdf(x).exp() } else { let freedom = self.freedom() as f64; (2.0f64).powf(1.0 - freedom / 2.0) * x.powf(freedom - 1.0) * (-x * x / 2.0).exp() / gamma::gamma(freedom / 2.0) } } /// Calculates the log probability density function for the chi distribution /// at `x` /// /// # Formula /// /// ```text /// ln((2^(1 - (k / 2)) * x^(k - 1) * e^(-x^2 / 2)) / Γ(k / 2)) /// ``` fn ln_pdf(&self, x: f64) -> f64 { if x == f64::INFINITY || x <= 0.0 { f64::NEG_INFINITY } else { let freedom = self.freedom() as f64; (1.0 - freedom / 2.0) * (2.0f64).ln() + ((freedom - 1.0) * x.ln()) - x * x / 2.0 - gamma::ln_gamma(freedom / 2.0) } } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(freedom: u64; Chi; ChiError); #[test] fn test_create() { create_ok(1); create_ok(3); } #[test] fn test_bad_create() { create_err(0); } #[test] fn test_mean() { let mean = |x: Chi| x.mean().unwrap(); test_absolute(1, 0.7978845608028653558799, 1e-15, mean); test_absolute(2, 1.25331413731550025121, 1e-14, mean); test_absolute(5, 2.12769216214097428235, 1e-14, mean); test_absolute(336, 18.31666925443713, 1e-12, mean); } #[test] fn test_large_dof_mean_not_nan() { for i in 1..1000 { let mean = Chi::new(i).unwrap().mean().unwrap(); assert!(!mean.is_nan(), "Chi mean for {i} dof was {mean}"); } } #[test] fn test_variance() { let variance = |x: Chi| x.variance().unwrap(); test_absolute(1, 0.3633802276324186569245, 1e-15, variance); test_absolute(2, 0.42920367320510338077, 1e-14, variance); test_absolute(3, 0.4535209105296746277, 1e-14, variance); } #[test] fn test_entropy() { let entropy = |x: Chi| x.entropy().unwrap(); test_absolute(1, 0.7257913526447274323631, 1e-15, entropy); test_absolute(2, 0.9420342421707937755946, 1e-15, entropy); test_absolute(3, 0.99615419810620560239, 1e-14, entropy); } #[test] fn test_skewness() { let skewness = |x: Chi| x.skewness().unwrap(); test_absolute(1, 0.995271746431156042444, 1e-14, skewness); test_absolute(3, 0.485692828049590809, 1e-12, skewness); } #[test] fn test_mode() { let mode = |x: Chi| x.mode().unwrap(); test_exact(1, 0.0, mode); test_exact(2, 1.0, mode); test_exact(3, f64::consts::SQRT_2, mode); } #[test] fn test_min_max() { let min = |x: Chi| x.min(); let max = |x: Chi| x.max(); test_exact(1, 0.0, min); test_exact(2, 0.0, min); test_exact(2, 0.0, min); test_exact(3, 0.0, min); test_exact(1, f64::INFINITY, max); test_exact(2, f64::INFINITY, max); test_exact(2, f64::INFINITY, max); test_exact(3, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Chi| x.pdf(arg); test_exact(1, 0.0, pdf(0.0)); test_absolute(1, 0.79390509495402353102, 1e-15, pdf(0.1)); test_absolute(1, 0.48394144903828669960, 1e-15, pdf(1.0)); test_absolute(1, 2.1539520085086552718e-7, 1e-22, pdf(5.5)); test_exact(1, 0.0, pdf(f64::INFINITY)); test_exact(2, 0.0, pdf(0.0)); test_absolute(2, 0.099501247919268231335, 1e-16, pdf(0.1)); test_absolute(2, 0.60653065971263342360, 1e-15, pdf(1.0)); test_absolute(2, 1.4847681768496578863e-6, 1e-21, pdf(5.5)); test_exact(2, 0.0, pdf(f64::INFINITY)); test_exact(2, 0.0, pdf(0.0)); test_exact(2, 0.0, pdf(f64::INFINITY)); test_absolute(170, 0.5644678498668440878, 1e-13, pdf(13.0)); } #[test] fn test_neg_pdf() { let pdf = |arg: f64| move |x: Chi| x.pdf(arg); test_exact(1, 0.0, pdf(-1.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Chi| x.ln_pdf(arg); test_exact(1, f64::NEG_INFINITY, ln_pdf(0.0)); test_absolute(1, -0.23079135264472743236, 1e-15, ln_pdf(0.1)); test_absolute(1, -0.72579135264472743236, 1e-15, ln_pdf(1.0)); test_absolute(1, -15.350791352644727432, 1e-14, ln_pdf(5.5)); test_exact(1, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); test_exact(2, f64::NEG_INFINITY, ln_pdf(0.0)); test_absolute(2, -2.3075850929940456840, 1e-15, ln_pdf(0.1)); test_absolute(2, -0.5, 1e-15, ln_pdf(1.0)); test_absolute(2, -13.420251907761574765, 1e-15, ln_pdf(5.5)); test_exact(2, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); test_exact(2, f64::NEG_INFINITY, ln_pdf(0.0)); test_exact(2, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); test_absolute(170, -0.57187185030600516424237, 1e-13, ln_pdf(13.0)); } #[test] fn test_neg_ln_pdf() { let ln_pdf = |arg: f64| move |x: Chi| x.ln_pdf(arg); test_exact(1, f64::NEG_INFINITY, ln_pdf(-1.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Chi| x.cdf(arg); test_exact(1, 0.0, cdf(0.0)); test_absolute(1, 0.079655674554057962931, 1e-16, cdf(0.1)); test_absolute(1, 0.68268949213708589717, 1e-15, cdf(1.0)); test_exact(1, 0.99999996202087506822, cdf(5.5)); test_exact(1, 1.0, cdf(f64::INFINITY)); test_exact(2, 0.0, cdf(0.0)); test_absolute(2, 0.0049875208073176866474, 1e-17, cdf(0.1)); test_exact(2, 1.0, cdf(f64::INFINITY)); test_exact(2, 0.0, cdf(0.0)); test_exact(2, 1.0, cdf(f64::INFINITY)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Chi| x.sf(arg); test_exact(1, 1.0, sf(0.0)); test_absolute(1, 0.920344325445942, 1e-16, sf(0.1)); test_absolute(1, 0.31731050786291404, 1e-15, sf(1.0)); test_absolute(1, 3.797912493177544e-8, 1e-15, sf(5.5)); test_exact(1, 0.0, sf(f64::INFINITY)); test_exact(2, 1.0, sf(0.0)); test_absolute(2, 0.9950124791926823, 1e-17, sf(0.1)); test_absolute(2, 0.6065306597126333, 1e-15, sf(1.0)); test_absolute(2, 2.699578503363014e-7, 1e-15, sf(5.5)); test_exact(2, 0.0, sf(f64::INFINITY)); test_exact(2, 1.0, sf(0.0)); test_exact(2, 0.0, sf(f64::INFINITY)); } #[test] fn test_neg_cdf() { let cdf = |arg: f64| move |x: Chi| x.cdf(arg); test_exact(1, 0.0, cdf(-1.0)); } #[test] fn test_neg_sf() { let sf = |arg: f64| move |x: Chi| x.sf(arg); test_exact(1, 1.0, sf(-1.0)); } #[test] fn test_continuous() { test::check_continuous_distribution(&create_ok(1), 0.0, 10.0); test::check_continuous_distribution(&create_ok(2), 0.0, 10.0); test::check_continuous_distribution(&create_ok(5), 0.0, 10.0); } } statrs-0.18.0/src/distribution/chi_squared.rs000064400000000000000000000205371046102023000174040ustar 00000000000000use crate::distribution::{Continuous, ContinuousCDF, Gamma, GammaError}; use crate::statistics::*; use std::f64; /// Implements the /// [Chi-squared](https://en.wikipedia.org/wiki/Chi-squared_distribution) /// distribution which is a special case of the /// [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution) distribution /// (referenced [Here](./struct.Gamma.html)) /// /// # Examples /// /// ``` /// use statrs::distribution::{ChiSquared, Continuous}; /// use statrs::statistics::Distribution; /// use statrs::prec; /// /// let n = ChiSquared::new(3.0).unwrap(); /// assert_eq!(n.mean().unwrap(), 3.0); /// assert!(prec::almost_eq(n.pdf(4.0), 0.107981933026376103901, 1e-15)); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct ChiSquared { freedom: f64, g: Gamma, } impl ChiSquared { /// Constructs a new chi-squared distribution with `freedom` /// degrees of freedom. This is equivalent to a Gamma distribution /// with a shape of `freedom / 2.0` and a rate of `0.5`. /// /// # Errors /// /// Returns an error if `freedom` is `NaN` or less than /// or equal to `0.0` /// /// # Examples /// /// ``` /// use statrs::distribution::ChiSquared; /// /// let mut result = ChiSquared::new(3.0); /// assert!(result.is_ok()); /// /// result = ChiSquared::new(0.0); /// assert!(result.is_err()); /// ``` pub fn new(freedom: f64) -> Result { Gamma::new(freedom / 2.0, 0.5).map(|g| ChiSquared { freedom, g }) } /// Returns the degrees of freedom of the chi-squared /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::ChiSquared; /// /// let n = ChiSquared::new(3.0).unwrap(); /// assert_eq!(n.freedom(), 3.0); /// ``` pub fn freedom(&self) -> f64 { self.freedom } /// Returns the shape of the underlying Gamma distribution /// /// # Examples /// /// ``` /// use statrs::distribution::ChiSquared; /// /// let n = ChiSquared::new(3.0).unwrap(); /// assert_eq!(n.shape(), 3.0 / 2.0); /// ``` pub fn shape(&self) -> f64 { self.g.shape() } /// Returns the rate of the underlying Gamma distribution /// /// # Examples /// /// ``` /// use statrs::distribution::ChiSquared; /// /// let n = ChiSquared::new(3.0).unwrap(); /// assert_eq!(n.rate(), 0.5); /// ``` pub fn rate(&self) -> f64 { self.g.rate() } } impl std::fmt::Display for ChiSquared { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "χ^2_{}", self.freedom) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for ChiSquared { fn sample(&self, r: &mut R) -> f64 { ::rand::distributions::Distribution::sample(&self.g, r) } } impl ContinuousCDF for ChiSquared { /// Calculates the cumulative distribution function for the /// chi-squared distribution at `x` /// /// # Formula /// /// ```text /// (1 / Γ(k / 2)) * γ(k / 2, x / 2) /// ``` /// /// where `k` is the degrees of freedom, `Γ` is the gamma function, /// and `γ` is the lower incomplete gamma function fn cdf(&self, x: f64) -> f64 { self.g.cdf(x) } /// Calculates the cumulative distribution function for the /// chi-squared distribution at `x` /// /// # Formula /// /// ```text /// (1 / Γ(k / 2)) * γ(k / 2, x / 2) /// ``` /// /// where `k` is the degrees of freedom, `Γ` is the gamma function, /// and `γ` is the upper incomplete gamma function fn sf(&self, x: f64) -> f64 { self.g.sf(x) } /// Calculates the inverse cumulative distribution function for the /// chi-squared distribution at `x` /// /// # Formula /// /// ```text /// γ^{-1}(k / 2, x * Γ(k / 2) / 2) /// ``` /// /// where `k` is the degrees of freedom, `Γ` is the gamma function, /// and `γ` is the lower incomplete gamma function fn inverse_cdf(&self, p: f64) -> f64 { self.g.inverse_cdf(p) } } impl Min for ChiSquared { /// Returns the minimum value in the domain of the /// chi-squared distribution representable by a double precision /// float /// /// # Formula /// /// ```text /// 0 /// ``` fn min(&self) -> f64 { 0.0 } } impl Max for ChiSquared { /// Returns the maximum value in the domain of the /// chi-squared distribution representable by a double precision /// float /// /// # Formula /// /// ```text /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY } } impl Distribution for ChiSquared { /// Returns the mean of the chi-squared distribution /// /// # Formula /// /// ```text /// k /// ``` /// /// where `k` is the degrees of freedom fn mean(&self) -> Option { self.g.mean() } /// Returns the variance of the chi-squared distribution /// /// # Formula /// /// ```text /// 2k /// ``` /// /// where `k` is the degrees of freedom fn variance(&self) -> Option { self.g.variance() } /// Returns the entropy of the chi-squared distribution /// /// # Formula /// /// ```text /// (k / 2) + ln(2 * Γ(k / 2)) + (1 - (k / 2)) * ψ(k / 2) /// ``` /// /// where `k` is the degrees of freedom, `Γ` is the gamma function, /// and `ψ` is the digamma function fn entropy(&self) -> Option { self.g.entropy() } /// Returns the skewness of the chi-squared distribution /// /// # Formula /// /// ```text /// sqrt(8 / k) /// ``` /// /// where `k` is the degrees of freedom fn skewness(&self) -> Option { self.g.skewness() } } impl Median for ChiSquared { /// Returns the median of the chi-squared distribution /// /// # Formula /// /// ```text /// k * (1 - (2 / 9k))^3 /// ``` fn median(&self) -> f64 { if self.freedom < 1.0 { // if k is small, calculate using expansion of formula self.freedom - 2.0 / 3.0 + 12.0 / (81.0 * self.freedom) - 8.0 / (729.0 * self.freedom * self.freedom) } else { // if k is large enough, median heads toward k - 2/3 self.freedom - 2.0 / 3.0 } } } impl Mode> for ChiSquared { /// Returns the mode of the chi-squared distribution /// /// # Formula /// /// ```text /// k - 2 /// ``` /// /// where `k` is the degrees of freedom fn mode(&self) -> Option { self.g.mode() } } impl Continuous for ChiSquared { /// Calculates the probability density function for the chi-squared /// distribution at `x` /// /// # Formula /// /// ```text /// 1 / (2^(k / 2) * Γ(k / 2)) * x^((k / 2) - 1) * e^(-x / 2) /// ``` /// /// where `k` is the degrees of freedom and `Γ` is the gamma function fn pdf(&self, x: f64) -> f64 { self.g.pdf(x) } /// Calculates the log probability density function for the chi-squared /// distribution at `x` /// /// # Formula /// /// ```text /// ln(1 / (2^(k / 2) * Γ(k / 2)) * x^((k / 2) - 1) * e^(-x / 2)) /// ``` fn ln_pdf(&self, x: f64) -> f64 { self.g.ln_pdf(x) } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(freedom: f64; ChiSquared; GammaError); #[test] fn test_median() { let median = |x: ChiSquared| x.median(); test_absolute(0.5, 0.0857338820301783264746, 1e-16, median); test_exact(1.0, 1.0 - 2.0 / 3.0, median); test_exact(2.0, 2.0 - 2.0 / 3.0, median); test_exact(2.5, 2.5 - 2.0 / 3.0, median); test_exact(3.0, 3.0 - 2.0 / 3.0, median); } #[test] fn test_continuous() { // TODO: figure out why this test fails: //test::check_continuous_distribution(&create_ok(1.0), 0.0, 10.0); test::check_continuous_distribution(&create_ok(2.0), 0.0, 10.0); test::check_continuous_distribution(&create_ok(5.0), 0.0, 50.0); } } statrs-0.18.0/src/distribution/dirac.rs000064400000000000000000000160061046102023000161730ustar 00000000000000use crate::distribution::ContinuousCDF; use crate::statistics::*; /// Implements the [Dirac Delta](https://en.wikipedia.org/wiki/Dirac_delta_function#As_a_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{Dirac, Continuous}; /// use statrs::statistics::Distribution; /// /// let n = Dirac::new(3.0).unwrap(); /// assert_eq!(n.mean().unwrap(), 3.0); /// ``` #[derive(Debug, Copy, Clone, PartialEq)] pub struct Dirac(f64); /// Represents the errors that can occur when creating a [`Dirac`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum DiracError { /// The value v is NaN. ValueInvalid, } impl std::fmt::Display for DiracError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { DiracError::ValueInvalid => write!(f, "Value v is NaN"), } } } impl std::error::Error for DiracError {} impl Dirac { /// Constructs a new dirac distribution function at value `v`. /// /// # Errors /// /// Returns an error if `v` is not-a-number. /// /// # Examples /// /// ``` /// use statrs::distribution::Dirac; /// /// let mut result = Dirac::new(0.0); /// assert!(result.is_ok()); /// /// result = Dirac::new(f64::NAN); /// assert!(result.is_err()); /// ``` pub fn new(v: f64) -> Result { if v.is_nan() { Err(DiracError::ValueInvalid) } else { Ok(Dirac(v)) } } } impl std::fmt::Display for Dirac { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "δ_{}", self.0) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Dirac { fn sample(&self, _: &mut R) -> f64 { self.0 } } impl ContinuousCDF for Dirac { /// Calculates the cumulative distribution function for the /// dirac distribution at `x` /// /// Where the value is 1 if x > `v`, 0 otherwise. fn cdf(&self, x: f64) -> f64 { if x < self.0 { 0.0 } else { 1.0 } } /// Calculates the survival function for the /// dirac distribution at `x` /// /// Where the value is 0 if x > `v`, 1 otherwise. fn sf(&self, x: f64) -> f64 { if x < self.0 { 1.0 } else { 0.0 } } } impl Min for Dirac { /// Returns the minimum value in the domain of the /// dirac distribution representable by a double precision float /// /// # Formula /// /// ```text /// v /// ``` fn min(&self) -> f64 { self.0 } } impl Max for Dirac { /// Returns the maximum value in the domain of the /// dirac distribution representable by a double precision float /// /// # Formula /// /// ```text /// v /// ``` fn max(&self) -> f64 { self.0 } } impl Distribution for Dirac { /// Returns the mean of the dirac distribution /// /// # Remarks /// /// Since the only value that can be produced by this distribution is `v` with probability /// 1, it is just `v`. fn mean(&self) -> Option { Some(self.0) } /// Returns the variance of the dirac distribution /// /// # Formula /// /// ```text /// 0 /// ``` /// /// Since only one value can be produced there is no variance. fn variance(&self) -> Option { Some(0.0) } /// Returns the entropy of the dirac distribution /// /// # Formula /// /// ```text /// 0 /// ``` /// /// Since this distribution has full certainty, it encodes no information fn entropy(&self) -> Option { Some(0.0) } /// Returns the skewness of the dirac distribution /// /// # Formula /// /// ```text /// 0 /// ``` fn skewness(&self) -> Option { Some(0.0) } } impl Median for Dirac { /// Returns the median of the dirac distribution /// /// # Formula /// /// ```text /// v /// ``` /// /// where `v` is the point of the dirac distribution fn median(&self) -> f64 { self.0 } } impl Mode> for Dirac { /// Returns the mode of the dirac distribution /// /// # Formula /// /// ```text /// v /// ``` /// /// where `v` is the point of the dirac distribution fn mode(&self) -> Option { Some(self.0) } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::testing_boiler; testing_boiler!(v: f64; Dirac; DiracError); #[test] fn test_create() { create_ok(10.0); create_ok(-5.0); create_ok(10.0); create_ok(100.0); create_ok(f64::INFINITY); } #[test] fn test_bad_create() { create_err(f64::NAN); } #[test] fn test_variance() { let variance = |x: Dirac| x.variance().unwrap(); test_exact(0.0, 0.0, variance); test_exact(-5.0, 0.0, variance); test_exact(f64::INFINITY, 0.0, variance); } #[test] fn test_entropy() { let entropy = |x: Dirac| x.entropy().unwrap(); test_exact(0.0, 0.0, entropy); test_exact(f64::INFINITY, 0.0, entropy); } #[test] fn test_skewness() { let skewness = |x: Dirac| x.skewness().unwrap(); test_exact(0.0, 0.0, skewness); test_exact(4.0, 0.0, skewness); test_exact(0.3, 0.0, skewness); test_exact(f64::INFINITY, 0.0, skewness); } #[test] fn test_mode() { let mode = |x: Dirac| x.mode().unwrap(); test_exact(0.0, 0.0, mode); test_exact(3.0, 3.0, mode); test_exact(f64::INFINITY, f64::INFINITY, mode); } #[test] fn test_median() { let median = |x: Dirac| x.median(); test_exact(0.0, 0.0, median); test_exact(3.0, 3.0, median); test_exact(f64::INFINITY, f64::INFINITY, median); } #[test] fn test_min_max() { let min = |x: Dirac| x.min(); let max = |x: Dirac| x.max(); test_exact(0.0, 0.0, min); test_exact(3.0, 3.0, min); test_exact(f64::INFINITY, f64::INFINITY, min); test_exact(0.0, 0.0, max); test_exact(3.0, 3.0, max); test_exact(f64::NEG_INFINITY, f64::NEG_INFINITY, max); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Dirac| x.cdf(arg); test_exact(0.0, 1.0, cdf(0.0)); test_exact(3.0, 1.0, cdf(3.0)); test_exact(f64::INFINITY, 0.0, cdf(1.0)); test_exact(f64::INFINITY, 1.0, cdf(f64::INFINITY)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Dirac| x.sf(arg); test_exact(0.0, 0.0, sf(0.0)); test_exact(3.0, 0.0, sf(3.0)); test_exact(f64::INFINITY, 1.0, sf(1.0)); test_exact(f64::INFINITY, 0.0, sf(f64::INFINITY)); } } statrs-0.18.0/src/distribution/dirichlet.rs000064400000000000000000000416641046102023000170700ustar 00000000000000use crate::distribution::Continuous; use crate::function::gamma; use crate::prec; use crate::statistics::*; use nalgebra::{Dim, Dyn, OMatrix, OVector}; use std::f64; /// Implements the /// [Dirichlet](https://en.wikipedia.org/wiki/Dirichlet_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{Dirichlet, Continuous}; /// use statrs::statistics::Distribution; /// use nalgebra::DVector; /// use statrs::statistics::MeanN; /// /// let n = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap(); /// assert_eq!(n.mean().unwrap(), DVector::from_vec(vec![1.0 / 6.0, 1.0 / 3.0, 0.5])); /// assert_eq!(n.pdf(&DVector::from_vec(vec![0.33333, 0.33333, 0.33333])), 2.222155556222205); /// ``` #[derive(Clone, PartialEq, Debug)] pub struct Dirichlet where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { alpha: OVector, } /// Represents the errors that can occur when creating a [`Dirichlet`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum DirichletError { /// Alpha contains less than two elements. AlphaTooShort, /// Alpha contains an element that is NaN, infinite, zero or less than zero. AlphaHasInvalidElements, } impl std::fmt::Display for DirichletError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { DirichletError::AlphaTooShort => write!(f, "Alpha contains less than two elements"), DirichletError::AlphaHasInvalidElements => write!( f, "Alpha contains an element that is NaN, infinite, zero or less than zero" ), } } } impl std::error::Error for DirichletError {} impl Dirichlet { /// Constructs a new dirichlet distribution with the given /// concentration parameters (alpha) /// /// # Errors /// /// Returns an error if any element `x` in alpha exist /// such that `x < = 0.0` or `x` is `NaN`, or if the length of alpha is /// less than 2 /// /// # Examples /// /// ``` /// use statrs::distribution::Dirichlet; /// use nalgebra::DVector; /// /// let alpha_ok = vec![1.0, 2.0, 3.0]; /// let mut result = Dirichlet::new(alpha_ok); /// assert!(result.is_ok()); /// /// let alpha_err = vec![0.0]; /// result = Dirichlet::new(alpha_err); /// assert!(result.is_err()); /// ``` pub fn new(alpha: Vec) -> Result { Self::new_from_nalgebra(alpha.into()) } /// Constructs a new dirichlet distribution with the given /// concentration parameter (alpha) repeated `n` times /// /// # Errors /// /// Returns an error if `alpha < = 0.0` or `alpha` is `NaN`, /// or if `n < 2` /// /// # Examples /// /// ``` /// use statrs::distribution::Dirichlet; /// /// let mut result = Dirichlet::new_with_param(1.0, 3); /// assert!(result.is_ok()); /// /// result = Dirichlet::new_with_param(0.0, 1); /// assert!(result.is_err()); /// ``` pub fn new_with_param(alpha: f64, n: usize) -> Result { Self::new(vec![alpha; n]) } } impl Dirichlet where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { /// Constructs a new distribution with the given vector for `alpha` /// Does not clone the vector it takes ownership of /// /// # Error /// /// Returns an error if vector has length less than 2 or if any element /// of alpha is NOT finite positive pub fn new_from_nalgebra(alpha: OVector) -> Result { if alpha.len() < 2 { return Err(DirichletError::AlphaTooShort); } if alpha.iter().any(|&a_i| !a_i.is_finite() || a_i <= 0.0) { return Err(DirichletError::AlphaHasInvalidElements); } Ok(Self { alpha }) } /// Returns the concentration parameters of /// the dirichlet distribution as a slice /// /// # Examples /// /// ``` /// use statrs::distribution::Dirichlet; /// use nalgebra::DVector; /// /// let n = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap(); /// assert_eq!(n.alpha(), &DVector::from_vec(vec![1.0, 2.0, 3.0])); /// ``` pub fn alpha(&self) -> &nalgebra::OVector { &self.alpha } fn alpha_sum(&self) -> f64 { self.alpha.sum() } /// Returns the entropy of the dirichlet distribution /// /// # Formula /// /// ```text /// ln(B(α)) - (K - α_0)ψ(α_0) - Σ((α_i - 1)ψ(α_i)) /// ``` /// /// where /// /// ```text /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i)) /// ``` /// /// `α_0` is the sum of all concentration parameters, /// `K` is the number of concentration parameters, `ψ` is the digamma /// function, `α_i` /// is the `i`th concentration parameter, and `Σ` is the sum from `1` to `K` pub fn entropy(&self) -> Option { let sum = self.alpha_sum(); let num = self.alpha.iter().fold(0.0, |acc, &x| { acc + gamma::ln_gamma(x) + (x - 1.0) * gamma::digamma(x) }); let entr = -gamma::ln_gamma(sum) + (sum - self.alpha.len() as f64) * gamma::digamma(sum) - num; Some(entr) } } impl std::fmt::Display for Dirichlet where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Dir({}, {})", self.alpha.len(), &self.alpha) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution> for Dirichlet where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { fn sample(&self, rng: &mut R) -> OVector { let mut sum = 0.0; OVector::from_iterator_generic( self.alpha.shape_generic().0, nalgebra::Const::<1>, self.alpha.iter().map(|&a| { let sample = super::gamma::sample_unchecked(rng, a, 1.0); sum += sample; sample }), ) } } impl MeanN> for Dirichlet where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { /// Returns the means of the dirichlet distribution /// /// # Formula /// /// ```text /// α_i / α_0 /// ``` /// /// for the `i`th element where `α_i` is the `i`th concentration parameter /// and `α_0` is the sum of all concentration parameters fn mean(&self) -> Option> { let sum = self.alpha_sum(); Some(self.alpha.map(|x| x / sum)) } } impl VarianceN> for Dirichlet where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Returns the variances of the dirichlet distribution /// /// # Formula /// /// ```text /// (α_i * (α_0 - α_i)) / (α_0^2 * (α_0 + 1)) /// ``` /// /// for the `i`th element where `α_i` is the `i`th concentration parameter /// and `α_0` is the sum of all concentration parameters fn variance(&self) -> Option> { let sum = self.alpha_sum(); let normalizing = sum * sum * (sum + 1.0); let mut cov = OMatrix::from_diagonal(&self.alpha.map(|x| x * (sum - x) / normalizing)); let mut offdiag = |x: usize, y: usize| { let elt = -self.alpha[x] * self.alpha[y] / normalizing; cov[(x, y)] = elt; cov[(y, x)] = elt; }; for i in 0..self.alpha.len() { for j in 0..i { offdiag(i, j); } } Some(cov) } } impl Continuous<&OVector, f64> for Dirichlet where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, D>, { /// Calculates the probabiliy density function for the dirichlet /// distribution /// with given `x`'s corresponding to the concentration parameters for this /// distribution /// /// # Panics /// /// If any element in `x` is not in `(0, 1)`, the elements in `x` do not /// sum to /// `1` with a tolerance of `1e-4`, or if `x` is not the same length as /// the vector of /// concentration parameters for this distribution /// /// # Formula /// /// ```text /// (1 / B(α)) * Π(x_i^(α_i - 1)) /// ``` /// /// where /// /// ```text /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i)) /// ``` /// /// `α` is the vector of concentration parameters, `α_i` is the `i`th /// concentration parameter, `x_i` is the `i`th argument corresponding to /// the `i`th concentration parameter, `Γ` is the gamma function, /// `Π` is the product from `1` to `K`, `Σ` is the sum from `1` to `K`, /// and `K` is the number of concentration parameters fn pdf(&self, x: &OVector) -> f64 { self.ln_pdf(x).exp() } /// Calculates the log probabiliy density function for the dirichlet /// distribution /// with given `x`'s corresponding to the concentration parameters for this /// distribution /// /// # Panics /// /// If any element in `x` is not in `(0, 1)`, the elements in `x` do not /// sum to /// `1` with a tolerance of `1e-4`, or if `x` is not the same length as /// the vector of /// concentration parameters for this distribution /// /// # Formula /// /// ```text /// ln((1 / B(α)) * Π(x_i^(α_i - 1))) /// ``` /// /// where /// /// ```text /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i)) /// ``` /// /// `α` is the vector of concentration parameters, `α_i` is the `i`th /// concentration parameter, `x_i` is the `i`th argument corresponding to /// the `i`th concentration parameter, `Γ` is the gamma function, /// `Π` is the product from `1` to `K`, `Σ` is the sum from `1` to `K`, /// and `K` is the number of concentration parameters fn ln_pdf(&self, x: &OVector) -> f64 { if self.alpha.len() != x.len() { panic!("Arguments must have correct dimensions."); } let mut term = 0.0; let mut sum_x = 0.0; let mut sum_alpha = 0.0; for (&x_i, &alpha_i) in x.iter().zip(self.alpha.iter()) { assert!(0.0 < x_i && x_i < 1.0, "Arguments must be in (0, 1)"); term += (alpha_i - 1.0) * x_i.ln() - gamma::ln_gamma(alpha_i); sum_x += x_i; sum_alpha += alpha_i; } assert!( prec::almost_eq(sum_x, 1.0, 1e-4), "Arguments must sum up to 1" ); term + gamma::ln_gamma(sum_alpha) } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use std::fmt::{Debug, Display}; use nalgebra::{dmatrix, dvector, vector, DimMin, OVector}; fn try_create(alpha: OVector) -> Dirichlet where D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { let mvn = Dirichlet::new_from_nalgebra(alpha); assert!(mvn.is_ok()); mvn.unwrap() } fn bad_create_case(alpha: OVector) where D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { let dd = Dirichlet::new_from_nalgebra(alpha); assert!(dd.is_err()); } fn test_almost(alpha: OVector, expected: T, acc: f64, eval: F) where T: Debug + Display + approx::RelativeEq, F: FnOnce(Dirichlet) -> T, D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { let dd = try_create(alpha); let x = eval(dd); assert_relative_eq!(expected, x, epsilon = acc); } #[test] fn test_create() { try_create(vector![1.0, 2.0]); try_create(vector![1.0, 2.0, 3.0, 4.0, 5.0]); assert!(Dirichlet::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]).is_ok()); // try_create(vector![0.001, f64::INFINITY, 3756.0]); // moved to bad case as this is degenerate } #[test] fn test_bad_create() { bad_create_case(vector![1.0, f64::NAN]); bad_create_case(vector![1.0, 0.0]); bad_create_case(vector![1.0, f64::INFINITY]); bad_create_case(vector![-1.0, 2.0]); bad_create_case(vector![1.0]); bad_create_case(vector![1.0, 2.0, 0.0, 4.0, 5.0]); bad_create_case(vector![1.0, f64::NAN, 3.0, 4.0, 5.0]); bad_create_case(vector![0.0, 0.0, 0.0]); bad_create_case(vector![0.001, f64::INFINITY, 3756.0]); // moved to bad case as this is degenerate } #[test] fn test_mean() { let mean = |dd: Dirichlet<_>| dd.mean().unwrap(); test_almost(vec![0.5; 5].into(), vec![1.0 / 5.0; 5].into(), 1e-15, mean); test_almost( dvector![0.1, 0.2, 0.3, 0.4], dvector![0.1, 0.2, 0.3, 0.4], 1e-15, mean, ); test_almost( dvector![1.0, 2.0, 3.0, 4.0], dvector![0.1, 0.2, 0.3, 0.4], 1e-15, mean, ); } #[test] fn test_variance() { let variance = |dd: Dirichlet<_>| dd.variance().unwrap(); test_almost( dvector![1.0, 2.0], dmatrix![0.055555555555555, -0.055555555555555; -0.055555555555555, 0.055555555555555; ], 1e-15, variance, ); test_almost( dvector![0.1, 0.2, 0.3, 0.4], dmatrix![0.045, -0.010, -0.015, -0.020; -0.010, 0.080, -0.030, -0.040; -0.015, -0.030, 0.105, -0.060; -0.020, -0.040, -0.060, 0.120; ], 1e-15, variance, ); } // #[test] // fn test_std_dev() { // let alpha = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; // let sum = alpha.iter().fold(0.0, |acc, x| acc + x); // let n = Dirichlet::new(&alpha).unwrap(); // let res = n.std_dev(); // for i in 1..11 { // let f = i as f64; // assert_almost_eq!(res[i-1], (f * (sum - f) / (sum * sum * (sum + 1.0))).sqrt(), 1e-15); // } // } #[test] fn test_entropy() { let entropy = |x: Dirichlet<_>| x.entropy().unwrap(); test_almost( vector![0.1, 0.3, 0.5, 0.8], -17.46469081094079, 1e-30, entropy, ); test_almost( vector![0.1, 0.2, 0.3, 0.4], -21.53881433791513, 1e-30, entropy, ); } #[test] fn test_pdf() { let pdf = |arg| move |x: Dirichlet<_>| x.pdf(&arg); test_almost( vector![0.1, 0.3, 0.5, 0.8], 18.77225681167061, 1e-12, pdf([0.01, 0.03, 0.5, 0.46].into()), ); test_almost( vector![0.1, 0.3, 0.5, 0.8], 0.8314656481199253, 1e-14, pdf([0.1, 0.2, 0.3, 0.4].into()), ); } #[test] fn test_ln_pdf() { let ln_pdf = |arg| move |x: Dirichlet<_>| x.ln_pdf(&arg); test_almost( vector![0.1, 0.3, 0.5, 0.8], 18.77225681167061_f64.ln(), 1e-12, ln_pdf([0.01, 0.03, 0.5, 0.46].into()), ); test_almost( vector![0.1, 0.3, 0.5, 0.8], 0.8314656481199253_f64.ln(), 1e-14, ln_pdf([0.1, 0.2, 0.3, 0.4].into()), ); } #[test] #[should_panic] fn test_pdf_bad_input_length() { let n = try_create(dvector![0.1, 0.3, 0.5, 0.8]); n.pdf(&dvector![0.5]); } #[test] #[should_panic] fn test_pdf_bad_input_range() { let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); n.pdf(&vector![1.5, 0.0, 0.0, 0.0]); } #[test] #[should_panic] fn test_pdf_bad_input_sum() { let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); n.pdf(&vector![0.5, 0.25, 0.8, 0.9]); } #[test] #[should_panic] fn test_ln_pdf_bad_input_length() { let n = try_create(dvector![0.1, 0.3, 0.5, 0.8]); n.ln_pdf(&dvector![0.5]); } #[test] #[should_panic] fn test_ln_pdf_bad_input_range() { let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); n.ln_pdf(&vector![1.5, 0.0, 0.0, 0.0]); } #[test] #[should_panic] fn test_ln_pdf_bad_input_sum() { let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); n.ln_pdf(&vector![0.5, 0.25, 0.8, 0.9]); } #[test] fn test_error_is_sync_send() { fn assert_sync_send() {} assert_sync_send::(); } } statrs-0.18.0/src/distribution/discrete_uniform.rs000064400000000000000000000261531046102023000204560ustar 00000000000000use crate::distribution::{Discrete, DiscreteCDF}; use crate::statistics::*; /// Implements the [Discrete /// Uniform](https://en.wikipedia.org/wiki/Discrete_uniform_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{DiscreteUniform, Discrete}; /// use statrs::statistics::Distribution; /// /// let n = DiscreteUniform::new(0, 5).unwrap(); /// assert_eq!(n.mean().unwrap(), 2.5); /// assert_eq!(n.pmf(3), 1.0 / 6.0); /// ``` #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub struct DiscreteUniform { min: i64, max: i64, } /// Represents the errors that can occur when creating a [`DiscreteUniform`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum DiscreteUniformError { /// The maximum is less than the minimum. MinMaxInvalid, } impl std::fmt::Display for DiscreteUniformError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { DiscreteUniformError::MinMaxInvalid => write!(f, "Maximum is less than minimum"), } } } impl std::error::Error for DiscreteUniformError {} impl DiscreteUniform { /// Constructs a new discrete uniform distribution with a minimum value /// of `min` and a maximum value of `max`. /// /// # Errors /// /// Returns an error if `max < min` /// /// # Examples /// /// ``` /// use statrs::distribution::DiscreteUniform; /// /// let mut result = DiscreteUniform::new(0, 5); /// assert!(result.is_ok()); /// /// result = DiscreteUniform::new(5, 0); /// assert!(result.is_err()); /// ``` pub fn new(min: i64, max: i64) -> Result { if max < min { Err(DiscreteUniformError::MinMaxInvalid) } else { Ok(DiscreteUniform { min, max }) } } } impl std::fmt::Display for DiscreteUniform { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Uni([{}, {}])", self.min, self.max) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for DiscreteUniform { fn sample(&self, rng: &mut R) -> i64 { rng.gen_range(self.min..=self.max) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for DiscreteUniform { fn sample(&self, rng: &mut R) -> f64 { rng.sample::(self) as f64 } } impl DiscreteCDF for DiscreteUniform { /// Calculates the cumulative distribution function for the /// discrete uniform distribution at `x` /// /// # Formula /// /// ```text /// (floor(x) - min + 1) / (max - min + 1) /// ``` fn cdf(&self, x: i64) -> f64 { if x < self.min { 0.0 } else if x >= self.max { 1.0 } else { let lower = self.min as f64; let upper = self.max as f64; let ans = (x as f64 - lower + 1.0) / (upper - lower + 1.0); if ans > 1.0 { 1.0 } else { ans } } } fn sf(&self, x: i64) -> f64 { // 1. - self.cdf(x) if x < self.min { 1.0 } else if x >= self.max { 0.0 } else { let lower = self.min as f64; let upper = self.max as f64; let ans = (upper - x as f64) / (upper - lower + 1.0); if ans > 1.0 { 1.0 } else { ans } } } } impl Min for DiscreteUniform { /// Returns the minimum value in the domain of the discrete uniform /// distribution /// /// # Remarks /// /// This is the same value as the minimum passed into the constructor fn min(&self) -> i64 { self.min } } impl Max for DiscreteUniform { /// Returns the maximum value in the domain of the discrete uniform /// distribution /// /// # Remarks /// /// This is the same value as the maximum passed into the constructor fn max(&self) -> i64 { self.max } } impl Distribution for DiscreteUniform { /// Returns the mean of the discrete uniform distribution /// /// # Formula /// /// ```text /// (min + max) / 2 /// ``` fn mean(&self) -> Option { Some((self.min + self.max) as f64 / 2.0) } /// Returns the variance of the discrete uniform distribution /// /// # Formula /// /// ```text /// ((max - min + 1)^2 - 1) / 12 /// ``` fn variance(&self) -> Option { let diff = (self.max - self.min) as f64; Some(((diff + 1.0) * (diff + 1.0) - 1.0) / 12.0) } /// Returns the entropy of the discrete uniform distribution /// /// # Formula /// /// ```text /// ln(max - min + 1) /// ``` fn entropy(&self) -> Option { let diff = (self.max - self.min) as f64; Some((diff + 1.0).ln()) } /// Returns the skewness of the discrete uniform distribution /// /// # Formula /// /// ```text /// 0 /// ``` fn skewness(&self) -> Option { Some(0.0) } } impl Median for DiscreteUniform { /// Returns the median of the discrete uniform distribution /// /// # Formula /// /// ```text /// (max + min) / 2 /// ``` fn median(&self) -> f64 { (self.min + self.max) as f64 / 2.0 } } impl Mode> for DiscreteUniform { /// Returns the mode for the discrete uniform distribution /// /// # Remarks /// /// Since every element has an equal probability, mode simply /// returns the middle element /// /// # Formula /// /// ```text /// N/A // (max + min) / 2 for the middle element /// ``` fn mode(&self) -> Option { Some(((self.min + self.max) as f64 / 2.0).floor() as i64) } } impl Discrete for DiscreteUniform { /// Calculates the probability mass function for the discrete uniform /// distribution at `x` /// /// # Remarks /// /// Returns `0.0` if `x` is not in `[min, max]` /// /// # Formula /// /// ```text /// 1 / (max - min + 1) /// ``` fn pmf(&self, x: i64) -> f64 { if x >= self.min && x <= self.max { 1.0 / (self.max - self.min + 1) as f64 } else { 0.0 } } /// Calculates the log probability mass function for the discrete uniform /// distribution at `x` /// /// # Remarks /// /// Returns `f64::NEG_INFINITY` if `x` is not in `[min, max]` /// /// # Formula /// /// ```text /// ln(1 / (max - min + 1)) /// ``` fn ln_pmf(&self, x: i64) -> f64 { if x >= self.min && x <= self.max { -((self.max - self.min + 1) as f64).ln() } else { f64::NEG_INFINITY } } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::testing_boiler; testing_boiler!(min: i64, max: i64; DiscreteUniform; DiscreteUniformError); #[test] fn test_create() { create_ok(-10, 10); create_ok(0, 4); create_ok(10, 20); create_ok(20, 20); } #[test] fn test_bad_create() { create_err(-1, -2); create_err(6, 5); } #[test] fn test_mean() { let mean = |x: DiscreteUniform| x.mean().unwrap(); test_exact(-10, 10, 0.0, mean); test_exact(0, 4, 2.0, mean); test_exact(10, 20, 15.0, mean); test_exact(20, 20, 20.0, mean); } #[test] fn test_variance() { let variance = |x: DiscreteUniform| x.variance().unwrap(); test_exact(-10, 10, 36.66666666666666666667, variance); test_exact(0, 4, 2.0, variance); test_exact(10, 20, 10.0, variance); test_exact(20, 20, 0.0, variance); } #[test] fn test_entropy() { let entropy = |x: DiscreteUniform| x.entropy().unwrap(); test_exact(-10, 10, 3.0445224377234229965005979803657054342845752874046093, entropy); test_exact(0, 4, 1.6094379124341003746007593332261876395256013542685181, entropy); test_exact(10, 20, 2.3978952727983705440619435779651292998217068539374197, entropy); test_exact(20, 20, 0.0, entropy); } #[test] fn test_skewness() { let skewness = |x: DiscreteUniform| x.skewness().unwrap(); test_exact(-10, 10, 0.0, skewness); test_exact(0, 4, 0.0, skewness); test_exact(10, 20, 0.0, skewness); test_exact(20, 20, 0.0, skewness); } #[test] fn test_median() { let median = |x: DiscreteUniform| x.median(); test_exact(-10, 10, 0.0, median); test_exact(0, 4, 2.0, median); test_exact(10, 20, 15.0, median); test_exact(20, 20, 20.0, median); } #[test] fn test_mode() { let mode = |x: DiscreteUniform| x.mode().unwrap(); test_exact(-10, 10, 0, mode); test_exact(0, 4, 2, mode); test_exact(10, 20, 15, mode); test_exact(20, 20, 20, mode); } #[test] fn test_pmf() { let pmf = |arg: i64| move |x: DiscreteUniform| x.pmf(arg); test_exact(-10, 10, 0.04761904761904761904762, pmf(-5)); test_exact(-10, 10, 0.04761904761904761904762, pmf(1)); test_exact(-10, 10, 0.04761904761904761904762, pmf(10)); test_exact(-10, -10, 0.0, pmf(0)); test_exact(-10, -10, 1.0, pmf(-10)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: i64| move |x: DiscreteUniform| x.ln_pmf(arg); test_exact(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(-5)); test_exact(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(1)); test_exact(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(10)); test_exact(-10, -10, f64::NEG_INFINITY, ln_pmf(0)); test_exact(-10, -10, 0.0, ln_pmf(-10)); } #[test] fn test_cdf() { let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg); test_exact(-10, 10, 0.2857142857142857142857, cdf(-5)); test_exact(-10, 10, 0.5714285714285714285714, cdf(1)); test_exact(-10, 10, 1.0, cdf(10)); test_exact(-10, -10, 1.0, cdf(-10)); } #[test] fn test_sf() { let sf = |arg: i64| move |x: DiscreteUniform| x.sf(arg); test_exact(-10, 10, 0.7142857142857142857143, sf(-5)); test_exact(-10, 10, 0.42857142857142855, sf(1)); test_exact(-10, 10, 0.0, sf(10)); test_exact(-10, -10, 0.0, sf(-10)); } #[test] fn test_cdf_lower_bound() { let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg); test_exact(0, 3, 0.0, cdf(-1)); } #[test] fn test_sf_lower_bound() { let sf = |arg: i64| move |x: DiscreteUniform| x.sf(arg); test_exact(0, 3, 1.0, sf(-1)); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg); test_exact(0, 3, 1.0, cdf(5)); } } statrs-0.18.0/src/distribution/empirical.rs000064400000000000000000000311321046102023000170530ustar 00000000000000use crate::distribution::ContinuousCDF; use crate::statistics::*; use non_nan::NonNan; use std::collections::btree_map::{BTreeMap, Entry}; use std::convert::Infallible; use std::ops::Bound; mod non_nan { use core::cmp::Ordering; #[derive(Clone, Copy, PartialEq, Debug)] pub struct NonNan(T); impl NonNan { pub fn get(self) -> T { self.0 } } impl NonNan { #[inline] pub fn new(x: f64) -> Option { if x.is_nan() { None } else { Some(Self(x)) } } } impl Eq for NonNan {} impl PartialOrd for NonNan { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for NonNan { fn cmp(&self, other: &Self) -> Ordering { self.0.partial_cmp(&other.0).unwrap() } } } /// Implements the [Empirical /// Distribution](https://en.wikipedia.org/wiki/Empirical_distribution_function) /// /// # Examples /// /// ``` /// use statrs::distribution::{Continuous, Empirical}; /// use statrs::statistics::Distribution; /// /// let samples = vec![0.0, 5.0, 10.0]; /// /// let empirical = Empirical::from_iter(samples); /// assert_eq!(empirical.mean().unwrap(), 5.0); /// ``` #[derive(Clone, PartialEq, Debug)] pub struct Empirical { // keys are data points, values are number of data points with equal value data: BTreeMap, u64>, // The following fields are only logically valid if !data.is_empty(): /// Total amount of data points (== sum of all _values_ inside self.data). /// Must be 0 iff data.is_empty() sum: u64, mean: f64, var: f64, } impl Empirical { /// Constructs a new discrete uniform distribution with a minimum value /// of `min` and a maximum value of `max`. /// /// Note that this will always succeed and never return the [`Err`][Result::Err] variant. /// /// # Examples /// /// ``` /// use statrs::distribution::Empirical; /// /// let mut result = Empirical::new(); /// assert!(result.is_ok()); /// ``` pub fn new() -> Result { Ok(Empirical { data: BTreeMap::new(), sum: 0, mean: 0.0, var: 0.0, }) } pub fn add(&mut self, data_point: f64) { let map_key = match NonNan::new(data_point) { Some(valid) => valid, None => return, }; self.sum += 1; let sum = self.sum as f64; self.var += (sum - 1.) * (data_point - self.mean) * (data_point - self.mean) / sum; self.mean += (data_point - self.mean) / sum; self.data .entry(map_key) .and_modify(|c| *c += 1) .or_insert(1); } pub fn remove(&mut self, data_point: f64) { let map_key = match NonNan::new(data_point) { Some(valid) => valid, None => return, }; let mut entry = match self.data.entry(map_key) { Entry::Occupied(entry) => entry, Entry::Vacant(_) => return, // no entry found }; if *entry.get() == 1 { entry.remove(); if self.data.is_empty() { // logically, this should not need special handling. // FP math can result in mean or var being != 0.0 though. self.sum = 0; self.mean = 0.0; self.var = 0.0; return; } } else { *entry.get_mut() -= 1; } // reset mean and var let sum = self.sum as f64; self.mean = (sum * self.mean - data_point) / (sum - 1.); self.var -= (sum - 1.) * (data_point - self.mean) * (data_point - self.mean) / sum; self.sum -= 1; } // Due to issues with rounding and floating-point accuracy the default // implementation may be ill-behaved. // Specialized inverse cdfs should be used whenever possible. // Performs a binary search on the domain of `cdf` to obtain an approximation // of `F^-1(p) := inf { x | F(x) >= p }`. Needless to say, performance may // may be lacking. // This function is identical to the default method implementation in the // `ContinuousCDF` trait and is used to implement the rand trait `Distribution`. fn __inverse_cdf(&self, p: f64) -> f64 { if p == 0.0 { return self.min(); }; if p == 1.0 { return self.max(); }; let mut high = 2.0; let mut low = -high; while self.cdf(low) > p { low = low + low; } while self.cdf(high) < p { high = high + high; } let mut i = 16; while i != 0 { let mid = (high + low) / 2.0; if self.cdf(mid) >= p { high = mid; } else { low = mid; } i -= 1; } (high + low) / 2.0 } } impl std::fmt::Display for Empirical { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut enumerated_values = self .data .iter() .flat_map(|(x, &count)| std::iter::repeat(x.get()).take(count as usize)); if let Some(x) = enumerated_values.next() { write!(f, "Empirical([{x:.3e}")?; } else { return write!(f, "Empirical(∅)"); } for val in enumerated_values.by_ref().take(4) { write!(f, ", {val:.3e}")?; } if enumerated_values.next().is_some() { write!(f, ", ...")?; } write!(f, "])") } } impl FromIterator for Empirical { fn from_iter>(iter: T) -> Self { let mut empirical = Self::new().unwrap(); for elt in iter { empirical.add(elt); } empirical } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Empirical { fn sample(&self, rng: &mut R) -> f64 { use crate::distribution::Uniform; let uniform = Uniform::new(0.0, 1.0).unwrap(); self.__inverse_cdf(uniform.sample(rng)) } } /// Panics if number of samples is zero impl Max for Empirical { fn max(&self) -> f64 { self.data.keys().rev().map(|key| key.get()).next().unwrap() } } /// Panics if number of samples is zero impl Min for Empirical { fn min(&self) -> f64 { self.data.keys().map(|key| key.get()).next().unwrap() } } impl Distribution for Empirical { fn mean(&self) -> Option { if self.data.is_empty() { None } else { Some(self.mean) } } fn variance(&self) -> Option { if self.data.is_empty() { None } else { Some(self.var / (self.sum as f64 - 1.)) } } } impl ContinuousCDF for Empirical { fn cdf(&self, x: f64) -> f64 { let start = Bound::Unbounded; let end = Bound::Included(NonNan::new(x).expect("x must not be NaN")); let sum: u64 = self.data.range((start, end)).map(|(_, v)| v).sum(); sum as f64 / self.sum as f64 } fn sf(&self, x: f64) -> f64 { let start = Bound::Excluded(NonNan::new(x).expect("x must not be NaN")); let end = Bound::Unbounded; let sum: u64 = self.data.range((start, end)).map(|(_, v)| v).sum(); sum as f64 / self.sum as f64 } fn inverse_cdf(&self, p: f64) -> f64 { self.__inverse_cdf(p) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_nan() { let mut empirical = Empirical::new().unwrap(); // should not panic empirical.add(f64::NAN); } #[test] fn test_remove_nan() { let mut empirical = Empirical::new().unwrap(); empirical.add(5.2); // should not panic empirical.remove(f64::NAN); } #[test] fn test_remove_nonexisting() { let mut empirical = Empirical::new().unwrap(); empirical.add(5.2); // should not panic empirical.remove(10.0); } #[test] fn test_remove_all() { let mut empirical = Empirical::new().unwrap(); empirical.add(17.123); empirical.add(-10.0); empirical.add(0.0); empirical.remove(-10.0); empirical.remove(17.123); empirical.remove(0.0); assert!(empirical.mean().is_none()); assert!(empirical.variance().is_none()); } #[test] fn test_mean() { fn test_mean_for_samples(expected_mean: f64, samples: Vec) { let dist = Empirical::from_iter(samples); assert_relative_eq!(dist.mean().unwrap(), expected_mean); } let dist = Empirical::from_iter(vec![]); assert!(dist.mean().is_none()); test_mean_for_samples(4.0, vec![4.0; 100]); test_mean_for_samples(-0.2, vec![-0.2; 100]); test_mean_for_samples(28.5, vec![21.3, 38.4, 12.7, 41.6]); } #[test] fn test_var() { fn test_var_for_samples(expected_var: f64, samples: Vec) { let dist = Empirical::from_iter(samples); assert_relative_eq!(dist.variance().unwrap(), expected_var); } let dist = Empirical::from_iter(vec![]); assert!(dist.variance().is_none()); test_var_for_samples(0.0, vec![4.0; 100]); test_var_for_samples(0.0, vec![-0.2; 100]); test_var_for_samples(190.36666666666667, vec![21.3, 38.4, 12.7, 41.6]); } #[test] fn test_cdf() { let samples = vec![5.0, 10.0]; let mut empirical = Empirical::from_iter(samples); assert_eq!(empirical.cdf(0.0), 0.0); assert_eq!(empirical.cdf(5.0), 0.5); assert_eq!(empirical.cdf(5.5), 0.5); assert_eq!(empirical.cdf(6.0), 0.5); assert_eq!(empirical.cdf(10.0), 1.0); assert_eq!(empirical.min(), 5.0); assert_eq!(empirical.max(), 10.0); empirical.add(2.0); empirical.add(2.0); assert_eq!(empirical.cdf(0.0), 0.0); assert_eq!(empirical.cdf(5.0), 0.75); assert_eq!(empirical.cdf(5.5), 0.75); assert_eq!(empirical.cdf(6.0), 0.75); assert_eq!(empirical.cdf(10.0), 1.0); assert_eq!(empirical.min(), 2.0); assert_eq!(empirical.max(), 10.0); let unchanged = empirical.clone(); empirical.add(2.0); empirical.remove(2.0); // because of rounding errors, this doesn't hold in general // due to the mean and variance being calculated in a streaming way assert_eq!(unchanged, empirical); } #[test] fn test_sf() { let samples = vec![5.0, 10.0]; let mut empirical = Empirical::from_iter(samples); assert_eq!(empirical.sf(0.0), 1.0); assert_eq!(empirical.sf(5.0), 0.5); assert_eq!(empirical.sf(5.5), 0.5); assert_eq!(empirical.sf(6.0), 0.5); assert_eq!(empirical.sf(10.0), 0.0); assert_eq!(empirical.min(), 5.0); assert_eq!(empirical.max(), 10.0); empirical.add(2.0); empirical.add(2.0); assert_eq!(empirical.sf(0.0), 1.0); assert_eq!(empirical.sf(5.0), 0.25); assert_eq!(empirical.sf(5.5), 0.25); assert_eq!(empirical.sf(6.0), 0.25); assert_eq!(empirical.sf(10.0), 0.0); assert_eq!(empirical.min(), 2.0); assert_eq!(empirical.max(), 10.0); let unchanged = empirical.clone(); empirical.add(2.0); empirical.remove(2.0); // because of rounding errors, this doesn't hold in general // due to the mean and variance being calculated in a streaming way assert_eq!(unchanged, empirical); } #[test] fn test_display() { let mut e = Empirical::new().unwrap(); assert_eq!(e.to_string(), "Empirical(∅)"); e.add(1.0); assert_eq!(e.to_string(), "Empirical([1.000e0])"); e.add(1.0); assert_eq!(e.to_string(), "Empirical([1.000e0, 1.000e0])"); e.add(2.0); assert_eq!(e.to_string(), "Empirical([1.000e0, 1.000e0, 2.000e0])"); e.add(2.0); assert_eq!( e.to_string(), "Empirical([1.000e0, 1.000e0, 2.000e0, 2.000e0])" ); e.add(5.0); assert_eq!( e.to_string(), "Empirical([1.000e0, 1.000e0, 2.000e0, 2.000e0, 5.000e0])" ); e.add(5.0); assert_eq!( e.to_string(), "Empirical([1.000e0, 1.000e0, 2.000e0, 2.000e0, 5.000e0, ...])" ); } } statrs-0.18.0/src/distribution/erlang.rs000064400000000000000000000176131046102023000163660ustar 00000000000000use crate::distribution::{Continuous, ContinuousCDF, Gamma, GammaError}; use crate::statistics::*; /// Implements the [Erlang](https://en.wikipedia.org/wiki/Erlang_distribution) /// distribution /// which is a special case of the /// [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{Erlang, Continuous}; /// use statrs::statistics::Distribution; /// use statrs::prec; /// /// let n = Erlang::new(3, 1.0).unwrap(); /// assert_eq!(n.mean().unwrap(), 3.0); /// assert!(prec::almost_eq(n.pdf(2.0), 0.270670566473225383788, 1e-15)); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct Erlang { g: Gamma, } impl Erlang { /// Constructs a new erlang distribution with a shape (k) /// of `shape` and a rate (λ) of `rate` /// /// # Errors /// /// Returns an error if `shape` or `rate` are `NaN`. /// Also returns an error if `shape == 0` or `rate <= 0.0` /// /// # Examples /// /// ``` /// use statrs::distribution::Erlang; /// /// let mut result = Erlang::new(3, 1.0); /// assert!(result.is_ok()); /// /// result = Erlang::new(0, 0.0); /// assert!(result.is_err()); /// ``` pub fn new(shape: u64, rate: f64) -> Result { Gamma::new(shape as f64, rate).map(|g| Erlang { g }) } /// Returns the shape (k) of the erlang distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Erlang; /// /// let n = Erlang::new(3, 1.0).unwrap(); /// assert_eq!(n.shape(), 3); /// ``` pub fn shape(&self) -> u64 { self.g.shape() as u64 } /// Returns the rate (λ) of the erlang distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Erlang; /// /// let n = Erlang::new(3, 1.0).unwrap(); /// assert_eq!(n.rate(), 1.0); /// ``` pub fn rate(&self) -> f64 { self.g.rate() } } impl std::fmt::Display for Erlang { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "E({}, {})", self.rate(), self.shape()) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Erlang { fn sample(&self, rng: &mut R) -> f64 { ::rand::distributions::Distribution::sample(&self.g, rng) } } impl ContinuousCDF for Erlang { /// Calculates the cumulative distribution function for the erlang /// distribution /// at `x` /// /// # Formula /// /// ```text /// γ(k, λx) (k - 1)! /// ``` /// /// where `k` is the shape, `λ` is the rate, and `γ` is the lower /// incomplete gamma function fn cdf(&self, x: f64) -> f64 { self.g.cdf(x) } /// Calculates the cumulative distribution function for the erlang /// distribution /// at `x` /// /// # Formula /// /// ```text /// γ(k, λx) (k - 1)! /// ``` /// /// where `k` is the shape, `λ` is the rate, and `γ` is the upper /// incomplete gamma function fn sf(&self, x: f64) -> f64 { self.g.sf(x) } /// Calculates the inverse cumulative distribution function for the erlang /// distribution at `x` /// /// # Formula /// /// ```text /// γ^{-1}(k, (k - 1)! x) / λ /// ``` /// /// where `k` is the shape, `λ` is the rate, and `γ` is the upper /// incomplete gamma function fn inverse_cdf(&self, p: f64) -> f64 { self.g.inverse_cdf(p) } } impl Min for Erlang { /// Returns the minimum value in the domain of the /// erlang distribution representable by a double precision /// float /// /// # Formula /// /// ```text /// 0 /// ``` fn min(&self) -> f64 { self.g.min() } } impl Max for Erlang { /// Returns the maximum value in the domain of the /// erlang distribution representable by a double precision /// float /// /// # Formula /// /// ```text /// f64::INFINITY /// ``` fn max(&self) -> f64 { self.g.max() } } impl Distribution for Erlang { /// Returns the mean of the erlang distribution /// /// # Remarks /// /// Returns `shape` if `rate == f64::INFINITY`. This behavior /// is borrowed from the Math.NET implementation /// /// # Formula /// /// ```text /// k / λ /// ``` /// /// where `k` is the shape and `λ` is the rate fn mean(&self) -> Option { self.g.mean() } /// Returns the variance of the erlang distribution /// /// # Formula /// /// ```text /// k / λ^2 /// ``` /// /// where `α` is the shape and `λ` is the rate fn variance(&self) -> Option { self.g.variance() } /// Returns the entropy of the erlang distribution /// /// # Formula /// /// ```text /// k - ln(λ) + ln(Γ(k)) + (1 - k) * ψ(k) /// ``` /// /// where `k` is the shape, `λ` is the rate, `Γ` is the gamma function, /// and `ψ` is the digamma function fn entropy(&self) -> Option { self.g.entropy() } /// Returns the skewness of the erlang distribution /// /// # Formula /// /// ```text /// 2 / sqrt(k) /// ``` /// /// where `k` is the shape fn skewness(&self) -> Option { self.g.skewness() } } impl Mode> for Erlang { /// Returns the mode for the erlang distribution /// /// # Remarks /// /// Returns `shape` if `rate ==f64::INFINITY`. This behavior /// is borrowed from the Math.NET implementation /// /// # Formula /// /// ```text /// (k - 1) / λ /// ``` /// /// where `k` is the shape and `λ` is the rate fn mode(&self) -> Option { self.g.mode() } } impl Continuous for Erlang { /// Calculates the probability density function for the erlang distribution /// at `x` /// /// # Remarks /// /// Returns `NAN` if any of `shape` or `rate` are `INF` /// or if `x` is `INF` /// /// # Formula /// /// ```text /// (λ^k / Γ(k)) * x^(k - 1) * e^(-λ * x) /// ``` /// /// where `k` is the shape, `λ` is the rate, and `Γ` is the gamma function fn pdf(&self, x: f64) -> f64 { self.g.pdf(x) } /// Calculates the log probability density function for the erlang /// distribution /// at `x` /// /// # Remarks /// /// Returns `NAN` if any of `shape` or `rate` are `INF` /// or if `x` is `INF` /// /// # Formula /// /// ```text /// ln((λ^k / Γ(k)) * x^(k - 1) * e ^(-λ * x)) /// ``` /// /// where `k` is the shape, `λ` is the rate, and `Γ` is the gamma function fn ln_pdf(&self, x: f64) -> f64 { self.g.ln_pdf(x) } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(shape: u64, rate: f64; Erlang; GammaError); #[test] fn test_create() { create_ok(1, 0.1); create_ok(1, 1.0); create_ok(10, 10.0); create_ok(10, 1.0); create_ok(10, f64::INFINITY); } #[test] fn test_bad_create() { let invalid = [ (0, 1.0, GammaError::ShapeInvalid), (1, 0.0, GammaError::RateInvalid), (1, f64::NAN, GammaError::RateInvalid), (1, -1.0, GammaError::RateInvalid), ]; for (s, r, err) in invalid { test_create_err(s, r, err); } } #[test] fn test_continuous() { test::check_continuous_distribution(&create_ok(1, 2.5), 0.0, 20.0); test::check_continuous_distribution(&create_ok(2, 1.5), 0.0, 20.0); test::check_continuous_distribution(&create_ok(3, 0.5), 0.0, 20.0); } } statrs-0.18.0/src/distribution/exponential.rs000064400000000000000000000331561046102023000174440ustar 00000000000000use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; use std::f64; /// Implements the /// [Exp](https://en.wikipedia.org/wiki/Exp_distribution) /// distribution and is a special case of the /// [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution) distribution /// (referenced [here](./struct.Gamma.html)) /// /// # Examples /// /// ``` /// use statrs::distribution::{Exp, Continuous}; /// use statrs::statistics::Distribution; /// /// let n = Exp::new(1.0).unwrap(); /// assert_eq!(n.mean().unwrap(), 1.0); /// assert_eq!(n.pdf(1.0), 0.3678794411714423215955); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct Exp { rate: f64, } /// Represents the errors that can occur when creating a [`Exp`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum ExpError { /// The rate is NaN, zero or less than zero. RateInvalid, } impl std::fmt::Display for ExpError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { ExpError::RateInvalid => write!(f, "Rate is NaN, zero or less than zero"), } } } impl std::error::Error for ExpError {} impl Exp { /// Constructs a new exponential distribution with a /// rate (λ) of `rate`. /// /// # Errors /// /// Returns an error if rate is `NaN` or `rate <= 0.0`. /// /// # Examples /// /// ``` /// use statrs::distribution::Exp; /// /// let mut result = Exp::new(1.0); /// assert!(result.is_ok()); /// /// result = Exp::new(-1.0); /// assert!(result.is_err()); /// ``` pub fn new(rate: f64) -> Result { if rate.is_nan() || rate <= 0.0 { Err(ExpError::RateInvalid) } else { Ok(Exp { rate }) } } /// Returns the rate of the exponential distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Exp; /// /// let n = Exp::new(1.0).unwrap(); /// assert_eq!(n.rate(), 1.0); /// ``` pub fn rate(&self) -> f64 { self.rate } } impl std::fmt::Display for Exp { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Exp({})", self.rate) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Exp { fn sample(&self, r: &mut R) -> f64 { use crate::distribution::ziggurat; ziggurat::sample_exp_1(r) / self.rate } } impl ContinuousCDF for Exp { /// Calculates the cumulative distribution function for the /// exponential distribution at `x` /// /// # Formula /// /// ```text /// 1 - e^(-λ * x) /// ``` /// /// where `λ` is the rate fn cdf(&self, x: f64) -> f64 { if x < 0.0 { 0.0 } else { 1.0 - (-self.rate * x).exp() } } /// Calculates the cumulative distribution function for the /// exponential distribution at `x` /// /// # Formula /// /// ```text /// e^(-λ * x) /// ``` /// /// where `λ` is the rate fn sf(&self, x: f64) -> f64 { if x < 0.0 { 1.0 } else { (-self.rate * x).exp() } } /// Calculates the inverse cumulative distribution function. /// /// # Formula /// /// ```text /// -ln(1 - p) / λ /// ``` /// /// where `p` is the probability and `λ` is the rate fn inverse_cdf(&self, p: f64) -> f64 { -(-p).ln_1p() / self.rate } } impl Min for Exp { /// Returns the minimum value in the domain of the exponential /// distribution representable by a double precision float /// /// # Formula /// /// ```text /// 0 /// ``` fn min(&self) -> f64 { 0.0 } } impl Max for Exp { /// Returns the maximum value in the domain of the exponential /// distribution representable by a double precision float /// /// # Formula /// /// ```text /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY } } impl Distribution for Exp { /// Returns the mean of the exponential distribution /// /// # Formula /// /// ```text /// 1 / λ /// ``` /// /// where `λ` is the rate fn mean(&self) -> Option { Some(1.0 / self.rate) } /// Returns the variance of the exponential distribution /// /// # Formula /// /// ```text /// 1 / λ^2 /// ``` /// /// where `λ` is the rate fn variance(&self) -> Option { Some(1.0 / (self.rate * self.rate)) } /// Returns the entropy of the exponential distribution /// /// # Formula /// /// ```text /// 1 - ln(λ) /// ``` /// /// where `λ` is the rate fn entropy(&self) -> Option { Some(1.0 - self.rate.ln()) } /// Returns the skewness of the exponential distribution /// /// # Formula /// /// ```text /// 2 /// ``` fn skewness(&self) -> Option { Some(2.0) } } impl Median for Exp { /// Returns the median of the exponential distribution /// /// # Formula /// /// ```text /// (1 / λ) * ln2 /// ``` /// /// where `λ` is the rate fn median(&self) -> f64 { f64::consts::LN_2 / self.rate } } impl Mode> for Exp { /// Returns the mode of the exponential distribution /// /// # Formula /// /// ```text /// 0 /// ``` fn mode(&self) -> Option { Some(0.0) } } impl Continuous for Exp { /// Calculates the probability density function for the exponential /// distribution at `x` /// /// # Formula /// /// ```text /// λ * e^(-λ * x) /// ``` /// /// where `λ` is the rate fn pdf(&self, x: f64) -> f64 { if x < 0.0 { 0.0 } else { self.rate * (-self.rate * x).exp() } } /// Calculates the log probability density function for the exponential /// distribution at `x` /// /// # Formula /// /// ```text /// ln(λ * e^(-λ * x)) /// ``` /// /// where `λ` is the rate fn ln_pdf(&self, x: f64) -> f64 { if x < 0.0 { f64::NEG_INFINITY } else { self.rate.ln() - self.rate * x } } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(rate: f64; Exp; ExpError); #[test] fn test_create() { create_ok(0.1); create_ok(1.0); create_ok(10.0); } #[test] fn test_bad_create() { create_err(f64::NAN); create_err(0.0); create_err(-1.0); create_err(-10.0); } #[test] fn test_mean() { let mean = |x: Exp| x.mean().unwrap(); test_exact(0.1, 10.0, mean); test_exact(1.0, 1.0, mean); test_exact(10.0, 0.1, mean); } #[test] fn test_variance() { let variance = |x: Exp| x.variance().unwrap(); test_absolute(0.1, 100.0, 1e-13, variance); test_exact(1.0, 1.0, variance); test_exact(10.0, 0.01, variance); } #[test] fn test_entropy() { let entropy = |x: Exp| x.entropy().unwrap(); test_absolute(0.1, 3.302585092994045684018, 1e-15, entropy); test_exact(1.0, 1.0, entropy); test_absolute(10.0, -1.302585092994045684018, 1e-15, entropy); } #[test] fn test_skewness() { let skewness = |x: Exp| x.skewness().unwrap(); test_exact(0.1, 2.0, skewness); test_exact(1.0, 2.0, skewness); test_exact(10.0, 2.0, skewness); } #[test] fn test_median() { let median = |x: Exp| x.median(); test_absolute(0.1, 6.931471805599453094172, 1e-15, median); test_exact(1.0, f64::consts::LN_2, median); test_exact(10.0, 0.06931471805599453094172, median); } #[test] fn test_mode() { let mode = |x: Exp| x.mode().unwrap(); test_exact(0.1, 0.0, mode); test_exact(1.0, 0.0, mode); test_exact(10.0, 0.0, mode); } #[test] fn test_min_max() { let min = |x: Exp| x.min(); let max = |x: Exp| x.max(); test_exact(0.1, 0.0, min); test_exact(1.0, 0.0, min); test_exact(10.0, 0.0, min); test_exact(0.1, f64::INFINITY, max); test_exact(1.0, f64::INFINITY, max); test_exact(10.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Exp| x.pdf(arg); test_exact(0.1, 0.1, pdf(0.0)); test_exact(1.0, 1.0, pdf(0.0)); test_exact(10.0, 10.0, pdf(0.0)); test_is_nan(f64::INFINITY, pdf(0.0)); test_exact(0.1, 0.09900498337491680535739, pdf(0.1)); test_absolute(1.0, 0.9048374180359595731642, 1e-15, pdf(0.1)); test_exact(10.0, 3.678794411714423215955, pdf(0.1)); test_is_nan(f64::INFINITY, pdf(0.1)); test_exact(0.1, 0.09048374180359595731642, pdf(1.0)); test_exact(1.0, 0.3678794411714423215955, pdf(1.0)); test_absolute(10.0, 4.539992976248485153559e-4, 1e-19, pdf(1.0)); test_is_nan(f64::INFINITY, pdf(1.0)); test_exact(0.1, 0.0, pdf(f64::INFINITY)); test_exact(1.0, 0.0, pdf(f64::INFINITY)); test_exact(10.0, 0.0, pdf(f64::INFINITY)); test_is_nan(f64::INFINITY, pdf(f64::INFINITY)); } #[test] fn test_neg_pdf() { let pdf = |arg: f64| move |x: Exp| x.pdf(arg); test_exact(0.1, 0.0, pdf(-1.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Exp| x.ln_pdf(arg); test_absolute(0.1, -2.302585092994045684018, 1e-15, ln_pdf(0.0)); test_exact(1.0, 0.0, ln_pdf(0.0)); test_exact(10.0, 2.302585092994045684018, ln_pdf(0.0)); test_is_nan(f64::INFINITY, ln_pdf(0.0)); test_absolute(0.1, -2.312585092994045684018, 1e-15, ln_pdf(0.1)); test_exact(1.0, -0.1, ln_pdf(0.1)); test_absolute(10.0, 1.302585092994045684018, 1e-15, ln_pdf(0.1)); test_is_nan(f64::INFINITY, ln_pdf(0.1)); test_exact(0.1, -2.402585092994045684018, ln_pdf(1.0)); test_exact(1.0, -1.0, ln_pdf(1.0)); test_exact(10.0, -7.697414907005954315982, ln_pdf(1.0)); test_is_nan(f64::INFINITY, ln_pdf(1.0)); test_exact(0.1, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); test_exact(1.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); test_exact(10.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); test_is_nan(f64::INFINITY, ln_pdf(f64::INFINITY)); } #[test] fn test_neg_ln_pdf() { let ln_pdf = |arg: f64| move |x: Exp| x.ln_pdf(arg); test_exact(0.1, f64::NEG_INFINITY, ln_pdf(-1.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Exp| x.cdf(arg); test_exact(0.1, 0.0, cdf(0.0)); test_exact(1.0, 0.0, cdf(0.0)); test_exact(10.0, 0.0, cdf(0.0)); test_is_nan(f64::INFINITY, cdf(0.0)); test_absolute(0.1, 0.009950166250831946426094, 1e-16, cdf(0.1)); test_absolute(1.0, 0.0951625819640404268358, 1e-16, cdf(0.1)); test_exact(10.0, 0.6321205588285576784045, cdf(0.1)); test_exact(f64::INFINITY, 1.0, cdf(0.1)); test_absolute(0.1, 0.0951625819640404268358, 1e-16, cdf(1.0)); test_exact(1.0, 0.6321205588285576784045, cdf(1.0)); test_exact(10.0, 0.9999546000702375151485, cdf(1.0)); test_exact(f64::INFINITY, 1.0, cdf(1.0)); test_exact(0.1, 1.0, cdf(f64::INFINITY)); test_exact(1.0, 1.0, cdf(f64::INFINITY)); test_exact(10.0, 1.0, cdf(f64::INFINITY)); test_exact(f64::INFINITY, 1.0, cdf(f64::INFINITY)); } #[test] fn test_inverse_cdf() { let distribution = Exp::new(0.42).unwrap(); assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); let distribution = Exp::new(0.042).unwrap(); assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); let distribution = Exp::new(0.0042).unwrap(); assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); let distribution = Exp::new(0.33).unwrap(); assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); let distribution = Exp::new(0.033).unwrap(); assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); let distribution = Exp::new(0.0033).unwrap(); assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Exp| x.sf(arg); test_exact(0.1, 1.0, sf(0.0)); test_exact(1.0, 1.0, sf(0.0)); test_exact(10.0, 1.0, sf(0.0)); test_is_nan(f64::INFINITY, sf(0.0)); test_absolute(0.1, 0.9900498337491681, 1e-16, sf(0.1)); test_absolute(1.0, 0.9048374180359595, 1e-16, sf(0.1)); test_absolute(10.0, 0.36787944117144233, 1e-15, sf(0.1)); test_exact(f64::INFINITY, 0.0, sf(0.1)); } #[test] fn test_neg_cdf() { let cdf = |arg: f64| move |x: Exp| x.cdf(arg); test_exact(0.1, 0.0, cdf(-1.0)); } #[test] fn test_neg_sf() { let sf = |arg: f64| move |x: Exp| x.sf(arg); test_exact(0.1, 1.0, sf(-1.0)); } #[test] fn test_continuous() { test::check_continuous_distribution(&create_ok(0.5), 0.0, 10.0); test::check_continuous_distribution(&create_ok(1.5), 0.0, 20.0); test::check_continuous_distribution(&create_ok(2.5), 0.0, 50.0); } } statrs-0.18.0/src/distribution/fisher_snedecor.rs000064400000000000000000000476561046102023000202720ustar 00000000000000use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::beta; use crate::statistics::*; use std::f64; /// Implements the /// [Fisher-Snedecor](https://en.wikipedia.org/wiki/F-distribution) distribution /// also commonly known as the F-distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{FisherSnedecor, Continuous}; /// use statrs::statistics::Distribution; /// use statrs::prec; /// /// let n = FisherSnedecor::new(3.0, 3.0).unwrap(); /// assert_eq!(n.mean().unwrap(), 3.0); /// assert!(prec::almost_eq(n.pdf(1.0), 0.318309886183790671538, 1e-15)); /// ``` #[derive(Debug, Copy, Clone, PartialEq)] pub struct FisherSnedecor { freedom_1: f64, freedom_2: f64, } /// Represents the errors that can occur when creating a [`FisherSnedecor`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum FisherSnedecorError { /// `freedom_1` is NaN, infinite, zero or less than zero. Freedom1Invalid, /// `freedom_2` is NaN, infinite, zero or less than zero. Freedom2Invalid, } impl std::fmt::Display for FisherSnedecorError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { FisherSnedecorError::Freedom1Invalid => { write!(f, "freedom_1 is NaN, infinite, zero or less than zero.") } FisherSnedecorError::Freedom2Invalid => { write!(f, "freedom_2 is NaN, infinite, zero or less than zero.") } } } } impl std::error::Error for FisherSnedecorError {} impl FisherSnedecor { /// Constructs a new fisher-snedecor distribution with /// degrees of freedom `freedom_1` and `freedom_2` /// /// # Errors /// /// Returns an error if `freedom_1` or `freedom_2` are `NaN`. /// Also returns an error if `freedom_1 <= 0.0` or `freedom_2 <= 0.0` /// /// # Examples /// /// ``` /// use statrs::distribution::FisherSnedecor; /// /// let mut result = FisherSnedecor::new(1.0, 1.0); /// assert!(result.is_ok()); /// /// result = FisherSnedecor::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` pub fn new(freedom_1: f64, freedom_2: f64) -> Result { if !freedom_1.is_finite() || freedom_1 <= 0.0 { return Err(FisherSnedecorError::Freedom1Invalid); } if !freedom_2.is_finite() || freedom_2 <= 0.0 { return Err(FisherSnedecorError::Freedom2Invalid); } Ok(FisherSnedecor { freedom_1, freedom_2, }) } /// Returns the first degree of freedom for the /// fisher-snedecor distribution /// /// # Examples /// /// ``` /// use statrs::distribution::FisherSnedecor; /// /// let n = FisherSnedecor::new(2.0, 3.0).unwrap(); /// assert_eq!(n.freedom_1(), 2.0); /// ``` pub fn freedom_1(&self) -> f64 { self.freedom_1 } /// Returns the second degree of freedom for the /// fisher-snedecor distribution /// /// # Examples /// /// ``` /// use statrs::distribution::FisherSnedecor; /// /// let n = FisherSnedecor::new(2.0, 3.0).unwrap(); /// assert_eq!(n.freedom_2(), 3.0); /// ``` pub fn freedom_2(&self) -> f64 { self.freedom_2 } } impl std::fmt::Display for FisherSnedecor { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "F({},{})", self.freedom_1, self.freedom_2) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for FisherSnedecor { fn sample(&self, rng: &mut R) -> f64 { (super::gamma::sample_unchecked(rng, self.freedom_1 / 2.0, 0.5) * self.freedom_2) / (super::gamma::sample_unchecked(rng, self.freedom_2 / 2.0, 0.5) * self.freedom_1) } } impl ContinuousCDF for FisherSnedecor { /// Calculates the cumulative distribution function for the fisher-snedecor /// distribution /// at `x` /// /// # Formula /// /// ```text /// I_((d1 * x) / (d1 * x + d2))(d1 / 2, d2 / 2) /// ``` /// /// where `d1` is the first degree of freedom, `d2` is /// the second degree of freedom, and `I` is the regularized incomplete /// beta function fn cdf(&self, x: f64) -> f64 { if x < 0.0 { 0.0 } else if x.is_infinite() { 1.0 } else { beta::beta_reg( self.freedom_1 / 2.0, self.freedom_2 / 2.0, self.freedom_1 * x / (self.freedom_1 * x + self.freedom_2), ) } } /// Calculates the survival function for the fisher-snedecor /// distribution at `x` /// /// # Formula /// /// ```text /// I_(1 - ((d1 * x) / (d1 * x + d2))(d2 / 2, d1 / 2) /// ``` /// /// where `d1` is the first degree of freedom, `d2` is /// the second degree of freedom, and `I` is the regularized incomplete /// beta function fn sf(&self, x: f64) -> f64 { if x < 0.0 { 1.0 } else if x.is_infinite() { 0.0 } else { beta::beta_reg( self.freedom_2 / 2.0, self.freedom_1 / 2.0, 1. - ((self.freedom_1 * x) / (self.freedom_1 * x + self.freedom_2)), ) } } /// Calculates the inverse cumulative distribution function for the /// fisher-snedecor distribution at `x` /// /// # Formula /// /// ```text /// z = I^{-1}_(x)(d1 / 2, d2 / 2) /// d2 / (d1 (1 / z - 1)) /// ``` /// /// where `d1` is the first degree of freedom, `d2` is /// the second degree of freedom, and `I` is the regularized incomplete /// beta function fn inverse_cdf(&self, x: f64) -> f64 { if !(0.0..=1.0).contains(&x) { panic!("x must be in [0, 1]"); } else { let z = beta::inv_beta_reg(self.freedom_1 / 2.0, self.freedom_2 / 2.0, x); self.freedom_2 / (self.freedom_1 * (1.0 / z - 1.0)) } } } impl Min for FisherSnedecor { /// Returns the minimum value in the domain of the /// fisher-snedecor distribution representable by a double precision /// float /// /// # Formula /// /// ```text /// 0 /// ``` fn min(&self) -> f64 { 0.0 } } impl Max for FisherSnedecor { /// Returns the maximum value in the domain of the /// fisher-snedecor distribution representable by a double precision /// float /// /// # Formula /// /// ```text /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY } } impl Distribution for FisherSnedecor { /// Returns the mean of the fisher-snedecor distribution /// /// # Panics /// /// If `freedom_2 <= 2.0` /// /// # Remarks /// /// Returns `NaN` if `freedom_2` is `INF` /// /// # Formula /// /// ```text /// d2 / (d2 - 2) /// ``` /// /// where `d2` is the second degree of freedom fn mean(&self) -> Option { if self.freedom_2 <= 2.0 { None } else { Some(self.freedom_2 / (self.freedom_2 - 2.0)) } } /// Returns the variance of the fisher-snedecor distribution /// /// # Panics /// /// If `freedom_2 <= 4.0` /// /// # Remarks /// /// Returns `NaN` if `freedom_1` or `freedom_2` is `INF` /// /// # Formula /// /// ```text /// (2 * d2^2 * (d1 + d2 - 2)) / (d1 * (d2 - 2)^2 * (d2 - 4)) /// ``` /// /// where `d1` is the first degree of freedom and `d2` is /// the second degree of freedom fn variance(&self) -> Option { if self.freedom_2 <= 4.0 { None } else { let val = (2.0 * self.freedom_2 * self.freedom_2 * (self.freedom_1 + self.freedom_2 - 2.0)) / (self.freedom_1 * (self.freedom_2 - 2.0) * (self.freedom_2 - 2.0) * (self.freedom_2 - 4.0)); Some(val) } } /// Returns the skewness of the fisher-snedecor distribution /// /// # Panics /// /// If `freedom_2 <= 6.0` /// /// # Remarks /// /// Returns `NaN` if `freedom_1` or `freedom_2` is `INF` /// /// # Formula /// /// ```text /// ((2d1 + d2 - 2) * sqrt(8 * (d2 - 4))) / ((d2 - 6) * sqrt(d1 * (d1 + d2 /// - 2))) /// ``` /// /// where `d1` is the first degree of freedom and `d2` is /// the second degree of freedom fn skewness(&self) -> Option { if self.freedom_2 <= 6.0 { None } else { let val = ((2.0 * self.freedom_1 + self.freedom_2 - 2.0) * (8.0 * (self.freedom_2 - 4.0)).sqrt()) / ((self.freedom_2 - 6.0) * (self.freedom_1 * (self.freedom_1 + self.freedom_2 - 2.0)).sqrt()); Some(val) } } } impl Mode> for FisherSnedecor { /// Returns the mode for the fisher-snedecor distribution /// /// # Panics /// /// If `freedom_1 <= 2.0` /// /// # Remarks /// /// Returns `NaN` if `freedom_1` or `freedom_2` is `INF` /// /// # Formula /// /// ```text /// ((d1 - 2) / d1) * (d2 / (d2 + 2)) /// ``` /// /// where `d1` is the first degree of freedom and `d2` is /// the second degree of freedom fn mode(&self) -> Option { if self.freedom_1 <= 2.0 { None } else { let val = (self.freedom_2 * (self.freedom_1 - 2.0)) / (self.freedom_1 * (self.freedom_2 + 2.0)); Some(val) } } } impl Continuous for FisherSnedecor { /// Calculates the probability density function for the fisher-snedecor /// distribution /// at `x` /// /// # Remarks /// /// Returns `NaN` if `freedom_1`, `freedom_2` is `INF`, or `x` is `+INF` or /// `-INF` /// /// # Formula /// /// ```text /// sqrt(((d1 * x) ^ d1 * d2 ^ d2) / (d1 * x + d2) ^ (d1 + d2)) / (x * β(d1 /// / 2, d2 / 2)) /// ``` /// /// where `d1` is the first degree of freedom, `d2` is /// the second degree of freedom, and `β` is the beta function fn pdf(&self, x: f64) -> f64 { if x.is_infinite() || x <= 0.0 { 0.0 } else { ((self.freedom_1 * x).powf(self.freedom_1) * self.freedom_2.powf(self.freedom_2) / (self.freedom_1 * x + self.freedom_2).powf(self.freedom_1 + self.freedom_2)) .sqrt() / (x * beta::beta(self.freedom_1 / 2.0, self.freedom_2 / 2.0)) } } /// Calculates the log probability density function for the fisher-snedecor /// distribution /// at `x` /// /// # Remarks /// /// Returns `NaN` if `freedom_1`, `freedom_2` is `INF`, or `x` is `+INF` or /// `-INF` /// /// # Formula /// /// ```text /// ln(sqrt(((d1 * x) ^ d1 * d2 ^ d2) / (d1 * x + d2) ^ (d1 + d2)) / (x * /// β(d1 / 2, d2 / 2))) /// ``` /// /// where `d1` is the first degree of freedom, `d2` is /// the second degree of freedom, and `β` is the beta function fn ln_pdf(&self, x: f64) -> f64 { self.pdf(x).ln() } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(freedom_1: f64, freedom_2: f64; FisherSnedecor; FisherSnedecorError); #[test] fn test_create() { create_ok(0.1, 0.1); create_ok(1.0, 0.1); create_ok(10.0, 0.1); create_ok(0.1, 1.0); create_ok(1.0, 1.0); create_ok(10.0, 1.0); } #[test] fn test_bad_create() { test_create_err(f64::INFINITY, 0.1, FisherSnedecorError::Freedom1Invalid); test_create_err(0.1, f64::INFINITY, FisherSnedecorError::Freedom2Invalid); create_err(f64::NAN, f64::NAN); create_err(0.0, f64::NAN); create_err(-1.0, f64::NAN); create_err(-10.0, f64::NAN); create_err(f64::NAN, 0.0); create_err(0.0, 0.0); create_err(-1.0, 0.0); create_err(-10.0, 0.0); create_err(f64::NAN, -1.0); create_err(0.0, -1.0); create_err(-1.0, -1.0); create_err(-10.0, -1.0); create_err(f64::NAN, -10.0); create_err(0.0, -10.0); create_err(-1.0, -10.0); create_err(-10.0, -10.0); create_err(f64::INFINITY, f64::INFINITY); } #[test] fn test_mean() { let mean = |x: FisherSnedecor| x.mean().unwrap(); test_exact(0.1, 10.0, 1.25, mean); test_exact(1.0, 10.0, 1.25, mean); test_exact(10.0, 10.0, 1.25, mean); } #[test] fn test_mean_with_low_d2() { test_none(0.1, 0.1, |dist| dist.mean()); } #[test] fn test_variance() { let variance = |x: FisherSnedecor| x.variance().unwrap(); test_absolute(0.1, 10.0, 42.1875, 1e-14, variance); test_exact(1.0, 10.0, 4.6875, variance); test_exact(10.0, 10.0, 0.9375, variance); } #[test] fn test_variance_with_low_d2() { test_none(0.1, 0.1, |dist| dist.variance()); } #[test] fn test_skewness() { let skewness = |x: FisherSnedecor| x.skewness().unwrap(); test_absolute(0.1, 10.0, 15.78090735784977089658, 1e-14, skewness); test_exact(1.0, 10.0, 5.773502691896257645091, skewness); test_exact(10.0, 10.0, 3.614784456460255759501, skewness); } #[test] fn test_skewness_with_low_d2() { test_none(0.1, 0.1, |dist| dist.skewness()); } #[test] fn test_mode() { let mode = |x: FisherSnedecor| x.mode().unwrap(); test_exact(10.0, 0.1, 0.0380952380952380952381, mode); test_exact(10.0, 1.0, 4.0 / 15.0, mode); test_exact(10.0, 10.0, 2.0 / 3.0, mode); } #[test] fn test_mode_with_low_d1() { test_none(0.1, 0.1, |dist| dist.mode()); } #[test] fn test_min_max() { let min = |x: FisherSnedecor| x.min(); let max = |x: FisherSnedecor| x.max(); test_exact(1.0, 1.0, 0.0, min); test_exact(1.0, 1.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: FisherSnedecor| x.pdf(arg); test_absolute(0.1, 0.1, 0.0234154207226588982471, 1e-16, pdf(1.0)); test_absolute(1.0, 0.1, 0.0396064560910663979961, 1e-16, pdf(1.0)); test_absolute(10.0, 0.1, 0.0418440630400545297349, 1e-14, pdf(1.0)); test_absolute(0.1, 1.0, 0.0396064560910663979961, 1e-16, pdf(1.0)); test_absolute(1.0, 1.0, 0.1591549430918953357689, 1e-16, pdf(1.0)); test_absolute(10.0, 1.0, 0.230361989229138647108, 1e-16, pdf(1.0)); test_absolute(0.1, 0.1, 0.00221546909694001013517, 1e-18, pdf(10.0)); test_absolute(1.0, 0.1, 0.00369960370387922619592, 1e-17, pdf(10.0)); test_absolute(10.0, 0.1, 0.00390179721174142927402, 1e-15, pdf(10.0)); test_absolute(0.1, 1.0, 0.00319864073359931548273, 1e-17, pdf(10.0)); test_absolute(1.0, 1.0, 0.009150765837179460915678, 1e-17, pdf(10.0)); test_absolute(10.0, 1.0, 0.0116493859171442148446, 1e-17, pdf(10.0)); test_absolute(0.1, 10.0, 0.00305087016058573989694, 1e-15, pdf(10.0)); test_absolute(1.0, 10.0, 0.00271897749113479577864, 1e-17, pdf(10.0)); test_absolute(10.0, 10.0, 2.4289227234060500084E-4, 1e-18, pdf(10.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: FisherSnedecor| x.ln_pdf(arg); test_absolute(0.1, 0.1, 0.0234154207226588982471f64.ln(), 1e-15, ln_pdf(1.0)); test_absolute(1.0, 0.1, 0.0396064560910663979961f64.ln(), 1e-15, ln_pdf(1.0)); test_absolute(10.0, 0.1, 0.0418440630400545297349f64.ln(), 1e-13, ln_pdf(1.0)); test_absolute(0.1, 1.0, 0.0396064560910663979961f64.ln(), 1e-15, ln_pdf(1.0)); test_absolute(1.0, 1.0, 0.1591549430918953357689f64.ln(), 1e-15, ln_pdf(1.0)); test_absolute(10.0, 1.0, 0.230361989229138647108f64.ln(), 1e-15, ln_pdf(1.0)); test_exact(0.1, 0.1, 0.00221546909694001013517f64.ln(), ln_pdf(10.0)); test_absolute(1.0, 0.1, 0.00369960370387922619592f64.ln(), 1e-15, ln_pdf(10.0)); test_absolute(10.0, 0.1, 0.00390179721174142927402f64.ln(), 1e-13, ln_pdf(10.0)); test_absolute(0.1, 1.0, 0.00319864073359931548273f64.ln(), 1e-15, ln_pdf(10.0)); test_absolute(1.0, 1.0, 0.009150765837179460915678f64.ln(), 1e-15, ln_pdf(10.0)); test_exact(10.0, 1.0, 0.0116493859171442148446f64.ln(), ln_pdf(10.0)); test_absolute(0.1, 10.0, 0.00305087016058573989694f64.ln(), 1e-13, ln_pdf(10.0)); test_exact(1.0, 10.0, 0.00271897749113479577864f64.ln(), ln_pdf(10.0)); test_absolute(10.0, 10.0, 2.4289227234060500084E-4f64.ln(), 1e-14, ln_pdf(10.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: FisherSnedecor| x.cdf(arg); test_absolute(0.1, 0.1, 0.44712986033425140335, 1e-15, cdf(0.1)); test_absolute(1.0, 0.1, 0.08156522095104674015, 1e-15, cdf(0.1)); test_absolute(10.0, 0.1, 0.033184005716276536322, 1e-13, cdf(0.1)); test_absolute(0.1, 1.0, 0.74378710917986379989, 1e-15, cdf(0.1)); test_absolute(1.0, 1.0, 0.1949822290421366451595, 1e-16, cdf(0.1)); test_absolute(10.0, 1.0, 0.0101195597354337146205, 1e-17, cdf(0.1)); test_absolute(0.1, 0.1, 0.5, 1e-15, cdf(1.0)); test_absolute(1.0, 0.1, 0.16734351500944271141, 1e-14, cdf(1.0)); test_absolute(10.0, 0.1, 0.12207560664741704938, 1e-13, cdf(1.0)); test_absolute(0.1, 1.0, 0.83265648499055728859, 1e-15, cdf(1.0)); test_absolute(1.0, 1.0, 0.5, 1e-15, cdf(1.0)); test_absolute(10.0, 1.0, 0.340893132302059872675, 1e-15, cdf(1.0)); } #[test] fn test_cdf_lower_bound() { let cdf = |arg: f64| move |x: FisherSnedecor| x.cdf(arg); test_exact(0.1, 0.1, 0.0, cdf(-1.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: FisherSnedecor| x.sf(arg); test_absolute(0.1, 0.1, 0.5528701396657489, 1e-12, sf(0.1)); test_absolute(1.0, 0.1, 0.9184347790489533, 1e-12, sf(0.1)); test_absolute(10.0, 0.1, 0.9668159942836896, 1e-12, sf(0.1)); test_absolute(0.1, 1.0, 0.25621289082013654, 1e-12, sf(0.1)); test_absolute(1.0, 1.0, 0.8050177709578634, 1e-12, sf(0.1)); test_absolute(10.0, 1.0, 0.9898804402645662, 1e-12, sf(0.1)); test_absolute(0.1, 0.1, 0.5, 1e-15, sf(1.0)); test_absolute(1.0, 0.1, 0.8326564849905562, 1e-12, sf(1.0)); test_absolute(10.0, 0.1, 0.8779243933525519, 1e-12, sf(1.0)); test_absolute(0.1, 1.0, 0.16734351500944344, 1e-12, sf(1.0)); test_absolute(1.0, 1.0, 0.5, 1e-12, sf(1.0)); test_absolute(10.0, 1.0, 0.65910686769794, 1e-12, sf(1.0)); } #[test] fn test_inverse_cdf() { let func = |arg: f64| move |x: FisherSnedecor| x.inverse_cdf(x.cdf(arg)); test_absolute(0.1, 0.1, 0.1, 1e-12, func(0.1)); test_absolute(1.0, 0.1, 0.1, 1e-12, func(0.1)); test_absolute(10.0, 0.1, 0.1, 1e-12, func(0.1)); test_absolute(0.1, 1.0, 0.1, 1e-12, func(0.1)); test_absolute(1.0, 1.0, 0.1, 1e-12, func(0.1)); test_absolute(10.0, 1.0, 0.1, 1e-12, func(0.1)); test_absolute(0.1, 0.1, 1.0, 1e-13, func(1.0)); test_absolute(1.0, 0.1, 1.0, 1e-12, func(1.0)); test_absolute(10.0, 0.1, 1.0, 1e-12, func(1.0)); test_absolute(0.1, 1.0, 1.0, 1e-12, func(1.0)); test_absolute(1.0, 1.0, 1.0, 1e-12, func(1.0)); test_absolute(10.0, 1.0, 1.0, 1e-12, func(1.0)); } #[test] fn test_sf_lower_bound() { let sf = |arg: f64| move |x: FisherSnedecor| x.sf(arg); test_exact(0.1, 0.1, 1.0, sf(-1.0)); } #[test] fn test_continuous() { test::check_continuous_distribution(&create_ok(10.0, 10.0), 0.0, 10.0); } } statrs-0.18.0/src/distribution/gamma.rs000064400000000000000000000470471046102023000162040ustar 00000000000000use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::prec; use crate::statistics::*; /// Implements the [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{Gamma, Continuous}; /// use statrs::statistics::Distribution; /// use statrs::prec; /// /// let n = Gamma::new(3.0, 1.0).unwrap(); /// assert_eq!(n.mean().unwrap(), 3.0); /// assert!(prec::almost_eq(n.pdf(2.0), 0.270670566473225383788, 1e-15)); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct Gamma { shape: f64, rate: f64, } /// Represents the errors that can occur when creating a [`Gamma`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum GammaError { /// The shape is NaN, zero or less than zero. ShapeInvalid, /// The rate is NaN, zero or less than zero. RateInvalid, /// The shape and rate are both infinite. ShapeAndRateInfinite, } impl std::fmt::Display for GammaError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { GammaError::ShapeInvalid => write!(f, "Shape is NaN zero, or less than zero."), GammaError::RateInvalid => write!(f, "Rate is NaN zero, or less than zero."), GammaError::ShapeAndRateInfinite => write!(f, "Shape and rate are infinite"), } } } impl std::error::Error for GammaError {} impl Gamma { /// Constructs a new gamma distribution with a shape (α) /// of `shape` and a rate (β) of `rate` /// /// # Errors /// /// Returns an error if `shape` is 'NaN' or inf or `rate` is `NaN` or inf. /// Also returns an error if `shape <= 0.0` or `rate <= 0.0` /// /// # Examples /// /// ``` /// use statrs::distribution::Gamma; /// /// let mut result = Gamma::new(3.0, 1.0); /// assert!(result.is_ok()); /// /// result = Gamma::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` pub fn new(shape: f64, rate: f64) -> Result { if shape.is_nan() || shape <= 0.0 { return Err(GammaError::ShapeInvalid); } if rate.is_nan() || rate <= 0.0 { return Err(GammaError::RateInvalid); } if shape.is_infinite() && rate.is_infinite() { return Err(GammaError::ShapeAndRateInfinite); } Ok(Gamma { shape, rate }) } /// Returns the shape (α) of the gamma distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Gamma; /// /// let n = Gamma::new(3.0, 1.0).unwrap(); /// assert_eq!(n.shape(), 3.0); /// ``` pub fn shape(&self) -> f64 { self.shape } /// Returns the rate (β) of the gamma distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Gamma; /// /// let n = Gamma::new(3.0, 1.0).unwrap(); /// assert_eq!(n.rate(), 1.0); /// ``` pub fn rate(&self) -> f64 { self.rate } } impl std::fmt::Display for Gamma { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Γ({}, {})", self.shape, self.rate) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Gamma { fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, self.shape, self.rate) } } impl ContinuousCDF for Gamma { /// Calculates the cumulative distribution function for the gamma /// distribution /// at `x` /// /// # Formula /// /// ```text /// (1 / Γ(α)) * γ(α, β * x) /// ``` /// /// where `α` is the shape, `β` is the rate, `Γ` is the gamma function, /// and `γ` is the lower incomplete gamma function fn cdf(&self, x: f64) -> f64 { if x <= 0.0 { 0.0 } else if ulps_eq!(x, self.shape) && self.rate.is_infinite() { 1.0 } else if self.rate.is_infinite() { 0.0 } else if x.is_infinite() { 1.0 } else { gamma::gamma_lr(self.shape, x * self.rate) } } /// Calculates the survival function for the gamma /// distribution at `x` /// /// # Formula /// /// ```text /// (1 / Γ(α)) * γ(α, β * x) /// ``` /// /// where `α` is the shape, `β` is the rate, `Γ` is the gamma function, /// and `γ` is the upper incomplete gamma function fn sf(&self, x: f64) -> f64 { if x <= 0.0 { 1.0 } else if ulps_eq!(x, self.shape) && self.rate.is_infinite() { 0.0 } else if self.rate.is_infinite() { 1.0 } else if x.is_infinite() { 0.0 } else { gamma::gamma_ur(self.shape, x * self.rate) } } fn inverse_cdf(&self, p: f64) -> f64 { if !(0.0..=1.0).contains(&p) { panic!("default inverse_cdf implementation should be provided probability on [0,1]") } if p == 0.0 { return self.min(); }; if p == 1.0 { return self.max(); }; // Bisection search for MAX_ITERS.0 iterations let mut high = 2.0; let mut low = 1.0; while self.cdf(low) > p { low /= 2.0; } while self.cdf(high) < p { high *= 2.0; } let mut x_0 = (high + low) / 2.0; for _ in 0..8 { if self.cdf(x_0) >= p { high = x_0; } else { low = x_0; } if prec::convergence(&mut x_0, (high + low) / 2.0) { break; } } // Newton Raphson, for at least one step for _ in 0..4 { let x_next = x_0 - (self.cdf(x_0) - p) / self.pdf(x_0); if prec::convergence(&mut x_0, x_next) { break; } } x_0 } } impl Min for Gamma { /// Returns the minimum value in the domain of the /// gamma distribution representable by a double precision /// float /// /// # Formula /// /// ```text /// 0 /// ``` fn min(&self) -> f64 { 0.0 } } impl Max for Gamma { /// Returns the maximum value in the domain of the /// gamma distribution representable by a double precision /// float /// /// # Formula /// /// ```text /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY } } impl Distribution for Gamma { /// Returns the mean of the gamma distribution /// /// # Formula /// /// ```text /// α / β /// ``` /// /// where `α` is the shape and `β` is the rate fn mean(&self) -> Option { Some(self.shape / self.rate) } /// Returns the variance of the gamma distribution /// /// # Formula /// /// ```text /// α / β^2 /// ``` /// /// where `α` is the shape and `β` is the rate fn variance(&self) -> Option { Some(self.shape / (self.rate * self.rate)) } /// Returns the entropy of the gamma distribution /// /// # Formula /// /// ```text /// α - ln(β) + ln(Γ(α)) + (1 - α) * ψ(α) /// ``` /// /// where `α` is the shape, `β` is the rate, `Γ` is the gamma function, /// and `ψ` is the digamma function fn entropy(&self) -> Option { let entr = self.shape - self.rate.ln() + gamma::ln_gamma(self.shape) + (1.0 - self.shape) * gamma::digamma(self.shape); Some(entr) } /// Returns the skewness of the gamma distribution /// /// # Formula /// /// ```text /// 2 / sqrt(α) /// ``` /// /// where `α` is the shape fn skewness(&self) -> Option { Some(2.0 / self.shape.sqrt()) } } impl Mode> for Gamma { /// Returns the mode for the gamma distribution /// /// # Formula /// /// ```text /// (α - 1) / β, where α≥1 /// ``` /// /// where `α` is the shape and `β` is the rate fn mode(&self) -> Option { if self.shape < 1.0 { None } else { Some((self.shape - 1.0) / self.rate) } } } impl Continuous for Gamma { /// Calculates the probability density function for the gamma distribution /// at `x` /// /// # Remarks /// /// Returns `NAN` if any of `shape` or `rate` are `f64::INFINITY` /// or if `x` is `f64::INFINITY` /// /// # Formula /// /// ```text /// (β^α / Γ(α)) * x^(α - 1) * e^(-β * x) /// ``` /// /// where `α` is the shape, `β` is the rate, and `Γ` is the gamma function fn pdf(&self, x: f64) -> f64 { if x < 0.0 { 0.0 } else if ulps_eq!(self.shape, 1.0) { self.rate * (-self.rate * x).exp() } else if self.shape > 160.0 { self.ln_pdf(x).exp() } else if x.is_infinite() { 0.0 } else { self.rate.powf(self.shape) * x.powf(self.shape - 1.0) * (-self.rate * x).exp() / gamma::gamma(self.shape) } } /// Calculates the log probability density function for the gamma /// distribution /// at `x` /// /// # Remarks /// /// Returns `NAN` if any of `shape` or `rate` are `f64::INFINITY` /// or if `x` is `f64::INFINITY` /// /// # Formula /// /// ```text /// ln((β^α / Γ(α)) * x^(α - 1) * e ^(-β * x)) /// ``` /// /// where `α` is the shape, `β` is the rate, and `Γ` is the gamma function fn ln_pdf(&self, x: f64) -> f64 { if x < 0.0 { f64::NEG_INFINITY } else if ulps_eq!(self.shape, 1.0) { self.rate.ln() - self.rate * x } else if x.is_infinite() { f64::NEG_INFINITY } else { self.shape * self.rate.ln() + (self.shape - 1.0) * x.ln() - self.rate * x - gamma::ln_gamma(self.shape) } } } /// Samples from a gamma distribution with a shape of `shape` and a /// rate of `rate` using `rng` as the source of randomness. Implementation from: /// /// _"A Simple Method for Generating Gamma Variables"_ - Marsaglia & Tsang /// /// ACM Transactions on Mathematical Software, Vol. 26, No. 3, September 2000, /// Pages 363-372 #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] pub fn sample_unchecked(rng: &mut R, shape: f64, rate: f64) -> f64 { let mut a = shape; let mut afix = 1.0; if shape < 1.0 { a = shape + 1.0; afix = rng.gen::().powf(1.0 / shape); } let d = a - 1.0 / 3.0; let c = 1.0 / (9.0 * d).sqrt(); loop { let mut x; let mut v; loop { x = super::normal::sample_unchecked(rng, 0.0, 1.0); v = 1.0 + c * x; if v > 0.0 { break; }; } v = v * v * v; x = x * x; let u: f64 = rng.gen(); if u < 1.0 - 0.0331 * x * x || u.ln() < 0.5 * x + d * (1.0 - v + v.ln()) { return afix * d * v / rate; } } } #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(shape: f64, rate: f64; Gamma; GammaError); #[test] fn test_create() { let valid = [ (1.0, 0.1), (1.0, 1.0), (10.0, 10.0), (10.0, 1.0), (10.0, f64::INFINITY), ]; for (s, r) in valid { create_ok(s, r); } } #[test] fn test_bad_create() { let invalid = [ (0.0, 0.0, GammaError::ShapeInvalid), (1.0, f64::NAN, GammaError::RateInvalid), (1.0, -1.0, GammaError::RateInvalid), (-1.0, 1.0, GammaError::ShapeInvalid), (-1.0, -1.0, GammaError::ShapeInvalid), (-1.0, f64::NAN, GammaError::ShapeInvalid), ( f64::INFINITY, f64::INFINITY, GammaError::ShapeAndRateInfinite, ), ]; for (s, r, err) in invalid { test_create_err(s, r, err); } } #[test] fn test_mean() { let f = |x: Gamma| x.mean().unwrap(); let test = [ ((1.0, 0.1), 10.0), ((1.0, 1.0), 1.0), ((10.0, 10.0), 1.0), ((10.0, 1.0), 10.0), ((10.0, f64::INFINITY), 0.0), ]; for ((s, r), res) in test { test_relative(s, r, res, f); } } #[test] fn test_variance() { let f = |x: Gamma| x.variance().unwrap(); let test = [ ((1.0, 0.1), 100.0), ((1.0, 1.0), 1.0), ((10.0, 10.0), 0.1), ((10.0, 1.0), 10.0), ((10.0, f64::INFINITY), 0.0), ]; for ((s, r), res) in test { test_relative(s, r, res, f); } } #[test] fn test_entropy() { let f = |x: Gamma| x.entropy().unwrap(); let test = [ ((1.0, 0.1), 3.302585092994045628506840223), ((1.0, 1.0), 1.0), ((10.0, 10.0), 0.2334690854869339583626209), ((10.0, 1.0), 2.53605417848097964238061239), ((10.0, f64::INFINITY), f64::NEG_INFINITY), ]; for ((s, r), res) in test { test_relative(s, r, res, f); } } #[test] fn test_skewness() { let f = |x: Gamma| x.skewness().unwrap(); let test = [ ((1.0, 0.1), 2.0), ((1.0, 1.0), 2.0), ((10.0, 10.0), 0.6324555320336758663997787), ((10.0, 1.0), 0.63245553203367586639977870), ((10.0, f64::INFINITY), 0.6324555320336758), ]; for ((s, r), res) in test { test_relative(s, r, res, f); } } #[test] fn test_mode() { let f = |x: Gamma| x.mode().unwrap(); let test = [((1.0, 0.1), 0.0), ((1.0, 1.0), 0.0)]; for &((s, r), res) in test.iter() { test_absolute(s, r, res, 10e-6, f); } let test = [ ((10.0, 10.0), 0.9), ((10.0, 1.0), 9.0), ((10.0, f64::INFINITY), 0.0), ]; for ((s, r), res) in test { test_relative(s, r, res, f); } } #[test] fn test_min_max() { let f = |x: Gamma| x.min(); let test = [ ((1.0, 0.1), 0.0), ((1.0, 1.0), 0.0), ((10.0, 10.0), 0.0), ((10.0, 1.0), 0.0), ((10.0, f64::INFINITY), 0.0), ]; for ((s, r), res) in test { test_relative(s, r, res, f); } let f = |x: Gamma| x.max(); let test = [ ((1.0, 0.1), f64::INFINITY), ((1.0, 1.0), f64::INFINITY), ((10.0, 10.0), f64::INFINITY), ((10.0, 1.0), f64::INFINITY), ((10.0, f64::INFINITY), f64::INFINITY), ]; for ((s, r), res) in test { test_relative(s, r, res, f); } } #[test] fn test_pdf() { let f = |arg: f64| move |x: Gamma| x.pdf(arg); let test = [ ((1.0, 0.1), 1.0, 0.090483741803595961836995), ((1.0, 0.1), 10.0, 0.036787944117144234201693), ((1.0, 1.0), 1.0, 0.367879441171442321595523), ((1.0, 1.0), 10.0, 0.000045399929762484851535), ((10.0, 10.0), 1.0, 1.251100357211332989847649), ((10.0, 10.0), 10.0, 1.025153212086870580621609e-30), ((10.0, 1.0), 1.0, 0.000001013777119630297402), ((10.0, 1.0), 10.0, 0.125110035721133298984764), ]; for ((s, r), x, res) in test { test_relative(s, r, res, f(x)); } // TODO: test special // test_is_nan((10.0, f64::INFINITY), pdf(1.0)); // is this really the behavior we want? // TODO: test special // (10.0, f64::INFINITY, f64::INFINITY, 0.0, pdf(f64::INFINITY)),]; } #[test] fn test_pdf_at_zero() { test_relative(1.0, 0.1, 0.1, |x| x.pdf(0.0)); test_relative(1.0, 0.1, 0.1f64.ln(), |x| x.ln_pdf(0.0)); } #[test] fn test_ln_pdf() { let f = |arg: f64| move |x: Gamma| x.ln_pdf(arg); let test = [ ((1.0, 0.1), 1.0, -2.40258509299404563405795), ((1.0, 0.1), 10.0, -3.30258509299404562850684), ((1.0, 1.0), 1.0, -1.0), ((1.0, 1.0), 10.0, -10.0), ((10.0, 10.0), 1.0, 0.224023449858987228972196), ((10.0, 10.0), 10.0, -69.0527107131946016148658), ((10.0, 1.0), 1.0, -13.8018274800814696112077), ((10.0, 1.0), 10.0, -2.07856164313505845504579), ((10.0, f64::INFINITY), f64::INFINITY, f64::NEG_INFINITY), ]; for ((s, r), x, res) in test { test_relative(s, r, res, f(x)); } // TODO: test special // test_is_nan((10.0, f64::INFINITY), f(1.0)); // is this really the behavior we want? } #[test] fn test_cdf() { let f = |arg: f64| move |x: Gamma| x.cdf(arg); let test = [ ((1.0, 0.1), 1.0, 0.095162581964040431858607), ((1.0, 0.1), 10.0, 0.632120558828557678404476), ((1.0, 1.0), 1.0, 0.632120558828557678404476), ((1.0, 1.0), 10.0, 0.999954600070237515148464), ((10.0, 10.0), 1.0, 0.542070285528147791685835), ((10.0, 10.0), 10.0, 0.999999999999999999999999), ((10.0, 1.0), 1.0, 0.000000111425478338720677), ((10.0, 1.0), 10.0, 0.542070285528147791685835), ((10.0, f64::INFINITY), 1.0, 0.0), ((10.0, f64::INFINITY), 10.0, 1.0), ]; for ((s, r), x, res) in test { test_relative(s, r, res, f(x)); } } #[test] fn test_cdf_at_zero() { test_relative(1.0, 0.1, 0.0, |x| x.cdf(0.0)); } #[test] fn test_cdf_inverse_identity() { let f = |p: f64| move |g: Gamma| g.cdf(g.inverse_cdf(p)); let params = [ (1.0, 0.1), (1.0, 1.0), (10.0, 10.0), (10.0, 1.0), (100.0, 200.0), ]; for (s, r) in params { for n in -5..0 { let p = 10.0f64.powi(n); test_relative(s, r, p, f(p)); } } // test case from issue #200 { let x = 20.5567; let f = |x: f64| move |g: Gamma| g.inverse_cdf(g.cdf(x)); test_relative(3.0, 0.5, x, f(x)) } } #[test] fn test_sf() { let f = |arg: f64| move |x: Gamma| x.sf(arg); let test = [ ((1.0, 0.1), 1.0, 0.9048374180359595), ((1.0, 0.1), 10.0, 0.3678794411714419), ((1.0, 1.0), 1.0, 0.3678794411714419), ((1.0, 1.0), 10.0, 4.539992976249074e-5), ((10.0, 10.0), 1.0, 0.4579297144718528), ((10.0, 10.0), 10.0, 1.1253473960842808e-31), ((10.0, 1.0), 1.0, 0.9999998885745217), ((10.0, 1.0), 10.0, 0.4579297144718528), ((10.0, f64::INFINITY), 1.0, 1.0), ((10.0, f64::INFINITY), 10.0, 0.0), ]; for ((s, r), x, res) in test { test_relative(s, r, res, f(x)); } } #[test] fn test_sf_at_zero() { test_relative(1.0, 0.1, 1.0, |x| x.sf(0.0)); } #[test] fn test_continuous() { test::check_continuous_distribution(&create_ok(1.0, 0.5), 0.0, 20.0); test::check_continuous_distribution(&create_ok(9.0, 2.0), 0.0, 20.0); } } statrs-0.18.0/src/distribution/geometric.rs000064400000000000000000000322331046102023000170670ustar 00000000000000use crate::distribution::{Discrete, DiscreteCDF}; use crate::statistics::*; use std::f64; /// Implements the /// [Geometric](https://en.wikipedia.org/wiki/Geometric_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{Geometric, Discrete}; /// use statrs::statistics::Distribution; /// /// let n = Geometric::new(0.3).unwrap(); /// assert_eq!(n.mean().unwrap(), 1.0 / 0.3); /// assert_eq!(n.pmf(1), 0.3); /// assert_eq!(n.pmf(2), 0.21); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct Geometric { p: f64, } /// Represents the errors that can occur when creating a [`Geometric`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum GeometricError { /// The probability is NaN or not in `(0, 1]`. ProbabilityInvalid, } impl std::fmt::Display for GeometricError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { GeometricError::ProbabilityInvalid => write!(f, "Probability is NaN or not in (0, 1]"), } } } impl std::error::Error for GeometricError {} impl Geometric { /// Constructs a new shifted geometric distribution with a probability /// of `p` /// /// # Errors /// /// Returns an error if `p` is not in `(0, 1]` /// /// # Examples /// /// ``` /// use statrs::distribution::Geometric; /// /// let mut result = Geometric::new(0.5); /// assert!(result.is_ok()); /// /// result = Geometric::new(0.0); /// assert!(result.is_err()); /// ``` pub fn new(p: f64) -> Result { if p <= 0.0 || p > 1.0 || p.is_nan() { Err(GeometricError::ProbabilityInvalid) } else { Ok(Geometric { p }) } } /// Returns the probability `p` of the geometric /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Geometric; /// /// let n = Geometric::new(0.5).unwrap(); /// assert_eq!(n.p(), 0.5); /// ``` pub fn p(&self) -> f64 { self.p } } impl std::fmt::Display for Geometric { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Geom({})", self.p) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Geometric { fn sample(&self, r: &mut R) -> u64 { if ulps_eq!(self.p, 1.0) { 1 } else { let x: f64 = r.sample(::rand::distributions::OpenClosed01); // This cast is safe, because the largest finite value this expression can take is when // `x = 1.4e-45` and `1.0 - self.p = 0.9999999999999999`, in which case we get // `930262250532780300`, which when casted to a `u64` is `930262250532780288`. x.log(1.0 - self.p).ceil() as u64 } } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Geometric { fn sample(&self, r: &mut R) -> f64 { r.sample::(self) as f64 } } impl DiscreteCDF for Geometric { /// Calculates the cumulative distribution function for the geometric /// distribution at `x` /// /// # Formula /// /// ```text /// 1 - (1 - p) ^ x /// ``` fn cdf(&self, x: u64) -> f64 { if x == 0 { 0.0 } else { // 1 - (1 - p) ^ x = 1 - exp(log(1 - p)*x) // = -expm1(log1p(-p)*x)) // = -((-p).ln_1p() * x).exp_m1() -((-self.p).ln_1p() * (x as f64)).exp_m1() } } /// Calculates the survival function for the geometric /// distribution at `x` /// /// # Formula /// /// ```text /// (1 - p) ^ x /// ``` fn sf(&self, x: u64) -> f64 { // (1-p) ^ x = exp(log(1-p)*x) // = exp(log1p(-p) * x) if x == 0 { 1.0 } else { ((-self.p).ln_1p() * (x as f64)).exp() } } } impl Min for Geometric { /// Returns the minimum value in the domain of the /// geometric distribution representable by a 64-bit /// integer /// /// # Formula /// /// ```text /// 1 /// ``` fn min(&self) -> u64 { 1 } } impl Max for Geometric { /// Returns the maximum value in the domain of the /// geometric distribution representable by a 64-bit /// integer /// /// # Formula /// /// ```text /// 2^63 - 1 /// ``` fn max(&self) -> u64 { u64::MAX } } impl Distribution for Geometric { /// Returns the mean of the geometric distribution /// /// # Formula /// /// ```text /// 1 / p /// ``` fn mean(&self) -> Option { Some(1.0 / self.p) } /// Returns the standard deviation of the geometric distribution /// /// # Formula /// /// ```text /// (1 - p) / p^2 /// ``` fn variance(&self) -> Option { Some((1.0 - self.p) / (self.p * self.p)) } /// Returns the entropy of the geometric distribution /// /// # Formula /// /// ```text /// (-(1 - p) * log_2(1 - p) - p * log_2(p)) / p /// ``` fn entropy(&self) -> Option { let inv = 1.0 / self.p; Some(-inv * (1. - self.p).log(2.0) + (inv - 1.).log(2.0)) } /// Returns the skewness of the geometric distribution /// /// # Formula /// /// ```text /// (2 - p) / sqrt(1 - p) /// ``` fn skewness(&self) -> Option { if ulps_eq!(self.p, 1.0) { return Some(f64::INFINITY); }; Some((2.0 - self.p) / (1.0 - self.p).sqrt()) } } impl Mode> for Geometric { /// Returns the mode of the geometric distribution /// /// # Formula /// /// ```text /// 1 /// ``` fn mode(&self) -> Option { Some(1) } } impl Median for Geometric { /// Returns the median of the geometric distribution /// /// # Remarks /// /// # Formula /// /// ```text /// ceil(-1 / log_2(1 - p)) /// ``` fn median(&self) -> f64 { (-f64::consts::LN_2 / (1.0 - self.p).ln()).ceil() } } impl Discrete for Geometric { /// Calculates the probability mass function for the geometric /// distribution at `x` /// /// # Formula /// /// ```text /// (1 - p)^(x - 1) * p /// ``` fn pmf(&self, x: u64) -> f64 { if x == 0 { 0.0 } else { (1.0 - self.p).powi(x as i32 - 1) * self.p } } /// Calculates the log probability mass function for the geometric /// distribution at `x` /// /// # Formula /// /// ```text /// ln((1 - p)^(x - 1) * p) /// ``` fn ln_pmf(&self, x: u64) -> f64 { if x == 0 { f64::NEG_INFINITY } else if ulps_eq!(self.p, 1.0) && x == 1 { 0.0 } else if ulps_eq!(self.p, 1.0) { f64::NEG_INFINITY } else { ((x - 1) as f64 * (1.0 - self.p).ln()) + self.p.ln() } } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(p: f64; Geometric; GeometricError); #[test] fn test_create() { create_ok(0.3); create_ok(1.0); } #[test] fn test_bad_create() { create_err(f64::NAN); create_err(0.0); create_err(-1.0); create_err(2.0); } #[test] fn test_mean() { let mean = |x: Geometric| x.mean().unwrap(); test_exact(0.3, 1.0 / 0.3, mean); test_exact(1.0, 1.0, mean); } #[test] fn test_variance() { let variance = |x: Geometric| x.variance().unwrap(); test_exact(0.3, 0.7 / (0.3 * 0.3), variance); test_exact(1.0, 0.0, variance); } #[test] fn test_entropy() { let entropy = |x: Geometric| x.entropy().unwrap(); test_absolute(0.3, 2.937636330768973333333, 1e-14, entropy); test_is_nan(1.0, entropy); } #[test] fn test_skewness() { let skewness = |x: Geometric| x.skewness().unwrap(); test_absolute(0.3, 2.031888635868469187947, 1e-15, skewness); test_exact(1.0, f64::INFINITY, skewness); } #[test] fn test_median() { let median = |x: Geometric| x.median(); test_exact(0.0001, 6932.0, median); test_exact(0.1, 7.0, median); test_exact(0.3, 2.0, median); test_exact(0.9, 1.0, median); // test_exact(0.99, 1.0, median); test_exact(1.0, 0.0, median); } #[test] fn test_mode() { let mode = |x: Geometric| x.mode().unwrap(); test_exact(0.3, 1, mode); test_exact(1.0, 1, mode); } #[test] fn test_min_max() { let min = |x: Geometric| x.min(); let max = |x: Geometric| x.max(); test_exact(0.3, 1, min); test_exact(0.3, u64::MAX, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: Geometric| x.pmf(arg); test_exact(0.3, 0.3, pmf(1)); test_exact(0.3, 0.21, pmf(2)); test_exact(1.0, 1.0, pmf(1)); test_exact(1.0, 0.0, pmf(2)); test_absolute(0.5, 0.5, 1e-10, pmf(1)); test_absolute(0.5, 0.25, 1e-10, pmf(2)); } #[test] fn test_pmf_lower_bound() { let pmf = |arg: u64| move |x: Geometric| x.pmf(arg); test_exact(0.3, 0.0, pmf(0)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: Geometric| x.ln_pmf(arg); test_absolute(0.3, -1.203972804325935992623, 1e-15, ln_pmf(1)); test_absolute(0.3, -1.560647748264668371535, 1e-15, ln_pmf(2)); test_exact(1.0, 0.0, ln_pmf(1)); test_exact(1.0, f64::NEG_INFINITY, ln_pmf(2)); } #[test] fn test_ln_pmf_lower_bound() { let ln_pmf = |arg: u64| move |x: Geometric| x.ln_pmf(arg); test_exact(0.3, f64::NEG_INFINITY, ln_pmf(0)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Geometric| x.cdf(arg); test_exact(1.0, 1.0, cdf(1)); test_exact(1.0, 1.0, cdf(2)); test_absolute(0.5, 0.5, 1e-15, cdf(1)); test_absolute(0.5, 0.75, 1e-15, cdf(2)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Geometric| x.sf(arg); test_exact(1.0, 0.0, sf(1)); test_exact(1.0, 0.0, sf(2)); test_absolute(0.5, 0.5, 1e-15, sf(1)); test_absolute(0.5, 0.25, 1e-15, sf(2)); } #[test] fn test_cdf_small_p() { // // Expected values were computed with the arbitrary precision // library mpmath in Python, e.g.: // // import mpmath // mpmath.mp.dps = 400 // p = mpmath.mpf(1e-9) // k = 5 // cdf = float(1 - (1 - p)**k) // # cdf is 4.99999999e-09 // let geom = Geometric::new(1e-9f64).unwrap(); let cdf = geom.cdf(5u64); let expected = 4.99999999e-09; assert_relative_eq!(cdf, expected, epsilon = 0.0, max_relative = 1e-15); } #[test] fn test_sf_small_p() { let geom = Geometric::new(1e-9f64).unwrap(); let sf = geom.sf(5u64); let expected = 0.999999995; assert_relative_eq!(sf, expected, epsilon = 0.0, max_relative = 1e-15); } #[test] fn test_cdf_very_small_p() { // // Expected values were computed with the arbitrary precision // library mpmath in Python, e.g.: // // import mpmath // mpmath.mp.dps = 400 // p = mpmath.mpf(1e-17) // k = 100000000000000 // cdf = float(1 - (1 - p)**k) // # cdf is 0.0009995001666250085 // let geom = Geometric::new(1e-17f64).unwrap(); let cdf = geom.cdf(10u64); let expected = 1e-16f64; assert_relative_eq!(cdf, expected, epsilon = 0.0, max_relative = 1e-15); let cdf = geom.cdf(100000000000000u64); let expected = 0.0009995001666250085f64; assert_relative_eq!(cdf, expected, epsilon = 0.0, max_relative = 1e-15); } #[test] fn test_sf_very_small_p() { let geom = Geometric::new(1e-17f64).unwrap(); let sf = geom.sf(10u64); let expected = 0.9999999999999999; assert_relative_eq!(sf, expected, epsilon = 0.0, max_relative = 1e-15); let sf = geom.sf(100000000000000u64); let expected = 0.999000499833375; assert_relative_eq!(sf, expected, epsilon = 0.0, max_relative = 1e-15); } #[test] fn test_cdf_lower_bound() { let cdf = |arg: u64| move |x: Geometric| x.cdf(arg); test_exact(0.3, 0.0, cdf(0)); } #[test] fn test_sf_lower_bound() { let sf = |arg: u64| move |x: Geometric| x.sf(arg); test_exact(0.3, 1.0, sf(0)); } #[test] fn test_discrete() { test::check_discrete_distribution(&create_ok(0.3), 100); test::check_discrete_distribution(&create_ok(0.6), 100); test::check_discrete_distribution(&create_ok(1.0), 1); } } statrs-0.18.0/src/distribution/gumbel.rs000064400000000000000000000441501046102023000163650ustar 00000000000000use super::{Continuous, ContinuousCDF}; use crate::consts::EULER_MASCHERONI; use crate::statistics::*; use std::f64::{self, consts::PI}; /// Implements the [Gumbel](https://en.wikipedia.org/wiki/Gumbel_distribution) /// distribution, also known as the type-I generalized extreme value distribution. /// /// # Examples /// /// ``` /// use statrs::distribution::{Gumbel, Continuous}; /// use statrs::{consts::EULER_MASCHERONI, statistics::Distribution}; /// /// let n = Gumbel::new(0.0, 1.0).unwrap(); /// assert_eq!(n.location(), 0.0); /// assert_eq!(n.skewness().unwrap(), 1.13955); /// assert_eq!(n.pdf(0.0), 0.36787944117144233); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct Gumbel { location: f64, scale: f64, } /// Represents the errors that can occur when creating a [`Gumbel`] #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum GumbelError { /// The location is invalid (NAN) LocationInvalid, /// The scale is NAN, zero or less than zero ScaleInvalid, } impl std::fmt::Display for GumbelError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { GumbelError::LocationInvalid => write!(f, "Location is NAN"), GumbelError::ScaleInvalid => write!(f, "Scale is NAN, zero or less than zero"), } } } impl std::error::Error for GumbelError {} impl Gumbel { /// Constructs a new Gumbel distribution with the given /// location and scale. /// /// # Errors /// /// Returns an error if location or scale are `NaN` or `scale <= 0.0` /// /// # Examples /// /// ``` /// use statrs::distribution::Gumbel; /// /// let mut result = Gumbel::new(0.0, 1.0); /// assert!(result.is_ok()); /// /// result = Gumbel::new(0.0, -1.0); /// assert!(result.is_err()); /// ``` pub fn new(location: f64, scale: f64) -> Result { if location.is_nan() { return Err(GumbelError::LocationInvalid); } if scale.is_nan() || scale <= 0.0 { return Err(GumbelError::ScaleInvalid); } Ok(Self { location, scale }) } /// Returns the location of the Gumbel distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Gumbel; /// /// let n = Gumbel::new(0.0, 1.0).unwrap(); /// assert_eq!(n.location(), 0.0); /// ``` pub fn location(&self) -> f64 { self.location } /// Returns the scale of the Gumbel distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Gumbel; /// /// let n = Gumbel::new(0.0, 1.0).unwrap(); /// assert_eq!(n.scale(), 1.0); /// ``` pub fn scale(&self) -> f64 { self.scale } } impl std::fmt::Display for Gumbel { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Gumbel({:?}, {:?})", self.location, self.scale) } } #[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Gumbel { fn sample(&self, r: &mut R) -> f64 { self.location - self.scale * ((-(r.gen::())).ln()).ln() } } impl ContinuousCDF for Gumbel { /// Calculates the cumulative distribution function for the /// Gumbel distribution at `x` /// /// # Formula /// /// ```text /// e^(-e^(-(x - μ) / β)) /// ``` /// /// where `μ` is the location and `β` is the scale fn cdf(&self, x: f64) -> f64 { (-(-(x - self.location) / self.scale).exp()).exp() } /// Calculates the inverse cumulative distribution function for the /// Gumbel distribution at `x` /// /// # Formula /// /// ```text /// μ - β ln(-ln(p)) where 0 < p < 1 /// -INF where p <= 0 /// INF otherwise /// ``` /// /// where `μ` is the location and `β` is the scale fn inverse_cdf(&self, p: f64) -> f64 { if p <= 0.0 { f64::NEG_INFINITY } else if p >= 1.0 { f64::INFINITY } else { self.location - self.scale * ((-(p.ln())).ln()) } } /// Calculates the survival function for the /// Gumbel distribution at `x` /// /// # Formula /// /// ```text /// 1 - e^(-e^(-(x - μ) / β)) /// ``` /// /// where `μ` is the location and `β` is the scale fn sf(&self, x: f64) -> f64 { -(-(-(x - self.location) / self.scale).exp()).exp_m1() } } impl Min for Gumbel { /// Returns the minimum value in the domain of the Gumbel /// distribution representable by a double precision float /// /// # Formula /// /// ```text /// NEG_INF /// ``` fn min(&self) -> f64 { f64::NEG_INFINITY } } impl Max for Gumbel { /// Returns the maximum value in the domain of the Gumbel /// distribution representable by a double precision float /// /// # Formula /// /// ```text /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY } } impl Distribution for Gumbel { /// Returns the entropy of the Gumbel distribution /// /// # Formula /// /// ```text /// ln(β) + γ + 1 /// ``` /// /// where `β` is the scale /// and `γ` is the Euler-Mascheroni constant (approx 0.57721) fn entropy(&self) -> Option { Some(1.0 + EULER_MASCHERONI + (self.scale).ln()) } /// Returns the mean of the Gumbel distribution /// /// # Formula /// /// ```text /// μ + γβ /// ``` /// /// where `μ` is the location, `β` is the scale /// and `γ` is the Euler-Mascheroni constant (approx 0.57721) fn mean(&self) -> Option { Some(self.location + (EULER_MASCHERONI * self.scale)) } /// Returns the skewness of the Gumbel distribution /// /// # Formula /// /// ```text /// 12 * sqrt(6) * ζ(3) / π^3 ≈ 1.13955 /// ``` /// ζ(3) is the Riemann zeta function evaluated at 3 (approx 1.20206) /// and π is the constant PI (approx 3.14159) /// /// This approximately evaluates to 1.13955 fn skewness(&self) -> Option { Some(1.13955) } /// Returns the variance of the Gumbel distribution /// /// # Formula /// /// ```text /// (π^2 / 6) * β^2 /// ``` /// /// where `β` is the scale and `π` is the constant PI (approx 3.14159) fn variance(&self) -> Option { Some(((PI * PI) / 6.0) * self.scale * self.scale) } /// Returns the standard deviation of the Gumbel distribution /// /// # Formula /// /// ```text /// β * π / sqrt(6) /// ``` /// /// where `β` is the scale and `π` is the constant PI (approx 3.14159) fn std_dev(&self) -> Option { Some(self.scale * PI / 6.0_f64.sqrt()) } } impl Median for Gumbel { /// Returns the median of the Gumbel distribution /// /// # Formula /// /// ```text /// μ - β ln(ln(2)) /// ``` /// /// where `μ` is the location and `β` is the scale parameter fn median(&self) -> f64 { self.location - self.scale * (((2.0_f64).ln()).ln()) } } impl Mode for Gumbel { /// Returns the mode of the Gumbel distribution /// /// # Formula /// /// ```text /// μ /// ``` /// /// where `μ` is the location fn mode(&self) -> f64 { self.location } } impl Continuous for Gumbel { /// Calculates the probability density function for the Gumbel /// distribution at `x` /// /// # Formula /// /// ```text /// (1/β) * exp(-(x - μ)/β) * exp(-exp(-(x - μ)/β)) /// ``` /// /// where `μ` is the location, `β` is the scale fn pdf(&self, x: f64) -> f64 { (1.0_f64 / self.scale) * (-(x - self.location) / (self.scale)).exp() * (-((-(x - self.location) / self.scale).exp())).exp() } /// Calculates the log probability density function for the Gumbel /// distribution at `x` /// /// # Formula /// /// ```text /// ln((1/β) * exp(-(x - μ)/β) * exp(-exp(-(x - μ)/β))) /// ``` /// /// where `μ` is the location, `β` is the scale fn ln_pdf(&self, x: f64) -> f64 { ((1.0_f64 / self.scale) * (-(x - self.location) / (self.scale)).exp() * (-((-(x - self.location) / self.scale).exp())).exp()) .ln() } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::testing_boiler; testing_boiler!(location: f64, scale: f64; Gumbel; GumbelError); #[test] fn test_create() { create_ok(0.0, 0.1); create_ok(0.0, 1.0); create_ok(0.0, 10.0); create_ok(10.0, 11.0); create_ok(-5.0, 100.0); create_ok(0.0, f64::INFINITY); } #[test] fn test_bad_create() { let invalid = [ (f64::NAN, 1.0, GumbelError::LocationInvalid), (1.0, f64::NAN, GumbelError::ScaleInvalid), (f64::NAN, f64::NAN, GumbelError::LocationInvalid), (1.0, 0.0, GumbelError::ScaleInvalid), (0.0, f64::NEG_INFINITY, GumbelError::ScaleInvalid) ]; for (location, scale, err) in invalid { test_create_err(location, scale, err); } } #[test] fn test_min_max() { let min = |x: Gumbel| x.min(); let max = |x:Gumbel| x.max(); test_exact(0.0, 1.0, f64::NEG_INFINITY, min); test_exact(0.0, 1.0, f64::INFINITY, max); } #[test] fn test_entropy() { let entropy = |x: Gumbel| x.entropy().unwrap(); test_exact(0.0, 2.0, 2.270362845461478, entropy); test_exact(0.1, 4.0, 2.9635100260214235, entropy); test_exact(1.0, 10.0, 3.8798007578955787, entropy); test_exact(10.0, 11.0, 3.9751109376999034, entropy); } #[test] fn test_mean() { let mean = |x: Gumbel| x.mean().unwrap(); test_exact(0.0, 2.0, 1.1544313298030658, mean); test_exact(0.1, 4.0, 2.4088626596061316, mean); test_exact(1.0, 10.0, 6.772156649015328, mean); test_exact(10.0, 11.0, 16.34937231391686, mean); test_exact(10.0, f64::INFINITY, f64::INFINITY, mean); } #[test] fn test_skewness() { let skewness = |x: Gumbel| x.skewness().unwrap(); test_exact(0.0, 2.0, 1.13955, skewness); test_exact(0.1, 4.0, 1.13955, skewness); test_exact(1.0, 10.0, 1.13955, skewness); test_exact(10.0, 11.0, 1.13955, skewness); test_exact(10.0, f64::INFINITY, 1.13955, skewness); } #[test] fn test_variance() { let variance = |x: Gumbel| x.variance().unwrap(); test_exact(0.0, 2.0, 6.579736267392906, variance); test_exact(0.1, 4.0, 26.318945069571624, variance); test_exact(1.0, 10.0, 164.49340668482265, variance); test_exact(10.0, 11.0, 199.03702208863538, variance); } #[test] fn test_std_dev() { let std_dev = |x: Gumbel| x.std_dev().unwrap(); test_exact(0.0, 2.0, 2.565099660323728, std_dev); test_exact(0.1, 4.0, 5.130199320647456, std_dev); test_exact(1.0, 10.0, 12.82549830161864, std_dev); test_exact(10.0, 11.0, 14.108048131780505, std_dev); } #[test] fn test_median() { let median = |x: Gumbel| x.median(); test_exact(0.0, 2.0, 0.7330258411633287, median); test_exact(0.1, 4.0, 1.5660516823266574, median); test_exact(1.0, 10.0, 4.665129205816644, median); test_exact(10.0, 11.0, 14.031642126398307, median); test_exact(10.0, f64::INFINITY, f64::INFINITY, median); } #[test] fn test_mode() { let mode = |x: Gumbel| x.mode(); test_exact(0.0, 2.0, 0.0, mode); test_exact(0.1, 4.0, 0.1, mode); test_exact(1.0, 10.0, 1.0, mode); test_exact(10.0, 11.0, 10.0, mode); test_exact(10.0, f64::INFINITY, 10.0, mode); } #[test] fn test_cdf() { let cdf = |a: f64| move |x: Gumbel| x.cdf(a); test_exact(0.0, 0.1, 0.0, cdf(-5.0)); test_exact(0.0, 0.1, 0.0, cdf(-1.0)); test_exact(0.0, 0.1, 0.36787944117144233, cdf(0.0)); test_exact(0.0, 0.1, 0.9999546011007987, cdf(1.0)); test_absolute(0.0, 0.1, 0.99999999999999999, 1e-12, cdf(5.0)); test_absolute(0.0, 1.0, 0.06598803584531253, 1e-12, cdf(-1.0)); test_exact(0.0, 1.0, 0.36787944117144233, cdf(0.0)); test_absolute(0.0, 10.0, 0.192295645547964928, 1e-12, cdf(-5.0)); test_absolute(0.0, 10.0, 0.3311542771529088, 1e-12, cdf(-1.0)); test_exact(0.0, 10.0, 0.36787944117144233, cdf(0.0)); test_absolute(0.0, 10.0, 0.4046076616641318, 1e-12, cdf(1.0)); test_absolute(0.0, 10.0, 0.545239211892605, 1e-12, cdf(5.0)); test_exact(-2.0, f64::INFINITY, 0.36787944117144233, cdf(-5.0)); test_exact(-2.0, f64::INFINITY, 0.36787944117144233, cdf(-1.0)); test_exact(-2.0, f64::INFINITY, 0.36787944117144233, cdf(0.0)); test_exact(-2.0, f64::INFINITY, 0.36787944117144233, cdf(1.0)); test_exact(-2.0, f64::INFINITY, 0.36787944117144233, cdf(5.0)); test_exact(f64::INFINITY, 1.0, 0.0, cdf(-5.0)); test_exact(f64::INFINITY, 1.0, 0.0, cdf(-1.0)); test_exact(f64::INFINITY, 1.0, 0.0, cdf(0.0)); test_exact(f64::INFINITY, 1.0, 0.0, cdf(1.0)); test_exact(f64::INFINITY, 1.0, 0.0, cdf(5.0)); } #[test] fn test_inverse_cdf() { let inv_cdf = |a: f64| move |x: Gumbel| x.inverse_cdf(a); test_exact(0.0, 0.1, f64::NEG_INFINITY, inv_cdf(-5.0)); test_exact(0.0, 0.1, f64::NEG_INFINITY, inv_cdf(-1.0)); test_exact(0.0, 0.1, f64::NEG_INFINITY, inv_cdf(0.0)); test_exact(0.0, 0.1, f64::INFINITY, inv_cdf(1.0)); test_exact(0.0, 0.1, f64::INFINITY, inv_cdf(5.0)); test_absolute(0.0, 1.0, -0.8340324452479557, 1e-12, inv_cdf(0.1)); test_absolute(0.0, 10.0, 3.6651292058166436, 1e-12, inv_cdf(0.5)); test_absolute(0.0, 10.0, 22.503673273124456, 1e-12, inv_cdf(0.9)); test_exact(2.0, f64::INFINITY, f64::NEG_INFINITY, inv_cdf(0.1)); test_exact(-2.0, f64::INFINITY, f64::INFINITY, inv_cdf(0.5)); test_exact(f64::INFINITY, 1.0, f64::INFINITY, inv_cdf(0.1)); } #[test] fn test_sf() { let sf = |a: f64| move |x: Gumbel| x.sf(a); test_exact(0.0, 0.1, 1.0, sf(-5.0)); test_exact(0.0, 0.1, 1.0, sf(-1.0)); test_absolute(0.0, 0.1, 0.632120558828557678, 1e-12, sf(0.0)); test_absolute(0.0, 0.1, 0.000045398899201269, 1e-12, sf(1.0)); test_absolute(0.0, 1.0, 0.934011964154687462, 1e-12, sf(-1.0)); test_absolute(0.0, 1.0, 0.632120558828557678, 1e-12, sf(0.0)); test_absolute(0.0, 1.0, 0.3077993724446536, 1e-12, sf(1.0)); test_absolute(0.0, 10.0, 0.66884572284709110, 1e-12, sf(-1.0)); test_absolute(0.0, 10.0, 0.632120558828557678, 1e-12, sf(0.0)); test_absolute(0.0, 10.0, 0.595392338335868174, 1e-12, sf(1.0)); test_exact(-2.0, f64::INFINITY, 0.6321205588285576784, sf(-5.0)); test_exact(-2.0, f64::INFINITY, 0.6321205588285576784, sf(-1.0)); test_exact(-2.0, f64::INFINITY, 0.6321205588285576784, sf(0.0)); test_exact(-2.0, f64::INFINITY, 0.6321205588285576784, sf(1.0)); test_exact(-2.0, f64::INFINITY, 0.6321205588285576784, sf(5.0)); test_exact(f64::INFINITY, 1.0, 1.0, sf(-5.0)); test_exact(f64::INFINITY, 1.0, 1.0, sf(-1.0)); test_exact(f64::INFINITY, 1.0, 1.0, sf(0.0)); test_exact(f64::INFINITY, 1.0, 1.0, sf(1.0)); test_exact(f64::INFINITY, 1.0, 1.0, sf(5.0)); test_absolute(0.0, 1.0, 4.248354255291589e-18, 1e-32, sf(40.0)); test_absolute(0.0, 1.0, 1.804851387845415e-35, 1e-50, sf(80.0)); } #[test] fn test_pdf() { let pdf = |a: f64| move |x: Gumbel| x.pdf(a); test_exact(0.0, 0.1, 0.0, pdf(-5.0)); test_exact(0.0, 0.1, 3.678794411714423215, pdf(0.0)); test_absolute(0.0, 0.1, 0.0004539786865564, 1e-12, pdf(1.0)); test_absolute(0.0, 1.0, 0.1793740787340171, 1e-12, pdf(-1.0)); test_exact(0.0, 1.0, 0.36787944117144233, pdf(0.0)); test_absolute(0.0, 1.0, 0.25464638004358249, 1e-12, pdf(1.0)); test_absolute(0.0, 10.0, 0.031704192107794217, 1e-12, pdf(-5.0)); test_absolute(0.0, 10.0, 0.0365982076505757, 1e-12, pdf(-1.0)); test_exact(0.0, 10.0, 0.036787944117144233, pdf(0.0)); test_absolute(0.0, 10.0, 0.03661041518977401428, 1e-12, pdf(1.0)); test_absolute(0.0, 10.0, 0.033070429889041, 1e-12, pdf(5.0)); test_exact(-2.0, f64::INFINITY, 0.0, pdf(-5.0)); test_exact(-2.0, f64::INFINITY, 0.0, pdf(-1.0)); test_exact(-2.0, f64::INFINITY, 0.0, pdf(0.0)); test_exact(-2.0, f64::INFINITY, 0.0, pdf(1.0)); test_exact(-2.0, f64::INFINITY, 0.0, pdf(5.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |a: f64| move |x: Gumbel| x.ln_pdf(a); test_exact(0.0, 0.1, 0.0_f64.ln(), ln_pdf(-5.0)); test_exact(0.0, 0.1, 3.678794411714423215_f64.ln(), ln_pdf(0.0)); test_absolute(0.0, 0.1, 0.0004539786865564_f64.ln(), 1e-12, ln_pdf(1.0)); test_absolute(0.0, 1.0, 0.1793740787340171_f64.ln(), 1e-12, ln_pdf(-1.0)); test_exact(0.0, 1.0, 0.36787944117144233_f64.ln(), ln_pdf(0.0)); test_absolute(0.0, 1.0, 0.25464638004358249_f64.ln(), 1e-12, ln_pdf(1.0)); test_absolute(0.0, 10.0, 0.031704192107794217_f64.ln(), 1e-12, ln_pdf(-5.0)); test_absolute(0.0, 10.0, 0.0365982076505757_f64.ln(), 1e-12, ln_pdf(-1.0)); test_exact(0.0, 10.0, 0.036787944117144233_f64.ln(), ln_pdf(0.0)); test_absolute(0.0, 10.0, 0.03661041518977401428_f64.ln(), 1e-12, ln_pdf(1.0)); test_absolute(0.0, 10.0, 0.033070429889041_f64.ln(), 1e-12, ln_pdf(5.0)); test_exact(-2.0, f64::INFINITY, 0.0_f64.ln(), ln_pdf(-5.0)); test_exact(-2.0, f64::INFINITY, 0.0_f64.ln(), ln_pdf(-1.0)); test_exact(-2.0, f64::INFINITY, 0.0_f64.ln(), ln_pdf(0.0)); test_exact(-2.0, f64::INFINITY, 0.0_f64.ln(), ln_pdf(1.0)); test_exact(-2.0, f64::INFINITY, 0.0_f64.ln(), ln_pdf(5.0)); } } statrs-0.18.0/src/distribution/hypergeometric.rs000064400000000000000000000425441046102023000201450ustar 00000000000000use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::factorial; use crate::statistics::*; use std::cmp; use std::f64; /// Implements the /// [Hypergeometric](http://en.wikipedia.org/wiki/Hypergeometric_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{Hypergeometric, Discrete}; /// use statrs::statistics::Distribution; /// use statrs::prec; /// /// let n = Hypergeometric::new(500, 50, 100).unwrap(); /// assert_eq!(n.mean().unwrap(), 10.); /// assert!(prec::almost_eq(n.pmf(10), 0.14736784, 1e-8)); /// assert!(prec::almost_eq(n.pmf(25), 3.537e-7, 1e-10)); /// ``` #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub struct Hypergeometric { population: u64, successes: u64, draws: u64, } /// Represents the errors that can occur when creating a [`Hypergeometric`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum HypergeometricError { /// The number of successes is greater than the population. TooManySuccesses, /// The number of draws is greater than the population. TooManyDraws, } impl std::fmt::Display for HypergeometricError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { HypergeometricError::TooManySuccesses => write!(f, "successes > population"), HypergeometricError::TooManyDraws => write!(f, "draws > population"), } } } impl std::error::Error for HypergeometricError {} impl Hypergeometric { /// Constructs a new hypergeometric distribution /// with a population (N) of `population`, number /// of successes (K) of `successes`, and number of draws /// (n) of `draws`. /// /// # Errors /// /// If `successes > population` or `draws > population`. /// /// # Examples /// /// ``` /// use statrs::distribution::Hypergeometric; /// /// let mut result = Hypergeometric::new(2, 2, 2); /// assert!(result.is_ok()); /// /// result = Hypergeometric::new(2, 3, 2); /// assert!(result.is_err()); /// ``` pub fn new( population: u64, successes: u64, draws: u64, ) -> Result { if successes > population { return Err(HypergeometricError::TooManySuccesses); } if draws > population { return Err(HypergeometricError::TooManyDraws); } Ok(Hypergeometric { population, successes, draws, }) } /// Returns the population size of the hypergeometric /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Hypergeometric; /// /// let n = Hypergeometric::new(10, 5, 3).unwrap(); /// assert_eq!(n.population(), 10); /// ``` pub fn population(&self) -> u64 { self.population } /// Returns the number of observed successes of the hypergeometric /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Hypergeometric; /// /// let n = Hypergeometric::new(10, 5, 3).unwrap(); /// assert_eq!(n.successes(), 5); /// ``` pub fn successes(&self) -> u64 { self.successes } /// Returns the number of draws of the hypergeometric /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Hypergeometric; /// /// let n = Hypergeometric::new(10, 5, 3).unwrap(); /// assert_eq!(n.draws(), 3); /// ``` pub fn draws(&self) -> u64 { self.draws } /// Returns population, successes, and draws in that order /// as a tuple of doubles fn values_f64(&self) -> (f64, f64, f64) { ( self.population as f64, self.successes as f64, self.draws as f64, ) } } impl std::fmt::Display for Hypergeometric { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "Hypergeometric({},{},{})", self.population, self.successes, self.draws ) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Hypergeometric { fn sample(&self, rng: &mut R) -> u64 { let mut population = self.population as f64; let mut successes = self.successes as f64; let mut draws = self.draws; let mut x = 0; loop { let p = successes / population; let next: f64 = rng.gen(); if next < p { x += 1; successes -= 1.0; } population -= 1.0; draws -= 1; if draws == 0 { break; } } x } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Hypergeometric { fn sample(&self, rng: &mut R) -> f64 { rng.sample::(self) as f64 } } impl DiscreteCDF for Hypergeometric { /// Calculates the cumulative distribution function for the hypergeometric /// distribution at `x` /// /// # Formula /// /// ```text /// 1 - ((n choose x+1) * (N-n choose K-x-1)) / (N choose K) * 3_F_2(1, /// x+1-K, x+1-n; k+2, N+x+2-K-n; 1) /// ``` /// /// where `N` is population, `K` is successes, `n` is draws, /// and `p_F_q` is the /// [generalized hypergeometric function](https://en.wikipedia.org/wiki/Generalized_hypergeometric_function) /// /// Calculated as a discrete integral over the probability mass /// function evaluated from 0..x+1 fn cdf(&self, x: u64) -> f64 { if x < self.min() { 0.0 } else if x >= self.max() { 1.0 } else { let k = x; let ln_denom = factorial::ln_binomial(self.population, self.draws); (0..k + 1).fold(0.0, |acc, i| { acc + (factorial::ln_binomial(self.successes, i) + factorial::ln_binomial(self.population - self.successes, self.draws - i) - ln_denom) .exp() }) } } /// Calculates the survival function for the hypergeometric /// distribution at `x` /// /// # Formula /// /// ```text /// 1 - ((n choose x+1) * (N-n choose K-x-1)) / (N choose K) * 3_F_2(1, /// x+1-K, x+1-n; x+2, N+x+2-K-n; 1) /// ``` /// /// where `N` is population, `K` is successes, `n` is draws, /// and `p_F_q` is the /// [generalized hypergeometric function](https://en.wikipedia.org/wiki/Generalized_hypergeometric_function) /// /// Calculated as a discrete integral over the probability mass /// function evaluated from (x+1)..max fn sf(&self, x: u64) -> f64 { if x < self.min() { 1.0 } else if x >= self.max() { 0.0 } else { let k = x; let ln_denom = factorial::ln_binomial(self.population, self.draws); (k + 1..=self.max()).fold(0.0, |acc, i| { acc + (factorial::ln_binomial(self.successes, i) + factorial::ln_binomial(self.population - self.successes, self.draws - i) - ln_denom) .exp() }) } } } impl Min for Hypergeometric { /// Returns the minimum value in the domain of the /// hypergeometric distribution representable by a 64-bit /// integer /// /// # Formula /// /// ```text /// max(0, n + K - N) /// ``` /// /// where `N` is population, `K` is successes, and `n` is draws fn min(&self) -> u64 { (self.draws + self.successes).saturating_sub(self.population) } } impl Max for Hypergeometric { /// Returns the maximum value in the domain of the /// hypergeometric distribution representable by a 64-bit /// integer /// /// # Formula /// /// ```text /// min(K, n) /// ``` /// /// where `K` is successes and `n` is draws fn max(&self) -> u64 { cmp::min(self.successes, self.draws) } } impl Distribution for Hypergeometric { /// Returns the mean of the hypergeometric distribution /// /// # None /// /// If `N` is `0` /// /// # Formula /// /// ```text /// K * n / N /// ``` /// /// where `N` is population, `K` is successes, and `n` is draws fn mean(&self) -> Option { if self.population == 0 { None } else { Some(self.successes as f64 * self.draws as f64 / self.population as f64) } } /// Returns the variance of the hypergeometric distribution /// /// # None /// /// If `N <= 1` /// /// # Formula /// /// ```text /// n * (K / N) * ((N - K) / N) * ((N - n) / (N - 1)) /// ``` /// /// where `N` is population, `K` is successes, and `n` is draws fn variance(&self) -> Option { if self.population <= 1 { None } else { let (population, successes, draws) = self.values_f64(); let val = draws * successes * (population - draws) * (population - successes) / (population * population * (population - 1.0)); Some(val) } } /// Returns the skewness of the hypergeometric distribution /// /// # None /// /// If `N <= 2` /// /// # Formula /// /// ```text /// ((N - 2K) * (N - 1)^(1 / 2) * (N - 2n)) / ([n * K * (N - K) * (N - /// n)]^(1 / 2) * (N - 2)) /// ``` /// /// where `N` is population, `K` is successes, and `n` is draws fn skewness(&self) -> Option { if self.population <= 2 { None } else { let (population, successes, draws) = self.values_f64(); let val = (population - 1.0).sqrt() * (population - 2.0 * draws) * (population - 2.0 * successes) / ((draws * successes * (population - successes) * (population - draws)).sqrt() * (population - 2.0)); Some(val) } } } impl Mode> for Hypergeometric { /// Returns the mode of the hypergeometric distribution /// /// # Formula /// /// ```text /// floor((n + 1) * (k + 1) / (N + 2)) /// ``` /// /// where `N` is population, `K` is successes, and `n` is draws fn mode(&self) -> Option { Some(((self.draws + 1) * (self.successes + 1)) / (self.population + 2)) } } impl Discrete for Hypergeometric { /// Calculates the probability mass function for the hypergeometric /// distribution at `x` /// /// # Formula /// /// ```text /// (K choose x) * (N-K choose n-x) / (N choose n) /// ``` /// /// where `N` is population, `K` is successes, and `n` is draws fn pmf(&self, x: u64) -> f64 { if x > self.draws { 0.0 } else { factorial::binomial(self.successes, x) * factorial::binomial(self.population - self.successes, self.draws - x) / factorial::binomial(self.population, self.draws) } } /// Calculates the log probability mass function for the hypergeometric /// distribution at `x` /// /// # Formula /// /// ```text /// ln((K choose x) * (N-K choose n-x) / (N choose n)) /// ``` /// /// where `N` is population, `K` is successes, and `n` is draws fn ln_pmf(&self, x: u64) -> f64 { factorial::ln_binomial(self.successes, x) + factorial::ln_binomial(self.population - self.successes, self.draws - x) - factorial::ln_binomial(self.population, self.draws) } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(population: u64, successes: u64, draws: u64; Hypergeometric; HypergeometricError); #[test] fn test_create() { create_ok(0, 0, 0); create_ok(1, 1, 1,); create_ok(2, 1, 1); create_ok(2, 2, 2); create_ok(10, 1, 1); create_ok(10, 5, 3); } #[test] fn test_bad_create() { test_create_err(2, 3, 2, HypergeometricError::TooManySuccesses); test_create_err(10, 5, 20, HypergeometricError::TooManyDraws); create_err(0, 1, 1); } #[test] fn test_mean() { let mean = |x: Hypergeometric| x.mean().unwrap(); test_exact(1, 1, 1, 1.0, mean); test_exact(2, 1, 1, 0.5, mean); test_exact(2, 2, 2, 2.0, mean); test_exact(10, 1, 1, 0.1, mean); test_exact(10, 5, 3, 15.0 / 10.0, mean); } #[test] fn test_mean_with_population_0() { test_none(0, 0, 0, |dist| dist.mean()); } #[test] fn test_variance() { let variance = |x: Hypergeometric| x.variance().unwrap(); test_exact(2, 1, 1, 0.25, variance); test_exact(2, 2, 2, 0.0, variance); test_exact(10, 1, 1, 81.0 / 900.0, variance); test_exact(10, 5, 3, 525.0 / 900.0, variance); } #[test] fn test_variance_with_pop_lte_1() { test_none(1, 1, 1, |dist| dist.variance()); } #[test] fn test_skewness() { let skewness = |x: Hypergeometric| x.skewness().unwrap(); test_exact(10, 1, 1, 8.0 / 3.0, skewness); test_exact(10, 5, 3, 0.0, skewness); } #[test] fn test_skewness_with_pop_lte_2() { test_none(2, 2, 2, |dist| dist.skewness()); } #[test] fn test_mode() { let mode = |x: Hypergeometric| x.mode().unwrap(); test_exact(0, 0, 0, 0, mode); test_exact(1, 1, 1, 1, mode); test_exact(2, 1, 1, 1, mode); test_exact(2, 2, 2, 2, mode); test_exact(10, 1, 1, 0, mode); test_exact(10, 5, 3, 2, mode); } #[test] fn test_min() { let min = |x: Hypergeometric| x.min(); test_exact(0, 0, 0, 0, min); test_exact(1, 1, 1, 1, min); test_exact(2, 1, 1, 0, min); test_exact(2, 2, 2, 2, min); test_exact(10, 1, 1, 0, min); test_exact(10, 5, 3, 0, min); } #[test] fn test_max() { let max = |x: Hypergeometric| x.max(); test_exact(0, 0, 0, 0, max); test_exact(1, 1, 1, 1, max); test_exact(2, 1, 1, 1, max); test_exact(2, 2, 2, 2, max); test_exact(10, 1, 1, 1, max); test_exact(10, 5, 3, 3, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: Hypergeometric| x.pmf(arg); test_exact(0, 0, 0, 1.0, pmf(0)); test_exact(1, 1, 1, 1.0, pmf(1)); test_exact(2, 1, 1, 0.5, pmf(0)); test_exact(2, 1, 1, 0.5, pmf(1)); test_exact(2, 2, 2, 1.0, pmf(2)); test_exact(10, 1, 1, 0.9, pmf(0)); test_exact(10, 1, 1, 0.1, pmf(1)); test_exact(10, 5, 3, 0.41666666666666666667, pmf(1)); test_exact(10, 5, 3, 0.083333333333333333333, pmf(3)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: Hypergeometric| x.ln_pmf(arg); test_exact(0, 0, 0, 0.0, ln_pmf(0)); test_exact(1, 1, 1, 0.0, ln_pmf(1)); test_exact(2, 1, 1, -0.6931471805599453094172, ln_pmf(0)); test_exact(2, 1, 1, -0.6931471805599453094172, ln_pmf(1)); test_exact(2, 2, 2, 0.0, ln_pmf(2)); test_absolute(10, 1, 1, -0.1053605156578263012275, 1e-14, ln_pmf(0)); test_absolute(10, 1, 1, -2.302585092994045684018, 1e-14, ln_pmf(1)); test_absolute(10, 5, 3, -0.875468737353899935621, 1e-14, ln_pmf(1)); test_absolute(10, 5, 3, -2.484906649788000310234, 1e-14, ln_pmf(3)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg); test_exact(2, 1, 1, 0.5, cdf(0)); test_absolute(10, 1, 1, 0.9, 1e-14, cdf(0)); test_absolute(10, 5, 3, 0.5, 1e-15, cdf(1)); test_absolute(10, 5, 3, 11.0 / 12.0, 1e-14, cdf(2)); test_absolute(10000, 2, 9800, 199.0 / 499950.0, 1e-14, cdf(0)); test_absolute(10000, 2, 9800, 19799.0 / 499950.0, 1e-12, cdf(1)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg); test_exact(2, 1, 1, 0.5, sf(0)); test_absolute(10, 1, 1, 0.1, 1e-14, sf(0)); test_absolute(10, 5, 3, 0.5, 1e-15, sf(1)); test_absolute(10, 5, 3, 1.0 / 12.0, 1e-14, sf(2)); test_absolute(10000, 2, 9800, 499751. / 499950.0, 1e-10, sf(0)); test_absolute(10000, 2, 9800, 480151. / 499950.0, 1e-10, sf(1)); } #[test] fn test_cdf_arg_too_big() { let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg); test_exact(0, 0, 0, 1.0, cdf(0)); } #[test] fn test_cdf_arg_too_small() { let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg); test_exact(2, 2, 2, 0.0, cdf(0)); } #[test] fn test_sf_arg_too_big() { let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg); test_exact(0, 0, 0, 0.0, sf(0)); } #[test] fn test_sf_arg_too_small() { let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg); test_exact(2, 2, 2, 1.0, sf(0)); } #[test] fn test_discrete() { test::check_discrete_distribution(&create_ok(5, 4, 3), 4); test::check_discrete_distribution(&create_ok(3, 2, 1), 2); } } statrs-0.18.0/src/distribution/internal.rs000064400000000000000000000366621046102023000167370ustar 00000000000000use num_traits::Num; /// Implements univariate function bisection searching for criteria /// ```text /// smallest k such that f(k) >= z /// ``` /// Evaluates to `None` if /// - provided interval has lower bound greater than upper bound /// - function found not semi-monotone on the provided interval containing `z` /// /// Evaluates to `Some(k)`, where `k` satisfies the search criteria pub fn integral_bisection_search( f: impl Fn(&K) -> T, z: T, lb: K, ub: K, ) -> Option { if !(f(&lb)..=f(&ub)).contains(&z) { return None; } let two = K::one() + K::one(); let mut lb = lb; let mut ub = ub; loop { let mid = (lb.clone() + ub.clone()) / two.clone(); if !(f(&lb)..=f(&ub)).contains(&f(&mid)) { return None; // f found not monotone on interval } else if f(&lb) == z { return Some(lb); } else if f(&ub) == z || (lb.clone() + K::one()) == ub { return Some(ub); // found or no more integers between } else if f(&mid) >= z { ub = mid; } else { lb = mid; } } } #[macro_use] #[cfg(test)] pub mod test { use super::*; use crate::distribution::{Continuous, ContinuousCDF, Discrete, DiscreteCDF}; #[macro_export] macro_rules! testing_boiler { ($($arg_name:ident: $arg_ty:ty),+; $dist:ty; $dist_err:ty) => { fn make_param_text($($arg_name: $arg_ty),+) -> String { // "" let mut param_text = String::new(); // "shape=10.0, rate=NaN, " $( param_text.push_str( &format!( "{}={:?}, ", stringify!($arg_name), $arg_name, ) ); )+ // "shape=10.0, rate=NaN" (removes trailing comma and whitespace) param_text.pop(); param_text.pop(); param_text } /// Creates and returns a distribution with the given parameters, /// panicking if `::new` fails. fn create_ok($($arg_name: $arg_ty),+) -> $dist { match <$dist>::new($($arg_name),+) { Ok(d) => d, Err(e) => panic!( "{}::new was expected to succeed, but failed for {} with error: '{}'", stringify!($dist), make_param_text($($arg_name),+), e ) } } /// Returns the error when creating a distribution with the given parameters, /// panicking if `::new` succeeds. #[allow(dead_code)] fn create_err($($arg_name: $arg_ty),+) -> $dist_err { match <$dist>::new($($arg_name),+) { Err(e) => e, Ok(d) => panic!( "{}::new was expected to fail, but succeeded for {} with result: {:?}", stringify!($dist), make_param_text($($arg_name),+), d ) } } /// Creates a distribution with the given parameters, calls the `get_fn` /// function with the new distribution and returns the result of `get_fn`. /// /// Panics if `::new` fails. fn create_and_get($($arg_name: $arg_ty),+, get_fn: F) -> T where F: Fn($dist) -> T, { let n = create_ok($($arg_name),+); get_fn(n) } /// Creates a distribution with the given parameters, calls the `get_fn` /// function with the new distribution and compares the result of `get_fn` /// to `expected` exactly. /// /// Panics if `::new` fails. #[allow(dead_code)] fn test_exact($($arg_name: $arg_ty),+, expected: T, get_fn: F) where F: Fn($dist) -> T, T: ::core::cmp::PartialEq + ::core::fmt::Debug { let x = create_and_get($($arg_name),+, get_fn); if x != expected { panic!( "Expected {:?}, got {:?} for {}", expected, x, make_param_text($($arg_name),+) ); } } /// Gets a value for the given parameters by calling `create_and_get` /// and compares it to `expected`. /// /// Allows relative error of up to [`crate::consts::ACC`]. /// /// Panics if `::new` fails. #[allow(dead_code)] fn test_relative($($arg_name: $arg_ty),+, expected: f64, get_fn: F) where F: Fn($dist) -> f64, { let x = create_and_get($($arg_name),+, get_fn); let max_relative = $crate::consts::ACC; if !::approx::relative_eq!(expected, x, max_relative = max_relative) { panic!( "Expected {:?} to be almost equal to {:?} (max. relative error of {:?}), but wasn't for {}", x, expected, max_relative, make_param_text($($arg_name),+) ); } } /// Gets a value for the given parameters by calling `create_and_get` /// and compares it to `expected`. /// /// Allows absolute error of up to `acc`. /// /// Panics if `::new` fails. #[allow(dead_code)] fn test_absolute($($arg_name: $arg_ty),+, expected: f64, acc: f64, get_fn: F) where F: Fn($dist) -> f64, { let x = create_and_get($($arg_name),+, get_fn); // abs_diff_eq! cannot handle infinities, so we manually accept them here if expected.is_infinite() && x == expected { return; } if !::approx::abs_diff_eq!(expected, x, epsilon = acc) { panic!( "Expected {:?} to be almost equal to {:?} (max. absolute error of {:?}), but wasn't for {}", x, expected, acc, make_param_text($($arg_name),+) ); } } /// Purposely fails creating a distribution with the given /// parameters and compares the returned error to `expected`. /// /// Panics if `::new` succeeds. #[allow(dead_code)] fn test_create_err($($arg_name: $arg_ty),+, expected: $dist_err) { let err = create_err($($arg_name),+); if err != expected { panic!( "{}::new was expected to fail with error {:?}, but failed with error {:?} for {}", stringify!($dist), expected, err, make_param_text($($arg_name),+) ) } } /// Gets a value for the given parameters by calling `create_and_get` /// and asserts that it is [`NAN`]. /// /// Panics if `::new` fails. #[allow(dead_code)] fn test_is_nan($($arg_name: $arg_ty),+, get_fn: F) where F: Fn($dist) -> f64 { let x = create_and_get($($arg_name),+, get_fn); assert!(x.is_nan()); } /// Gets a value for the given parameters by calling `create_and_get` /// and asserts that it is [`None`]. /// /// Panics if `::new` fails. #[allow(dead_code)] fn test_none($($arg_name: $arg_ty),+, get_fn: F) where F: Fn($dist) -> Option, T: ::core::fmt::Debug, { let x = create_and_get($($arg_name),+, get_fn); if let Some(inner) = x { panic!( "Expected None, got {:?} for {}", inner, make_param_text($($arg_name),+) ) } } /// Asserts that associated error type is Send and Sync #[test] fn test_error_is_sync_send() { fn assert_sync_send() {} assert_sync_send::<$dist_err>(); } }; } pub mod boiler_tests { use crate::distribution::{Beta, BetaError}; use crate::statistics::*; testing_boiler!(shape_a: f64, shape_b: f64; Beta; BetaError); #[test] fn create_ok_success() { let b = create_ok(0.8, 1.2); assert_eq!(b.shape_a(), 0.8); assert_eq!(b.shape_b(), 1.2); } #[test] #[should_panic] fn create_err_failure() { create_err(0.8, 1.2); } #[test] fn create_err_success() { let err = create_err(-0.5, 1.2); assert_eq!(err, BetaError::ShapeAInvalid); } #[test] #[should_panic] fn create_ok_failure() { create_ok(-0.5, 1.2); } #[test] fn test_exact_success() { test_exact(1.5, 1.5, 0.5, |dist| dist.mode().unwrap()); } #[test] #[should_panic] fn test_exact_failure() { test_exact(1.2, 1.4, 0.333333333333, |dist| dist.mode().unwrap()); } #[test] fn test_relative_success() { test_relative(1.2, 1.4, 0.333333333333, |dist| dist.mode().unwrap()); } #[test] #[should_panic] fn test_relative_failure() { test_relative(1.2, 1.4, 0.333, |dist| dist.mode().unwrap()); } #[test] fn test_absolute_success() { test_absolute(1.2, 1.4, 0.333333333333, 1e-12, |dist| dist.mode().unwrap()); } #[test] #[should_panic] fn test_absolute_failure() { test_absolute(1.2, 1.4, 0.333333333333, 1e-15, |dist| dist.mode().unwrap()); } #[test] fn test_create_err_success() { test_create_err(0.0, 0.5, BetaError::ShapeAInvalid); } #[test] #[should_panic] fn test_create_err_failure() { test_create_err(0.0, 0.5, BetaError::ShapeBInvalid); } #[test] fn test_is_nan_success() { // Not sure that any Beta API can return a NaN, so we force the issue test_is_nan(0.8, 1.2, |_| f64::NAN); } #[test] #[should_panic] fn test_is_nan_failure() { test_is_nan(0.8, 1.2, |dist| dist.mean().unwrap()); } #[test] fn test_is_none_success() { test_none(0.5, 1.2, |dist| dist.mode()); } #[test] #[should_panic] fn test_is_none_failure() { test_none(0.8, 1.2, |dist| dist.mean()); } } /// cdf should be the integral of the pdf fn check_integrate_pdf_is_cdf + Continuous>( dist: &D, x_min: f64, x_max: f64, step: f64, ) { let mut prev_x = x_min; let mut prev_density = dist.pdf(x_min); let mut sum = 0.0; loop { let x = prev_x + step; let density = dist.pdf(x); assert!(density >= 0.0); let ln_density = dist.ln_pdf(x); assert_almost_eq!(density.ln(), ln_density, 1e-10); // triangle rule sum += (prev_density + density) * step / 2.0; let cdf = dist.cdf(x); if (sum - cdf).abs() > 1e-3 { println!("Integral of pdf doesn't equal cdf!"); println!("Integration from {x_min} by {step} to {x} = {sum}"); println!("cdf = {cdf}"); panic!(); } if x >= x_max { break; } else { prev_x = x; prev_density = density; } } assert!(sum > 0.99); assert!(sum <= 1.001); } /// cdf should be the sum of the pmf fn check_sum_pmf_is_cdf + Discrete>(dist: &D, x_max: u64) { let mut sum = 0.0; // go slightly beyond x_max to test for off-by-one errors for i in 0..x_max + 3 { let prob = dist.pmf(i); assert!(prob >= 0.0); assert!(prob <= 1.0); sum += prob; if i == x_max { assert!(sum > 0.99); } assert_almost_eq!(sum, dist.cdf(i), 1e-10); // assert_almost_eq!(sum, dist.cdf(i as f64), 1e-10); // assert_almost_eq!(sum, dist.cdf(i as f64 + 0.1), 1e-10); // assert_almost_eq!(sum, dist.cdf(i as f64 + 0.5), 1e-10); // assert_almost_eq!(sum, dist.cdf(i as f64 + 0.9), 1e-10); } assert!(sum > 0.99); assert!(sum <= 1.0 + 1e-10); } /// Does a series of checks that all continuous distributions must obey. /// 99% of the probability mass should be between x_min and x_max. pub fn check_continuous_distribution + Continuous>( dist: &D, x_min: f64, x_max: f64, ) { assert_eq!(dist.pdf(f64::NEG_INFINITY), 0.0); assert_eq!(dist.pdf(f64::INFINITY), 0.0); assert_eq!(dist.ln_pdf(f64::NEG_INFINITY), f64::NEG_INFINITY); assert_eq!(dist.ln_pdf(f64::INFINITY), f64::NEG_INFINITY); assert_eq!(dist.cdf(f64::NEG_INFINITY), 0.0); assert_eq!(dist.cdf(f64::INFINITY), 1.0); check_integrate_pdf_is_cdf(dist, x_min, x_max, (x_max - x_min) / 100000.0); } /// Does a series of checks that all positive discrete distributions must /// obey. /// 99% of the probability mass should be between 0 and x_max (inclusive). pub fn check_discrete_distribution + Discrete>( dist: &D, x_max: u64, ) { // assert_eq!(dist.cdf(f64::NEG_INFINITY), 0.0); // assert_eq!(dist.cdf(-10.0), 0.0); // assert_eq!(dist.cdf(-1.0), 0.0); // assert_eq!(dist.cdf(-0.01), 0.0); // assert_eq!(dist.cdf(f64::INFINITY), 1.0); check_sum_pmf_is_cdf(dist, x_max); } #[test] fn test_integer_bisection() { fn search(z: usize, data: &[usize]) -> Option { integral_bisection_search(|idx: &usize| data[*idx], z, 0, data.len() - 1) } let needle = 3; let data = (0..5) .map(|n| if n >= needle { n + 1 } else { n }) .collect::>(); for i in 0..(data.len()) { assert_eq!(search(data[i], &data), Some(i),) } { let infimum = search(needle, &data); let found_element = search(needle + 1, &data); // 4 > needle && member of range assert_eq!(found_element, Some(needle)); assert_eq!(infimum, found_element) } } } statrs-0.18.0/src/distribution/inverse_gamma.rs000064400000000000000000000320721046102023000177270ustar 00000000000000use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::statistics::*; use std::f64; /// Implements the [Inverse /// Gamma](https://en.wikipedia.org/wiki/Inverse-gamma_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{InverseGamma, Continuous}; /// use statrs::statistics::Distribution; /// use statrs::prec; /// /// let n = InverseGamma::new(1.1, 0.1).unwrap(); /// assert!(prec::almost_eq(n.mean().unwrap(), 1.0, 1e-14)); /// assert_eq!(n.pdf(1.0), 0.07554920138253064); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct InverseGamma { shape: f64, rate: f64, } /// Represents the errors that can occur when creating an [`InverseGamma`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum InverseGammaError { /// The shape is NaN, infinite, zero or less than zero. ShapeInvalid, /// The rate is NaN, infinite, zero or less than zero. RateInvalid, } impl std::fmt::Display for InverseGammaError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { InverseGammaError::ShapeInvalid => { write!(f, "Shape is NaN, infinite, zero or less than zero") } InverseGammaError::RateInvalid => { write!(f, "Rate is NaN, infinite, zero or less than zero") } } } } impl std::error::Error for InverseGammaError {} impl InverseGamma { /// Constructs a new inverse gamma distribution with a shape (α) /// of `shape` and a rate (β) of `rate` /// /// # Errors /// /// Returns an error if `shape` or `rate` are `NaN`. /// Also returns an error if `shape` or `rate` are not in `(0, +inf)` /// /// # Examples /// /// ``` /// use statrs::distribution::InverseGamma; /// /// let mut result = InverseGamma::new(3.0, 1.0); /// assert!(result.is_ok()); /// /// result = InverseGamma::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` pub fn new(shape: f64, rate: f64) -> Result { if shape.is_nan() || shape.is_infinite() || shape <= 0.0 { return Err(InverseGammaError::ShapeInvalid); } if rate.is_nan() || rate.is_infinite() || rate <= 0.0 { return Err(InverseGammaError::RateInvalid); } Ok(InverseGamma { shape, rate }) } /// Returns the shape (α) of the inverse gamma distribution /// /// # Examples /// /// ``` /// use statrs::distribution::InverseGamma; /// /// let n = InverseGamma::new(3.0, 1.0).unwrap(); /// assert_eq!(n.shape(), 3.0); /// ``` pub fn shape(&self) -> f64 { self.shape } /// Returns the rate (β) of the inverse gamma distribution /// /// # Examples /// /// ``` /// use statrs::distribution::InverseGamma; /// /// let n = InverseGamma::new(3.0, 1.0).unwrap(); /// assert_eq!(n.rate(), 1.0); /// ``` pub fn rate(&self) -> f64 { self.rate } } impl std::fmt::Display for InverseGamma { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Inv-Gamma({}, {})", self.shape, self.rate) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for InverseGamma { fn sample(&self, r: &mut R) -> f64 { 1.0 / super::gamma::sample_unchecked(r, self.shape, self.rate) } } impl ContinuousCDF for InverseGamma { /// Calculates the cumulative distribution function for the inverse gamma /// distribution at `x` /// /// # Formula /// /// ```text /// Γ(α, β / x) / Γ(α) /// ``` /// /// where the numerator is the upper incomplete gamma function, /// the denominator is the gamma function, `α` is the shape, /// and `β` is the rate fn cdf(&self, x: f64) -> f64 { if x <= 0.0 { 0.0 } else if x.is_infinite() { 1.0 } else { gamma::gamma_ur(self.shape, self.rate / x) } } /// Calculates the survival function for the inverse gamma /// distribution at `x` /// /// # Formula /// /// ```text /// Γ(α, β / x) / Γ(α) /// ``` /// /// where the numerator is the lower incomplete gamma function, /// the denominator is the gamma function, `α` is the shape, /// and `β` is the rate fn sf(&self, x: f64) -> f64 { if x <= 0.0 { 1.0 } else if x.is_infinite() { 0.0 } else { gamma::gamma_lr(self.shape, self.rate / x) } } } impl Min for InverseGamma { /// Returns the minimum value in the domain of the /// inverse gamma distribution representable by a double precision /// float /// /// # Formula /// /// ```text /// 0 /// ``` fn min(&self) -> f64 { 0.0 } } impl Max for InverseGamma { /// Returns the maximum value in the domain of the /// inverse gamma distribution representable by a double precision /// float /// /// # Formula /// /// ```text /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY } } impl Distribution for InverseGamma { /// Returns the mean of the inverse distribution /// /// # None /// /// If `shape <= 1.0` /// /// # Formula /// /// ```text /// β / (α - 1) /// ``` /// /// where `α` is the shape and `β` is the rate fn mean(&self) -> Option { if self.shape <= 1.0 { None } else { Some(self.rate / (self.shape - 1.0)) } } /// Returns the variance of the inverse gamma distribution /// /// # None /// /// If `shape <= 2.0` /// /// # Formula /// /// ```text /// β^2 / ((α - 1)^2 * (α - 2)) /// ``` /// /// where `α` is the shape and `β` is the rate fn variance(&self) -> Option { if self.shape <= 2.0 { None } else { let val = self.rate * self.rate / ((self.shape - 1.0) * (self.shape - 1.0) * (self.shape - 2.0)); Some(val) } } /// Returns the entropy of the inverse gamma distribution /// /// # Formula /// /// ```text /// α + ln(β * Γ(α)) - (1 + α) * ψ(α) /// ``` /// /// where `α` is the shape, `β` is the rate, `Γ` is the gamma function, /// and `ψ` is the digamma function fn entropy(&self) -> Option { let entr = self.shape + self.rate.ln() + gamma::ln_gamma(self.shape) - (1.0 + self.shape) * gamma::digamma(self.shape); Some(entr) } /// Returns the skewness of the inverse gamma distribution /// /// # None /// /// If `shape <= 3` /// /// # Formula /// /// ```text /// 4 * sqrt(α - 2) / (α - 3) /// ``` /// /// where `α` is the shape fn skewness(&self) -> Option { if self.shape <= 3.0 { None } else { Some(4.0 * (self.shape - 2.0).sqrt() / (self.shape - 3.0)) } } } impl Mode> for InverseGamma { /// Returns the mode of the inverse gamma distribution /// /// # Formula /// /// ```text /// β / (α + 1) /// ``` /// /// /// where `α` is the shape and `β` is the rate fn mode(&self) -> Option { Some(self.rate / (self.shape + 1.0)) } } impl Continuous for InverseGamma { /// Calculates the probability density function for the /// inverse gamma distribution at `x` /// /// # Formula /// /// ```text /// (β^α / Γ(α)) * x^(-α - 1) * e^(-β / x) /// ``` /// /// where `α` is the shape, `β` is the rate, and `Γ` is the gamma function fn pdf(&self, x: f64) -> f64 { if x <= 0.0 || x.is_infinite() { 0.0 } else if ulps_eq!(self.shape, 1.0) { self.rate / (x * x) * (-self.rate / x).exp() } else { self.rate.powf(self.shape) * x.powf(-self.shape - 1.0) * (-self.rate / x).exp() / gamma::gamma(self.shape) } } /// Calculates the probability density function for the /// inverse gamma distribution at `x` /// /// # Formula /// /// ```text /// ln((β^α / Γ(α)) * x^(-α - 1) * e^(-β / x)) /// ``` /// /// where `α` is the shape, `β` is the rate, and `Γ` is the gamma function fn ln_pdf(&self, x: f64) -> f64 { self.pdf(x).ln() } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(shape: f64, rate: f64; InverseGamma; InverseGammaError); #[test] fn test_create() { create_ok(0.1, 0.1); create_ok(1.0, 1.0); } #[test] fn test_bad_create() { test_create_err(0.0, 1.0, InverseGammaError::ShapeInvalid); test_create_err(1.0, -1.0, InverseGammaError::RateInvalid); create_err(-1.0, 1.0); create_err(-100.0, 1.0); create_err(f64::NEG_INFINITY, 1.0); create_err(f64::NAN, 1.0); create_err(1.0, 0.0); create_err(1.0, -100.0); create_err(1.0, f64::NEG_INFINITY); create_err(1.0, f64::NAN); create_err(f64::INFINITY, 1.0); create_err(1.0, f64::INFINITY); create_err(f64::INFINITY, f64::INFINITY); } #[test] fn test_mean() { let mean = |x: InverseGamma| x.mean().unwrap(); test_absolute(1.1, 0.1, 1.0, 1e-14, mean); test_absolute(1.1, 1.0, 10.0, 1e-14, mean); } #[test] fn test_mean_with_shape_lte_1() { test_none(0.1, 0.1, |dist| dist.mean()); } #[test] fn test_variance() { let variance = |x: InverseGamma| x.variance().unwrap(); test_absolute(2.1, 0.1, 0.08264462809917355371901, 1e-15, variance); test_absolute(2.1, 1.0, 8.264462809917355371901, 1e-13, variance); } #[test] fn test_variance_with_shape_lte_2() { test_none(0.1, 0.1, |dist| dist.variance()); } #[test] fn test_entropy() { let entropy = |x: InverseGamma| x.entropy().unwrap(); test_absolute(0.1, 0.1, 11.51625799319234475054, 1e-14, entropy); test_absolute(1.0, 1.0, 2.154431329803065721213, 1e-14, entropy); } #[test] fn test_skewness() { let skewness = |x: InverseGamma| x.skewness().unwrap(); test_absolute(3.1, 0.1, 41.95235392680606187966, 1e-13, skewness); test_absolute(3.1, 1.0, 41.95235392680606187966, 1e-13, skewness); test_exact(5.0, 0.1, 3.464101615137754587055, skewness); } #[test] fn test_skewness_with_shape_lte_3() { test_none(0.1, 0.1, |dist| dist.skewness()); } #[test] fn test_mode() { let mode = |x: InverseGamma| x.mode().unwrap(); test_exact(0.1, 0.1, 0.09090909090909090909091, mode); test_exact(1.0, 1.0, 0.5, mode); } #[test] fn test_min_max() { let min = |x: InverseGamma| x.min(); let max = |x: InverseGamma| x.max(); test_exact(1.0, 1.0, 0.0, min); test_exact(1.0, 1.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: InverseGamma| x.pdf(arg); test_absolute(0.1, 0.1, 0.0628591853882328004197, 1e-15, pdf(1.2)); test_absolute(0.1, 1.0, 0.0297426109178248997426, 1e-15, pdf(2.0)); test_exact(1.0, 0.1, 0.04157808822362745501024, pdf(1.5)); test_exact(1.0, 1.0, 0.3018043114632487660842, pdf(1.2)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: InverseGamma| x.ln_pdf(arg); test_absolute(0.1, 0.1, 0.0628591853882328004197f64.ln(), 1e-15, ln_pdf(1.2)); test_absolute(0.1, 1.0, 0.0297426109178248997426f64.ln(), 1e-15, ln_pdf(2.0)); test_exact(1.0, 0.1, 0.04157808822362745501024f64.ln(), ln_pdf(1.5)); test_exact(1.0, 1.0, 0.3018043114632487660842f64.ln(), ln_pdf(1.2)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: InverseGamma| x.cdf(arg); test_absolute(0.1, 0.1, 0.1862151961946054271994, 1e-14, cdf(1.2)); test_absolute(0.1, 1.0, 0.05859755410986647796141, 1e-14, cdf(2.0)); test_exact(1.0, 0.1, 0.9355069850316177377304, cdf(1.5)); test_absolute(1.0, 1.0, 0.4345982085070782231613, 1e-14, cdf(1.2)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: InverseGamma| x.sf(arg); test_absolute(0.1, 0.1, 0.8137848038053936, 1e-14, sf(1.2)); test_absolute(0.1, 1.0, 0.9414024458901327, 1e-14, sf(2.0)); test_absolute(1.0, 0.1, 0.0644930149683822, 1e-14, sf(1.5)); test_absolute(1.0, 1.0, 0.565401791492922, 1e-14, sf(1.2)); } #[test] fn test_continuous() { test::check_continuous_distribution(&create_ok(1.0, 0.5), 0.0, 100.0); test::check_continuous_distribution(&create_ok(9.0, 2.0), 0.0, 100.0); } } statrs-0.18.0/src/distribution/laplace.rs000064400000000000000000000442461046102023000165210ustar 00000000000000use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::{Distribution, Max, Median, Min, Mode}; use std::f64; /// Implements the [Laplace](https://en.wikipedia.org/wiki/Laplace_distribution) /// distribution. /// /// # Examples /// /// ``` /// use statrs::distribution::{Laplace, Continuous}; /// use statrs::statistics::Mode; /// /// let n = Laplace::new(0.0, 1.0).unwrap(); /// assert_eq!(n.mode().unwrap(), 0.0); /// assert_eq!(n.pdf(1.0), 0.18393972058572117); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct Laplace { location: f64, scale: f64, } /// Represents the errors that can occur when creating a [`Laplace`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum LaplaceError { /// The location is NaN. LocationInvalid, /// The scale is NaN, zero or less than zero. ScaleInvalid, } impl std::fmt::Display for LaplaceError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { LaplaceError::LocationInvalid => write!(f, "Location is NaN"), LaplaceError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"), } } } impl std::error::Error for LaplaceError {} impl Laplace { /// Constructs a new laplace distribution with the given /// location and scale. /// /// # Errors /// /// Returns an error if location or scale are `NaN` or `scale <= 0.0` /// /// # Examples /// /// ``` /// use statrs::distribution::Laplace; /// /// let mut result = Laplace::new(0.0, 1.0); /// assert!(result.is_ok()); /// /// result = Laplace::new(0.0, -1.0); /// assert!(result.is_err()); /// ``` pub fn new(location: f64, scale: f64) -> Result { if location.is_nan() { return Err(LaplaceError::LocationInvalid); } if scale.is_nan() || scale <= 0.0 { return Err(LaplaceError::ScaleInvalid); } Ok(Laplace { location, scale }) } /// Returns the location of the laplace distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Laplace; /// /// let n = Laplace::new(0.0, 1.0).unwrap(); /// assert_eq!(n.location(), 0.0); /// ``` pub fn location(&self) -> f64 { self.location } /// Returns the scale of the laplace distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Laplace; /// /// let n = Laplace::new(0.0, 1.0).unwrap(); /// assert_eq!(n.scale(), 1.0); /// ``` pub fn scale(&self) -> f64 { self.scale } } impl std::fmt::Display for Laplace { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Laplace({}, {})", self.location, self.scale) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Laplace { fn sample(&self, rng: &mut R) -> f64 { let x: f64 = rng.gen_range(-0.5..0.5); self.location - self.scale * x.signum() * (1. - 2. * x.abs()).ln() } } impl ContinuousCDF for Laplace { /// Calculates the cumulative distribution function for the /// laplace distribution at `x` /// /// # Formula /// /// ```text /// (1 / 2) * (1 + signum(x - μ)) - signum(x - μ) * exp(-|x - μ| / b) /// ``` /// /// where `μ` is the location, `b` is the scale fn cdf(&self, x: f64) -> f64 { let y = (-(x - self.location).abs() / self.scale).exp() / 2.; if x >= self.location { 1. - y } else { y } } /// Calculates the survival function for the /// laplace distribution at `x` /// /// # Formula /// /// ```text /// 1 - [(1 / 2) * (1 + signum(x - μ)) - signum(x - μ) * exp(-|x - μ| / b)] /// ``` /// /// where `μ` is the location, `b` is the scale fn sf(&self, x: f64) -> f64 { let y = (-(x - self.location).abs() / self.scale).exp() / 2.; if x >= self.location { y } else { 1. - y } } /// Calculates the inverse cumulative distribution function for the /// laplace distribution at `p` /// /// # Formula /// /// if p <= 1/2 /// ```text /// μ + b * ln(2p) /// ``` /// if p >= 1/2 /// ```text /// μ - b * ln(2 - 2p) /// ``` /// /// where `μ` is the location, `b` is the scale fn inverse_cdf(&self, p: f64) -> f64 { if p <= 0. || 1. <= p { panic!("p must be in [0, 1]"); }; if p <= 0.5 { self.location + self.scale * (2. * p).ln() } else { self.location - self.scale * (2. - 2. * p).ln() } } } impl Min for Laplace { /// Returns the minimum value in the domain of the laplace /// distribution representable by a double precision float /// /// # Formula /// /// ```text /// NEG_INF /// ``` fn min(&self) -> f64 { f64::NEG_INFINITY } } impl Max for Laplace { /// Returns the maximum value in the domain of the laplace /// distribution representable by a double precision float /// /// # Formula /// /// ```text /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY } } impl Distribution for Laplace { /// Returns the mode of the laplace distribution /// /// # Formula /// /// ```text /// μ /// ``` /// /// where `μ` is the location fn mean(&self) -> Option { Some(self.location) } /// Returns the variance of the laplace distribution /// /// # Formula /// /// ```text /// 2*b^2 /// ``` /// /// where `b` is the scale fn variance(&self) -> Option { Some(2. * self.scale * self.scale) } /// Returns the entropy of the laplace distribution /// /// # Formula /// /// ```text /// ln(2be) /// ``` /// /// where `b` is the scale fn entropy(&self) -> Option { Some((2. * self.scale).ln() + 1.) } /// Returns the skewness of the laplace distribution /// /// # Formula /// /// ```text /// 0 /// ``` fn skewness(&self) -> Option { Some(0.) } } impl Median for Laplace { /// Returns the median of the laplace distribution /// /// # Formula /// /// ```text /// μ /// ``` /// /// where `μ` is the location fn median(&self) -> f64 { self.location } } impl Mode> for Laplace { /// Returns the mode of the laplace distribution /// /// # Formula /// /// ```text /// μ /// ``` /// /// where `μ` is the location fn mode(&self) -> Option { Some(self.location) } } impl Continuous for Laplace { /// Calculates the probability density function for the laplace /// distribution at `x` /// /// # Formula /// /// ```text /// (1 / 2b) * exp(-|x - μ| / b) /// ``` /// where `μ` is the location and `b` is the scale fn pdf(&self, x: f64) -> f64 { (-(x - self.location).abs() / self.scale).exp() / (2. * self.scale) } /// Calculates the log probability density function for the laplace /// distribution at `x` /// /// # Formula /// /// ```text /// ln((1 / 2b) * exp(-|x - μ| / b)) /// ``` /// /// where `μ` is the location and `b` is the scale fn ln_pdf(&self, x: f64) -> f64 { ((-(x - self.location).abs() / self.scale).exp() / (2. * self.scale)).ln() } } #[cfg(test)] mod tests { use super::*; use crate::testing_boiler; testing_boiler!(location: f64, scale: f64; Laplace; LaplaceError); // A wrapper for the `assert_relative_eq!` macro from the approx crate. // // `rtol` is the accepable relative error. This function is for testing // relative tolerance *only*. It should not be used with `expected = 0`. // fn test_rel_close(location: f64, scale: f64, expected: f64, rtol: f64, get_fn: F) where F: Fn(Laplace) -> f64, { let x = create_and_get(location, scale, get_fn); assert_relative_eq!(expected, x, epsilon = 0.0, max_relative = rtol); } #[test] fn test_create() { create_ok(1.0, 2.0); create_ok(f64::NEG_INFINITY, 0.1); create_ok(-5.0 - 1.0, 1.0); create_ok(0.0, 5.0); create_ok(1.0, 7.0); create_ok(5.0, 10.0); create_ok(f64::INFINITY, f64::INFINITY); } #[test] fn test_bad_create() { test_create_err(2.0, -1.0, LaplaceError::ScaleInvalid); test_create_err(f64::NAN, 1.0, LaplaceError::LocationInvalid); create_err(f64::NAN, -1.0); } #[test] fn test_mean() { let mean = |x: Laplace| x.mean().unwrap(); test_exact(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, mean); test_exact(-5.0 - 1.0, 1.0, -6.0, mean); test_exact(0.0, 5.0, 0.0, mean); test_exact(1.0, 10.0, 1.0, mean); test_exact(f64::INFINITY, f64::INFINITY, f64::INFINITY, mean); } #[test] fn test_variance() { let variance = |x: Laplace| x.variance().unwrap(); test_absolute(f64::NEG_INFINITY, 0.1, 0.02, 1E-12, variance); test_absolute(-5.0 - 1.0, 1.0, 2.0, 1E-12, variance); test_absolute(0.0, 5.0, 50.0, 1E-12, variance); test_absolute(1.0, 7.0, 98.0, 1E-12, variance); test_absolute(5.0, 10.0, 200.0, 1E-12, variance); test_absolute(f64::INFINITY, f64::INFINITY, f64::INFINITY, 1E-12, variance); } #[test] fn test_entropy() { let entropy = |x: Laplace| x.entropy().unwrap(); test_absolute( f64::NEG_INFINITY, 0.1, (2.0 * f64::consts::E * 0.1).ln(), 1E-12, entropy, ); test_absolute(-6.0, 1.0, (2.0 * f64::consts::E).ln(), 1E-12, entropy); test_absolute(1.0, 7.0, (2.0 * f64::consts::E * 7.0).ln(), 1E-12, entropy); test_absolute(5., 10., (2. * f64::consts::E * 10.).ln(), 1E-12, entropy); test_absolute(f64::INFINITY, f64::INFINITY, f64::INFINITY, 1E-12, entropy); } #[test] fn test_skewness() { let skewness = |x: Laplace| x.skewness().unwrap(); test_exact(f64::NEG_INFINITY, 0.1, 0.0, skewness); test_exact(-6.0, 1.0, 0.0, skewness); test_exact(1.0, 7.0, 0.0, skewness); test_exact(5.0, 10.0, 0.0, skewness); test_exact(f64::INFINITY, f64::INFINITY, 0.0, skewness); } #[test] fn test_mode() { let mode = |x: Laplace| x.mode().unwrap(); test_exact(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, mode); test_exact(-6.0, 1.0, -6.0, mode); test_exact(1.0, 7.0, 1.0, mode); test_exact(5.0, 10.0, 5.0, mode); test_exact(f64::INFINITY, f64::INFINITY, f64::INFINITY, mode); } #[test] fn test_median() { let median = |x: Laplace| x.median(); test_exact(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, median); test_exact(-6.0, 1.0, -6.0, median); test_exact(1.0, 7.0, 1.0, median); test_exact(5.0, 10.0, 5.0, median); test_exact(f64::INFINITY, f64::INFINITY, f64::INFINITY, median); } #[test] fn test_min() { test_exact(0.0, 1.0, f64::NEG_INFINITY, |l| l.min()); } #[test] fn test_max() { test_exact(0.0, 1.0, f64::INFINITY, |l| l.max()); } #[test] fn test_density() { let pdf = |arg: f64| move |x: Laplace| x.pdf(arg); test_absolute(0.0, 0.1, 1.529511602509129e-06, 1E-12, pdf(1.5)); test_absolute(1.0, 0.1, 7.614989872356341e-08, 1E-12, pdf(2.8)); test_absolute(-1.0, 0.1, 3.8905661205668983e-19, 1E-12, pdf(-5.4)); test_absolute(5.0, 0.1, 5.056107463052243e-43, 1E-12, pdf(-4.9)); test_absolute(-5.0, 0.1, 1.9877248679543235e-30, 1E-12, pdf(2.0)); test_absolute(f64::INFINITY, 0.1, 0.0, 1E-12, pdf(5.5)); test_absolute(f64::NEG_INFINITY, 0.1, 0.0, 1E-12, pdf(-0.0)); test_absolute(0.0, 1.0, 0.0, 1E-12, pdf(f64::INFINITY)); test_absolute(1.0, 1.0, 0.00915781944436709, 1E-12, pdf(5.0)); test_absolute(-1.0, 1.0, 0.5, 1E-12, pdf(-1.0)); test_absolute(5.0, 1.0, 0.0012393760883331792, 1E-12, pdf(-1.0)); test_absolute(-5.0, 1.0, 0.0002765421850739168, 1E-12, pdf(2.5)); test_absolute(f64::INFINITY, 0.1, 0.0, 1E-12, pdf(2.0)); test_absolute(f64::NEG_INFINITY, 0.1, 0.0, 1E-12, pdf(15.0)); test_absolute(0.0, f64::INFINITY, 0.0, 1E-12, pdf(89.3)); test_absolute(1.0, f64::INFINITY, 0.0, 1E-12, pdf(-0.1)); test_absolute(-1.0, f64::INFINITY, 0.0, 1E-12, pdf(0.1)); test_absolute(5.0, f64::INFINITY, 0.0, 1E-12, pdf(-6.1)); test_absolute(-5.0, f64::INFINITY, 0.0, 1E-12, pdf(-10.0)); test_is_nan(f64::INFINITY, f64::INFINITY, pdf(2.0)); test_is_nan(f64::NEG_INFINITY, f64::INFINITY, pdf(-5.1)); } #[test] fn test_ln_density() { let ln_pdf = |arg: f64| move |x: Laplace| x.ln_pdf(arg); test_absolute(0.0, 0.1, -13.3905620875659, 1E-12, ln_pdf(1.5)); test_absolute(1.0, 0.1, -16.390562087565897, 1E-12, ln_pdf(2.8)); test_absolute(-1.0, 0.1, -42.39056208756591, 1E-12, ln_pdf(-5.4)); test_absolute(5.0, 0.1, -97.3905620875659, 1E-12, ln_pdf(-4.9)); test_absolute(-5.0, 0.1, -68.3905620875659, 1E-12, ln_pdf(2.0)); test_exact(f64::INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(5.5)); test_exact(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(-0.0)); test_exact(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); test_absolute(1.0, 1.0, -4.693147180559945, 1E-12, ln_pdf(5.0)); test_absolute(-1.0, 1.0, -f64::consts::LN_2, 1E-12, ln_pdf(-1.0)); test_absolute(5.0, 1.0, -6.693147180559945, 1E-12, ln_pdf(-1.0)); test_absolute(-5.0, 1.0, -8.193147180559945, 1E-12, ln_pdf(2.5)); test_exact(f64::INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(2.0)); test_exact(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(15.0)); test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(89.3)); test_exact(1.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-0.1)); test_exact(-1.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.1)); test_exact(5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-6.1)); test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-10.0)); test_is_nan(f64::INFINITY, f64::INFINITY, ln_pdf(2.0)); test_is_nan(f64::NEG_INFINITY, f64::INFINITY, ln_pdf(-5.1)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Laplace| x.cdf(arg); let loc = 0.0f64; let scale = 1.0f64; let reltol = 1e-15f64; // Expected value from Wolfram Alpha: CDF[LaplaceDistribution[0, 1], 1/2]. let expected = 0.69673467014368328819810023250440977f64; test_rel_close(loc, scale, expected, reltol, cdf(0.5)); // Wolfram Alpha: CDF[LaplaceDistribution[0, 1], -1/2] let expected = 0.30326532985631671180189976749559023f64; test_rel_close(loc, scale, expected, reltol, cdf(-0.5)); // Wolfram Alpha: CDF[LaplaceDistribution[0, 1], -100] let expected = 1.8600379880104179814798479019315592e-44f64; test_rel_close(loc, scale, expected, reltol, cdf(-100.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Laplace| x.sf(arg); let loc = 0.0f64; let scale = 1.0f64; let reltol = 1e-15f64; // Expected value from Wolfram Alpha: SurvivalFunction[LaplaceDistribution[0, 1], 1/2]. let expected = 0.30326532985631671180189976749559022f64; test_rel_close(loc, scale, expected, reltol, sf(0.5)); // Wolfram Alpha: SurvivalFunction[LaplaceDistribution[0, 1], -1/2] let expected = 0.69673467014368328819810023250440977f64; test_rel_close(loc, scale, expected, reltol, sf(-0.5)); // Wolfram Alpha: SurvivalFunction[LaplaceDistribution[0, 1], 100] let expected = 1.86003798801041798147984790193155916e-44; test_rel_close(loc, scale, expected, reltol, sf(100.0)); } #[test] fn test_inverse_cdf() { let inverse_cdf = |arg: f64| move |x: Laplace| x.inverse_cdf(arg); let loc = 0.0f64; let scale = 1.0f64; let reltol = 1e-15f64; // Wolfram Alpha: Inverse CDF[LaplaceDistribution[0, 1], 1/10000000000] let expected = -22.3327037493805115307626824253854655f64; test_rel_close(loc, scale, expected, reltol, inverse_cdf(1e-10)); // Wolfram Alpha: Inverse CDF[LaplaceDistribution[0, 1], 1/1000]. let expected = -6.2146080984221917426367422425949161f64; test_rel_close(loc, scale, expected, reltol, inverse_cdf(0.001)); // Wolfram Alpha: Inverse CDF[LaplaceDistribution[0, 1], 95/100] let expected = 2.3025850929940456840179914546843642f64; test_rel_close(loc, scale, expected, reltol, inverse_cdf(0.95)); } #[cfg(feature = "rand")] #[test] fn test_sample() { use ::rand::distributions::Distribution; use ::rand::thread_rng; let l = create_ok(0.1, 0.5); l.sample(&mut thread_rng()); } #[cfg(feature = "rand")] #[test] fn test_sample_distribution() { use ::rand::distributions::Distribution; use ::rand::rngs::StdRng; use ::rand::SeedableRng; // sanity check sampling let location = 0.0; let scale = 1.0; let n = create_ok(location, scale); let trials = 10_000; let tolerance = 250; for seed in 0..10 { let mut r: StdRng = SeedableRng::seed_from_u64(seed); let result = (0..trials).map(|_| n.sample(&mut r)).fold(0, |sum, val| { if val > 0.0 { sum + 1 } else if val < 0.0 { sum - 1 } else { 0 } }); assert!( result > -tolerance && result < tolerance, "Balance is {result} for seed {seed}" ); } } } statrs-0.18.0/src/distribution/log_normal.rs000064400000000000000000001035411046102023000172430ustar 00000000000000use crate::consts; use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::erf; use crate::statistics::*; use std::f64; /// Implements the /// [Log-normal](https://en.wikipedia.org/wiki/Log-normal_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{LogNormal, Continuous}; /// use statrs::statistics::Distribution; /// use statrs::prec; /// /// let n = LogNormal::new(0.0, 1.0).unwrap(); /// assert_eq!(n.mean().unwrap(), (0.5f64).exp()); /// assert!(prec::almost_eq(n.pdf(1.0), 0.3989422804014326779399, 1e-16)); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct LogNormal { location: f64, scale: f64, } /// Represents the errors that can occur when creating a [`LogNormal`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum LogNormalError { /// The location is NaN. LocationInvalid, /// The scale is NaN, zero or less than zero. ScaleInvalid, } impl std::fmt::Display for LogNormalError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { LogNormalError::LocationInvalid => write!(f, "Location is NaN"), LogNormalError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"), } } } impl std::error::Error for LogNormalError {} impl LogNormal { /// Constructs a new log-normal distribution with a location of `location` /// and a scale of `scale` /// /// # Errors /// /// Returns an error if `location` or `scale` are `NaN`. /// Returns an error if `scale <= 0.0` /// /// # Examples /// /// ``` /// use statrs::distribution::LogNormal; /// /// let mut result = LogNormal::new(0.0, 1.0); /// assert!(result.is_ok()); /// /// result = LogNormal::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` pub fn new(location: f64, scale: f64) -> Result { if location.is_nan() { return Err(LogNormalError::LocationInvalid); } if scale.is_nan() || scale <= 0.0 { return Err(LogNormalError::ScaleInvalid); } Ok(LogNormal { location, scale }) } } impl std::fmt::Display for LogNormal { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "LogNormal({}, {}^2)", self.location, self.scale) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for LogNormal { fn sample(&self, rng: &mut R) -> f64 { super::normal::sample_unchecked(rng, self.location, self.scale).exp() } } impl ContinuousCDF for LogNormal { /// Calculates the cumulative distribution function for the log-normal /// distribution /// at `x` /// /// # Formula /// /// ```text /// (1 / 2) + (1 / 2) * erf((ln(x) - μ) / sqrt(2) * σ) /// ``` /// /// where `μ` is the location, `σ` is the scale, and `erf` is the /// error function fn cdf(&self, x: f64) -> f64 { if x <= 0.0 { 0.0 } else if x.is_infinite() { 1.0 } else { 0.5 * erf::erfc((self.location - x.ln()) / (self.scale * f64::consts::SQRT_2)) } } /// Calculates the survival function for the log-normal /// distribution at `x` /// /// # Formula /// /// ```text /// (1 / 2) + (1 / 2) * erf(-(ln(x) - μ) / sqrt(2) * σ) /// ``` /// /// where `μ` is the location, `σ` is the scale, and `erf` is the /// error function /// /// note that this calculates the complement due to flipping /// the sign of the argument error function with respect to the cdf. /// /// the normal cdf Φ (and internal error function) as the following property: /// ```text /// Φ(-x) + Φ(x) = 1 /// Φ(-x) = 1 - Φ(x) /// ``` fn sf(&self, x: f64) -> f64 { if x <= 0.0 { 1.0 } else if x.is_infinite() { 0.0 } else { 0.5 * erf::erfc((x.ln() - self.location) / (self.scale * f64::consts::SQRT_2)) } } /// Calculates the inverse cumulative distribution function for the /// log-normal distribution at `p` /// /// # Panics /// /// If `p < 0.0` or `p > 1.0` /// /// # Formula /// /// ```text /// μ - σ * sqrt(2) * erfc_inv(2p) /// ``` /// /// where `μ` is the location, `σ` is the scale and `erfc_inv` is /// the inverse of the complementary error function fn inverse_cdf(&self, p: f64) -> f64 { if p == 0.0 { 0.0 } else if p < 1.0 { (self.location - (self.scale * f64::consts::SQRT_2 * erf::erfc_inv(2.0 * p))).exp() } else if p == 1.0 { f64::INFINITY } else { panic!("p must be within [0.0, 1.0]"); } } } impl Min for LogNormal { /// Returns the minimum value in the domain of the log-normal /// distribution representable by a double precision float /// /// # Formula /// /// ```text /// 0 /// ``` fn min(&self) -> f64 { 0.0 } } impl Max for LogNormal { /// Returns the maximum value in the domain of the log-normal /// distribution representable by a double precision float /// /// # Formula /// /// ```text /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY } } impl Distribution for LogNormal { /// Returns the mean of the log-normal distribution /// /// # Formula /// /// ```text /// e^(μ + σ^2 / 2) /// ``` /// /// where `μ` is the location and `σ` is the scale fn mean(&self) -> Option { Some((self.location + self.scale * self.scale / 2.0).exp()) } /// Returns the variance of the log-normal distribution /// /// # Formula /// /// ```text /// (e^(σ^2) - 1) * e^(2μ + σ^2) /// ``` /// /// where `μ` is the location and `σ` is the scale fn variance(&self) -> Option { let sigma2 = self.scale * self.scale; Some((sigma2.exp() - 1.0) * (self.location + self.location + sigma2).exp()) } /// Returns the entropy of the log-normal distribution /// /// # Formula /// /// ```text /// ln(σe^(μ + 1 / 2) * sqrt(2π)) /// ``` /// /// where `μ` is the location and `σ` is the scale fn entropy(&self) -> Option { Some(0.5 + self.scale.ln() + self.location + consts::LN_SQRT_2PI) } /// Returns the skewness of the log-normal distribution /// /// # Formula /// /// ```text /// (e^(σ^2) + 2) * sqrt(e^(σ^2) - 1) /// ``` /// /// where `μ` is the location and `σ` is the scale fn skewness(&self) -> Option { let expsigma2 = (self.scale * self.scale).exp(); Some((expsigma2 + 2.0) * (expsigma2 - 1.0).sqrt()) } } impl Median for LogNormal { /// Returns the median of the log-normal distribution /// /// # Formula /// /// ```text /// e^μ /// ``` /// /// where `μ` is the location fn median(&self) -> f64 { self.location.exp() } } impl Mode> for LogNormal { /// Returns the mode of the log-normal distribution /// /// # Formula /// /// ```text /// e^(μ - σ^2) /// ``` /// /// where `μ` is the location and `σ` is the scale fn mode(&self) -> Option { Some((self.location - self.scale * self.scale).exp()) } } impl Continuous for LogNormal { /// Calculates the probability density function for the log-normal /// distribution at `x` /// /// # Formula /// /// ```text /// (1 / xσ * sqrt(2π)) * e^(-((ln(x) - μ)^2) / 2σ^2) /// ``` /// /// where `μ` is the location and `σ` is the scale fn pdf(&self, x: f64) -> f64 { if x <= 0.0 || x.is_infinite() { 0.0 } else { let d = (x.ln() - self.location) / self.scale; (-0.5 * d * d).exp() / (x * consts::SQRT_2PI * self.scale) } } /// Calculates the log probability density function for the log-normal /// distribution at `x` /// /// # Formula /// /// ```text /// ln((1 / xσ * sqrt(2π)) * e^(-((ln(x) - μ)^2) / 2σ^2)) /// ``` /// /// where `μ` is the location and `σ` is the scale fn ln_pdf(&self, x: f64) -> f64 { if x <= 0.0 || x.is_infinite() { f64::NEG_INFINITY } else { let d = (x.ln() - self.location) / self.scale; (-0.5 * d * d) - consts::LN_SQRT_2PI - (x * self.scale).ln() } } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(location: f64, scale: f64; LogNormal; LogNormalError); #[test] fn test_create() { create_ok(10.0, 0.1); create_ok(-5.0, 1.0); create_ok(0.0, 10.0); create_ok(10.0, 100.0); create_ok(-5.0, f64::INFINITY); } #[test] fn test_bad_create() { test_create_err(f64::NAN, 1.0, LogNormalError::LocationInvalid); test_create_err(1.0, f64::NAN, LogNormalError::ScaleInvalid); create_err(0.0, 0.0); create_err(f64::NAN, f64::NAN); create_err(1.0, -1.0); } #[test] fn test_mean() { let mean = |x: LogNormal| x.mean().unwrap(); test_exact(-1.0, 0.1, 0.369723444544058982601, mean); test_exact(-1.0, 1.5, 1.133148453066826316829, mean); test_exact(-1.0, 2.5, 8.372897488127264663205, mean); test_exact(-1.0, 5.5, 1362729.18425285481771, mean); test_exact(-0.1, 0.1, 0.9093729344682314204933, mean); test_exact(-0.1, 1.5, 2.787095460565850768514, mean); test_exact(-0.1, 2.5, 20.59400471119602917533, mean); test_absolute(-0.1, 5.5, 3351772.941252693807591, 1e-9, mean); test_exact(0.1, 0.1, 1.110710610355705232259, mean); test_exact(0.1, 1.5, 3.40416608279081898632, mean); test_absolute(0.1, 2.5, 25.15357415581836182776, 1e-14, mean); test_absolute(0.1, 5.5, 4093864.715172665106863, 1e-8, mean); test_absolute(1.5, 0.1, 4.50415363028848413209, 1e-15, mean); test_exact(1.5, 1.5, 13.80457418606709491926, mean); test_exact(1.5, 2.5, 102.0027730826996844534, mean); test_exact(1.5, 5.5, 16601440.05723477471392, mean); test_absolute(2.5, 0.1, 12.24355896580102707724, 1e-14, mean); test_absolute(2.5, 1.5, 37.52472315960099891407, 1e-11, mean); test_exact(2.5, 2.5, 277.2722845231339804081, mean); test_exact(2.5, 5.5, 45127392.83383337999291, mean); test_absolute(5.5, 0.1, 245.9184556788219446833, 1e-13, mean); test_exact(5.5, 1.5, 753.7042125545612656606, mean); test_exact(5.5, 2.5, 5569.162708566004074422, mean); test_exact(5.5, 5.5, 906407915.0111549133446, mean); } #[test] fn test_variance() { let variance = |x: LogNormal| x.variance().unwrap(); test_absolute(-1.0, 0.1, 0.001373811865368952608715, 1e-16, variance); test_exact(-1.0, 1.5, 10.898468544015731954, variance); test_exact(-1.0, 2.5, 36245.39726189994988081, variance); test_absolute(-1.0, 5.5, 2.5481629178024539E+25, 1e10, variance); test_absolute(-0.1, 0.1, 0.008311077467909703803238, 1e-16, variance); test_exact(-0.1, 1.5, 65.93189259328902509552, variance); test_absolute(-0.1, 2.5, 219271.8756420929704707, 1e-10, variance); test_absolute(-0.1, 5.5, 1.541548733459471E+26, 1e12, variance); test_absolute(0.1, 0.1, 0.01239867063063756838894, 1e-15, variance); test_absolute(0.1, 1.5, 98.35882573290010981464, 1e-13, variance); test_absolute(0.1, 2.5, 327115.1995809995715014, 1e-10, variance); test_absolute(0.1, 5.5, 2.299720473192458E+26, 1e12, variance); test_absolute(1.5, 0.1, 0.2038917589520099120699, 1e-14, variance); test_absolute(1.5, 1.5, 1617.476145997433210727, 1e-12, variance); test_absolute(1.5, 2.5, 5379293.910566451644527, 1e-9, variance); test_absolute(1.5, 5.5, 3.7818090853910142E+27, 1e12, variance); test_absolute(2.5, 0.1, 1.506567645006046841936, 1e-13, variance); test_absolute(2.5, 1.5, 11951.62198145717670088, 1e-11, variance); test_exact(2.5, 2.5, 39747904.47781154725843, variance); test_absolute(2.5, 5.5, 2.7943999487399818E+28, 1e13, variance); test_absolute(5.5, 0.1, 607.7927673399807484235, 1e-11, variance); test_exact(5.5, 1.5, 4821628.436260521100027, variance); test_exact(5.5, 2.5, 16035449147.34799637823, variance); test_exact(5.5, 5.5, 1.127341399856331737823E+31, variance); } #[test] fn test_entropy() { let entropy = |x: LogNormal| x.entropy().unwrap(); test_exact(-1.0, 0.1, -1.8836465597893728867265104870209210873020761202386, entropy); test_exact(-1.0, 1.5, 0.82440364131283712375834285186996677643338789710028, entropy); test_exact(-1.0, 2.5, 1.335229265078827806963856948173628711311498693546, entropy); test_exact(-1.0, 5.5, 2.1236866254430979764250411929125703716076041932149, entropy); test_absolute(-0.1, 0.1, -0.9836465597893728922776256101467037894202344606927, 1e-15, entropy); test_exact(-0.1, 1.5, 1.7244036413128371182072277287441840743152295566462, entropy); test_exact(-0.1, 2.5, 2.2352292650788278014127418250478460091933403530919, entropy); test_exact(-0.1, 5.5, 3.0236866254430979708739260697867876694894458527608, entropy); test_absolute(0.1, 0.1, -0.7836465597893728811753953638951383851839177797845, 1e-15, entropy); test_absolute(0.1, 1.5, 1.9244036413128371293094579749957494785515462375544, 1e-15, entropy); test_exact(0.1, 2.5, 2.4352292650788278125149720712994114134296570340001, entropy); test_exact(0.1, 5.5, 3.223686625443097981976156316038353073725762533669, entropy); test_absolute(1.5, 0.1, 0.6163534402106271132734895129790789126979238797614, 1e-15, entropy); test_exact(1.5, 1.5, 3.3244036413128371237583428518699667764333878971003, entropy); test_exact(1.5, 2.5, 3.835229265078827806963856948173628711311498693546, entropy); test_exact(1.5, 5.5, 4.6236866254430979764250411929125703716076041932149, entropy); test_exact(2.5, 0.1, 1.6163534402106271132734895129790789126979238797614, entropy); test_absolute(2.5, 1.5, 4.3244036413128371237583428518699667764333878971003, 1e-15, entropy); test_exact(2.5, 2.5, 4.835229265078827806963856948173628711311498693546, entropy); test_exact(2.5, 5.5, 5.6236866254430979764250411929125703716076041932149, entropy); test_exact(5.5, 0.1, 4.6163534402106271132734895129790789126979238797614, entropy); test_absolute(5.5, 1.5, 7.3244036413128371237583428518699667764333878971003, 1e-15, entropy); test_exact(5.5, 2.5, 7.835229265078827806963856948173628711311498693546, entropy); test_exact(5.5, 5.5, 8.6236866254430979764250411929125703716076041932149, entropy); } #[test] fn test_skewness() { let skewness = |x: LogNormal| x.skewness().unwrap(); test_absolute(-1.0, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); test_exact(-1.0, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); test_absolute(-1.0, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); test_absolute(-1.0, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); test_absolute(-0.1, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); test_exact(-0.1, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); test_absolute(-0.1, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); test_absolute(-0.1, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); test_absolute(0.1, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); test_exact(0.1, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); test_absolute(0.1, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); test_absolute(0.1, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); test_absolute(1.5, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); test_exact(1.5, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); test_absolute(1.5, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); test_absolute(1.5, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); test_absolute(2.5, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); test_exact(2.5, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); test_absolute(2.5, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); test_absolute(2.5, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); test_absolute(5.5, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); test_exact(5.5, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); test_absolute(5.5, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); test_absolute(5.5, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); } #[test] fn test_mode() { let mode = |x: LogNormal| x.mode().unwrap(); test_exact(-1.0, 0.1, 0.36421897957152331652213191863106773137983085909534, mode); test_exact(-1.0, 1.5, 0.03877420783172200988689983526759614326014406193602, mode); test_exact(-1.0, 2.5, 0.0007101743888425490635846003705775444086763023873619, mode); test_exact(-1.0, 5.5, 0.000000000000026810038677818032221548731163905979029274677187036, mode); test_exact(-0.1, 0.1, 0.89583413529652823774737070060865897390995185639633, mode); test_exact(-0.1, 1.5, 0.095369162215549610417813418326627245539514227574881, mode); test_exact(-0.1, 2.5, 0.0017467471362611196181003627521060283221112106850165, mode); test_exact(-0.1, 5.5, 0.00000000000006594205454219929159167575814655534255162059017114, mode); test_exact(0.1, 0.1, 1.0941742837052103542285651753780976842292770841345, mode); test_exact(0.1, 1.5, 0.11648415777349696821514223131929465848700730137808, mode); test_exact(0.1, 2.5, 0.0021334817700377079925027678518795817076296484352472, mode); test_exact(0.1, 5.5, 0.000000000000080541807296590798973741710866097756565304960216803, mode); test_exact(1.5, 0.1, 4.4370955190036645692996309927420381428715912422597, mode); test_exact(1.5, 1.5, 0.47236655274101470713804655094326791297020357913648, mode); test_exact(1.5, 2.5, 0.008651695203120634177071503957250390848166331197708, mode); test_exact(1.5, 5.5, 0.00000000000032661313427874471360158184468030186601222739665225, mode); test_exact(2.5, 0.1, 12.061276120444720299113038763305617245808510584994, mode); test_exact(2.5, 1.5, 1.2840254166877414840734205680624364583362808652815, mode); test_exact(2.5, 2.5, 0.023517745856009108236151185100432939470067655273072, mode); test_exact(2.5, 5.5, 0.00000000000088782654784596584473099190326928541185172970391855, mode); test_exact(5.5, 0.1, 242.2572068579541371904816252345031593584721473492, mode); test_exact(5.5, 1.5, 25.790339917193062089080107669377221876655268848954, mode); test_exact(5.5, 2.5, 0.47236655274101470713804655094326791297020357913648, mode); test_exact(5.5, 5.5, 0.000000000017832472908146389493511850431527026413424899198327, mode); } #[test] fn test_median() { let median = |x: LogNormal| x.median(); test_exact(-1.0, 0.1, 0.36787944117144232159552377016146086744581113103177, median); test_exact(-1.0, 1.5, 0.36787944117144232159552377016146086744581113103177, median); test_exact(-1.0, 2.5, 0.36787944117144232159552377016146086744581113103177, median); test_exact(-1.0, 5.5, 0.36787944117144232159552377016146086744581113103177, median); test_exact(-0.1, 0.1, 0.90483741803595956814139238421693559530906465375738, median); test_exact(-0.1, 1.5, 0.90483741803595956814139238421693559530906465375738, median); test_exact(-0.1, 2.5, 0.90483741803595956814139238421693559530906465375738, median); test_exact(-0.1, 5.5, 0.90483741803595956814139238421693559530906465375738, median); test_exact(0.1, 0.1, 1.1051709180756476309466388234587796577416634163742, median); test_exact(0.1, 1.5, 1.1051709180756476309466388234587796577416634163742, median); test_exact(0.1, 2.5, 1.1051709180756476309466388234587796577416634163742, median); test_exact(0.1, 5.5, 1.1051709180756476309466388234587796577416634163742, median); test_exact(1.5, 0.1, 4.4816890703380648226020554601192758190057498683697, median); test_exact(1.5, 1.5, 4.4816890703380648226020554601192758190057498683697, median); test_exact(1.5, 2.5, 4.4816890703380648226020554601192758190057498683697, median); test_exact(1.5, 5.5, 4.4816890703380648226020554601192758190057498683697, median); test_exact(2.5, 0.1, 12.182493960703473438070175951167966183182767790063, median); test_exact(2.5, 1.5, 12.182493960703473438070175951167966183182767790063, median); test_exact(2.5, 2.5, 12.182493960703473438070175951167966183182767790063, median); test_exact(2.5, 5.5, 12.182493960703473438070175951167966183182767790063, median); test_exact(5.5, 0.1, 244.6919322642203879151889495118393501842287101075, median); test_exact(5.5, 1.5, 244.6919322642203879151889495118393501842287101075, median); test_exact(5.5, 2.5, 244.6919322642203879151889495118393501842287101075, median); test_exact(5.5, 5.5, 244.6919322642203879151889495118393501842287101075, median); } #[test] fn test_min_max() { let min = |x: LogNormal| x.min(); let max = |x: LogNormal| x.max(); test_exact(0.0, 0.1, 0.0, min); test_exact(-3.0, 10.0, 0.0, min); test_exact(0.0, 0.1, f64::INFINITY, max); test_exact(-3.0, 10.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: LogNormal| x.pdf(arg); test_absolute(-0.1, 0.1, 1.7968349035073582236359415565799753846986440127816e-104, 1e-118, pdf(0.1)); test_absolute(-0.1, 0.1, 0.00000018288923328441197822391757965928083462391836798722, 1e-21, pdf(0.5)); test_exact(-0.1, 0.1, 2.3363114904470413709866234247494393485647978367885, pdf(0.8)); test_absolute(-0.1, 1.5, 0.90492497850024368541682348133921492204585092983646, 1e-15, pdf(0.1)); test_absolute(-0.1, 1.5, 0.49191985207660942803818797602364034466489243416574, 1e-16, pdf(0.5)); test_exact(-0.1, 1.5, 0.33133347214343229148978298237579567194870525187207, pdf(0.8)); test_exact(-0.1, 2.5, 1.0824698632626565182080576574958317806389057196768, pdf(0.1)); test_absolute(-0.1, 2.5, 0.31029619474753883558901295436486123689563749784867, 1e-16, pdf(0.5)); test_absolute(-0.1, 2.5, 0.19922929916156673799861939824205622734205083805245, 1e-16, pdf(0.8)); // Test removed because it was causing compiler issues (see issue 31407 for rust) // test_absolute(1.5, 0.1, 4.1070141770545881694056265342787422035256248474059e-313, 1e-322, pdf(0.1)); // test_absolute(1.5, 0.1, 2.8602688726477103843476657332784045661507239533567e-104, 1e-116, pdf(0.5)); test_exact(1.5, 0.1, 1.6670425710002183246335601541889400558525870482613e-64, pdf(0.8)); test_absolute(1.5, 1.5, 0.10698412103361841220076392503406214751353235895732, 1e-16, pdf(0.1)); test_absolute(1.5, 1.5, 0.18266125308224685664142384493330155315630876975024, 1e-16, pdf(0.5)); test_absolute(1.5, 1.5, 0.17185785323404088913982425377565512294017306418953, 1e-16, pdf(0.8)); test_absolute(1.5, 2.5, 0.50186885259059181992025035649158160252576845315332, 1e-15, pdf(0.1)); test_absolute(1.5, 2.5, 0.21721369314437986034957451699565540205404697589349, 1e-16, pdf(0.5)); test_exact(1.5, 2.5, 0.15729636000661278918949298391170443742675565300598, pdf(0.8)); test_exact(2.5, 0.1, 5.6836826548848916385760779034504046896805825555997e-500, pdf(0.1)); test_absolute(2.5, 0.1, 3.1225608678589488061206338085285607881363155340377e-221, 1e-233, pdf(0.5)); test_absolute(2.5, 0.1, 4.6994713794671660918554320071312374073172560048297e-161, 1e-173, pdf(0.8)); test_absolute(2.5, 1.5, 0.015806486291412916772431170442330946677601577502353, 1e-16, pdf(0.1)); test_absolute(2.5, 1.5, 0.055184331257528847223852028950484131834529030116388, 1e-16, pdf(0.5)); test_exact(2.5, 1.5, 0.063982134749859504449658286955049840393511776984362, pdf(0.8)); test_absolute(2.5, 2.5, 0.25212505662402617595900822552548977822542300480086, 1e-15, pdf(0.1)); test_absolute(2.5, 2.5, 0.14117186955911792460646517002386088579088567275401, 1e-16, pdf(0.5)); test_absolute(2.5, 2.5, 0.11021452580363707866161369621432656293405065561317, 1e-16, pdf(0.8)); } #[test] fn test_neg_pdf() { let pdf = |arg: f64| move |x: LogNormal| x.pdf(arg); test_exact(0.0, 1.0, 0.0, pdf(0.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: LogNormal| x.ln_pdf(arg); test_exact(-0.1, 0.1, -238.88282294119596467794686179588610665317241097599, ln_pdf(0.1)); test_absolute(-0.1, 0.1, -15.514385149961296196003163062199569075052113039686, 1e-14, ln_pdf(0.5)); test_exact(-0.1, 0.1, 0.84857339958981283964373051826407417105725729082041, ln_pdf(0.8)); test_absolute(-0.1, 1.5, -0.099903235403144611051953094864849327288457482212211, 1e-15, ln_pdf(0.1)); test_absolute(-0.1, 1.5, -0.70943947804316122682964396008813828577195771418027, 1e-15, ln_pdf(0.5)); test_absolute(-0.1, 1.5, -1.1046299420497998262946038709903250420774183529995, 1e-15, ln_pdf(0.8)); test_absolute(-0.1, 2.5, 0.07924534056485078867266307735371665927517517183681, 1e-16, ln_pdf(0.1)); test_exact(-0.1, 2.5, -1.1702279707433794860424967893989374511050637417043, ln_pdf(0.5)); test_exact(-0.1, 2.5, -1.6132988605030400828957768752511536087538109996183, ln_pdf(0.8)); test_exact(1.5, 0.1, -719.29643782024317312262673764204041218720576249741, ln_pdf(0.1)); test_absolute(1.5, 0.1, -238.41793403955250272430898754048547661932857086122, 1e-13, ln_pdf(0.5)); test_exact(1.5, 0.1, -146.85439481068371057247137024006716189469284256628, ln_pdf(0.8)); test_absolute(1.5, 1.5, -2.2350748570877992856465076624973458117562108140674, 1e-15, ln_pdf(0.1)); test_absolute(1.5, 1.5, -1.7001219175524556705452882616787223585705662860012, 1e-15, ln_pdf(0.5)); test_absolute(1.5, 1.5, -1.7610875785399045023354101841009649273236721172008, 1e-15, ln_pdf(0.8)); test_absolute(1.5, 2.5, -0.68941644324162489418137656699398207513321602763104, 1e-15, ln_pdf(0.1)); test_exact(1.5, 2.5, -1.5268736489667254857801287379715477173125628275598, ln_pdf(0.5)); test_exact(1.5, 2.5, -1.8496236096394777662704671479709839674424623547308, ln_pdf(0.8)); test_absolute(2.5, 0.1, -1149.5549471196476523788026360929146688367845019398, 1e-12, ln_pdf(0.1)); test_absolute(2.5, 0.1, -507.73265209554698134113704985174959301922196605736, 1e-12, ln_pdf(0.5)); test_absolute(2.5, 0.1, -369.16874994210463740474549611573497379941224077335, 1e-13, ln_pdf(0.8)); test_absolute(2.5, 1.5, -4.1473348984184862316495477617980296904955324113457, 1e-15, ln_pdf(0.1)); test_absolute(2.5, 1.5, -2.8970762200235424747307247601045786110485663457169, 1e-15, ln_pdf(0.5)); test_exact(2.5, 1.5, -2.7491513791239977024488074547907467152956602019989, ln_pdf(0.8)); test_absolute(2.5, 2.5, -1.3778300581206721947424710027422282714793718026513, 1e-15, ln_pdf(0.1)); test_exact(2.5, 2.5, -1.9577771978563167352868858774048559682046428490575, ln_pdf(0.5)); test_exact(2.5, 2.5, -2.2053265778497513183112901654193054111123780652581, ln_pdf(0.8)); } #[test] fn test_neg_ln_pdf() { let ln_pdf = |arg: f64| move |x: LogNormal| x.ln_pdf(arg); test_exact(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(0.0)); } #[test] fn test_cdf() { cdf_tests(false); } #[test] fn test_inverse_cdf() { cdf_tests(true) } // we can reuse the (input, output) pairs from the CDF unit test // and verify that passing an 'output' to .inverse_cdf gives 'input', // except in cases where output would be 0.0 (the inverse_cdf is defined to // always give 0.0 in this case). fn cdf_tests(inverse: bool) { let f = |arg: f64| move |x: LogNormal| if inverse { x.inverse_cdf(arg) } else { x.cdf(arg) }; // given some cdf_input and cdf_output, returns a tuple (input, output) where // input is what we will provide to cdf/inverse_cdf, and output is expected return // value let arrange_input_output = |cdf_input: f64, cdf_output: f64| { if inverse { (cdf_output, cdf_input) } else { (cdf_input, cdf_output) } }; // calls test_almost after re-arranging the input/output arguments and calling f with input let almost = |mean: f64, std_dev: f64, cdf_input: f64, cdf_output: f64, acc: f64| { let (input, output) = arrange_input_output(cdf_input, cdf_output); test_absolute(mean, std_dev, output, acc, f(input)); }; // calls test_case after re-arranging the input/output arguments and calling f with input let case = |mean: f64, std_dev: f64, cdf_input: f64, cdf_output: f64| { let (input, output) = arrange_input_output(cdf_input, cdf_output); test_exact(mean, std_dev, output, f(input)); }; // we skip cases where the CDF outputs 0.0 when testing the inverse CDF because // there are multiple inputs to the CDF which give an answer of 0.0, therefore testing whether // inputting 0.0 to the inverse cdf will give the same answer is not a valid test // the inverse cdf for log-normal is defined to give answer 0.0 for input 0.0 if inverse { case(-0.1, 0.1, 0.0, 0.0); } if !inverse { almost(-0.1, 0.1, 0.1, 0.0, 1e-107); } almost(-0.1, 0.1, 0.5, 0.0000000015011556178148777579869633555518882664666520593658, 1e-16); almost(-0.1, 0.1, 0.8, 0.10908001076375810900224507908874442583171381706127, 1e-11); almost(-0.1, 1.5, 0.1, 0.070999149762464508991968731574953594549291668468349, 1e-11); case(-0.1, 1.5, 0.5, 0.34626224992888089297789445771047690175505847991946); case(-0.1, 1.5, 0.8, 0.46728530589487698517090261668589508746353129242404); almost(-0.1, 2.5, 0.1, 0.18914969879695093477606645992572208111152994999076, 1e-10); case(-0.1, 2.5, 0.5, 0.40622798321378106125020505907901206714868922279347); case(-0.1, 2.5, 0.8, 0.48035707589956665425068652807400957345208517749893); // input to inverse would be 0.0 if !inverse { almost(1.5, 0.1, 0.1, 0.0, 1e-315); almost(1.5, 0.1, 0.5, 0.0, 1e-106); almost(1.5, 0.1, 0.8, 0.0, 1e-66); } almost(1.5, 1.5, 0.1, 0.005621455876973168709588070988239748831823850202953, 1e-12); almost(1.5, 1.5, 0.8, 0.12532699044614938400496547188720940854423187977236, 1e-11); almost(1.5, 2.5, 0.1, 0.064125647996943514411570834861724406903677144126117, 1e-11); almost(1.5, 2.5, 0.5, 0.19017302281590810871719754032332631806011441356498, 1e-10); almost(1.5, 2.5, 0.8, 0.24533064397555500690927047163085419096928289095201, 1e-16); // input to inverse would be 0.0 if !inverse { case(2.5, 0.1, 0.1, 0.0); almost(2.5, 0.1, 0.5, 0.0, 1e-223); almost(2.5, 0.1, 0.8, 0.0, 1e-162); } almost(2.5, 1.5, 0.1, 0.00068304052220788502001572635016579586444611070077399, 1e-13); almost(2.5, 1.5, 0.5, 0.016636862816580533038130583128179878924863968664206, 1e-12); almost(2.5, 1.5, 0.8, 0.034729001282904174941366974418836262996834852343018, 1e-11); almost(2.5, 2.5, 0.1, 0.027363708266690978870139978537188410215717307180775, 1e-11); almost(2.5, 2.5, 0.5, 0.10075543423327634536450625420610429181921642201567, 1e-11); almost(2.5, 2.5, 0.8, 0.13802019192453118732001307556787218421918336849121, 1e-11); } #[test] fn test_sf() { let sf = |arg: f64| move |x: LogNormal| x.sf(arg); // Wolfram Alpha:: SurvivalFunction[ LogNormalDistribution(-0.1, 0.1), 0.1] test_absolute(-0.1, 0.1, 1.0, 1e-107, sf(0.1)); // Wolfram Alpha:: SurvivalFunction[ LogNormalDistribution(-0.1, 0.1), 0.8] test_absolute(-0.1, 0.1, 0.890919989231123, 1e-14, sf(0.8)); // Wolfram Alpha:: SurvivalFunction[LogNormalDistribution[1.5, 1], 0.8] test_absolute(1.5, 1.0, 0.957568715612642, 1e-14, sf(0.8)); // Wolfram Alpha:: SurvivalFunction[ LogNormalDistribution(2.5, 1.5), 0.1] test_absolute(2.5, 1.5, 0.9993169594777358, 1e-14, sf(0.1)); } #[test] fn test_neg_cdf() { let cdf = |arg: f64| move |x: LogNormal| x.cdf(arg); test_exact(0.0, 1.0, 0.0, cdf(0.0)); } #[test] fn test_neg_sf() { let sf = |arg: f64| move |x: LogNormal| x.sf(arg); test_exact(0.0, 1.0, 1.0, sf(0.0)); } #[test] fn test_continuous() { test::check_continuous_distribution(&create_ok(0.0, 0.25), 0.0, 10.0); test::check_continuous_distribution(&create_ok(0.0, 0.5), 0.0, 10.0); } } statrs-0.18.0/src/distribution/mod.rs000064400000000000000000000224651046102023000156760ustar 00000000000000//! Defines common interfaces for interacting with statistical distributions //! and provides //! concrete implementations for a variety of distributions. use super::statistics::{Max, Min}; use ::num_traits::{Float, Num}; use num_traits::NumAssignOps; pub use self::bernoulli::Bernoulli; pub use self::beta::{Beta, BetaError}; pub use self::binomial::{Binomial, BinomialError}; pub use self::categorical::{Categorical, CategoricalError}; pub use self::cauchy::{Cauchy, CauchyError}; pub use self::chi::{Chi, ChiError}; pub use self::chi_squared::ChiSquared; pub use self::dirac::{Dirac, DiracError}; #[cfg(feature = "nalgebra")] pub use self::dirichlet::{Dirichlet, DirichletError}; pub use self::discrete_uniform::{DiscreteUniform, DiscreteUniformError}; pub use self::empirical::Empirical; pub use self::erlang::Erlang; pub use self::exponential::{Exp, ExpError}; pub use self::fisher_snedecor::{FisherSnedecor, FisherSnedecorError}; pub use self::gamma::{Gamma, GammaError}; pub use self::geometric::{Geometric, GeometricError}; pub use self::gumbel::{Gumbel, GumbelError}; pub use self::hypergeometric::{Hypergeometric, HypergeometricError}; pub use self::inverse_gamma::{InverseGamma, InverseGammaError}; pub use self::laplace::{Laplace, LaplaceError}; pub use self::log_normal::{LogNormal, LogNormalError}; #[cfg(feature = "nalgebra")] pub use self::multinomial::{Multinomial, MultinomialError}; #[cfg(feature = "nalgebra")] pub use self::multivariate_normal::{MultivariateNormal, MultivariateNormalError}; #[cfg(feature = "nalgebra")] pub use self::multivariate_students_t::{MultivariateStudent, MultivariateStudentError}; pub use self::negative_binomial::{NegativeBinomial, NegativeBinomialError}; pub use self::normal::{Normal, NormalError}; pub use self::pareto::{Pareto, ParetoError}; pub use self::poisson::{Poisson, PoissonError}; pub use self::students_t::{StudentsT, StudentsTError}; pub use self::triangular::{Triangular, TriangularError}; pub use self::uniform::{Uniform, UniformError}; pub use self::weibull::{Weibull, WeibullError}; mod bernoulli; mod beta; mod binomial; mod categorical; mod cauchy; mod chi; mod chi_squared; mod dirac; #[cfg(feature = "nalgebra")] #[cfg_attr(docsrs, doc(cfg(feature = "nalgebra")))] mod dirichlet; mod discrete_uniform; mod empirical; mod erlang; mod exponential; mod fisher_snedecor; mod gamma; mod geometric; mod gumbel; mod hypergeometric; #[macro_use] mod internal; mod inverse_gamma; mod laplace; mod log_normal; #[cfg(feature = "nalgebra")] #[cfg_attr(docsrs, doc(cfg(feature = "nalgebra")))] mod multinomial; #[cfg(feature = "nalgebra")] #[cfg_attr(docsrs, doc(cfg(feature = "nalgebra")))] mod multivariate_normal; #[cfg(feature = "nalgebra")] #[cfg_attr(docsrs, doc(cfg(feature = "nalgebra")))] mod multivariate_students_t; mod negative_binomial; mod normal; mod pareto; mod poisson; mod students_t; mod triangular; mod uniform; mod weibull; #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] mod ziggurat; #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] mod ziggurat_tables; /// The `ContinuousCDF` trait is used to specify an interface for univariate /// distributions for which cdf float arguments are sensible. pub trait ContinuousCDF: Min + Max { /// Returns the cumulative distribution function calculated /// at `x` for a given distribution. May panic depending /// on the implementor. /// /// # Examples /// /// ``` /// use statrs::distribution::{ContinuousCDF, Uniform}; /// /// let n = Uniform::new(0.0, 1.0).unwrap(); /// assert_eq!(0.5, n.cdf(0.5)); /// ``` fn cdf(&self, x: K) -> T; /// Returns the survival function calculated /// at `x` for a given distribution. May panic depending /// on the implementor. /// /// # Examples /// /// ``` /// use statrs::distribution::{ContinuousCDF, Uniform}; /// /// let n = Uniform::new(0.0, 1.0).unwrap(); /// assert_eq!(0.5, n.sf(0.5)); /// ``` fn sf(&self, x: K) -> T { T::one() - self.cdf(x) } /// Due to issues with rounding and floating-point accuracy the default /// implementation may be ill-behaved. /// Specialized inverse cdfs should be used whenever possible. /// Performs a binary search on the domain of `cdf` to obtain an approximation /// of `F^-1(p) := inf { x | F(x) >= p }`. Needless to say, performance may /// may be lacking. #[doc(alias = "quantile function")] #[doc(alias = "quantile")] fn inverse_cdf(&self, p: T) -> K { if p == T::zero() { return self.min(); }; if p == T::one() { return self.max(); }; let two = K::one() + K::one(); let mut high = two; let mut low = -high; while self.cdf(low) > p { low = low + low; } while self.cdf(high) < p { high = high + high; } let mut i = 16; while i != 0 { let mid = (high + low) / two; if self.cdf(mid) >= p { high = mid; } else { low = mid; } i -= 1; } (high + low) / two } } /// The `DiscreteCDF` trait is used to specify an interface for univariate /// discrete distributions. pub trait DiscreteCDF: Min + Max { /// Returns the cumulative distribution function calculated /// at `x` for a given distribution. May panic depending /// on the implementor. /// /// # Examples /// /// ``` /// use statrs::distribution::{DiscreteCDF, DiscreteUniform}; /// /// let n = DiscreteUniform::new(1, 10).unwrap(); /// assert_eq!(0.6, n.cdf(6)); /// ``` fn cdf(&self, x: K) -> T; /// Returns the survival function calculated at `x` for /// a given distribution. May panic depending on the implementor. /// /// # Examples /// /// ``` /// use statrs::distribution::{DiscreteCDF, DiscreteUniform}; /// /// let n = DiscreteUniform::new(1, 10).unwrap(); /// assert_eq!(0.4, n.sf(6)); /// ``` fn sf(&self, x: K) -> T { T::one() - self.cdf(x) } /// Due to issues with rounding and floating-point accuracy the default implementation may be ill-behaved /// Specialized inverse cdfs should be used whenever possible. /// /// # Panics /// this default impl panics if provided `p` not on interval [0.0, 1.0] fn inverse_cdf(&self, p: T) -> K { if p == T::zero() { return self.min(); } else if p == T::one() { return self.max(); } else if !(T::zero()..=T::one()).contains(&p) { panic!("p must be on [0, 1]") } let two = K::one() + K::one(); let mut ub = two.clone(); let lb = self.min(); while self.cdf(ub.clone()) < p { ub *= two.clone(); } internal::integral_bisection_search(|p| self.cdf(p.clone()), p, lb, ub).unwrap() } } /// The `Continuous` trait provides an interface for interacting with /// continuous statistical distributions /// /// # Remarks /// /// All methods provided by the `Continuous` trait are unchecked, meaning /// they can panic if in an invalid state or encountering invalid input /// depending on the implementing distribution. pub trait Continuous { /// Returns the probability density function calculated at `x` for a given /// distribution. /// May panic depending on the implementor. /// /// # Examples /// /// ``` /// use statrs::distribution::{Continuous, Uniform}; /// /// let n = Uniform::new(0.0, 1.0).unwrap(); /// assert_eq!(1.0, n.pdf(0.5)); /// ``` fn pdf(&self, x: K) -> T; /// Returns the log of the probability density function calculated at `x` /// for a given distribution. /// May panic depending on the implementor. /// /// # Examples /// /// ``` /// use statrs::distribution::{Continuous, Uniform}; /// /// let n = Uniform::new(0.0, 1.0).unwrap(); /// assert_eq!(0.0, n.ln_pdf(0.5)); /// ``` fn ln_pdf(&self, x: K) -> T; } /// The `Discrete` trait provides an interface for interacting with discrete /// statistical distributions /// /// # Remarks /// /// All methods provided by the `Discrete` trait are unchecked, meaning /// they can panic if in an invalid state or encountering invalid input /// depending on the implementing distribution. pub trait Discrete { /// Returns the probability mass function calculated at `x` for a given /// distribution. /// May panic depending on the implementor. /// /// # Examples /// /// ``` /// use statrs::distribution::{Discrete, Binomial}; /// use statrs::prec; /// /// let n = Binomial::new(0.5, 10).unwrap(); /// assert!(prec::almost_eq(n.pmf(5), 0.24609375, 1e-15)); /// ``` fn pmf(&self, x: K) -> T; /// Returns the log of the probability mass function calculated at `x` for /// a given distribution. /// May panic depending on the implementor. /// /// # Examples /// /// ``` /// use statrs::distribution::{Discrete, Binomial}; /// use statrs::prec; /// /// let n = Binomial::new(0.5, 10).unwrap(); /// assert!(prec::almost_eq(n.ln_pmf(5), (0.24609375f64).ln(), 1e-15)); /// ``` fn ln_pmf(&self, x: K) -> T; } statrs-0.18.0/src/distribution/multinomial.rs000064400000000000000000000413631046102023000174470ustar 00000000000000use crate::distribution::Discrete; use crate::function::factorial; use crate::statistics::*; use nalgebra::{Dim, Dyn, OMatrix, OVector}; /// Implements the /// [Multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution) /// distribution which is a generalization of the /// [Binomial](https://en.wikipedia.org/wiki/Binomial_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Multinomial; /// use statrs::statistics::MeanN; /// use nalgebra::vector; /// /// let n = Multinomial::new_from_nalgebra(vector![0.3, 0.7], 5).unwrap(); /// assert_eq!(n.mean().unwrap(), (vector![1.5, 3.5])); /// ``` #[derive(Debug, Clone, PartialEq)] pub struct Multinomial where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { /// normalized probabilities for each species p: OVector, /// count of trials n: u64, } /// Represents the errors that can occur when creating a [`Multinomial`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum MultinomialError { /// Fewer than two probabilities. NotEnoughProbabilities, /// The sum of all probabilities is zero. ProbabilitySumZero, /// At least one probability is NaN, infinite or less than zero. ProbabilityInvalid, } impl std::fmt::Display for MultinomialError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { MultinomialError::NotEnoughProbabilities => write!(f, "Fewer than two probabilities"), MultinomialError::ProbabilitySumZero => write!(f, "The probabilities sum up to zero"), MultinomialError::ProbabilityInvalid => write!( f, "At least one probability is NaN, infinity or less than zero" ), } } } impl std::error::Error for MultinomialError {} impl Multinomial { /// Constructs a new multinomial distribution with probabilities `p` /// and `n` number of trials. /// /// # Errors /// /// Returns an error if `p` is empty, the sum of the elements /// in `p` is 0, or any element in `p` is less than 0 or is `f64::NAN` /// /// # Note /// /// The elements in `p` do not need to be normalized /// /// # Examples /// /// ``` /// use statrs::distribution::Multinomial; /// /// let mut result = Multinomial::new(vec![0.0, 1.0, 2.0], 3); /// assert!(result.is_ok()); /// /// result = Multinomial::new(vec![0.0, -1.0, 2.0], 3); /// assert!(result.is_err()); /// ``` pub fn new(p: Vec, n: u64) -> Result { Self::new_from_nalgebra(p.into(), n) } } impl Multinomial where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { pub fn new_from_nalgebra(mut p: OVector, n: u64) -> Result { if p.len() < 2 { return Err(MultinomialError::NotEnoughProbabilities); } let mut sum = 0.0; for &val in &p { if val.is_nan() || val < 0.0 { return Err(MultinomialError::ProbabilityInvalid); } sum += val; } if sum == 0.0 { return Err(MultinomialError::ProbabilitySumZero); } p.unscale_mut(p.lp_norm(1)); Ok(Self { p, n }) } /// Returns the probabilities of the multinomial /// distribution as a slice /// /// # Examples /// /// ``` /// use statrs::distribution::Multinomial; /// use nalgebra::dvector; /// /// let n = Multinomial::new(vec![0.0, 1.0, 2.0], 3).unwrap(); /// assert_eq!(*n.p(), dvector![0.0, 1.0/3.0, 2.0/3.0]); /// ``` pub fn p(&self) -> &OVector { &self.p } /// Returns the number of trials of the multinomial /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Multinomial; /// /// let n = Multinomial::new(vec![0.0, 1.0, 2.0], 3).unwrap(); /// assert_eq!(n.n(), 3); /// ``` pub fn n(&self) -> u64 { self.n } } impl std::fmt::Display for Multinomial where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Multinom({:#?},{})", self.p, self.n) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution> for Multinomial where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { fn sample(&self, rng: &mut R) -> OVector { sample_generic(self, rng) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution> for Multinomial where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { fn sample(&self, rng: &mut R) -> OVector { sample_generic(self, rng) } } #[cfg(feature = "rand")] fn sample_generic(dist: &Multinomial, rng: &mut R) -> OVector where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, R: ::rand::Rng + ?Sized, T: ::num_traits::Num + ::nalgebra::Scalar + ::std::ops::AddAssign, { use nalgebra::Const; let p_cdf = super::categorical::prob_mass_to_cdf(dist.p().as_slice()); let mut res = OVector::zeros_generic(dist.p.shape_generic().0, Const::<1>); for _ in 0..dist.n { let i = super::categorical::sample_unchecked(rng, &p_cdf); res[i] += T::one(); } res } impl MeanN> for Multinomial where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { /// Returns the mean of the multinomial distribution /// /// # Formula /// /// ```text /// n * p_i for i in 1...k /// ``` /// /// where `n` is the number of trials, `p_i` is the `i`th probability, /// and `k` is the total number of probabilities fn mean(&self) -> Option> { Some(self.p.map(|x| x * self.n as f64)) } } impl VarianceN> for Multinomial where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Returns the variance of the multinomial distribution /// /// # Formula /// /// ```text /// n * p_i * (1 - p_i) for i in 1...k /// ``` /// /// where `n` is the number of trials, `p_i` is the `i`th probability, /// and `k` is the total number of probabilities fn variance(&self) -> Option> { let mut cov = OMatrix::from_diagonal(&self.p.map(|x| x * (1.0 - x))); let mut offdiag = |x: usize, y: usize| { let elt = -self.p[x] * self.p[y]; // cov[(x, y)] = elt; cov[(y, x)] = elt; }; for i in 0..self.p.len() { for j in 0..i { offdiag(i, j); } } cov.fill_lower_triangle_with_upper_triangle(); Some(cov.scale(self.n as f64)) } } // impl Skewness> for Multinomial { // /// Returns the skewness of the multinomial distribution // /// // /// # Formula // /// // /// ```text // /// (1 - 2 * p_i) / (n * p_i * (1 - p_i)) for i in 1...k // /// ``` // /// // /// where `n` is the number of trials, `p_i` is the `i`th probability, // /// and `k` is the total number of probabilities // fn skewness(&self) -> Option> { // Some( // self.p // .iter() // .map(|x| (1.0 - 2.0 * x) / (self.n as f64 * (1.0 - x) * x).sqrt()) // .collect(), // ) // } // } impl Discrete<&OVector, f64> for Multinomial where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { /// Calculates the probability mass function for the multinomial /// distribution /// with the given `x`'s corresponding to the probabilities for this /// distribution /// /// # Panics /// /// If length of `x` is not equal to length of `p` /// /// # Formula /// /// ```text /// (n! / x_1!...x_k!) * p_i^x_i for i in 1...k /// ``` /// /// where `n` is the number of trials, `p_i` is the `i`th probability, /// `x_i` is the `i`th `x` value, and `k` is the total number of /// probabilities fn pmf(&self, x: &OVector) -> f64 { if self.p.len() != x.len() { panic!("Expected x and p to have equal lengths."); } if x.iter().sum::() != self.n { return 0.0; } let coeff = factorial::multinomial(self.n, x.as_slice()); let val = coeff * self .p .iter() .zip(x.iter()) .fold(1.0, |acc, (pi, xi)| acc * pi.powf(*xi as f64)); val } /// Calculates the log probability mass function for the multinomial /// distribution /// with the given `x`'s corresponding to the probabilities for this /// distribution /// /// # Panics /// /// If length of `x` is not equal to length of `p` /// /// # Formula /// /// ```text /// ln((n! / x_1!...x_k!) * p_i^x_i) for i in 1...k /// ``` /// /// where `n` is the number of trials, `p_i` is the `i`th probability, /// `x_i` is the `i`th `x` value, and `k` is the total number of /// probabilities fn ln_pmf(&self, x: &OVector) -> f64 { if self.p.len() != x.len() { panic!("Expected x and p to have equal lengths."); } if x.iter().sum::() != self.n { return f64::NEG_INFINITY; } let coeff = factorial::multinomial(self.n, x.as_slice()).ln(); let val = coeff + self .p .iter() .zip(x.iter()) .map(|(pi, xi)| *xi as f64 * pi.ln()) .fold(0.0, |acc, x| acc + x); val } } #[rustfmt::skip] #[cfg(test)] mod tests { use crate::{ distribution::{Discrete, Multinomial, MultinomialError}, statistics::{MeanN, VarianceN}, }; use nalgebra::{dmatrix, dvector, vector, DimMin, Dyn, OVector}; use std::fmt::{Debug, Display}; fn try_create(p: OVector, n: u64) -> Multinomial where D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { let mvn = Multinomial::new_from_nalgebra(p, n); assert!(mvn.is_ok()); mvn.unwrap() } fn bad_create_case(p: OVector, n: u64) -> MultinomialError where D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { let dd = Multinomial::new_from_nalgebra(p, n); assert!(dd.is_err()); dd.unwrap_err() } fn test_almost(p: OVector, n: u64, expected: T, acc: f64, eval: F) where T: Debug + Display + approx::RelativeEq, F: FnOnce(Multinomial) -> T, D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { let dd = try_create(p, n); let x = eval(dd); assert_relative_eq!(expected, x, epsilon = acc); } #[test] fn test_create() { assert_relative_eq!( *try_create(vector![1.0, 1.0, 1.0], 4).p(), vector![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0] ); try_create(dvector![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 4); } #[test] fn test_bad_create() { assert_eq!( bad_create_case(vector![0.5], 4), MultinomialError::NotEnoughProbabilities, ); assert_eq!( bad_create_case(vector![-1.0, 2.0], 4), MultinomialError::ProbabilityInvalid, ); assert_eq!( bad_create_case(vector![0.0, 0.0], 4), MultinomialError::ProbabilitySumZero, ); assert_eq!( bad_create_case(vector![1.0, f64::NAN], 4), MultinomialError::ProbabilityInvalid, ); } #[test] fn test_mean() { let mean = |x: Multinomial<_>| x.mean().unwrap(); test_almost(dvector![0.3, 0.7], 5, dvector![1.5, 3.5], 1e-12, mean); test_almost( dvector![0.1, 0.3, 0.6], 10, dvector![1.0, 3.0, 6.0], 1e-12, mean, ); test_almost( dvector![1.0, 3.0, 6.0], 10, dvector![1.0, 3.0, 6.0], 1e-12, mean, ); test_almost( dvector![0.15, 0.35, 0.3, 0.2], 20, dvector![3.0, 7.0, 6.0, 4.0], 1e-12, mean, ); } #[test] fn test_variance() { let variance = |x: Multinomial<_>| x.variance().unwrap(); test_almost( dvector![0.3, 0.7], 5, dmatrix![1.05, -1.05; -1.05, 1.05], 1e-15, variance, ); test_almost( dvector![0.1, 0.3, 0.6], 10, dmatrix![0.9, -0.3, -0.6; -0.3, 2.1, -1.8; -0.6, -1.8, 2.4; ], 1e-15, variance, ); test_almost( dvector![0.15, 0.35, 0.3, 0.2], 20, dmatrix![2.55, -1.05, -0.90, -0.60; -1.05, 4.55, -2.10, -1.40; -0.90, -2.10, 4.20, -1.20; -0.60, -1.40, -1.20, 3.20; ], 1e-15, variance, ); } // // #[test] // // fn test_skewness() { // // let skewness = |x: Multinomial| x.skewness().unwrap(); // // test_almost(&[0.3, 0.7], 5, &[0.390360029179413, -0.390360029179413], 1e-15, skewness); // // test_almost(&[0.1, 0.3, 0.6], 10, &[0.843274042711568, 0.276026223736942, -0.12909944487358], 1e-15, skewness); // // test_almost(&[0.15, 0.35, 0.3, 0.2], 20, &[0.438357003759605, 0.140642169281549, 0.195180014589707, 0.335410196624968], 1e-15, skewness); // // } #[test] fn test_pmf() { let pmf = |arg: OVector| move |x: Multinomial<_>| x.pmf(&arg); test_almost( dvector![0.3, 0.7], 10, 0.121060821, 1e-15, pmf(dvector![1, 9]), ); test_almost( dvector![0.1, 0.3, 0.6], 10, 0.105815808, 1e-15, pmf(dvector![1, 3, 6]), ); test_almost( dvector![0.15, 0.35, 0.3, 0.2], 10, 0.000145152, 1e-15, pmf(dvector![1, 1, 1, 7]), ); } #[test] fn test_error_is_sync_send() { fn assert_sync_send() {} assert_sync_send::(); } // #[test] // #[should_panic] // fn test_pmf_x_wrong_length() { // let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); // n.pmf(&[1]); // } // #[test] // #[should_panic] // fn test_pmf_x_wrong_sum() { // let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); // n.pmf(&[1, 3]); // } // #[test] // fn test_ln_pmf() { // let large_p = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; // let n = Multinomial::new(large_p, 45).unwrap(); // let x = &[1, 2, 3, 4, 5, 6, 7, 8, 9]; // assert_almost_eq!(n.pmf(x).ln(), n.ln_pmf(x), 1e-13); // let n2 = Multinomial::new(large_p, 18).unwrap(); // let x2 = &[1, 1, 1, 2, 2, 2, 3, 3, 3]; // assert_almost_eq!(n2.pmf(x2).ln(), n2.ln_pmf(x2), 1e-13); // let n3 = Multinomial::new(large_p, 51).unwrap(); // let x3 = &[5, 6, 7, 8, 7, 6, 5, 4, 3]; // assert_almost_eq!(n3.pmf(x3).ln(), n3.ln_pmf(x3), 1e-13); // } // #[test] // #[should_panic] // fn test_ln_pmf_x_wrong_length() { // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); // n.ln_pmf(&[1]); // } // #[test] // #[should_panic] // fn test_ln_pmf_x_wrong_sum() { // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); // n.ln_pmf(&[1, 3]); // } } statrs-0.18.0/src/distribution/multivariate_normal.rs000064400000000000000000000532571046102023000212000ustar 00000000000000use crate::distribution::Continuous; use crate::statistics::{Max, MeanN, Min, Mode, VarianceN}; use nalgebra::{Cholesky, Const, DMatrix, DVector, Dim, DimMin, Dyn, OMatrix, OVector}; use std::f64; use std::f64::consts::{E, PI}; /// Computes both the normalization and exponential argument in the normal /// distribution, returning `None` on dimension mismatch. pub(super) fn density_normalization_and_exponential( mu: &OVector, cov: &OMatrix, precision: &OMatrix, x: &OVector, ) -> Option<(f64, f64)> where D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { Some(( density_distribution_pdf_const(mu, cov)?, density_distribution_exponential(mu, precision, x)?, )) } /// Computes the argument of the exponential term in the normal distribution, /// returning `None` on dimension mismatch. #[inline] fn density_distribution_exponential( mu: &OVector, precision: &OMatrix, x: &OVector, ) -> Option where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { if x.shape_generic().0 != precision.shape_generic().0 || x.shape_generic().0 != mu.shape_generic().0 || !precision.is_square() { return None; } let dv = x - mu; let exp_term: f64 = -0.5 * (precision * &dv).dot(&dv); Some(exp_term) } /// Computes the argument of the normalization term in the normal distribution, /// returning `None` on dimension mismatch. #[inline] fn density_distribution_pdf_const(mu: &OVector, cov: &OMatrix) -> Option where D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { if cov.shape_generic().0 != mu.shape_generic().0 || !cov.is_square() { return None; } let cov_det = cov.determinant(); Some( ((2. * PI).powi(mu.nrows() as i32) * cov_det.abs()) .recip() .sqrt(), ) } /// Implements the [Multivariate Normal](https://en.wikipedia.org/wiki/Multivariate_normal_distribution) /// distribution using the "nalgebra" crate for matrix operations /// /// # Examples /// /// ``` /// use statrs::distribution::{MultivariateNormal, Continuous}; /// use nalgebra::{matrix, vector}; /// use statrs::statistics::{MeanN, VarianceN}; /// /// let mvn = MultivariateNormal::new_from_nalgebra(vector![0., 0.], matrix![1., 0.; 0., 1.]).unwrap(); /// assert_eq!(mvn.mean().unwrap(), vector![0., 0.]); /// assert_eq!(mvn.variance().unwrap(), matrix![1., 0.; 0., 1.]); /// assert_eq!(mvn.pdf(&vector![1., 1.]), 0.05854983152431917); /// ``` #[derive(Clone, PartialEq, Debug)] pub struct MultivariateNormal where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { cov_chol_decomp: OMatrix, mu: OVector, cov: OMatrix, precision: OMatrix, pdf_const: f64, } /// Represents the errors that can occur when creating a [`MultivariateNormal`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum MultivariateNormalError { /// The covariance matrix is asymmetric or contains a NaN. CovInvalid, /// The mean vector contains a NaN. MeanInvalid, /// The amount of rows in the vector of means is not equal to the amount /// of rows in the covariance matrix. DimensionMismatch, /// After all other validation, computing the Cholesky decomposition failed. CholeskyFailed, } impl std::fmt::Display for MultivariateNormalError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { MultivariateNormalError::CovInvalid => { write!(f, "Covariance matrix is asymmetric or contains a NaN") } MultivariateNormalError::MeanInvalid => write!(f, "Mean vector contains a NaN"), MultivariateNormalError::DimensionMismatch => write!( f, "Mean vector and covariance matrix do not have the same number of rows" ), MultivariateNormalError::CholeskyFailed => { write!(f, "Computing the Cholesky decomposition failed") } } } } impl std::error::Error for MultivariateNormalError {} impl MultivariateNormal { /// Constructs a new multivariate normal distribution with a mean of `mean` /// and covariance matrix `cov` /// /// # Errors /// /// Returns an error if the given covariance matrix is not /// symmetric or positive-definite pub fn new(mean: Vec, cov: Vec) -> Result { let mean = DVector::from_vec(mean); let cov = DMatrix::from_vec(mean.len(), mean.len(), cov); MultivariateNormal::new_from_nalgebra(mean, cov) } } impl MultivariateNormal where D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Constructs a new multivariate normal distribution with a mean of `mean` /// and covariance matrix `cov` using `nalgebra` `OVector` and `OMatrix` /// instead of `Vec` /// /// # Errors /// /// Returns an error if the given covariance matrix is not /// symmetric or positive-definite pub fn new_from_nalgebra( mean: OVector, cov: OMatrix, ) -> Result { if mean.iter().any(|f| f.is_nan()) { return Err(MultivariateNormalError::MeanInvalid); } if !cov.is_square() || cov.lower_triangle() != cov.upper_triangle().transpose() || cov.iter().any(|f| f.is_nan()) { return Err(MultivariateNormalError::CovInvalid); } // Compare number of rows if mean.shape_generic().0 != cov.shape_generic().0 { return Err(MultivariateNormalError::DimensionMismatch); } // Store the Cholesky decomposition of the covariance matrix // for sampling match Cholesky::new(cov.clone()) { None => Err(MultivariateNormalError::CholeskyFailed), Some(cholesky_decomp) => { let precision = cholesky_decomp.inverse(); Ok(MultivariateNormal { // .unwrap() because prerequisites are already checked above pdf_const: density_distribution_pdf_const(&mean, &cov).unwrap(), cov_chol_decomp: cholesky_decomp.unpack(), mu: mean, cov, precision, }) } } } /// Returns the entropy of the multivariate normal distribution /// /// # Formula /// /// ```text /// (1 / 2) * ln(det(2 * π * e * Σ)) /// ``` /// /// where `Σ` is the covariance matrix and `det` is the determinant pub fn entropy(&self) -> Option { Some( 0.5 * self .variance() .unwrap() .scale(2. * PI * E) .determinant() .ln(), ) } } impl std::fmt::Display for MultivariateNormal where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "N({}, {})", &self.mu, &self.cov) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution> for MultivariateNormal where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Samples from the multivariate normal distribution /// /// # Formula /// ```text /// L * Z + μ /// ``` /// /// where `L` is the Cholesky decomposition of the covariance matrix, /// `Z` is a vector of normally distributed random variables, and /// `μ` is the mean vector fn sample(&self, rng: &mut R) -> OVector { let d = crate::distribution::Normal::new(0., 1.).unwrap(); let z = OVector::from_distribution_generic(self.mu.shape_generic().0, Const::<1>, &d, rng); (&self.cov_chol_decomp * z) + &self.mu } } impl Min> for MultivariateNormal where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Returns the minimum value in the domain of the /// multivariate normal distribution represented by a real vector fn min(&self) -> OVector { OMatrix::repeat_generic(self.mu.shape_generic().0, Const::<1>, f64::NEG_INFINITY) } } impl Max> for MultivariateNormal where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Returns the maximum value in the domain of the /// multivariate normal distribution represented by a real vector fn max(&self) -> OVector { OMatrix::repeat_generic(self.mu.shape_generic().0, Const::<1>, f64::INFINITY) } } impl MeanN> for MultivariateNormal where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Returns the mean of the normal distribution /// /// # Remarks /// /// This is the same mean used to construct the distribution fn mean(&self) -> Option> { Some(self.mu.clone()) } } impl VarianceN> for MultivariateNormal where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Returns the covariance matrix of the multivariate normal distribution fn variance(&self) -> Option> { Some(self.cov.clone()) } } impl Mode> for MultivariateNormal where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Returns the mode of the multivariate normal distribution /// /// # Formula /// /// ```text /// μ /// ``` /// /// where `μ` is the mean fn mode(&self) -> OVector { self.mu.clone() } } impl Continuous<&OVector, f64> for MultivariateNormal where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Calculates the probability density function for the multivariate /// normal distribution at `x` /// /// # Formula /// /// ```text /// (2 * π) ^ (-k / 2) * det(Σ) ^ (-1 / 2) * e ^ ( -(1 / 2) * transpose(x - μ) * inv(Σ) * (x - μ)) /// ``` /// /// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant /// of the covariance matrix, and `k` is the dimension of the distribution fn pdf(&self, x: &OVector) -> f64 { self.pdf_const * density_distribution_exponential(&self.mu, &self.precision, x) .unwrap() .exp() } /// Calculates the log probability density function for the multivariate /// normal distribution at `x`. Equivalent to pdf(x).ln(). fn ln_pdf(&self, x: &OVector) -> f64 { self.pdf_const.ln() + density_distribution_exponential(&self.mu, &self.precision, x).unwrap() } } #[rustfmt::skip] #[cfg(test)] mod tests { use core::fmt::Debug; use nalgebra::{dmatrix, dvector, matrix, vector, DimMin, OMatrix, OVector}; use crate::{ distribution::{Continuous, MultivariateNormal}, statistics::{Max, MeanN, Min, Mode, VarianceN}, }; use super::MultivariateNormalError; fn try_create(mean: OVector, covariance: OMatrix) -> MultivariateNormal where D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { let mvn = MultivariateNormal::new_from_nalgebra(mean, covariance); assert!(mvn.is_ok()); mvn.unwrap() } fn create_case(mean: OVector, covariance: OMatrix) where D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { let mvn = try_create(mean.clone(), covariance.clone()); assert_eq!(mean, mvn.mean().unwrap()); assert_eq!(covariance, mvn.variance().unwrap()); } fn bad_create_case(mean: OVector, covariance: OMatrix) where D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { let mvn = MultivariateNormal::new_from_nalgebra(mean, covariance); assert!(mvn.is_err()); } fn test_case( mean: OVector, covariance: OMatrix, expected: T, eval: F, ) where T: Debug + PartialEq, F: FnOnce(MultivariateNormal) -> T, D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { let mvn = try_create(mean, covariance); let x = eval(mvn); assert_eq!(expected, x); } fn test_almost( mean: OVector, covariance: OMatrix, expected: f64, acc: f64, eval: F, ) where F: FnOnce(MultivariateNormal) -> f64, D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { let mvn = try_create(mean, covariance); let x = eval(mvn); assert_almost_eq!(expected, x, acc); } #[test] fn test_create() { create_case(vector![0., 0.], matrix![1., 0.; 0., 1.]); create_case(vector![10., 5.], matrix![2., 1.; 1., 2.]); create_case( vector![4., 5., 6.], matrix![2., 1., 0.; 1., 2., 1.; 0., 1., 2.], ); create_case(dvector![0., f64::INFINITY], dmatrix![1., 0.; 0., 1.]); create_case( dvector![0., 0.], dmatrix![f64::INFINITY, 0.; 0., f64::INFINITY], ); } #[test] fn test_bad_create() { // Covariance not symmetric bad_create_case(vector![0., 0.], matrix![1., 1.; 0., 1.]); // Covariance not positive-definite bad_create_case(vector![0., 0.], matrix![1., 2.; 2., 1.]); // NaN in mean bad_create_case(dvector![0., f64::NAN], dmatrix![1., 0.; 0., 1.]); // NaN in Covariance Matrix bad_create_case(dvector![0., 0.], dmatrix![1., 0.; 0., f64::NAN]); } #[test] fn test_variance() { let variance = |x: MultivariateNormal<_>| x.variance().unwrap(); test_case( vector![0., 0.], matrix![1., 0.; 0., 1.], matrix![1., 0.; 0., 1.], variance, ); test_case( vector![0., 0.], matrix![f64::INFINITY, 0.; 0., f64::INFINITY], matrix![f64::INFINITY, 0.; 0., f64::INFINITY], variance, ); } #[test] fn test_entropy() { let entropy = |x: MultivariateNormal<_>| x.entropy().unwrap(); test_case( dvector![0., 0.], dmatrix![1., 0.; 0., 1.], 2.8378770664093453, entropy, ); test_case( dvector![0., 0.], dmatrix![1., 0.5; 0.5, 1.], 2.694036030183455, entropy, ); test_case( dvector![0., 0.], dmatrix![f64::INFINITY, 0.; 0., f64::INFINITY], f64::INFINITY, entropy, ); } #[test] fn test_mode() { let mode = |x: MultivariateNormal<_>| x.mode(); test_case( vector![0., 0.], matrix![1., 0.; 0., 1.], vector![0., 0.], mode, ); test_case( vector![f64::INFINITY, f64::INFINITY], matrix![1., 0.; 0., 1.], vector![f64::INFINITY, f64::INFINITY], mode, ); } #[test] fn test_min_max() { let min = |x: MultivariateNormal<_>| x.min(); let max = |x: MultivariateNormal<_>| x.max(); test_case( dvector![0., 0.], dmatrix![1., 0.; 0., 1.], dvector![f64::NEG_INFINITY, f64::NEG_INFINITY], min, ); test_case( dvector![0., 0.], dmatrix![1., 0.; 0., 1.], dvector![f64::INFINITY, f64::INFINITY], max, ); test_case( dvector![10., 1.], dmatrix![1., 0.; 0., 1.], dvector![f64::NEG_INFINITY, f64::NEG_INFINITY], min, ); test_case( dvector![-3., 5.], dmatrix![1., 0.; 0., 1.], dvector![f64::INFINITY, f64::INFINITY], max, ); } #[test] fn test_pdf() { let pdf = |arg| move |x: MultivariateNormal<_>| x.pdf(&arg); test_case( vector![0., 0.], matrix![1., 0.; 0., 1.], 0.05854983152431917, pdf(vector![1., 1.]), ); test_almost( vector![0., 0.], matrix![1., 0.; 0., 1.], 0.013064233284684921, 1e-15, pdf(vector![1., 2.]), ); test_almost( vector![0., 0.], matrix![1., 0.; 0., 1.], 1.8618676045881531e-23, 1e-35, pdf(vector![1., 10.]), ); test_almost( vector![0., 0.], matrix![1., 0.; 0., 1.], 5.920684802611216e-45, 1e-58, pdf(vector![10., 10.]), ); test_almost( vector![0., 0.], matrix![1., 0.9; 0.9, 1.], 1.6576716577547003e-05, 1e-18, pdf(vector![1., -1.]), ); test_almost( vector![0., 0.], matrix![1., 0.99; 0.99, 1.], 4.1970621773477824e-44, 1e-54, pdf(vector![1., -1.]), ); test_almost( vector![0.5, -0.2], matrix![2.0, 0.3; 0.3, 0.5], 0.0013075203140666656, 1e-15, pdf(vector![2., 2.]), ); test_case( vector![0., 0.], matrix![f64::INFINITY, 0.; 0., f64::INFINITY], 0.0, pdf(vector![10., 10.]), ); test_case( vector![0., 0.], matrix![f64::INFINITY, 0.; 0., f64::INFINITY], 0.0, pdf(vector![100., 100.]), ); } #[test] fn test_ln_pdf() { let ln_pdf = |arg| move |x: MultivariateNormal<_>| x.ln_pdf(&arg); test_case( dvector![0., 0.], dmatrix![1., 0.; 0., 1.], (0.05854983152431917f64).ln(), ln_pdf(dvector![1., 1.]), ); test_almost( dvector![0., 0.], dmatrix![1., 0.; 0., 1.], (0.013064233284684921f64).ln(), 1e-15, ln_pdf(dvector![1., 2.]), ); test_almost( dvector![0., 0.], dmatrix![1., 0.; 0., 1.], (1.8618676045881531e-23f64).ln(), 1e-15, ln_pdf(dvector![1., 10.]), ); test_almost( dvector![0., 0.], dmatrix![1., 0.; 0., 1.], (5.920684802611216e-45f64).ln(), 1e-15, ln_pdf(dvector![10., 10.]), ); test_almost( dvector![0., 0.], dmatrix![1., 0.9; 0.9, 1.], (1.6576716577547003e-05f64).ln(), 1e-14, ln_pdf(dvector![1., -1.]), ); test_almost( dvector![0., 0.], dmatrix![1., 0.99; 0.99, 1.], (4.1970621773477824e-44f64).ln(), 1e-12, ln_pdf(dvector![1., -1.]), ); test_almost( dvector![0.5, -0.2], dmatrix![2.0, 0.3; 0.3, 0.5], (0.0013075203140666656f64).ln(), 1e-15, ln_pdf(dvector![2., 2.]), ); test_case( dvector![0., 0.], dmatrix![f64::INFINITY, 0.; 0., f64::INFINITY], f64::NEG_INFINITY, ln_pdf(dvector![10., 10.]), ); test_case( dvector![0., 0.], dmatrix![f64::INFINITY, 0.; 0., f64::INFINITY], f64::NEG_INFINITY, ln_pdf(dvector![100., 100.]), ); } #[test] #[should_panic] fn test_pdf_mismatched_arg_size() { let mvn = MultivariateNormal::new(vec![0., 0.], vec![1., 0., 0., 1.,]).unwrap(); mvn.pdf(&vec![1.].into()); // x.size != mu.size } #[test] fn test_error_is_sync_send() { fn assert_sync_send() {} assert_sync_send::(); } } statrs-0.18.0/src/distribution/multivariate_students_t.rs000064400000000000000000000570221046102023000220760ustar 00000000000000use crate::distribution::Continuous; use crate::function::gamma; use crate::statistics::{Max, MeanN, Min, Mode, VarianceN}; use nalgebra::{Cholesky, Const, DMatrix, Dim, DimMin, Dyn, OMatrix, OVector}; use std::f64::consts::PI; /// Implements the [Multivariate Student's t-distribution](https://en.wikipedia.org/wiki/Multivariate_t-distribution) /// distribution using the "nalgebra" crate for matrix operations. /// /// Assumes all the marginal distributions have the same degree of freedom, ν. /// /// # Examples /// /// ``` /// use statrs::distribution::{MultivariateStudent, Continuous}; /// use nalgebra::{DVector, DMatrix}; /// use statrs::statistics::{MeanN, VarianceN}; /// /// let mvs = MultivariateStudent::new(vec![0., 0.], vec![1., 0., 0., 1.], 4.).unwrap(); /// assert_eq!(mvs.mean().unwrap(), DVector::from_vec(vec![0., 0.])); /// assert_eq!(mvs.variance().unwrap(), DMatrix::from_vec(2, 2, vec![2., 0., 0., 2.])); /// assert_eq!(mvs.pdf(&DVector::from_vec(vec![1., 1.])), 0.04715702017537655); /// ``` #[derive(Debug, Clone, PartialEq)] pub struct MultivariateStudent where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { scale_chol_decomp: OMatrix, location: OVector, scale: OMatrix, freedom: f64, precision: OMatrix, ln_pdf_const: f64, } /// Represents the errors that can occur when creating a [`MultivariateStudent`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum MultivariateStudentError { /// The scale matrix is asymmetric or contains a NaN. ScaleInvalid, /// The location vector contains a NaN. LocationInvalid, /// The degrees of freedom are NaN, zero or less than zero. FreedomInvalid, /// The amount of rows in the location vector is not equal to the amount /// of rows in the scale matrix. DimensionMismatch, /// After all other validation, computing the Cholesky decomposition failed. /// This means that the scale matrix is not definite-positive. CholeskyFailed, } impl std::fmt::Display for MultivariateStudentError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { MultivariateStudentError::ScaleInvalid => { write!(f, "Scale matrix is asymmetric or contains a NaN") } MultivariateStudentError::LocationInvalid => { write!(f, "Location vector contains a NaN") } MultivariateStudentError::FreedomInvalid => { write!(f, "Degrees of freedom are NaN, zero or less than zero") } MultivariateStudentError::DimensionMismatch => write!( f, "Location vector and scale matrix do not have the same number of rows" ), MultivariateStudentError::CholeskyFailed => { write!(f, "Computing the Cholesky decomposition failed") } } } } impl std::error::Error for MultivariateStudentError {} impl MultivariateStudent { /// Constructs a new multivariate students t distribution with a location of `location`, /// scale matrix `scale` and `freedom` degrees of freedom. /// /// # Errors /// /// Returns `StatsError::BadParams` if the scale matrix is not symmetric-positive /// definite and `StatsError::ArgMustBePositive` if freedom is non-positive. pub fn new( location: Vec, scale: Vec, freedom: f64, ) -> Result { let dim = location.len(); Self::new_from_nalgebra(location.into(), DMatrix::from_vec(dim, dim, scale), freedom) } /// Returns the dimension of the distribution. pub fn dim(&self) -> usize { self.location.len() } } impl MultivariateStudent where D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { pub fn new_from_nalgebra( location: OVector, scale: OMatrix, freedom: f64, ) -> Result { let dim = location.len(); if location.iter().any(|f| f.is_nan()) { return Err(MultivariateStudentError::LocationInvalid); } if !scale.is_square() || scale.lower_triangle() != scale.upper_triangle().transpose() || scale.iter().any(|f| f.is_nan()) { return Err(MultivariateStudentError::ScaleInvalid); } if freedom.is_nan() || freedom <= 0.0 { return Err(MultivariateStudentError::FreedomInvalid); } if location.nrows() != scale.nrows() { return Err(MultivariateStudentError::DimensionMismatch); } let scale_det = scale.determinant(); let ln_pdf_const = gamma::ln_gamma(0.5 * (freedom + dim as f64)) - gamma::ln_gamma(0.5 * freedom) - 0.5 * (dim as f64) * (freedom * PI).ln() - 0.5 * scale_det.ln(); match Cholesky::new(scale.clone()) { None => Err(MultivariateStudentError::CholeskyFailed), Some(cholesky_decomp) => { let precision = cholesky_decomp.inverse(); Ok(MultivariateStudent { scale_chol_decomp: cholesky_decomp.unpack(), location, scale, freedom, precision, ln_pdf_const, }) } } } /// Returns the cholesky decomposiiton matrix of the scale matrix. /// /// Returns A where Σ = AAᵀ. pub fn scale_chol_decomp(&self) -> &OMatrix { &self.scale_chol_decomp } /// Returns the location of the distribution. pub fn location(&self) -> &OVector { &self.location } /// Returns the scale matrix of the distribution. pub fn scale(&self) -> &OMatrix { &self.scale } /// Returns the degrees of freedom of the distribution. pub fn freedom(&self) -> f64 { self.freedom } /// Returns the inverse of the cholesky decomposition matrix. pub fn precision(&self) -> &OMatrix { &self.precision } /// Returns the logarithmed constant part of the probability /// distribution function. pub fn ln_pdf_const(&self) -> f64 { self.ln_pdf_const } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution> for MultivariateStudent where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Samples from the multivariate student distribution /// /// # Formula /// /// ```math /// W ⋅ L ⋅ Z + μ /// ``` /// /// where `W` has √(ν/Sν) distribution, Sν has Chi-squared /// distribution with ν degrees of freedom, /// `L` is the Cholesky decomposition of the scale matrix, /// `Z` is a vector of normally distributed random variables, and /// `μ` is the location vector fn sample(&self, rng: &mut R) -> OVector { use crate::distribution::{ChiSquared, Normal}; let d = Normal::new(0., 1.).unwrap(); let s = ChiSquared::new(self.freedom).unwrap(); let w = (self.freedom / s.sample(rng)).sqrt(); let (r, c) = self.location.shape_generic(); let z = OVector::::from_distribution_generic(r, c, &d, rng); (w * &self.scale_chol_decomp * z) + &self.location } } impl Min> for MultivariateStudent where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Returns the minimum value in the domain of the /// multivariate normal distribution represented by a real vector fn min(&self) -> OVector { OMatrix::repeat_generic( self.location.shape_generic().0, Const::<1>, f64::NEG_INFINITY, ) } } impl Max> for MultivariateStudent where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Returns the minimum value in the domain of the /// multivariate normal distribution represented by a real vector fn max(&self) -> OVector { OMatrix::repeat_generic(self.location.shape_generic().0, Const::<1>, f64::INFINITY) } } impl MeanN> for MultivariateStudent where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Returns the mean of the student distribution. /// /// # Remarks /// /// This is the same mean used to construct the distribution if /// the degrees of freedom is larger than 1. fn mean(&self) -> Option> { if self.freedom > 1. { Some(self.location.clone()) } else { None } } } impl VarianceN> for MultivariateStudent where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Returns the covariance matrix of the multivariate student distribution. /// /// # Formula /// /// ```math /// Σ ⋅ ν / (ν - 2) /// ``` /// /// where `Σ` is the scale matrix and `ν` is the degrees of freedom. /// Only defined if freedom is larger than 2. fn variance(&self) -> Option> { if self.freedom > 2. { Some(self.scale.clone() * self.freedom / (self.freedom - 2.)) } else { None } } } impl Mode> for MultivariateStudent where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Returns the mode of the multivariate student distribution. /// /// # Formula /// /// ```math /// μ /// ``` /// /// where `μ` is the location. fn mode(&self) -> OVector { self.location.clone() } } impl Continuous<&OVector, f64> for MultivariateStudent where D: Dim + DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, { /// Calculates the probability density function for the multivariate. /// student distribution at `x`. /// /// # Formula /// /// ```math /// [Γ(ν+p)/2] / [Γ(ν/2) ((ν * π)^p det(Σ))^(1 / 2)] * [1 + 1/ν (x - μ)ᵀ inv(Σ) (x - μ)]^(-(ν+p)/2) /// ``` /// /// where /// - `ν` is the degrees of freedom, /// - `μ` is the mean, /// - `Γ` is the Gamma function, /// - `inv(Σ)` is the precision matrix, /// - `det(Σ)` is the determinant of the scale matrix, and /// - `k` is the dimension of the distribution. fn pdf(&self, x: &OVector) -> f64 { if self.freedom.is_infinite() { use super::multivariate_normal::density_normalization_and_exponential; let (pdf_const, exp_arg) = density_normalization_and_exponential( &self.location, &self.scale, &self.precision, x, ) .unwrap(); return pdf_const * exp_arg.exp(); } let dv = x - &self.location; let exp_arg: f64 = (&self.precision * &dv).dot(&dv); let base_term = 1. + exp_arg / self.freedom; self.ln_pdf_const.exp() * base_term.powf(-(self.freedom + self.location.len() as f64) / 2.) } /// Calculates the log probability density function for the multivariate /// student distribution at `x`. Equivalent to pdf(x).ln(). fn ln_pdf(&self, x: &OVector) -> f64 { if self.freedom.is_infinite() { use super::multivariate_normal::density_normalization_and_exponential; let (pdf_const, exp_arg) = density_normalization_and_exponential( &self.location, &self.scale, &self.precision, x, ) .unwrap(); return pdf_const.ln() + exp_arg; } let dv = x - &self.location; let exp_arg: f64 = (&self.precision * &dv).dot(&dv); let base_term = 1. + exp_arg / self.freedom; self.ln_pdf_const - (self.freedom + self.location.len() as f64) / 2. * base_term.ln() } } #[rustfmt::skip] #[cfg(test)] mod tests { use core::fmt::Debug; use approx::RelativeEq; use nalgebra::{DMatrix, DVector, Dyn, OMatrix, OVector, U1, U2}; use crate::{ distribution::{Continuous, MultivariateStudent, MultivariateNormal}, statistics::{Max, MeanN, Min, Mode, VarianceN}, }; use super::MultivariateStudentError; fn try_create(location: Vec, scale: Vec, freedom: f64) -> MultivariateStudent { let mvs = MultivariateStudent::new(location, scale, freedom); assert!(mvs.is_ok()); mvs.unwrap() } fn create_case(location: Vec, scale: Vec, freedom: f64) { let mvs = try_create(location.clone(), scale.clone(), freedom); assert_eq!(DMatrix::from_vec(location.len(), location.len(), scale), mvs.scale); assert_eq!(DVector::from_vec(location), mvs.location); } fn bad_create_case(location: Vec, scale: Vec, freedom: f64) { let mvs = MultivariateStudent::new(location, scale, freedom); assert!(mvs.is_err()); } fn test_case(location: Vec, scale: Vec, freedom: f64, expected: T, eval: F) where T: Debug + PartialEq, F: FnOnce(MultivariateStudent) -> T, { let mvs = try_create(location, scale, freedom); let x = eval(mvs); assert_eq!(expected, x); } fn test_almost( location: Vec, scale: Vec, freedom: f64, expected: f64, acc: f64, eval: F, ) where F: FnOnce(MultivariateStudent) -> f64, { let mvs = try_create(location, scale, freedom); let x = eval(mvs); assert_almost_eq!(expected, x, acc); } fn test_almost_multivariate_normal( location: Vec, scale: Vec, freedom: f64, acc: f64, x: DVector, eval_mvs: F1, eval_mvn: F2, ) where F1: FnOnce(MultivariateStudent, DVector) -> f64, F2: FnOnce(MultivariateNormal, DVector) -> f64, { let mvs = try_create(location.clone(), scale.clone(), freedom); let mvn0 = MultivariateNormal::new(location, scale); assert!(mvn0.is_ok()); let mvn = mvn0.unwrap(); let mvs_x = eval_mvs(mvs, x.clone()); let mvn_x = eval_mvn(mvn, x.clone()); assert!(mvs_x.relative_eq(&mvn_x, acc, acc), "mvn: {mvn_x} =/=\nmvs: {mvs_x}"); // assert_relative_eq!(mvs_x, mvn_x, acc); } macro_rules! dvec { ($($x:expr),*) => (DVector::from_vec(vec![$($x),*])); } macro_rules! mat2 { ($x11:expr, $x12:expr, $x21:expr, $x22:expr) => (DMatrix::from_vec(2,2,vec![$x11, $x12, $x21, $x22])); } // macro_rules! mat3 { // ($x11:expr, $x12:expr, $x13:expr, $x21:expr, $x22:expr, $x23:expr, $x31:expr, $x32:expr, $x33:expr) => (DMatrix::from_vec(3,3,vec![$x11, $x12, $x13, $x21, $x22, $x23, $x31, $x32, $x33])); // } #[test] fn test_create() { create_case(vec![0., 0.], vec![1., 0., 0., 1.], 1.); create_case(vec![10., 5.], vec![2., 1., 1., 2.], 3.); create_case(vec![4., 5., 6.], vec![2., 1., 0., 1., 2., 1., 0., 1., 2.], 14.); create_case(vec![0., f64::INFINITY], vec![1., 0., 0., 1.], f64::INFINITY); create_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 0.1); } #[test] fn test_bad_create() { // scale not symmetric. bad_create_case(vec![0., 0.], vec![1., 1., 0., 1.], 1.); // scale not positive-definite. bad_create_case(vec![0., 0.], vec![1., 2., 2., 1.], 1.); // NaN in location. bad_create_case(vec![0., f64::NAN], vec![1., 0., 0., 1.], 1.); // NaN in scale Matrix. bad_create_case(vec![0., 0.], vec![1., 0., 0., f64::NAN], 1.); // NaN in freedom. bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], f64::NAN); // Non-positive freedom. bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], 0.); bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], -1.); } #[test] fn test_variance() { let variance = |x: MultivariateStudent| x.variance().unwrap(); test_case(vec![0., 0.], vec![1., 0., 0., 1.], 3., 3. * mat2![1., 0., 0., 1.], variance); test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 3., mat2![f64::INFINITY, 0., 0., f64::INFINITY], variance); } // Variance is only defined for freedom > 2. #[test] fn test_bad_variance() { let variance = |x: MultivariateStudent| x.variance(); test_case(vec![0., 0.], vec![1., 0., 0., 1.], 2., None, variance); } #[test] fn test_mode() { let mode = |x: MultivariateStudent| x.mode(); test_case(vec![0., 0.], vec![1., 0., 0., 1.], 1., dvec![0., 0.], mode); test_case(vec![f64::INFINITY, f64::INFINITY], vec![1., 0., 0., 1.], 1., dvec![f64::INFINITY, f64::INFINITY], mode); } #[test] fn test_mean() { let mean = |x: MultivariateStudent| x.mean().unwrap(); test_case(vec![0., 0.], vec![1., 0., 0., 1.], 2., dvec![0., 0.], mean); test_case(vec![-1., 1., 3.], vec![1., 0., 0.5, 0., 2.0, 0., 0.5, 0., 3.0], 2., dvec![-1., 1., 3.], mean); } // Mean is only defined if freedom > 1. #[test] fn test_bad_mean() { let mean = |x: MultivariateStudent| x.mean(); test_case(vec![0., 0.], vec![1., 0., 0., 1.], 1., None, mean); } #[test] fn test_min_max() { let min = |x: MultivariateStudent| x.min(); let max = |x: MultivariateStudent| x.max(); test_case(vec![0., 0.], vec![1., 0., 0., 1.], 1., dvec![f64::NEG_INFINITY, f64::NEG_INFINITY], min); test_case(vec![0., 0.], vec![1., 0., 0., 1.], 1., dvec![f64::INFINITY, f64::INFINITY], max); test_case(vec![10., 1.], vec![1., 0., 0., 1.], 1., dvec![f64::NEG_INFINITY, f64::NEG_INFINITY], min); test_case(vec![-3., 5.], vec![1., 0., 0., 1.], 1., dvec![f64::INFINITY, f64::INFINITY], max); } #[test] fn test_pdf() { let pdf = |arg: DVector| move |x: MultivariateStudent| x.pdf(&arg); test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., 0.047157020175376416, 1e-15, pdf(dvec![1., 1.])); test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., 0.013972450422333741737457302178882, 1e-15, pdf(dvec![1., 2.])); test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 2., 0.012992240252399619, 1e-17, pdf(dvec![1., 2.])); test_almost(vec![2., 1.], vec![5., 0., 0., 1.], 2.5, 2.639780816598878e-5, 1e-19, pdf(dvec![1., 10.])); test_almost(vec![-1., 0.], vec![2., 1., 1., 6.], 1.5, 6.438051574348526e-5, 1e-19, pdf(dvec![10., 10.])); // These three are crossed checked against both python's scipy.multivariate_t.pdf and octave's mvtpdf. test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 6.960998836915657e-16, 1e-30, pdf(dvec![0.9718, 0.1298, 0.8134])); test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 7.369987979187023e-16, 1e-30, pdf(dvec![0.4922, 0.5522, 0.7185])); test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8.,6.951631724511314e-16, 1e-30, pdf(dvec![0.3020, 0.1491, 0.5008])); test_case(vec![-1., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 10., 0., pdf(dvec![10., 10.])); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: DVector| move |x: MultivariateStudent| x.ln_pdf(&arg); test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., -3.0542723907338383, 1e-14, ln_pdf(dvec![1., 1.])); test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 2., -4.3434030034000815, 1e-14, ln_pdf(dvec![1., 2.])); test_almost(vec![2., 1.], vec![5., 0., 0., 1.], 2.5, -10.542229575274265, 1e-14, ln_pdf(dvec![1., 10.])); test_almost(vec![-1., 0.], vec![2., 1., 1., 6.], 1.5, -9.650699521198622, 1e-14, ln_pdf(dvec![10., 10.])); // test_case(vec![-1., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 10., f64::NEG_INFINITY, ln_pdf(dvec![10., 10.])); } #[test] fn test_pdf_freedom_large() { let pdf_mvs = |mv: MultivariateStudent, arg: DVector| mv.pdf(&arg); let pdf_mvn = |mv: MultivariateNormal, arg: DVector| mv.pdf(&arg); test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-6, dvec![1., 1.], pdf_mvs, pdf_mvn); test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 1e-7, dvec![1., 1.], pdf_mvs, pdf_mvn); test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn); test_almost_multivariate_normal(vec![5., -1.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![5., 1.], pdf_mvs, pdf_mvn); } #[test] fn test_ln_pdf_freedom_large() { let pdf_mvs = |mv: MultivariateStudent, arg: DVector| mv.ln_pdf(&arg); let pdf_mvn = |mv: MultivariateNormal, arg: DVector| mv.ln_pdf(&arg); test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-5, dvec![1., 1.], pdf_mvs, pdf_mvn); test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 5e-6, dvec![1., 1.], pdf_mvs, pdf_mvn); test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn); test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn); } #[test] fn test_immut_field_access() { // init as Dyn let mvs = MultivariateStudent::new(vec![1., 1.], vec![1., 0., 0., 1.], 2.) .expect("hard coded valid construction"); assert_eq!(mvs.freedom(), 2.); assert_relative_eq!(mvs.ln_pdf_const(), std::f64::consts::TAU.recip().ln(), epsilon = 1e-15); // compare to static assert_eq!(mvs.dim(), 2); assert!(mvs.location().eq(&OVector::::new(1., 1.))); assert!(mvs.scale().eq(&OMatrix::::identity())); assert!(mvs.precision().eq(&OMatrix::::identity())); assert!(mvs.scale_chol_decomp().eq(&OMatrix::::identity())); // compare to Dyn assert_eq!(mvs.location(),&OVector::::from_element_generic(Dyn(2), U1, 1.)); assert_eq!(mvs.scale(), &OMatrix::::identity(2, 2)); assert_eq!(mvs.precision(), &OMatrix::::identity(2, 2)); assert_eq!(mvs.scale_chol_decomp(), &OMatrix::::identity(2, 2)); } #[test] fn test_error_is_sync_send() { fn assert_sync_send() {} assert_sync_send::(); } } statrs-0.18.0/src/distribution/negative_binomial.rs000064400000000000000000000364711046102023000205750ustar 00000000000000use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::{beta, gamma}; use crate::statistics::*; use std::f64; /// Implements the /// [negative binomial](http://en.wikipedia.org/wiki/Negative_binomial_distribution) /// distribution. /// /// *Please note carefully the meaning of the parameters.* As noted in the /// wikipedia article, there are several different commonly used conventions /// for the parameters of the negative binomial distribution. /// /// The negative binomial distribution is a discrete distribution with two /// parameters, `r` and `p`. When `r` is an integer, the negative binomial /// distribution can be interpreted as the distribution of the number of /// failures in a sequence of Bernoulli trials that continue until `r` /// successes occur. `p` is the probability of success in a single Bernoulli /// trial. /// /// `NegativeBinomial` accepts non-integer values for `r`. This is a /// generalization of the more common case where `r` is an integer. /// /// # Examples /// /// ``` /// use statrs::distribution::{NegativeBinomial, Discrete}; /// use statrs::statistics::DiscreteDistribution; /// use statrs::prec::almost_eq; /// /// let r = NegativeBinomial::new(4.0, 0.5).unwrap(); /// assert_eq!(r.mean().unwrap(), 4.0); /// assert!(almost_eq(r.pmf(0), 0.0625, 1e-8)); /// assert!(almost_eq(r.pmf(3), 0.15625, 1e-8)); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct NegativeBinomial { r: f64, p: f64, } /// Represents the errors that can occur when creating a [`NegativeBinomial`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum NegativeBinomialError { /// `r` is NaN or less than zero. RInvalid, /// `p` is NaN or not in `[0, 1]`. PInvalid, } impl std::fmt::Display for NegativeBinomialError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { NegativeBinomialError::RInvalid => write!(f, "r is NaN or less than zero"), NegativeBinomialError::PInvalid => write!(f, "p is NaN or not in [0, 1]"), } } } impl std::error::Error for NegativeBinomialError {} impl NegativeBinomial { /// Constructs a new negative binomial distribution with parameters `r` /// and `p`. When `r` is an integer, the negative binomial distribution /// can be interpreted as the distribution of the number of failures in /// a sequence of Bernoulli trials that continue until `r` successes occur. /// `p` is the probability of success in a single Bernoulli trial. /// /// # Errors /// /// Returns an error if `p` is `NaN`, less than `0.0`, /// greater than `1.0`, or if `r` is `NaN` or less than `0` /// /// # Examples /// /// ``` /// use statrs::distribution::NegativeBinomial; /// /// let mut result = NegativeBinomial::new(4.0, 0.5); /// assert!(result.is_ok()); /// /// result = NegativeBinomial::new(-0.5, 5.0); /// assert!(result.is_err()); /// ``` pub fn new(r: f64, p: f64) -> Result { if r.is_nan() || r < 0.0 { return Err(NegativeBinomialError::RInvalid); } if p.is_nan() || !(0.0..=1.0).contains(&p) { return Err(NegativeBinomialError::PInvalid); } Ok(NegativeBinomial { r, p }) } /// Returns the probability of success `p` of a single /// Bernoulli trial associated with the negative binomial /// distribution. /// /// # Examples /// /// ``` /// use statrs::distribution::NegativeBinomial; /// /// let r = NegativeBinomial::new(5.0, 0.5).unwrap(); /// assert_eq!(r.p(), 0.5); /// ``` pub fn p(&self) -> f64 { self.p } /// Returns the number `r` of success of this negative /// binomial distribution. /// /// # Examples /// /// ``` /// use statrs::distribution::NegativeBinomial; /// /// let r = NegativeBinomial::new(5.0, 0.5).unwrap(); /// assert_eq!(r.r(), 5.0); /// ``` pub fn r(&self) -> f64 { self.r } } impl std::fmt::Display for NegativeBinomial { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "NB({},{})", self.r, self.p) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for NegativeBinomial { fn sample(&self, r: &mut R) -> u64 { use crate::distribution::{gamma, poisson}; let lambda = gamma::sample_unchecked(r, self.r, (1.0 - self.p) / self.p); poisson::sample_unchecked(r, lambda).floor() as u64 } } impl DiscreteCDF for NegativeBinomial { /// Calculates the cumulative distribution function for the /// negative binomial distribution at `x`. /// /// # Formula /// /// ```text /// I_(p)(r, x+1) /// ``` /// /// where `I_(x)(a, b)` is the regularized incomplete beta function. fn cdf(&self, x: u64) -> f64 { beta::beta_reg(self.r, x as f64 + 1.0, self.p) } /// Calculates the survival function for the /// negative binomial distribution at `x` /// /// Note that due to extending the distribution to the reals /// (allowing positive real values for `r`), while still technically /// a discrete distribution the CDF behaves more like that of a /// continuous distribution rather than a discrete distribution /// (i.e. a smooth graph rather than a step-ladder) /// /// # Formula /// /// ```text /// I_(1-p)(x+1, r) /// ``` /// /// where `I_(x)(a, b)` is the regularized incomplete beta function fn sf(&self, x: u64) -> f64 { beta::beta_reg(x as f64 + 1.0, self.r, 1. - self.p) } } impl Min for NegativeBinomial { /// Returns the minimum value in the domain of the /// negative binomial distribution representable by a 64-bit /// integer. /// /// # Formula /// /// ```text /// 0 /// ``` fn min(&self) -> u64 { 0 } } impl Max for NegativeBinomial { /// Returns the maximum value in the domain of the /// negative binomial distribution representable by a 64-bit /// integer. /// /// # Formula /// /// ```text /// u64::MAX /// ``` fn max(&self) -> u64 { u64::MAX } } impl DiscreteDistribution for NegativeBinomial { /// Returns the mean of the negative binomial distribution. /// /// # Formula /// /// ```text /// r * (1-p) / p /// ``` fn mean(&self) -> Option { Some(self.r * (1.0 - self.p) / self.p) } /// Returns the variance of the negative binomial distribution. /// /// # Formula /// /// ```text /// r * (1-p) / p^2 /// ``` fn variance(&self) -> Option { Some(self.r * (1.0 - self.p) / (self.p * self.p)) } /// Returns the skewness of the negative binomial distribution. /// /// # Formula /// /// ```text /// (2-p) / sqrt(r * (1-p)) /// ``` fn skewness(&self) -> Option { Some((2.0 - self.p) / f64::sqrt(self.r * (1.0 - self.p))) } } impl Mode> for NegativeBinomial { /// Returns the mode for the negative binomial distribution. /// /// # Formula /// /// ```text /// if r > 1 then /// floor((r - 1) * (1-p / p)) /// else /// 0 /// ``` fn mode(&self) -> Option { let mode = if self.r > 1.0 { f64::floor((self.r - 1.0) * (1.0 - self.p) / self.p) } else { 0.0 }; Some(mode) } } impl Discrete for NegativeBinomial { /// Calculates the probability mass function for the negative binomial /// distribution at `x`. /// /// # Formula /// /// When `r` is an integer, the formula is: /// /// ```text /// (x + r - 1 choose x) * (1 - p)^x * p^r /// ``` /// /// The general formula for real `r` is: /// /// ```text /// Γ(r + x)/(Γ(r) * Γ(x + 1)) * (1 - p)^x * p^r /// ``` /// /// where Γ(x) is the Gamma function. fn pmf(&self, x: u64) -> f64 { self.ln_pmf(x).exp() } /// Calculates the log probability mass function for the negative binomial /// distribution at `x`. /// /// # Formula /// /// When `r` is an integer, the formula is: /// /// ```text /// ln((x + r - 1 choose x) * (1 - p)^x * p^r) /// ``` /// /// The general formula for real `r` is: /// /// ```text /// ln(Γ(r + x)/(Γ(r) * Γ(x + 1)) * (1 - p)^x * p^r) /// ``` /// /// where Γ(x) is the Gamma function. fn ln_pmf(&self, x: u64) -> f64 { let k = x as f64; gamma::ln_gamma(self.r + k) - gamma::ln_gamma(self.r) - gamma::ln_gamma(k + 1.0) + (self.r * self.p.ln()) + (k * (-self.p).ln_1p()) } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::test; use crate::testing_boiler; testing_boiler!(r: f64, p: f64; NegativeBinomial; NegativeBinomialError); #[test] fn test_create() { create_ok(0.0, 0.0); create_ok(0.3, 0.4); create_ok(1.0, 0.3); } #[test] fn test_bad_create() { test_create_err(f64::NAN, 1.0, NegativeBinomialError::RInvalid); test_create_err(0.0, f64::NAN, NegativeBinomialError::PInvalid); create_err(-1.0, 1.0); create_err(2.0, 2.0); } #[test] fn test_mean() { let mean = |x: NegativeBinomial| x.mean().unwrap(); test_exact(4.0, 0.0, f64::INFINITY, mean); test_absolute(3.0, 0.3, 7.0, 1e-15 , mean); test_exact(2.0, 1.0, 0.0, mean); } #[test] fn test_variance() { let variance = |x: NegativeBinomial| x.variance().unwrap(); test_exact(4.0, 0.0, f64::INFINITY, variance); test_absolute(3.0, 0.3, 23.333333333333, 1e-12, variance); test_exact(2.0, 1.0, 0.0, variance); } #[test] fn test_skewness() { let skewness = |x: NegativeBinomial| x.skewness().unwrap(); test_exact(0.0, 0.0, f64::INFINITY, skewness); test_absolute(0.1, 0.3, 6.425396041, 1e-09, skewness); test_exact(1.0, 1.0, f64::INFINITY, skewness); } #[test] fn test_mode() { let mode = |x: NegativeBinomial| x.mode().unwrap(); test_exact(0.0, 0.0, 0.0, mode); test_exact(0.3, 0.0, 0.0, mode); test_exact(1.0, 1.0, 0.0, mode); test_exact(10.0, 0.01, 891.0, mode); } #[test] fn test_min_max() { let min = |x: NegativeBinomial| x.min(); let max = |x: NegativeBinomial| x.max(); test_exact(1.0, 0.5, 0, min); test_exact(1.0, 0.3, u64::MAX, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: NegativeBinomial| x.pmf(arg); test_absolute(4.0, 0.5, 0.0625, 1e-8, pmf(0)); test_absolute(4.0, 0.5, 0.15625, 1e-8, pmf(3)); test_exact(1.0, 0.0, 0.0, pmf(0)); test_exact(1.0, 0.0, 0.0, pmf(1)); test_absolute(3.0, 0.2, 0.008, 1e-15, pmf(0)); test_absolute(3.0, 0.2, 0.0192, 1e-15, pmf(1)); test_absolute(3.0, 0.2, 0.04096, 1e-15, pmf(3)); test_absolute(10.0, 0.2, 1.024e-07, 1e-07, pmf(0)); test_absolute(10.0, 0.2, 8.192e-07, 1e-07, pmf(1)); test_absolute(10.0, 0.2, 0.001015706852, 1e-07, pmf(10)); test_absolute(1.0, 0.3, 0.3, 1e-15, pmf(0)); test_absolute(1.0, 0.3, 0.21, 1e-15, pmf(1)); test_absolute(3.0, 0.3, 0.027, 1e-15, pmf(0)); test_exact(0.3, 1.0, 0.0, pmf(1)); test_exact(0.3, 1.0, 0.0, pmf(3)); test_is_nan(0.3, 1.0, pmf(0)); test_exact(0.3, 1.0, 0.0, pmf(1)); test_exact(0.3, 1.0, 0.0, pmf(10)); test_is_nan(1.0, 1.0, pmf(0)); test_exact(1.0, 1.0, 0.0, pmf(1)); test_is_nan(3.0, 1.0, pmf(0)); test_exact(3.0, 1.0, 0.0, pmf(1)); test_exact(3.0, 1.0, 0.0, pmf(3)); test_is_nan(10.0, 1.0, pmf(0)); test_exact(10.0, 1.0, 0.0, pmf(1)); test_exact(10.0, 1.0, 0.0, pmf(10)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: NegativeBinomial| x.ln_pmf(arg); test_exact(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(0)); test_exact(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(1)); test_absolute(3.0, 0.2, -4.828313737, 1e-08, ln_pmf(0)); test_absolute(3.0, 0.2, -3.952845, 1e-08, ln_pmf(1)); test_absolute(3.0, 0.2, -3.195159298, 1e-08, ln_pmf(3)); test_absolute(10.0, 0.2, -16.09437912, 1e-08, ln_pmf(0)); test_absolute(10.0, 0.2, -14.01493758, 1e-08, ln_pmf(1)); test_absolute(10.0, 0.2, -6.892170503, 1e-08, ln_pmf(10)); test_absolute(1.0, 0.3, -1.203972804, 1e-08, ln_pmf(0)); test_absolute(1.0, 0.3, -1.560647748, 1e-08, ln_pmf(1)); test_absolute(3.0, 0.3, -3.611918413, 1e-08, ln_pmf(0)); test_exact(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1)); test_exact(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(3)); test_is_nan(0.3, 1.0, ln_pmf(0)); test_exact(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1)); test_exact(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(10)); test_is_nan(1.0, 1.0, ln_pmf(0)); test_exact(1.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); test_is_nan(3.0, 1.0, ln_pmf(0)); test_exact(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); test_exact(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(3)); test_is_nan(10.0, 1.0, ln_pmf(0)); test_exact(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); test_exact(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(10)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: NegativeBinomial| x.cdf(arg); test_absolute(1.0, 0.3, 0.3, 1e-08, cdf(0)); test_absolute(1.0, 0.3, 0.51, 1e-08, cdf(1)); test_absolute(1.0, 0.3, 0.83193, 1e-08, cdf(4)); test_absolute(1.0, 0.3, 0.9802267326, 1e-08, cdf(10)); test_exact(1.0, 1.0, 1.0, cdf(0)); test_exact(1.0, 1.0, 1.0, cdf(1)); test_absolute(10.0, 0.75, 0.05631351471, 1e-08, cdf(0)); test_absolute(10.0, 0.75, 0.1970973015, 1e-08, cdf(1)); test_absolute(10.0, 0.75, 0.9960578583, 1e-08, cdf(10)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: NegativeBinomial| x.sf(arg); test_absolute(1.0, 0.3, 0.7, 1e-08, sf(0)); test_absolute(1.0, 0.3, 0.49, 1e-08, sf(1)); test_absolute(1.0, 0.3, 0.1680699999999986, 1e-08, sf(4)); test_absolute(1.0, 0.3, 0.019773267430000074, 1e-08, sf(10)); test_exact(1.0, 1.0, 0.0, sf(0)); test_exact(1.0, 1.0, 0.0, sf(1)); test_absolute(10.0, 0.75, 0.9436864852905275, 1e-08, sf(0)); test_absolute(10.0, 0.75, 0.8029026985168456, 1e-08, sf(1)); test_absolute(10.0, 0.75, 0.003942141664083465, 1e-08, sf(10)); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: u64| move |x: NegativeBinomial| x.cdf(arg); test_exact(3.0, 0.5, 1.0, cdf(100)); } #[test] fn test_discrete() { test::check_discrete_distribution(&create_ok(5.0, 0.3), 35); test::check_discrete_distribution(&create_ok(10.0, 0.7), 21); } #[test] fn test_sf_upper_bound() { let sf = |arg: u64| move |x: NegativeBinomial| x.sf(arg); test_absolute(3.0, 0.5, 5.282409836586059e-28, 1e-28, sf(100)); } } statrs-0.18.0/src/distribution/normal.rs000064400000000000000000000456741046102023000164160ustar 00000000000000use crate::consts; use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::erf; use crate::statistics::*; use std::f64; /// Implements the [Normal](https://en.wikipedia.org/wiki/Normal_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{Normal, Continuous}; /// use statrs::statistics::Distribution; /// /// let n = Normal::new(0.0, 1.0).unwrap(); /// assert_eq!(n.mean().unwrap(), 0.0); /// assert_eq!(n.pdf(1.0), 0.2419707245191433497978); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct Normal { mean: f64, std_dev: f64, } /// Represents the errors that can occur when creating a [`Normal`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum NormalError { /// The mean is NaN. MeanInvalid, /// The standard deviation is NaN, zero or less than zero. StandardDeviationInvalid, } impl std::fmt::Display for NormalError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { NormalError::MeanInvalid => write!(f, "Mean is NaN"), NormalError::StandardDeviationInvalid => { write!(f, "Standard deviation is NaN, zero or less than zero") } } } } impl std::error::Error for NormalError {} impl Normal { /// Constructs a new normal distribution with a mean of `mean` /// and a standard deviation of `std_dev` /// /// # Errors /// /// Returns an error if `mean` or `std_dev` are `NaN` or if /// `std_dev <= 0.0` /// /// # Examples /// /// ``` /// use statrs::distribution::Normal; /// /// let mut result = Normal::new(0.0, 1.0); /// assert!(result.is_ok()); /// /// result = Normal::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` pub fn new(mean: f64, std_dev: f64) -> Result { if mean.is_nan() { return Err(NormalError::MeanInvalid); } if std_dev.is_nan() || std_dev <= 0.0 { return Err(NormalError::StandardDeviationInvalid); } Ok(Normal { mean, std_dev }) } /// Constructs a new standard normal distribution with a mean of 0 /// and a standard deviation of 1. /// /// /// # Examples /// /// ``` /// use statrs::distribution::Normal; /// /// let mut result = Normal::standard(); /// ``` pub fn standard() -> Normal { Normal { mean: 0.0, std_dev: 1.0, } } } impl std::fmt::Display for Normal { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "N({},{})", self.mean, self.std_dev) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Normal { fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, self.mean, self.std_dev) } } impl ContinuousCDF for Normal { /// Calculates the cumulative distribution function for the /// normal distribution at `x` /// /// # Formula /// /// ```text /// (1 / 2) * (1 + erf((x - μ) / (σ * sqrt(2)))) /// ``` /// /// where `μ` is the mean, `σ` is the standard deviation, and /// `erf` is the error function fn cdf(&self, x: f64) -> f64 { cdf_unchecked(x, self.mean, self.std_dev) } /// Calculates the survival function for the /// normal distribution at `x` /// /// # Formula /// /// ```text /// (1 / 2) * (1 + erf(-(x - μ) / (σ * sqrt(2)))) /// ``` /// /// where `μ` is the mean, `σ` is the standard deviation, and /// `erf` is the error function /// /// note that this calculates the complement due to flipping /// the sign of the argument error function with respect to the cdf. /// /// the normal cdf Φ (and internal error function) as the following property: /// ```text /// Φ(-x) + Φ(x) = 1 /// Φ(-x) = 1 - Φ(x) /// ``` fn sf(&self, x: f64) -> f64 { sf_unchecked(x, self.mean, self.std_dev) } /// Calculates the inverse cumulative distribution function for the /// normal distribution at `x`. /// In other languages, such as R, this is known as the the quantile function. /// /// # Panics /// /// If `x < 0.0` or `x > 1.0` /// /// # Formula /// /// ```text /// μ - sqrt(2) * σ * erfc_inv(2x) /// ``` /// /// where `μ` is the mean, `σ` is the standard deviation and `erfc_inv` is /// the inverse of the complementary error function fn inverse_cdf(&self, x: f64) -> f64 { if !(0.0..=1.0).contains(&x) { panic!("x must be in [0, 1]"); } else { self.mean - (self.std_dev * f64::consts::SQRT_2 * erf::erfc_inv(2.0 * x)) } } } impl Min for Normal { /// Returns the minimum value in the domain of the /// normal distribution representable by a double precision float /// /// # Formula /// /// ```text /// f64::NEG_INFINITY /// ``` fn min(&self) -> f64 { f64::NEG_INFINITY } } impl Max for Normal { /// Returns the maximum value in the domain of the /// normal distribution representable by a double precision float /// /// # Formula /// /// ```text /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY } } impl Distribution for Normal { /// Returns the mean of the normal distribution /// /// # Remarks /// /// This is the same mean used to construct the distribution fn mean(&self) -> Option { Some(self.mean) } /// Returns the variance of the normal distribution /// /// # Formula /// /// ```text /// σ^2 /// ``` /// /// where `σ` is the standard deviation fn variance(&self) -> Option { Some(self.std_dev * self.std_dev) } /// Returns the standard deviation of the normal distribution /// # Remarks /// This is the same standard deviation used to construct the distribution fn std_dev(&self) -> Option { Some(self.std_dev) } /// Returns the entropy of the normal distribution /// /// # Formula /// /// ```text /// (1 / 2) * ln(2σ^2 * π * e) /// ``` /// /// where `σ` is the standard deviation fn entropy(&self) -> Option { Some(self.std_dev.ln() + consts::LN_SQRT_2PIE) } /// Returns the skewness of the normal distribution /// /// # Formula /// /// ```text /// 0 /// ``` fn skewness(&self) -> Option { Some(0.0) } } impl Median for Normal { /// Returns the median of the normal distribution /// /// # Formula /// /// ```text /// μ /// ``` /// /// where `μ` is the mean fn median(&self) -> f64 { self.mean } } impl Mode> for Normal { /// Returns the mode of the normal distribution /// /// # Formula /// /// ```text /// μ /// ``` /// /// where `μ` is the mean fn mode(&self) -> Option { Some(self.mean) } } impl Continuous for Normal { /// Calculates the probability density function for the normal distribution /// at `x` /// /// # Formula /// /// ```text /// (1 / sqrt(2σ^2 * π)) * e^(-(x - μ)^2 / 2σ^2) /// ``` /// /// where `μ` is the mean and `σ` is the standard deviation fn pdf(&self, x: f64) -> f64 { pdf_unchecked(x, self.mean, self.std_dev) } /// Calculates the log probability density function for the normal /// distribution /// at `x` /// /// # Formula /// /// ```text /// ln((1 / sqrt(2σ^2 * π)) * e^(-(x - μ)^2 / 2σ^2)) /// ``` /// /// where `μ` is the mean and `σ` is the standard deviation fn ln_pdf(&self, x: f64) -> f64 { ln_pdf_unchecked(x, self.mean, self.std_dev) } } /// performs an unchecked cdf calculation for a normal distribution /// with the given mean and standard deviation at x pub fn cdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 { 0.5 * erf::erfc((mean - x) / (std_dev * f64::consts::SQRT_2)) } /// performs an unchecked sf calculation for a normal distribution /// with the given mean and standard deviation at x pub fn sf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 { 0.5 * erf::erfc((x - mean) / (std_dev * f64::consts::SQRT_2)) } /// performs an unchecked pdf calculation for a normal distribution /// with the given mean and standard deviation at x pub fn pdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 { let d = (x - mean) / std_dev; (-0.5 * d * d).exp() / (consts::SQRT_2PI * std_dev) } /// performs an unchecked log(pdf) calculation for a normal distribution /// with the given mean and standard deviation at x pub fn ln_pdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 { let d = (x - mean) / std_dev; (-0.5 * d * d) - consts::LN_SQRT_2PI - std_dev.ln() } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] /// draws a sample from a normal distribution using the Box-Muller algorithm pub fn sample_unchecked(rng: &mut R, mean: f64, std_dev: f64) -> f64 { use crate::distribution::ziggurat; mean + std_dev * ziggurat::sample_std_normal(rng) } impl std::default::Default for Normal { /// Returns the standard normal distribution with a mean of 0 /// and a standard deviation of 1. fn default() -> Self { Self::standard() } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(mean: f64, std_dev: f64; Normal; NormalError); #[test] fn test_create() { create_ok(10.0, 0.1); create_ok(-5.0, 1.0); create_ok(0.0, 10.0); create_ok(10.0, 100.0); create_ok(-5.0, f64::INFINITY); } #[test] fn test_bad_create() { test_create_err(f64::NAN, 1.0, NormalError::MeanInvalid); test_create_err(1.0, f64::NAN, NormalError::StandardDeviationInvalid); create_err(0.0, 0.0); create_err(f64::NAN, f64::NAN); create_err(1.0, -1.0); } #[test] fn test_variance() { let variance = |x: Normal| x.variance().unwrap(); test_exact(0.0, 0.1, 0.1 * 0.1, variance); test_exact(0.0, 1.0, 1.0, variance); test_exact(0.0, 10.0, 100.0, variance); test_exact(0.0, f64::INFINITY, f64::INFINITY, variance); } #[test] fn test_entropy() { let entropy = |x: Normal| x.entropy().unwrap(); test_absolute(0.0, 0.1, -0.8836465597893729422377, 1e-15, entropy); test_exact(0.0, 1.0, 1.41893853320467274178, entropy); test_exact(0.0, 10.0, 3.721523626198718425798, entropy); test_exact(0.0, f64::INFINITY, f64::INFINITY, entropy); } #[test] fn test_skewness() { let skewness = |x: Normal| x.skewness().unwrap(); test_exact(0.0, 0.1, 0.0, skewness); test_exact(4.0, 1.0, 0.0, skewness); test_exact(0.3, 10.0, 0.0, skewness); test_exact(0.0, f64::INFINITY, 0.0, skewness); } #[test] fn test_mode() { let mode = |x: Normal| x.mode().unwrap(); test_exact(-0.0, 1.0, 0.0, mode); test_exact(0.0, 1.0, 0.0, mode); test_exact(0.1, 1.0, 0.1, mode); test_exact(1.0, 1.0, 1.0, mode); test_exact(-10.0, 1.0, -10.0, mode); test_exact(f64::INFINITY, 1.0, f64::INFINITY, mode); } #[test] fn test_median() { let median = |x: Normal| x.median(); test_exact(-0.0, 1.0, 0.0, median); test_exact(0.0, 1.0, 0.0, median); test_exact(0.1, 1.0, 0.1, median); test_exact(1.0, 1.0, 1.0, median); test_exact(-0.0, 1.0, -0.0, median); test_exact(f64::INFINITY, 1.0, f64::INFINITY, median); } #[test] fn test_min_max() { let min = |x: Normal| x.min(); let max = |x: Normal| x.max(); test_exact(0.0, 0.1, f64::NEG_INFINITY, min); test_exact(-3.0, 10.0, f64::NEG_INFINITY, min); test_exact(0.0, 0.1, f64::INFINITY, max); test_exact(-3.0, 10.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Normal| x.pdf(arg); test_absolute(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(8.5)); test_absolute(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(9.8)); test_absolute(10.0, 0.1, 3.989422804014326779399, 1e-15, pdf(10.0)); test_absolute(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(10.2)); test_absolute(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(11.5)); test_exact(-5.0, 1.0, 1.486719514734297707908E-6, pdf(-10.0)); test_exact(-5.0, 1.0, 0.01752830049356853736216, pdf(-7.5)); test_absolute(-5.0, 1.0, 0.3989422804014326779399, 1e-16, pdf(-5.0)); test_exact(-5.0, 1.0, 0.01752830049356853736216, pdf(-2.5)); test_exact(-5.0, 1.0, 1.486719514734297707908E-6, pdf(0.0)); test_exact(0.0, 10.0, 0.03520653267642994777747, pdf(-5.0)); test_absolute(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(-2.5)); test_absolute(0.0, 10.0, 0.03989422804014326779399, 1e-17, pdf(0.0)); test_absolute(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(2.5)); test_exact(0.0, 10.0, 0.03520653267642994777747, pdf(5.0)); test_absolute(10.0, 100.0, 4.398359598042719404845E-4, 1e-19, pdf(-200.0)); test_exact(10.0, 100.0, 0.002178521770325505313831, pdf(-100.0)); test_exact(10.0, 100.0, 0.003969525474770117655105, pdf(0.0)); test_absolute(10.0, 100.0, 0.002660852498987548218204, 1e-18, pdf(100.0)); test_exact(10.0, 100.0, 6.561581477467659126534E-4, pdf(200.0)); test_exact(-5.0, f64::INFINITY, 0.0, pdf(-5.0)); test_exact(-5.0, f64::INFINITY, 0.0, pdf(0.0)); test_exact(-5.0, f64::INFINITY, 0.0, pdf(100.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Normal| x.ln_pdf(arg); test_absolute(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(8.5)); test_absolute(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(9.8)); test_absolute(10.0, 0.1, (3.989422804014326779399f64).ln(), 1e-15, ln_pdf(10.0)); test_absolute(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(10.2)); test_absolute(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(11.5)); test_exact(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(-10.0)); test_exact(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-7.5)); test_absolute(-5.0, 1.0, (0.3989422804014326779399f64).ln(), 1e-15, ln_pdf(-5.0)); test_exact(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-2.5)); test_exact(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(0.0)); test_exact(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(-5.0)); test_exact(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(-2.5)); test_exact(0.0, 10.0, (0.03989422804014326779399f64).ln(), ln_pdf(0.0)); test_exact(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(2.5)); test_exact(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(5.0)); test_exact(10.0, 100.0, (4.398359598042719404845E-4f64).ln(), ln_pdf(-200.0)); test_exact(10.0, 100.0, (0.002178521770325505313831f64).ln(), ln_pdf(-100.0)); test_absolute(10.0, 100.0, (0.003969525474770117655105f64).ln(),1e-15, ln_pdf(0.0)); test_absolute(10.0, 100.0, (0.002660852498987548218204f64).ln(), 1e-15, ln_pdf(100.0)); test_absolute(10.0, 100.0, (6.561581477467659126534E-4f64).ln(), 1e-15, ln_pdf(200.0)); test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0)); test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0)); test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(100.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Normal| x.cdf(arg); test_exact(5.0, 2.0, 0.0, cdf(f64::NEG_INFINITY)); test_absolute(5.0, 2.0, 0.0000002866515718, 1e-16, cdf(-5.0)); test_absolute(5.0, 2.0, 0.0002326290790, 1e-13, cdf(-2.0)); test_absolute(5.0, 2.0, 0.006209665325, 1e-12, cdf(0.0)); test_exact(5.0, 2.0, 0.30853753872598689636229538939166226011639782444542207, cdf(4.0)); test_exact(5.0, 2.0, 0.5, cdf(5.0)); test_exact(5.0, 2.0, 0.69146246127401310363770461060833773988360217555457859, cdf(6.0)); test_absolute(5.0, 2.0, 0.993790334674, 1e-12, cdf(10.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Normal| x.sf(arg); test_exact(5.0, 2.0, 1.0, sf(f64::NEG_INFINITY)); test_absolute(5.0, 2.0, 0.9999997133484281, 1e-16, sf(-5.0)); test_absolute(5.0, 2.0, 0.9997673709209455, 1e-13, sf(-2.0)); test_absolute(5.0, 2.0, 0.9937903346744879, 1e-12, sf(0.0)); test_exact(5.0, 2.0, 0.6914624612740131, sf(4.0)); test_exact(5.0, 2.0, 0.5, sf(5.0)); test_exact(5.0, 2.0, 0.3085375387259869, sf(6.0)); test_absolute(5.0, 2.0, 0.006209665325512148, 1e-12, sf(10.0)); } #[test] fn test_continuous() { test::check_continuous_distribution(&create_ok(0.0, 1.0), -10.0, 10.0); test::check_continuous_distribution(&create_ok(20.0, 0.5), 10.0, 30.0); } #[test] fn test_inverse_cdf() { let inverse_cdf = |arg: f64| move |x: Normal| x.inverse_cdf(arg); test_exact(5.0, 2.0, f64::NEG_INFINITY, inverse_cdf( 0.0)); test_absolute(5.0, 2.0, -5.0, 1e-14, inverse_cdf(0.00000028665157187919391167375233287464535385442301361187883)); test_absolute(5.0, 2.0, -2.0, 1e-14, inverse_cdf(0.0002326290790355250363499258867279847735487493358890356)); test_absolute(5.0, 2.0, -0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036)); test_absolute(5.0, 2.0, 0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036)); test_absolute(5.0, 2.0, 4.0, 1e-14, inverse_cdf(0.30853753872598689636229538939166226011639782444542207)); test_absolute(5.0, 2.0, 5.0, 1e-14, inverse_cdf(0.5)); test_absolute(5.0, 2.0, 6.0, 1e-14, inverse_cdf(0.69146246127401310363770461060833773988360217555457859)); test_absolute(5.0, 2.0, 10.0, 1e-14, inverse_cdf(0.9937903346742238648330218954258077788721022530769078)); test_exact(5.0, 2.0, f64::INFINITY, inverse_cdf(1.0)); } #[test] fn test_default() { let n = Normal::default(); let n_mean = n.mean().unwrap(); let n_std = n.std_dev().unwrap(); // Check that the mean of the distribution is close to 0 assert_almost_eq!(n_mean, 0.0, 1e-15); // Check that the standard deviation of the distribution is close to 1 assert_almost_eq!(n_std, 1.0, 1e-15); } } statrs-0.18.0/src/distribution/pareto.rs000064400000000000000000000361611046102023000164070ustar 00000000000000use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; use std::f64; /// Implements the [Pareto](https://en.wikipedia.org/wiki/Pareto_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{Pareto, Continuous}; /// use statrs::statistics::Distribution; /// use statrs::prec; /// /// let p = Pareto::new(1.0, 2.0).unwrap(); /// assert_eq!(p.mean().unwrap(), 2.0); /// assert!(prec::almost_eq(p.pdf(2.0), 0.25, 1e-15)); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct Pareto { scale: f64, shape: f64, } /// Represents the errors that can occur when creating a [`Pareto`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum ParetoError { /// The scale is NaN, zero or less than zero. ScaleInvalid, /// The shape is NaN, zero or less than zero. ShapeInvalid, } impl std::fmt::Display for ParetoError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { ParetoError::ScaleInvalid => write!(f, "Scale is NaN, zero, or less than zero"), ParetoError::ShapeInvalid => write!(f, "Shape is NaN, zero, or less than zero"), } } } impl std::error::Error for ParetoError {} impl Pareto { /// Constructs a new Pareto distribution with scale `scale`, and `shape` /// shape. /// /// # Errors /// /// Returns an error if any of `scale` or `shape` are `NaN`. /// Returns an error if `scale <= 0.0` or `shape <= 0.0` /// /// # Examples /// /// ``` /// use statrs::distribution::Pareto; /// /// let mut result = Pareto::new(1.0, 2.0); /// assert!(result.is_ok()); /// /// result = Pareto::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` pub fn new(scale: f64, shape: f64) -> Result { if scale.is_nan() || scale <= 0.0 { return Err(ParetoError::ScaleInvalid); } if shape.is_nan() || shape <= 0.0 { return Err(ParetoError::ShapeInvalid); } Ok(Pareto { scale, shape }) } /// Returns the scale of the Pareto distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Pareto; /// /// let n = Pareto::new(1.0, 2.0).unwrap(); /// assert_eq!(n.scale(), 1.0); /// ``` pub fn scale(&self) -> f64 { self.scale } /// Returns the shape of the Pareto distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Pareto; /// /// let n = Pareto::new(1.0, 2.0).unwrap(); /// assert_eq!(n.shape(), 2.0); /// ``` pub fn shape(&self) -> f64 { self.shape } } impl std::fmt::Display for Pareto { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Pareto({},{})", self.scale, self.shape) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Pareto { fn sample(&self, rng: &mut R) -> f64 { use rand::distributions::OpenClosed01; // Inverse transform sampling let u: f64 = rng.sample(OpenClosed01); self.scale * u.powf(-1.0 / self.shape) } } impl ContinuousCDF for Pareto { /// Calculates the cumulative distribution function for the Pareto /// distribution at `x` /// /// # Formula /// /// ```text /// if x < x_m { /// 0 /// } else { /// 1 - (x_m/x)^α /// } /// ``` /// /// where `x_m` is the scale and `α` is the shape fn cdf(&self, x: f64) -> f64 { if x < self.scale { 0.0 } else { 1.0 - (self.scale / x).powf(self.shape) } } /// Calculates the survival function for the Pareto /// distribution at `x` /// /// # Formula /// /// ```text /// if x < x_m { /// 1 /// } else { /// (x_m/x)^α /// } /// ``` /// /// where `x_m` is the scale and `α` is the shape fn sf(&self, x: f64) -> f64 { if x < self.scale { 1.0 } else { (self.scale / x).powf(self.shape) } } /// Calculates the inverse cumulative distribution function for the Pareto /// distribution at `x` /// /// # Formula /// /// ```text /// x_m / (1 - x)^(1 / α) /// ``` /// /// where `x_m` is the scale and `α` is the shape fn inverse_cdf(&self, p: f64) -> f64 { if !(0.0..=1.0).contains(&p) { panic!("x must be in [0, 1]"); } else { self.scale * (1.0 - p).powf(-1.0 / self.shape) } } } impl Min for Pareto { /// Returns the minimum value in the domain of the Pareto distribution /// representable by a double precision float /// /// # Formula /// /// ```text /// x_m /// ``` /// /// where `x_m` is the scale fn min(&self) -> f64 { self.scale } } impl Max for Pareto { /// Returns the maximum value in the domain of the Pareto distribution /// representable by a double precision float /// /// # Formula /// /// ```text /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY } } impl Distribution for Pareto { /// Returns the mean of the Pareto distribution /// /// # Formula /// /// ```text /// if α <= 1 { /// f64::INFINITY /// } else { /// (α * x_m)/(α - 1) /// } /// ``` /// /// where `x_m` is the scale and `α` is the shape fn mean(&self) -> Option { if self.shape <= 1.0 { None } else { Some((self.shape * self.scale) / (self.shape - 1.0)) } } /// Returns the variance of the Pareto distribution /// /// # Formula /// /// ```text /// if α <= 2 { /// f64::INFINITY /// } else { /// (x_m/(α - 1))^2 * (α/(α - 2)) /// } /// ``` /// /// where `x_m` is the scale and `α` is the shape fn variance(&self) -> Option { if self.shape <= 2.0 { None } else { let a = self.scale / (self.shape - 1.0); // just a temporary variable Some(a * a * self.shape / (self.shape - 2.0)) } } /// Returns the entropy for the Pareto distribution /// /// # Formula /// /// ```text /// ln(α/x_m) - 1/α - 1 /// ``` /// /// where `x_m` is the scale and `α` is the shape fn entropy(&self) -> Option { Some(self.shape.ln() - self.scale.ln() - (1.0 / self.shape) - 1.0) } /// Returns the skewness of the Pareto distribution /// /// # Panics /// /// If `α <= 3.0` /// /// where `α` is the shape /// /// # Formula /// /// ```text /// (2*(α + 1)/(α - 3))*sqrt((α - 2)/α) /// ``` /// /// where `α` is the shape fn skewness(&self) -> Option { if self.shape <= 3.0 { None } else { Some( (2.0 * (self.shape + 1.0) / (self.shape - 3.0)) * ((self.shape - 2.0) / self.shape).sqrt(), ) } } } impl Median for Pareto { /// Returns the median of the Pareto distribution /// /// # Formula /// /// ```text /// x_m*2^(1/α) /// ``` /// /// where `x_m` is the scale and `α` is the shape fn median(&self) -> f64 { self.scale * (2f64.powf(1.0 / self.shape)) } } impl Mode> for Pareto { /// Returns the mode of the Pareto distribution /// /// # Formula /// /// ```text /// x_m /// ``` /// /// where `x_m` is the scale fn mode(&self) -> Option { Some(self.scale) } } impl Continuous for Pareto { /// Calculates the probability density function for the Pareto distribution /// at `x` /// /// # Formula /// /// ```text /// if x < x_m { /// 0 /// } else { /// (α * x_m^α)/(x^(α + 1)) /// } /// ``` /// /// where `x_m` is the scale and `α` is the shape fn pdf(&self, x: f64) -> f64 { if x < self.scale { 0.0 } else { (self.shape * self.scale.powf(self.shape)) / x.powf(self.shape + 1.0) } } /// Calculates the log probability density function for the Pareto /// distribution at `x` /// /// # Formula /// /// ```text /// if x < x_m { /// f64::NEG_INFINITY /// } else { /// ln(α) + α*ln(x_m) - (α + 1)*ln(x) /// } /// ``` /// /// where `x_m` is the scale and `α` is the shape fn ln_pdf(&self, x: f64) -> f64 { if x < self.scale { f64::NEG_INFINITY } else { self.shape.ln() + self.shape * self.scale.ln() - (self.shape + 1.0) * x.ln() } } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(scale: f64, shape: f64; Pareto; ParetoError); #[test] fn test_create() { create_ok(10.0, 0.1); create_ok(5.0, 1.0); create_ok(0.1, 10.0); create_ok(10.0, 100.0); create_ok(1.0, f64::INFINITY); create_ok(f64::INFINITY, f64::INFINITY); } #[test] fn test_bad_create() { test_create_err(1.0, -1.0, ParetoError::ShapeInvalid); test_create_err(-1.0, 1.0, ParetoError::ScaleInvalid); create_err(0.0, 0.0); create_err(-1.0, -1.0); create_err(f64::NAN, 1.0); create_err(1.0, f64::NAN); create_err(f64::NAN, f64::NAN); } #[test] fn test_variance() { let variance = |x: Pareto| x.variance().unwrap(); test_exact(1.0, 3.0, 0.75, variance); test_absolute(10.0, 10.0, 125.0 / 81.0, 1e-13, variance); } #[test] fn test_variance_degen() { test_none(1.0, 1.0, |dist| dist.variance()); // shape <= 2.0 } #[test] fn test_entropy() { let entropy = |x: Pareto| x.entropy().unwrap(); test_exact(0.1, 0.1, -11.0, entropy); test_exact(1.0, 1.0, -2.0, entropy); test_exact(10.0, 10.0, -1.1, entropy); test_exact(3.0, 1.0, -2.0 - 3f64.ln(), entropy); test_exact(1.0, 3.0, -4.0/3.0 + 3f64.ln(), entropy); } #[test] fn test_skewness() { let skewness = |x: Pareto| x.skewness().unwrap(); test_exact(1.0, 4.0, 5.0*2f64.sqrt(), skewness); test_exact(1.0, 100.0, (707.0/485.0)*2f64.sqrt(), skewness); } #[test] fn test_skewness_invalid_shape() { test_none(1.0, 3.0, |dist| dist.skewness()); } #[test] fn test_mode() { let mode = |x: Pareto| x.mode().unwrap(); test_exact(0.1, 1.0, 0.1, mode); test_exact(2.0, 1.0, 2.0, mode); test_exact(10.0, f64::INFINITY, 10.0, mode); test_exact(f64::INFINITY, 1.0, f64::INFINITY, mode); } #[test] fn test_median() { let median = |x: Pareto| x.median(); test_exact(0.1, 0.1, 102.4, median); test_exact(1.0, 1.0, 2.0, median); test_exact(10.0, 10.0, 10.0*2f64.powf(0.1), median); test_exact(3.0, 0.5, 12.0, median); test_exact(10.0, f64::INFINITY, 10.0, median); } #[test] fn test_min_max() { let min = |x: Pareto| x.min(); let max = |x: Pareto| x.max(); test_exact(0.2, f64::INFINITY, 0.2, min); test_exact(10.0, f64::INFINITY, 10.0, min); test_exact(f64::INFINITY, 1.0, f64::INFINITY, min); test_exact(1.0, 0.1, f64::INFINITY, max); test_exact(3.0, 10.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Pareto| x.pdf(arg); test_exact(1.0, 1.0, 0.0, pdf(0.1)); test_exact(1.0, 1.0, 1.0, pdf(1.0)); test_exact(1.0, 1.0, 4.0/9.0, pdf(1.5)); test_exact(1.0, 1.0, 1.0/25.0, pdf(5.0)); test_exact(1.0, 1.0, 1.0/2500.0, pdf(50.0)); test_exact(1.0, 4.0, 4.0, pdf(1.0)); test_exact(1.0, 4.0, 128.0/243.0, pdf(1.5)); test_exact(1.0, 4.0, 1.0/78125000.0, pdf(50.0)); test_exact(3.0, 2.0, 2.0/3.0, pdf(3.0)); test_exact(3.0, 2.0, 18.0/125.0, pdf(5.0)); test_absolute(25.0, 100.0, 1.5777218104420236e-30, 1e-50, pdf(50.0)); test_absolute(100.0, 25.0, 6.6003546737276816e-6, 1e-16, pdf(150.0)); test_exact(1.0, 2.0, 0.0, pdf(f64::INFINITY)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Pareto| x.ln_pdf(arg); test_exact(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(0.1)); test_exact(1.0, 1.0, 0.0, ln_pdf(1.0)); test_absolute(1.0, 1.0, 4f64.ln() - 9f64.ln(), 1e-14, ln_pdf(1.5)); test_absolute(1.0, 1.0, -(25f64.ln()), 1e-14, ln_pdf(5.0)); test_absolute(1.0, 1.0, -(2500f64.ln()), 1e-14, ln_pdf(50.0)); test_absolute(1.0, 4.0, 4f64.ln(), 1e-14, ln_pdf(1.0)); test_absolute(1.0, 4.0, 128f64.ln() - 243f64.ln(), 1e-14, ln_pdf(1.5)); test_absolute(1.0, 4.0, -(78125000f64.ln()), 1e-14, ln_pdf(50.0)); test_absolute(3.0, 2.0, 2f64.ln() - 3f64.ln(), 1e-14, ln_pdf(3.0)); test_absolute(3.0, 2.0, 18f64.ln() - 125f64.ln(), 1e-14, ln_pdf(5.0)); test_absolute(25.0, 100.0, 1.5777218104420236e-30f64.ln(), 1e-12, ln_pdf(50.0)); test_absolute(100.0, 25.0, 6.6003546737276816e-6f64.ln(), 1e-12, ln_pdf(150.0)); test_exact(1.0, 2.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Pareto| x.cdf(arg); test_exact(0.1, 0.1, 0.0, cdf(0.1)); test_exact(1.0, 1.0, 0.0, cdf(1.0)); test_exact(5.0, 5.0, 0.0, cdf(2.0)); test_exact(7.0, 7.0, 0.9176457, cdf(10.0)); test_exact(10.0, 10.0, 50700551.0/60466176.0, cdf(12.0)); test_exact(5.0, 1.0, 0.5, cdf(10.0)); test_exact(3.0, 10.0, 1023.0/1024.0, cdf(6.0)); test_exact(1.0, 1.0, 1.0, cdf(f64::INFINITY)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Pareto| x.sf(arg); test_exact(0.1, 0.1, 1.0, sf(0.1)); test_exact(1.0, 1.0, 1.0, sf(1.0)); test_exact(5.0, 5.0, 1.0, sf(2.0)); test_absolute(7.0, 7.0, 0.08235429999999999, 1e-14, sf(10.0)); test_absolute(10.0, 10.0, 0.16150558288984573, 1e-14, sf(12.0)); test_exact(5.0, 1.0, 0.5, sf(10.0)); test_absolute(3.0, 10.0, 0.0009765625, 1e-14, sf(6.0)); test_exact(1.0, 1.0, 0.0, sf(f64::INFINITY)); } #[test] fn test_inverse_cdf() { let func = |arg: f64| move |x: Pareto| x.inverse_cdf(x.cdf(arg)); test_exact(0.1, 0.1, 0.1, func(0.1)); test_exact(1.0, 1.0, 1.0, func(1.0)); test_exact(7.0, 7.0, 10.0, func(10.0)); test_exact(10.0, 10.0, 12.0, func(12.0)); test_exact(5.0, 1.0, 10.0, func(10.0)); test_exact(3.0, 10.0, 6.0, func(6.0)); } #[test] fn test_continuous() { test::check_continuous_distribution(&create_ok(1.0, 10.0), 1.0, 10.0); test::check_continuous_distribution(&create_ok(0.1, 2.0), 0.1, 100.0); } } statrs-0.18.0/src/distribution/poisson.rs000064400000000000000000000337321046102023000166100ustar 00000000000000use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::{factorial, gamma}; use crate::statistics::*; use std::f64; /// Implements the [Poisson](https://en.wikipedia.org/wiki/Poisson_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{Poisson, Discrete}; /// use statrs::statistics::Distribution; /// use statrs::prec; /// /// let n = Poisson::new(1.0).unwrap(); /// assert_eq!(n.mean().unwrap(), 1.0); /// assert!(prec::almost_eq(n.pmf(1), 0.367879441171442, 1e-15)); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct Poisson { lambda: f64, } /// Represents the errors that can occur when creating a [`Poisson`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum PoissonError { /// The lambda is NaN, zero or less than zero. LambdaInvalid, } impl std::fmt::Display for PoissonError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { PoissonError::LambdaInvalid => write!(f, "Lambda is NaN, zero or less than zero"), } } } impl std::error::Error for PoissonError {} impl Poisson { /// Constructs a new poisson distribution with a rate (λ) /// of `lambda` /// /// # Errors /// /// Returns an error if `lambda` is `NaN` or `lambda <= 0.0` /// /// # Examples /// /// ``` /// use statrs::distribution::Poisson; /// /// let mut result = Poisson::new(1.0); /// assert!(result.is_ok()); /// /// result = Poisson::new(0.0); /// assert!(result.is_err()); /// ``` pub fn new(lambda: f64) -> Result { if lambda.is_nan() || lambda <= 0.0 { Err(PoissonError::LambdaInvalid) } else { Ok(Poisson { lambda }) } } /// Returns the rate (λ) of the poisson distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Poisson; /// /// let n = Poisson::new(1.0).unwrap(); /// assert_eq!(n.lambda(), 1.0); /// ``` pub fn lambda(&self) -> f64 { self.lambda } } impl std::fmt::Display for Poisson { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Pois({})", self.lambda) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Poisson { /// Generates one sample from the Poisson distribution either by /// Knuth's method if lambda < 30.0 or Rejection method PA by /// A. C. Atkinson from the Journal of the Royal Statistical Society /// Series C (Applied Statistics) Vol. 28 No. 1. (1979) pp. 29 - 35 /// otherwise fn sample(&self, rng: &mut R) -> u64 { sample_unchecked(rng, self.lambda) as u64 } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Poisson { /// Generates one sample from the Poisson distribution either by /// Knuth's method if lambda < 30.0 or Rejection method PA by /// A. C. Atkinson from the Journal of the Royal Statistical Society /// Series C (Applied Statistics) Vol. 28 No. 1. (1979) pp. 29 - 35 /// otherwise fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, self.lambda) } } impl DiscreteCDF for Poisson { /// Calculates the cumulative distribution function for the poisson /// distribution at `x` /// /// # Formula /// /// ```text /// P(x + 1, λ) /// ``` /// /// where `λ` is the rate and `P` is the upper regularized gamma function fn cdf(&self, x: u64) -> f64 { gamma::gamma_ur(x as f64 + 1.0, self.lambda) } /// Calculates the survival function for the poisson /// distribution at `x` /// /// # Formula /// /// ```text /// P(x + 1, λ) /// ``` /// /// where `λ` is the rate and `P` is the lower regularized gamma function fn sf(&self, x: u64) -> f64 { gamma::gamma_lr(x as f64 + 1.0, self.lambda) } } impl Min for Poisson { /// Returns the minimum value in the domain of the poisson distribution /// representable by a 64-bit integer /// /// # Formula /// /// ```text /// 0 /// ``` fn min(&self) -> u64 { 0 } } impl Max for Poisson { /// Returns the maximum value in the domain of the poisson distribution /// representable by a 64-bit integer /// /// # Formula /// /// ```text /// 2^63 - 1 /// ``` fn max(&self) -> u64 { u64::MAX } } impl Distribution for Poisson { /// Returns the mean of the poisson distribution /// /// # Formula /// /// ```text /// λ /// ``` /// /// where `λ` is the rate fn mean(&self) -> Option { Some(self.lambda) } /// Returns the variance of the poisson distribution /// /// # Formula /// /// ```text /// λ /// ``` /// /// where `λ` is the rate fn variance(&self) -> Option { Some(self.lambda) } /// Returns the entropy of the poisson distribution /// /// # Formula /// /// ```text /// (1 / 2) * ln(2πeλ) - 1 / (12λ) - 1 / (24λ^2) - 19 / (360λ^3) /// ``` /// /// where `λ` is the rate fn entropy(&self) -> Option { Some( 0.5 * (2.0 * f64::consts::PI * f64::consts::E * self.lambda).ln() - 1.0 / (12.0 * self.lambda) - 1.0 / (24.0 * self.lambda * self.lambda) - 19.0 / (360.0 * self.lambda * self.lambda * self.lambda), ) } /// Returns the skewness of the poisson distribution /// /// # Formula /// /// ```text /// λ^(-1/2) /// ``` /// /// where `λ` is the rate fn skewness(&self) -> Option { Some(1.0 / self.lambda.sqrt()) } } impl Median for Poisson { /// Returns the median of the poisson distribution /// /// # Formula /// /// ```text /// floor(λ + 1 / 3 - 0.02 / λ) /// ``` /// /// where `λ` is the rate fn median(&self) -> f64 { (self.lambda + 1.0 / 3.0 - 0.02 / self.lambda).floor() } } impl Mode> for Poisson { /// Returns the mode of the poisson distribution /// /// # Formula /// /// ```text /// floor(λ) /// ``` /// /// where `λ` is the rate fn mode(&self) -> Option { Some(self.lambda.floor() as u64) } } impl Discrete for Poisson { /// Calculates the probability mass function for the poisson distribution at /// `x` /// /// # Formula /// /// ```text /// (λ^x * e^(-λ)) / x! /// ``` /// /// where `λ` is the rate fn pmf(&self, x: u64) -> f64 { (-self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x)).exp() } /// Calculates the log probability mass function for the poisson /// distribution at /// `x` /// /// # Formula /// /// ```text /// ln((λ^x * e^(-λ)) / x!) /// ``` /// /// where `λ` is the rate fn ln_pmf(&self, x: u64) -> f64 { -self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x) } } /// Generates one sample from the Poisson distribution either by /// Knuth's method if lambda < 30.0 or Rejection method PA by /// A. C. Atkinson from the Journal of the Royal Statistical Society /// Series C (Applied Statistics) Vol. 28 No. 1. (1979) pp. 29 - 35 /// otherwise #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] pub fn sample_unchecked(rng: &mut R, lambda: f64) -> f64 { if lambda < 30.0 { let limit = (-lambda).exp(); let mut count = 0.0; let mut product: f64 = rng.gen(); while product >= limit { count += 1.0; product *= rng.gen::(); } count } else { let c = 0.767 - 3.36 / lambda; let beta = f64::consts::PI / (3.0 * lambda).sqrt(); let alpha = beta * lambda; let k = c.ln() - lambda - beta.ln(); loop { let u: f64 = rng.gen(); let x = (alpha - ((1.0 - u) / u).ln()) / beta; let n = (x + 0.5).floor(); if n < 0.0 { continue; } let v: f64 = rng.gen(); let y = alpha - beta * x; let temp = 1.0 + y.exp(); let lhs = y + (v / (temp * temp)).ln(); let rhs = k + n * lambda.ln() - factorial::ln_factorial(n as u64); if lhs <= rhs { return n; } } } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(lambda: f64; Poisson; PoissonError); #[test] fn test_create() { create_ok(1.5); create_ok(5.4); create_ok(10.8); } #[test] fn test_bad_create() { create_err(f64::NAN); create_err(-1.5); create_err(0.0); } #[test] fn test_mean() { let mean = |x: Poisson| x.mean().unwrap(); test_exact(1.5, 1.5, mean); test_exact(5.4, 5.4, mean); test_exact(10.8, 10.8, mean); } #[test] fn test_variance() { let variance = |x: Poisson| x.variance().unwrap(); test_exact(1.5, 1.5, variance); test_exact(5.4, 5.4, variance); test_exact(10.8, 10.8, variance); } #[test] fn test_entropy() { let entropy = |x: Poisson| x.entropy().unwrap(); test_absolute(1.5, 1.531959153102376331946, 1e-15, entropy); test_absolute(5.4, 2.244941839577643504608, 1e-15, entropy); test_exact(10.8, 2.600596429676975222694, entropy); } #[test] fn test_skewness() { let skewness = |x: Poisson| x.skewness().unwrap(); test_absolute(1.5, 0.8164965809277260327324, 1e-15, skewness); test_absolute(5.4, 0.4303314829119352094644, 1e-16, skewness); test_absolute(10.8, 0.3042903097250922852539, 1e-16, skewness); } #[test] fn test_median() { let median = |x: Poisson| x.median(); test_exact(1.5, 1.0, median); test_exact(5.4, 5.0, median); test_exact(10.8, 11.0, median); } #[test] fn test_mode() { let mode = |x: Poisson| x.mode().unwrap(); test_exact(1.5, 1, mode); test_exact(5.4, 5, mode); test_exact(10.8, 10, mode); } #[test] fn test_min_max() { let min = |x: Poisson| x.min(); let max = |x: Poisson| x.max(); test_exact(1.5, 0, min); test_exact(5.4, 0, min); test_exact(10.8, 0, min); test_exact(1.5, u64::MAX, max); test_exact(5.4, u64::MAX, max); test_exact(10.8, u64::MAX, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: Poisson| x.pmf(arg); test_absolute(1.5, 0.334695240222645000000000000000, 1e-15, pmf(1)); test_absolute(1.5, 0.000003545747740570180000000000, 1e-20, pmf(10)); test_absolute(1.5, 0.000000000000000304971208961018, 1e-30, pmf(20)); test_absolute(5.4, 0.024389537090108400000000000000, 1e-17, pmf(1)); test_absolute(5.4, 0.026241240591792300000000000000, 1e-16, pmf(10)); test_absolute(5.4, 0.000000825202200316548000000000, 1e-20, pmf(20)); test_absolute(10.8, 0.000220314636840657000000000000, 1e-18, pmf(1)); test_absolute(10.8, 0.121365183659420000000000000000, 1e-15, pmf(10)); test_absolute(10.8, 0.003908139778574110000000000000, 1e-16, pmf(20)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: Poisson| x.ln_pmf(arg); test_absolute(1.5, -1.09453489189183485135413967177, 1e-15, ln_pmf(1)); test_absolute(1.5, -12.5497614919938728510400000000, 1e-14, ln_pmf(10)); test_absolute(1.5, -35.7263142985901000000000000000, 1e-13, ln_pmf(20)); test_exact(5.4, -3.71360104642977159156055355910, ln_pmf(1)); test_absolute(5.4, -3.64042303737322774736223038530, 1e-15, ln_pmf(10)); test_absolute(5.4, -14.0076373893489089949388000000, 1e-14, ln_pmf(20)); test_absolute(10.8, -8.42045386586982559781714423000, 1e-14, ln_pmf(1)); test_absolute(10.8, -2.10895123177378079525424989992, 1e-14, ln_pmf(10)); test_absolute(10.8, -5.54469377815000936289610059500, 1e-14, ln_pmf(20)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Poisson| x.cdf(arg); test_absolute(1.5, 0.5578254003710750000000, 1e-15, cdf(1)); test_absolute(1.5, 0.9999994482467640000000, 1e-15, cdf(10)); test_exact(1.5, 1.0, cdf(20)); test_absolute(5.4, 0.0289061180327211000000, 1e-16, cdf(1)); test_absolute(5.4, 0.9774863006897650000000, 1e-15, cdf(10)); test_absolute(5.4, 0.9999997199928290000000, 1e-15, cdf(20)); test_absolute(10.8, 0.0002407141402518290000, 1e-16, cdf(1)); test_absolute(10.8, 0.4839692359955690000000, 1e-15, cdf(10)); test_absolute(10.8, 0.9961800769608090000000, 1e-15, cdf(20)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Poisson| x.sf(arg); test_absolute(1.5, 0.44217459962892536, 1e-15, sf(1)); test_absolute(1.5, 0.0000005517532358246565, 1e-15, sf(10)); test_absolute(1.5, 2.3372210700347092e-17, 1e-15, sf(20)); test_absolute(5.4, 0.971093881967279, 1e-16, sf(1)); test_absolute(5.4, 0.022513699310235582, 1e-15, sf(10)); test_absolute(5.4, 0.0000002800071708975261, 1e-15, sf(20)); test_absolute(10.8, 0.9997592858597482, 1e-16, sf(1)); test_absolute(10.8, 0.5160307640044303, 1e-15, sf(10)); test_absolute(10.8, 0.003819923039191422, 1e-15, sf(20)); } #[test] fn test_discrete() { test::check_discrete_distribution(&create_ok(0.3), 10); test::check_discrete_distribution(&create_ok(4.5), 30); } } statrs-0.18.0/src/distribution/students_t.rs000064400000000000000000001156651046102023000173200ustar 00000000000000use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::{beta, gamma}; use crate::statistics::*; use std::f64; /// Implements the [Student's /// T](https://en.wikipedia.org/wiki/Student%27s_t-distribution) distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{StudentsT, Continuous}; /// use statrs::statistics::Distribution; /// use statrs::prec; /// /// let n = StudentsT::new(0.0, 1.0, 2.0).unwrap(); /// assert_eq!(n.mean().unwrap(), 0.0); /// assert!(prec::almost_eq(n.pdf(0.0), 0.353553390593274, 1e-15)); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct StudentsT { location: f64, scale: f64, freedom: f64, } /// Represents the errors that can occur when creating a [`StudentsT`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum StudentsTError { /// The location is NaN. LocationInvalid, /// The scale is NaN, zero or less than zero. ScaleInvalid, /// The degrees of freedom are NaN, zero or less than zero. FreedomInvalid, } impl std::fmt::Display for StudentsTError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { StudentsTError::LocationInvalid => write!(f, "Location is NaN"), StudentsTError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"), StudentsTError::FreedomInvalid => { write!(f, "Degrees of freedom are NaN, zero or less than zero") } } } } impl std::error::Error for StudentsTError {} impl StudentsT { /// Constructs a new student's t-distribution with location `location`, /// scale `scale`, and `freedom` freedom. /// /// # Errors /// /// Returns an error if any of `location`, `scale`, or `freedom` are `NaN`. /// Returns an error if `scale <= 0.0` or `freedom <= 0.0`. /// /// # Examples /// /// ``` /// use statrs::distribution::StudentsT; /// /// let mut result = StudentsT::new(0.0, 1.0, 2.0); /// assert!(result.is_ok()); /// /// result = StudentsT::new(0.0, 0.0, 0.0); /// assert!(result.is_err()); /// ``` pub fn new(location: f64, scale: f64, freedom: f64) -> Result { if location.is_nan() { return Err(StudentsTError::LocationInvalid); } if scale.is_nan() || scale <= 0.0 { return Err(StudentsTError::ScaleInvalid); } if freedom.is_nan() || freedom <= 0.0 { return Err(StudentsTError::FreedomInvalid); } Ok(StudentsT { location, scale, freedom, }) } /// Returns the location of the student's t-distribution /// /// # Examples /// /// ``` /// use statrs::distribution::StudentsT; /// /// let n = StudentsT::new(0.0, 1.0, 2.0).unwrap(); /// assert_eq!(n.location(), 0.0); /// ``` pub fn location(&self) -> f64 { self.location } /// Returns the scale of the student's t-distribution /// /// # Examples /// /// ``` /// use statrs::distribution::StudentsT; /// /// let n = StudentsT::new(0.0, 1.0, 2.0).unwrap(); /// assert_eq!(n.scale(), 1.0); /// ``` pub fn scale(&self) -> f64 { self.scale } /// Returns the freedom of the student's t-distribution /// /// # Examples /// /// ``` /// use statrs::distribution::StudentsT; /// /// let n = StudentsT::new(0.0, 1.0, 2.0).unwrap(); /// assert_eq!(n.freedom(), 2.0); /// ``` pub fn freedom(&self) -> f64 { self.freedom } } impl std::fmt::Display for StudentsT { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "t_{}({},{})", self.freedom, self.location, self.scale) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for StudentsT { fn sample(&self, r: &mut R) -> f64 { // based on method 2, section 5 in chapter 9 of L. Devroye's // "Non-Uniform Random Variate Generation" let gamma = super::gamma::sample_unchecked(r, 0.5 * self.freedom, 0.5); super::normal::sample_unchecked( r, self.location, self.scale * (self.freedom / gamma).sqrt(), ) } } impl ContinuousCDF for StudentsT { /// Calculates the cumulative distribution function for the student's /// t-distribution /// at `x` /// /// # Formula /// /// ```text /// if x < μ { /// (1 / 2) * I(t, v / 2, 1 / 2) /// } else { /// 1 - (1 / 2) * I(t, v / 2, 1 / 2) /// } /// ``` /// /// where `t = v / (v + k^2)`, `k = (x - μ) / σ`, `μ` is the location, /// `σ` is the scale, `v` is the freedom, and `I` is the regularized /// incomplete beta function fn cdf(&self, x: f64) -> f64 { if self.freedom.is_infinite() { super::normal::cdf_unchecked(x, self.location, self.scale) } else { let k = (x - self.location) / self.scale; let h = self.freedom / (self.freedom + k * k); let ib = 0.5 * beta::beta_reg(self.freedom / 2.0, 0.5, h); if x <= self.location { ib } else { 1.0 - ib } } } /// Calculates the cumulative distribution function for the student's /// t-distribution /// at `x` /// /// # Formula /// /// ```text /// if x < μ { /// 1 - (1 / 2) * I(t, v / 2, 1 / 2) /// } else { /// (1 / 2) * I(t, v / 2, 1 / 2) /// } /// ``` /// /// where `t = v / (v + k^2)`, `k = (x - μ) / σ`, `μ` is the location, /// `σ` is the scale, `v` is the freedom, and `I` is the regularized /// incomplete beta function fn sf(&self, x: f64) -> f64 { if self.freedom.is_infinite() { super::normal::sf_unchecked(x, self.location, self.scale) } else { let k = (x - self.location) / self.scale; let h = self.freedom / (self.freedom + k * k); let ib = 0.5 * beta::beta_reg(self.freedom / 2.0, 0.5, h); if x <= self.location { 1.0 - ib } else { ib } } } /// Calculates the inverse cumulative distribution function for the /// Student's T-distribution at `x` fn inverse_cdf(&self, x: f64) -> f64 { // first calculate inverse_cdf for normal Student's T assert!((0.0..=1.0).contains(&x)); let x1 = if x >= 0.5 { 1.0 - x } else { x }; let a = 0.5 * self.freedom; let b = 0.5; let mut y = beta::inv_beta_reg(a, b, 2.0 * x1); y = (self.freedom * (1. - y) / y).sqrt(); y = if x >= 0.5 { y } else { -y }; // generalised Student's T is related to normal Student's T by `Y = μ + σ X` // where `X` is distributed as Student's T, so this result has to be scaled and shifted back // formally: F_Y(t) = P(Y <= t) = P(X <= (t - μ) / σ) = F_X((t - μ) / σ) // F_Y^{-1}(p) = inf { t' | F_Y(t') >= p } = inf { t' = μ + σ t | F_X((t' - μ) / σ) >= p } // because scale is positive: loc + scale * t is strictly monotonic function // = μ + σ inf { t | F_X(t) >= p } = μ + σ F_X^{-1}(p) self.location + self.scale * y } } impl Min for StudentsT { /// Returns the minimum value in the domain of the student's t-distribution /// representable by a double precision float /// /// # Formula /// /// ```text /// f64::NEG_INFINITY /// ``` fn min(&self) -> f64 { f64::NEG_INFINITY } } impl Max for StudentsT { /// Returns the maximum value in the domain of the student's t-distribution /// representable by a double precision float /// /// # Formula /// /// ```text /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY } } impl Distribution for StudentsT { /// Returns the mean of the student's t-distribution /// /// # None /// /// If `freedom <= 1.0` /// /// # Formula /// /// ```text /// μ /// ``` /// /// where `μ` is the location fn mean(&self) -> Option { if self.freedom <= 1.0 { None } else { Some(self.location) } } /// Returns the variance of the student's t-distribution /// /// # None /// /// If `freedom <= 2.0` /// /// # Formula /// /// ```text /// if v == f64::INFINITY { /// Some(σ^2) /// } else if freedom > 2.0 { /// Some(v * σ^2 / (v - 2)) /// } else { /// None /// } /// ``` /// /// where `σ` is the scale and `v` is the freedom fn variance(&self) -> Option { if self.freedom.is_infinite() { Some(self.scale * self.scale) } else if self.freedom > 2.0 { Some(self.freedom * self.scale * self.scale / (self.freedom - 2.0)) } else { None } } /// Returns the entropy for the student's t-distribution /// /// # Formula /// /// ```text /// - ln(σ) + (v + 1) / 2 * (ψ((v + 1) / 2) - ψ(v / 2)) + ln(sqrt(v) * B(v / 2, 1 / /// 2)) /// ``` /// /// where `σ` is the scale, `v` is the freedom, `ψ` is the digamma function, and `B` is the /// beta function fn entropy(&self) -> Option { // generalised Student's T is related to normal Student's T by `Y = μ + σ X` // where `X` is distributed as Student's T, plugging into the definition // of entropy shows scaling affects the entropy by an additive constant `- ln σ` let shift = -self.scale.ln(); let result = (self.freedom + 1.0) / 2.0 * (gamma::digamma((self.freedom + 1.0) / 2.0) - gamma::digamma(self.freedom / 2.0)) + (self.freedom.sqrt() * beta::beta(self.freedom / 2.0, 0.5)).ln(); Some(result + shift) } /// Returns the skewness of the student's t-distribution /// /// # None /// /// If `x <= 3.0` /// /// # Formula /// /// ```text /// 0 /// ``` fn skewness(&self) -> Option { if self.freedom <= 3.0 { None } else { Some(0.0) } } } impl Median for StudentsT { /// Returns the median of the student's t-distribution /// /// # Formula /// /// ```text /// μ /// ``` /// /// where `μ` is the location fn median(&self) -> f64 { self.location } } impl Mode> for StudentsT { /// Returns the mode of the student's t-distribution /// /// # Formula /// /// ```text /// μ /// ``` /// /// where `μ` is the location fn mode(&self) -> Option { Some(self.location) } } impl Continuous for StudentsT { /// Calculates the probability density function for the student's /// t-distribution /// at `x` /// /// # Formula /// /// ```text /// Γ((v + 1) / 2) / (sqrt(vπ) * Γ(v / 2) * σ) * (1 + k^2 / v)^(-1 / 2 * (v /// + 1)) /// ``` /// /// where `k = (x - μ) / σ`, `μ` is the location, `σ` is the scale, `v` is /// the freedom, /// and `Γ` is the gamma function fn pdf(&self, x: f64) -> f64 { if x.is_infinite() { 0.0 } else if self.freedom >= 1e8 { super::normal::pdf_unchecked(x, self.location, self.scale) } else { let d = (x - self.location) / self.scale; (gamma::ln_gamma((self.freedom + 1.0) / 2.0) - gamma::ln_gamma(self.freedom / 2.0)) .exp() * (1.0 + d * d / self.freedom).powf(-0.5 * (self.freedom + 1.0)) / (self.freedom * f64::consts::PI).sqrt() / self.scale } } /// Calculates the log probability density function for the student's /// t-distribution /// at `x` /// /// # Formula /// /// ```text /// ln(Γ((v + 1) / 2) / (sqrt(vπ) * Γ(v / 2) * σ) * (1 + k^2 / v)^(-1 / 2 * /// (v + 1))) /// ``` /// /// where `k = (x - μ) / σ`, `μ` is the location, `σ` is the scale, `v` is /// the freedom, /// and `Γ` is the gamma function fn ln_pdf(&self, x: f64) -> f64 { if x.is_infinite() { f64::NEG_INFINITY } else if self.freedom >= 1e8 { super::normal::ln_pdf_unchecked(x, self.location, self.scale) } else { let d = (x - self.location) / self.scale; gamma::ln_gamma((self.freedom + 1.0) / 2.0) - 0.5 * ((self.freedom + 1.0) * (1.0 + d * d / self.freedom).ln()) - gamma::ln_gamma(self.freedom / 2.0) - 0.5 * (self.freedom * f64::consts::PI).ln() - self.scale.ln() } } } #[cfg(test)] mod tests { use super::*; use crate::consts::ACC; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(location: f64, scale: f64, freedom: f64; StudentsT; StudentsTError); #[test] fn test_create() { create_ok(0.0, 0.1, 1.0); create_ok(0.0, 1.0, 1.0); create_ok(-5.0, 1.0, 3.0); create_ok(10.0, 10.0, f64::INFINITY); } // #[test] // fn foo() { // let dist = StudentsT::new(0.0,1.0,1.0).unwrap(); // dbg!(dist.mean()); // } #[test] fn test_bad_create() { let invalid = [ (f64::NAN, 1.0, 1.0, StudentsTError::LocationInvalid), (0.0, f64::NAN, 1.0, StudentsTError::ScaleInvalid), (0.0, 1.0, f64::NAN, StudentsTError::FreedomInvalid), (0.0, -10.0, 1.0, StudentsTError::ScaleInvalid), (0.0, 10.0, -1.0, StudentsTError::FreedomInvalid), ]; for (l, s, f, err) in invalid { test_create_err(l, s, f, err); } } #[test] fn test_mean() { let mean = |x: StudentsT| x.mean().unwrap(); test_relative(0.0, 1.0, 3.0, 0.0, mean); test_relative(0.0, 10.0, 2.0, 0.0, mean); test_relative(0.0, 10.0, f64::INFINITY, 0.0, mean); test_relative(-5.0, 100.0, 1.5, -5.0, mean); let mean = |x: StudentsT| x.mean(); test_none(0.0, 1.0, 1.0, mean); test_none(0.0, 0.1, 1.0, mean); test_none(0.0, 10.0, 1.0, mean); test_none(10.0, 1.0, 1.0, mean); test_none(0.0, f64::INFINITY, 1.0, mean); } #[test] fn test_mean_freedom_lte_1() { test_none(1.0, 1.0, 0.5, |dist| dist.mean()); } #[test] fn test_variance() { let variance = |x: StudentsT| x.variance().unwrap(); test_relative(0.0, 1.0, 3.0, 3.0, variance); test_relative(0.0, 10.0, 2.5, 500.0, variance); test_relative(10.0, 1.0, 2.5, 5.0, variance); let variance = |x: StudentsT| x.variance(); test_none(0.0, 10.0, 2.0, variance); test_none(0.0, 1.0, 1.0, variance); test_none(0.0, 0.1, 1.0, variance); test_none(0.0, 10.0, 1.0, variance); test_none(10.0, 1.0, 1.0, variance); test_none(-5.0, 100.0, 1.5, variance); test_none(0.0, f64::INFINITY, 1.0, variance); } #[test] fn test_variance_freedom_lte1() { test_none(1.0, 1.0, 0.5, |dist| dist.variance()); } // TODO: valid skewness tests #[test] fn test_skewness_freedom_lte_3() { test_none(1.0, 1.0, 1.0, |dist| dist.skewness()); } #[test] fn test_mode() { let mode = |x: StudentsT| x.mode().unwrap(); test_relative(0.0, 1.0, 1.0, 0.0, mode); test_relative(0.0, 0.1, 1.0, 0.0, mode); test_relative(0.0, 1.0, 3.0, 0.0, mode); test_relative(0.0, 10.0, 1.0, 0.0, mode); test_relative(0.0, 10.0, 2.0, 0.0, mode); test_relative(0.0, 10.0, 2.5, 0.0, mode); test_relative(0.0, 10.0, f64::INFINITY, 0.0, mode); test_relative(10.0, 1.0, 1.0, 10.0, mode); test_relative(10.0, 1.0, 2.5, 10.0, mode); test_relative(-5.0, 100.0, 1.5, -5.0, mode); test_relative(0.0, f64::INFINITY, 1.0, 0.0, mode); } #[test] fn test_median() { let median = |x: StudentsT| x.median(); test_relative(0.0, 1.0, 1.0, 0.0, median); test_relative(0.0, 0.1, 1.0, 0.0, median); test_relative(0.0, 1.0, 3.0, 0.0, median); test_relative(0.0, 10.0, 1.0, 0.0, median); test_relative(0.0, 10.0, 2.0, 0.0, median); test_relative(0.0, 10.0, 2.5, 0.0, median); test_relative(0.0, 10.0, f64::INFINITY, 0.0, median); test_relative(10.0, 1.0, 1.0, 10.0, median); test_relative(10.0, 1.0, 2.5, 10.0, median); test_relative(-5.0, 100.0, 1.5, -5.0, median); test_relative(0.0, f64::INFINITY, 1.0, 0.0, median); } #[test] fn test_min_max() { let min = |x: StudentsT| x.min(); let max = |x: StudentsT| x.max(); test_relative(0.0, 1.0, 1.0, f64::NEG_INFINITY, min); test_relative(2.5, 100.0, 1.5, f64::NEG_INFINITY, min); test_relative(10.0, f64::INFINITY, 3.5, f64::NEG_INFINITY, min); test_relative(0.0, 1.0, 1.0, f64::INFINITY, max); test_relative(2.5, 100.0, 1.5, f64::INFINITY, max); test_relative(10.0, f64::INFINITY, 5.5, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: StudentsT| x.pdf(arg); test_relative(0.0, 1.0, 1.0, std::f64::consts::FRAC_1_PI, pdf(0.0)); test_relative(0.0, 1.0, 1.0, 0.159154943091895, pdf(1.0)); test_relative(0.0, 1.0, 1.0, 0.159154943091895, pdf(-1.0)); test_relative(0.0, 1.0, 1.0, 0.063661977236758, pdf(2.0)); test_relative(0.0, 1.0, 1.0, 0.063661977236758, pdf(-2.0)); test_relative(0.0, 1.0, 2.0, 0.353553390593274, pdf(0.0)); test_relative(0.0, 1.0, 2.0, 0.192450089729875, pdf(1.0)); test_relative(0.0, 1.0, 2.0, 0.192450089729875, pdf(-1.0)); test_relative(0.0, 1.0, 2.0, 0.068041381743977, pdf(2.0)); test_relative(0.0, 1.0, 2.0, 0.068041381743977, pdf(-2.0)); test_relative(0.0, 1.0, f64::INFINITY, 0.398942280401433, pdf(0.0)); test_relative(0.0, 1.0, f64::INFINITY, 0.241970724519143, pdf(1.0)); test_relative(0.0, 1.0, f64::INFINITY, 0.053990966513188, pdf(2.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: StudentsT| x.ln_pdf(arg); test_relative(0.0, 1.0, 1.0, -1.144729885849399, ln_pdf(0.0)); test_relative(0.0, 1.0, 1.0, -1.837877066409348, ln_pdf(1.0)); test_relative(0.0, 1.0, 1.0, -1.837877066409348, ln_pdf(-1.0)); test_relative(0.0, 1.0, 1.0, -2.754167798283503, ln_pdf(2.0)); test_relative(0.0, 1.0, 1.0, -2.754167798283503, ln_pdf(-2.0)); test_relative(0.0, 1.0, 2.0, -1.039720770839917, ln_pdf(0.0)); test_relative(0.0, 1.0, 2.0, -1.647918433002166, ln_pdf(1.0)); test_relative(0.0, 1.0, 2.0, -1.647918433002166, ln_pdf(-1.0)); test_relative(0.0, 1.0, 2.0, -2.687639203842085, ln_pdf(2.0)); test_relative(0.0, 1.0, 2.0, -2.687639203842085, ln_pdf(-2.0)); test_relative(0.0, 1.0, f64::INFINITY, -0.918938533204672, ln_pdf(0.0)); test_relative(0.0, 1.0, f64::INFINITY, -1.418938533204674, ln_pdf(1.0)); test_relative(0.0, 1.0, f64::INFINITY, -2.918938533204674, ln_pdf(2.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: StudentsT| x.cdf(arg); test_relative(0.0, 1.0, 1.0, 0.5, cdf(0.0)); test_relative(0.0, 1.0, 1.0, 0.75, cdf(1.0)); test_relative(0.0, 1.0, 1.0, 0.25, cdf(-1.0)); test_relative(0.0, 1.0, 1.0, 0.852416382349567, cdf(2.0)); test_relative(0.0, 1.0, 1.0, 0.147583617650433, cdf(-2.0)); test_relative(0.0, 1.0, 2.0, 0.5, cdf(0.0)); test_relative(0.0, 1.0, 2.0, 0.788675134594813, cdf(1.0)); test_relative(0.0, 1.0, 2.0, 0.211324865405187, cdf(-1.0)); test_relative(0.0, 1.0, 2.0, 0.908248290463863, cdf(2.0)); test_relative(0.0, 1.0, 2.0, 0.091751709536137, cdf(-2.0)); test_relative(0.0, 1.0, f64::INFINITY, 0.5, cdf(0.0)); // TODO: these are curiously low accuracy and should be re-examined test_relative(0.0, 1.0, f64::INFINITY, 0.841344746068543, cdf(1.0)); test_relative(0.0, 1.0, f64::INFINITY, 0.977249868051821, cdf(2.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: StudentsT| x.sf(arg); test_relative(0.0, 1.0, 1.0, 0.5, sf(0.0)); test_relative(0.0, 1.0, 1.0, 0.25, sf(1.0)); test_relative(0.0, 1.0, 1.0, 0.75, sf(-1.0)); test_relative(0.0, 1.0, 1.0, 0.147583617650433, sf(2.0)); test_relative(0.0, 1.0, 1.0, 0.852416382349566, sf(-2.0)); test_relative(0.0, 1.0, 2.0, 0.5, sf(0.0)); test_relative(0.0, 1.0, 2.0, 0.211324865405186, sf(1.0)); test_relative(0.0, 1.0, 2.0, 0.788675134594813, sf(-1.0)); test_relative(0.0, 1.0, 2.0, 0.091751709536137, sf(2.0)); test_relative(0.0, 1.0, 2.0, 0.908248290463862, sf(-2.0)); test_relative(0.0, 1.0, f64::INFINITY, 0.5, sf(0.0)); // TODO: these are curiously low accuracy and should be re-examined test_relative(0.0, 1.0, f64::INFINITY, 0.158655253945057, sf(1.0)); test_relative(0.0, 1.0, f64::INFINITY, 0.022750131947162, sf(2.0)); } #[test] fn test_continuous() { test::check_continuous_distribution(&create_ok(0.0, 1.0, 3.0), -30.0, 30.0); test::check_continuous_distribution(&create_ok(0.0, 1.0, 10.0), -10.0, 10.0); test::check_continuous_distribution(&create_ok(20.0, 0.5, 10.0), 10.0, 30.0); } #[test] fn test_inv_cdf() { let test = |x: f64, freedom: f64, expected: f64| { use approx::*; let d = StudentsT::new(0., 1., freedom).unwrap(); // Checks that left == right to 4 significant figures, unlike // test_almost() which uses decimal places assert_relative_eq!(d.inverse_cdf(x), expected, max_relative = 0.001); }; // This test checks our implementation against the whole t-table // copied from https://en.wikipedia.org/wiki/Student's_t-distribution test(0.75, 1.0, 1.000); test(0.8, 1.0, 1.376); test(0.85, 1.0, 1.963); test(0.9, 1.0, 3.078); test(0.95, 1.0, 6.314); test(0.975, 1.0, 12.71); test(0.99, 1.0, 31.82); test(0.995, 1.0, 63.66); test(0.9975, 1.0, 127.3); test(0.999, 1.0, 318.3); test(0.9995, 1.0, 636.6); test(0.75, 002.0, 0.816); // TODO: investigate // test(0.8, 002.0, 1.080); // We get 1.061 for some reason... test(0.85, 002.0, 1.386); test(0.9, 002.0, 1.886); test(0.95, 002.0, 2.920); test(0.975, 002.0, 4.303); test(0.99, 002.0, 6.965); test(0.995, 002.0, 9.925); test(0.9975, 002.0, 14.09); test(0.999, 002.0, 22.33); test(0.9995, 002.0, 31.60); test(0.75, 003.0, 0.765); test(0.8, 003.0, 0.978); test(0.85, 003.0, 1.250); test(0.9, 003.0, 1.638); test(0.95, 003.0, 2.353); test(0.975, 003.0, 3.182); test(0.99, 003.0, 4.541); test(0.995, 003.0, 5.841); test(0.9975, 003.0, 7.453); test(0.999, 003.0, 10.21); test(0.9995, 003.0, 12.92); test(0.75, 004.0, 0.741); test(0.8, 004.0, 0.941); test(0.85, 004.0, 1.190); test(0.9, 004.0, 1.533); test(0.95, 004.0, 2.132); test(0.975, 004.0, 2.776); test(0.99, 004.0, 3.747); test(0.995, 004.0, 4.604); test(0.9975, 004.0, 5.598); test(0.999, 004.0, 7.173); test(0.9995, 004.0, 8.610); test(0.75, 005.0, 0.727); test(0.8, 005.0, 0.920); test(0.85, 005.0, 1.156); test(0.9, 005.0, 1.476); test(0.95, 005.0, 2.015); test(0.975, 005.0, 2.571); test(0.99, 005.0, 3.365); test(0.995, 005.0, 4.032); test(0.9975, 005.0, 4.773); test(0.999, 005.0, 5.893); test(0.9995, 005.0, 6.869); test(0.75, 006.0, 0.718); test(0.8, 006.0, 0.906); test(0.85, 006.0, 1.134); test(0.9, 006.0, 1.440); test(0.95, 006.0, 1.943); test(0.975, 006.0, 2.447); test(0.99, 006.0, 3.143); test(0.995, 006.0, 3.707); test(0.9975, 006.0, 4.317); test(0.999, 006.0, 5.208); test(0.9995, 006.0, 5.959); test(0.75, 007.0, 0.711); test(0.8, 007.0, 0.896); test(0.85, 007.0, 1.119); test(0.9, 007.0, 1.415); test(0.95, 007.0, 1.895); test(0.975, 007.0, 2.365); test(0.99, 007.0, 2.998); test(0.995, 007.0, 3.499); test(0.9975, 007.0, 4.029); test(0.999, 007.0, 4.785); test(0.9995, 007.0, 5.408); test(0.75, 008.0, 0.706); test(0.8, 008.0, 0.889); test(0.85, 008.0, 1.108); test(0.9, 008.0, 1.397); test(0.95, 008.0, 1.860); test(0.975, 008.0, 2.306); test(0.99, 008.0, 2.896); test(0.995, 008.0, 3.355); test(0.9975, 008.0, 3.833); test(0.999, 008.0, 4.501); test(0.9995, 008.0, 5.041); test(0.75, 009.0, 0.703); test(0.8, 009.0, 0.883); test(0.85, 009.0, 1.100); test(0.9, 009.0, 1.383); test(0.95, 009.0, 1.833); test(0.975, 009.0, 2.262); test(0.99, 009.0, 2.821); test(0.995, 009.0, 3.250); test(0.9975, 009.0, 3.690); test(0.999, 009.0, 4.297); test(0.9995, 009.0, 4.781); test(0.75, 010.0, 0.700); test(0.8, 010.0, 0.879); test(0.85, 010.0, 1.093); test(0.9, 010.0, 1.372); test(0.95, 010.0, 1.812); test(0.975, 010.0, 2.228); test(0.99, 010.0, 2.764); test(0.995, 010.0, 3.169); test(0.9975, 010.0, 3.581); test(0.999, 010.0, 4.144); test(0.9995, 010.0, 4.587); test(0.75, 011.0, 0.697); test(0.8, 011.0, 0.876); test(0.85, 011.0, 1.088); test(0.9, 011.0, 1.363); test(0.95, 011.0, 1.796); test(0.975, 011.0, 2.201); // 2.718 is roughly equal to E #[allow(clippy::approx_constant)] test(0.99, 011.0, 2.718); test(0.995, 011.0, 3.106); test(0.9975, 011.0, 3.497); test(0.999, 011.0, 4.025); test(0.9995, 011.0, 4.437); test(0.75, 012.0, 0.695); test(0.8, 012.0, 0.873); test(0.85, 012.0, 1.083); test(0.9, 012.0, 1.356); test(0.95, 012.0, 1.782); test(0.975, 012.0, 2.179); test(0.99, 012.0, 2.681); test(0.995, 012.0, 3.055); test(0.9975, 012.0, 3.428); test(0.999, 012.0, 3.930); test(0.9995, 012.0, 4.318); test(0.75, 013.0, 0.694); test(0.8, 013.0, 0.870); test(0.85, 013.0, 1.079); test(0.9, 013.0, 1.350); test(0.95, 013.0, 1.771); test(0.975, 013.0, 2.160); test(0.99, 013.0, 2.650); test(0.995, 013.0, 3.012); test(0.9975, 013.0, 3.372); test(0.999, 013.0, 3.852); test(0.9995, 013.0, 4.221); test(0.75, 014.0, 0.692); test(0.8, 014.0, 0.868); test(0.85, 014.0, 1.076); test(0.9, 014.0, 1.345); test(0.95, 014.0, 1.761); test(0.975, 014.0, 2.145); test(0.99, 014.0, 2.624); test(0.995, 014.0, 2.977); test(0.9975, 014.0, 3.326); test(0.999, 014.0, 3.787); test(0.9995, 014.0, 4.140); test(0.75, 015.0, 0.691); test(0.8, 015.0, 0.866); test(0.85, 015.0, 1.074); test(0.9, 015.0, 1.341); test(0.95, 015.0, 1.753); test(0.975, 015.0, 2.131); test(0.99, 015.0, 2.602); test(0.995, 015.0, 2.947); test(0.9975, 015.0, 3.286); test(0.999, 015.0, 3.733); test(0.9995, 015.0, 4.073); test(0.75, 016.0, 0.690); test(0.8, 016.0, 0.865); test(0.85, 016.0, 1.071); test(0.9, 016.0, 1.337); test(0.95, 016.0, 1.746); test(0.975, 016.0, 2.120); test(0.99, 016.0, 2.583); test(0.995, 016.0, 2.921); test(0.9975, 016.0, 3.252); test(0.999, 016.0, 3.686); test(0.9995, 016.0, 4.015); test(0.75, 017.0, 0.689); test(0.8, 017.0, 0.863); test(0.85, 017.0, 1.069); test(0.9, 017.0, 1.333); test(0.95, 017.0, 1.740); test(0.975, 017.0, 2.110); test(0.99, 017.0, 2.567); test(0.995, 017.0, 2.898); test(0.9975, 017.0, 3.222); test(0.999, 017.0, 3.646); test(0.9995, 017.0, 3.965); test(0.75, 018.0, 0.688); test(0.8, 018.0, 0.862); test(0.85, 018.0, 1.067); test(0.9, 018.0, 1.330); test(0.95, 018.0, 1.734); test(0.975, 018.0, 2.101); test(0.99, 018.0, 2.552); test(0.995, 018.0, 2.878); test(0.9975, 018.0, 3.197); test(0.999, 018.0, 3.610); test(0.9995, 018.0, 3.922); test(0.75, 019.0, 0.688); test(0.8, 019.0, 0.861); test(0.85, 019.0, 1.066); test(0.9, 019.0, 1.328); test(0.95, 019.0, 1.729); test(0.975, 019.0, 2.093); test(0.99, 019.0, 2.539); test(0.995, 019.0, 2.861); test(0.9975, 019.0, 3.174); test(0.999, 019.0, 3.579); test(0.9995, 019.0, 3.883); test(0.75, 020.0, 0.687); test(0.8, 020.0, 0.860); test(0.85, 020.0, 1.064); test(0.9, 020.0, 1.325); test(0.95, 020.0, 1.725); test(0.975, 020.0, 2.086); test(0.99, 020.0, 2.528); test(0.995, 020.0, 2.845); test(0.9975, 020.0, 3.153); test(0.999, 020.0, 3.552); test(0.9995, 020.0, 3.850); test(0.75, 021.0, 0.686); test(0.8, 021.0, 0.859); test(0.85, 021.0, 1.063); test(0.9, 021.0, 1.323); test(0.95, 021.0, 1.721); test(0.975, 021.0, 2.080); test(0.99, 021.0, 2.518); test(0.995, 021.0, 2.831); test(0.9975, 021.0, 3.135); test(0.999, 021.0, 3.527); test(0.9995, 021.0, 3.819); test(0.75, 022.0, 0.686); test(0.8, 022.0, 0.858); test(0.85, 022.0, 1.061); test(0.9, 022.0, 1.321); test(0.95, 022.0, 1.717); test(0.975, 022.0, 2.074); test(0.99, 022.0, 2.508); test(0.995, 022.0, 2.819); test(0.9975, 022.0, 3.119); test(0.999, 022.0, 3.505); test(0.9995, 022.0, 3.792); test(0.75, 023.0, 0.685); test(0.8, 023.0, 0.858); test(0.85, 023.0, 1.060); test(0.9, 023.0, 1.319); test(0.95, 023.0, 1.714); test(0.975, 023.0, 2.069); test(0.99, 023.0, 2.500); test(0.995, 023.0, 2.807); test(0.9975, 023.0, 3.104); test(0.999, 023.0, 3.485); test(0.9995, 023.0, 3.767); test(0.75, 024.0, 0.685); test(0.8, 024.0, 0.857); test(0.85, 024.0, 1.059); test(0.9, 024.0, 1.318); test(0.95, 024.0, 1.711); test(0.975, 024.0, 2.064); test(0.99, 024.0, 2.492); test(0.995, 024.0, 2.797); test(0.9975, 024.0, 3.091); test(0.999, 024.0, 3.467); test(0.9995, 024.0, 3.745); test(0.75, 025.0, 0.684); test(0.8, 025.0, 0.856); test(0.85, 025.0, 1.058); test(0.9, 025.0, 1.316); test(0.95, 025.0, 1.708); test(0.975, 025.0, 2.060); test(0.99, 025.0, 2.485); test(0.995, 025.0, 2.787); test(0.9975, 025.0, 3.078); test(0.999, 025.0, 3.450); test(0.9995, 025.0, 3.725); test(0.75, 026.0, 0.684); test(0.8, 026.0, 0.856); test(0.85, 026.0, 1.058); test(0.9, 026.0, 1.315); test(0.95, 026.0, 1.706); test(0.975, 026.0, 2.056); test(0.99, 026.0, 2.479); test(0.995, 026.0, 2.779); test(0.9975, 026.0, 3.067); test(0.999, 026.0, 3.435); test(0.9995, 026.0, 3.707); test(0.75, 027.0, 0.684); test(0.8, 027.0, 0.855); test(0.85, 027.0, 1.057); test(0.9, 027.0, 1.314); test(0.95, 027.0, 1.703); test(0.975, 027.0, 2.052); test(0.99, 027.0, 2.473); test(0.995, 027.0, 2.771); test(0.9975, 027.0, 3.057); test(0.999, 027.0, 3.421); test(0.9995, 027.0, 3.690); test(0.75, 028.0, 0.683); test(0.8, 028.0, 0.855); test(0.85, 028.0, 1.056); test(0.9, 028.0, 1.313); test(0.95, 028.0, 1.701); test(0.975, 028.0, 2.048); test(0.99, 028.0, 2.467); test(0.995, 028.0, 2.763); test(0.9975, 028.0, 3.047); test(0.999, 028.0, 3.408); test(0.9995, 028.0, 3.674); test(0.75, 029.0, 0.683); test(0.8, 029.0, 0.854); test(0.85, 029.0, 1.055); test(0.9, 029.0, 1.311); test(0.95, 029.0, 1.699); test(0.975, 029.0, 2.045); test(0.99, 029.0, 2.462); test(0.995, 029.0, 2.756); test(0.9975, 029.0, 3.038); test(0.999, 029.0, 3.396); test(0.9995, 029.0, 3.659); test(0.75, 030.0, 0.683); test(0.8, 030.0, 0.854); test(0.85, 030.0, 1.055); test(0.9, 030.0, 1.310); test(0.95, 030.0, 1.697); test(0.975, 030.0, 2.042); test(0.99, 030.0, 2.457); test(0.995, 030.0, 2.750); test(0.9975, 030.0, 3.030); test(0.999, 030.0, 3.385); test(0.9995, 030.0, 3.646); test(0.75, 040.0, 0.681); test(0.8, 040.0, 0.851); test(0.85, 040.0, 1.050); test(0.9, 040.0, 1.303); test(0.95, 040.0, 1.684); test(0.975, 040.0, 2.021); test(0.99, 040.0, 2.423); test(0.995, 040.0, 2.704); test(0.9975, 040.0, 2.971); test(0.999, 040.0, 3.307); test(0.9995, 040.0, 3.551); test(0.75, 050.0, 0.679); test(0.8, 050.0, 0.849); test(0.85, 050.0, 1.047); test(0.9, 050.0, 1.299); test(0.95, 050.0, 1.676); test(0.975, 050.0, 2.009); test(0.99, 050.0, 2.403); test(0.995, 050.0, 2.678); test(0.9975, 050.0, 2.937); test(0.999, 050.0, 3.261); test(0.9995, 050.0, 3.496); test(0.75, 060.0, 0.679); test(0.8, 060.0, 0.848); test(0.85, 060.0, 1.045); test(0.9, 060.0, 1.296); test(0.95, 060.0, 1.671); test(0.975, 060.0, 2.000); test(0.99, 060.0, 2.390); test(0.995, 060.0, 2.660); test(0.9975, 060.0, 2.915); test(0.999, 060.0, 3.232); test(0.9995, 060.0, 3.460); test(0.75, 080.0, 0.678); test(0.8, 080.0, 0.846); test(0.85, 080.0, 1.043); test(0.9, 080.0, 1.292); test(0.95, 080.0, 1.664); test(0.975, 080.0, 1.990); test(0.99, 080.0, 2.374); test(0.995, 080.0, 2.639); test(0.9975, 080.0, 2.887); test(0.999, 080.0, 3.195); test(0.9995, 080.0, 3.416); test(0.75, 100.0, 0.677); test(0.8, 100.0, 0.845); test(0.85, 100.0, 1.042); test(0.9, 100.0, 1.290); test(0.95, 100.0, 1.660); test(0.975, 100.0, 1.984); test(0.99, 100.0, 2.364); test(0.995, 100.0, 2.626); test(0.9975, 100.0, 2.871); test(0.999, 100.0, 3.174); test(0.9995, 100.0, 3.390); test(0.75, 120.0, 0.677); test(0.8, 120.0, 0.845); test(0.85, 120.0, 1.041); test(0.9, 120.0, 1.289); test(0.95, 120.0, 1.658); test(0.975, 120.0, 1.980); test(0.99, 120.0, 2.358); test(0.995, 120.0, 2.617); test(0.9975, 120.0, 2.860); test(0.999, 120.0, 3.160); test(0.9995, 120.0, 3.373); } #[test] fn test_inv_cdf_high_precision() { let test = |x: f64, freedom: f64, expected: f64| { use approx::assert_relative_eq; let d = StudentsT::new(0., 1., freedom).unwrap(); assert_relative_eq!(d.inverse_cdf(x), expected, max_relative = ACC); }; // The data in this table of expected values was generated in // Python, using the mpsci package (based on mpmath): // // import mpmath // from mpsci.distributions import t // // # Set the number of digits of precision // mpmath.mp.dps = 200 // // ps = [0.001, 0.01, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45] // dfs = [1.0, 10.0, 100.0] // // for df in dfs: // for p in ps: // q = t.invcdf(p, df) // print(f"({p:5.3f}, {df:5.1f}, {float(q)}),") #[rustfmt::skip] let invcdf_data = [ // p df inverse_cdf(p, df) (0.001, 1.0, -318.30883898555044), (0.010, 1.0, -31.820515953773956), (0.100, 1.0, -3.077683537175253), (0.150, 1.0, -1.9626105055051506), (0.200, 1.0, -1.3763819204711734), (0.250, 1.0, -1.0), (0.300, 1.0, -0.7265425280053609), (0.350, 1.0, -0.5095254494944289), (0.400, 1.0, -0.32491969623290623), (0.450, 1.0, -0.15838444032453625), (0.001, 10.0, -4.143700494046589), (0.010, 10.0, -2.763769458112696), (0.100, 10.0, -1.3721836411103356), (0.150, 10.0, -1.093058073590526), (0.200, 10.0, -0.8790578285505887), (0.250, 10.0, -0.6998120613124317), (0.300, 10.0, -0.5415280387550157), (0.350, 10.0, -0.3965914937556218), (0.400, 10.0, -0.26018482949208016), (0.450, 10.0, -0.12889018929327375), (0.001, 100.0, -3.173739493738783), (0.010, 100.0, -2.364217366238482), (0.100, 100.0, -1.290074761346516), (0.150, 100.0, -1.041835900908347), (0.200, 100.0, -0.845230424491016), (0.250, 100.0, -0.6769510430114715), (0.300, 100.0, -0.5260762706003463), (0.350, 100.0, -0.3864289804076715), (0.400, 100.0, -0.2540221824582278), (0.450, 100.0, -0.12598088204153965), ]; for (p, df, expected) in invcdf_data.iter() { test(*p, *df, *expected); test(1.0 - *p, *df, -*expected); } } #[test] fn test_inv_cdf_midpoint() { for loc in [0.0, 1.0, -3.5] { let d = StudentsT::new(loc, 1.0, 12.0).unwrap(); // inverse_cdf(p) is a floating point calculation, so using // assert_eq here is optimistic. For the given location values, // the check passes, so let's use the optimistic check for now. assert_eq!(d.inverse_cdf(0.5), loc); } } #[test] fn test_inv_cdf_p0() { let d = StudentsT::new(0.0, 1.0, 12.0).unwrap(); assert_eq!(d.inverse_cdf(0.0), f64::NEG_INFINITY); } #[test] fn test_inv_cdf_p1() { let d = StudentsT::new(0.0, 1.0, 12.0).unwrap(); assert_eq!(d.inverse_cdf(1.0), f64::INFINITY); } } statrs-0.18.0/src/distribution/triangular.rs000064400000000000000000000456761046102023000173000ustar 00000000000000use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; use std::f64; /// Implements the /// [Triangular](https://en.wikipedia.org/wiki/Triangular_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{Triangular, Continuous}; /// use statrs::statistics::Distribution; /// /// let n = Triangular::new(0.0, 5.0, 2.5).unwrap(); /// assert_eq!(n.mean().unwrap(), 7.5 / 3.0); /// assert_eq!(n.pdf(2.5), 5.0 / 12.5); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct Triangular { min: f64, max: f64, mode: f64, } /// Represents the errors that can occur when creating a [`Triangular`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum TriangularError { /// The minimum is NaN or infinite. MinInvalid, /// The maximum is NaN or infinite. MaxInvalid, /// The mode is NaN or infinite. ModeInvalid, /// The mode is less than the minimum or greater than the maximum. ModeOutOfRange, /// The minimum equals the maximum. MinEqualsMax, } impl std::fmt::Display for TriangularError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { TriangularError::MinInvalid => write!(f, "Minimum is NaN or infinite."), TriangularError::MaxInvalid => write!(f, "Maximum is NaN or infinite."), TriangularError::ModeInvalid => write!(f, "Mode is NaN or infinite."), TriangularError::ModeOutOfRange => { write!(f, "Mode is less than minimum or greater than maximum") } TriangularError::MinEqualsMax => write!(f, "Minimum equals Maximum"), } } } impl std::error::Error for TriangularError {} impl Triangular { /// Constructs a new triangular distribution with a minimum of `min`, /// maximum of `max`, and a mode of `mode`. /// /// # Errors /// /// Returns an error if `min`, `max`, or `mode` are `NaN` or `±INF`. /// Returns an error if `max < mode`, `mode < min`, or `max == min`. /// /// # Examples /// /// ``` /// use statrs::distribution::Triangular; /// /// let mut result = Triangular::new(0.0, 5.0, 2.5); /// assert!(result.is_ok()); /// /// result = Triangular::new(2.5, 1.5, 0.0); /// assert!(result.is_err()); /// ``` pub fn new(min: f64, max: f64, mode: f64) -> Result { if !min.is_finite() { return Err(TriangularError::MinInvalid); } if !max.is_finite() { return Err(TriangularError::MaxInvalid); } if !mode.is_finite() { return Err(TriangularError::ModeInvalid); } if max < mode || mode < min { return Err(TriangularError::ModeOutOfRange); } if min == max { return Err(TriangularError::MinEqualsMax); } Ok(Triangular { min, max, mode }) } } impl std::fmt::Display for Triangular { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Triangular([{},{}], {})", self.min, self.max, self.mode) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Triangular { fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, self.min, self.max, self.mode) } } impl ContinuousCDF for Triangular { /// Calculates the cumulative distribution function for the triangular /// distribution /// at `x` /// /// # Formula /// /// ```text /// if x == min { /// 0 /// } if min < x <= mode { /// (x - min)^2 / ((max - min) * (mode - min)) /// } else if mode < x < max { /// 1 - (max - x)^2 / ((max - min) * (max - mode)) /// } else { /// 1 /// } /// ``` fn cdf(&self, x: f64) -> f64 { let a = self.min; let b = self.max; let c = self.mode; if x <= a { 0.0 } else if x <= c { (x - a) * (x - a) / ((b - a) * (c - a)) } else if x < b { 1.0 - (b - x) * (b - x) / ((b - a) * (b - c)) } else { 1.0 } } /// Calculates the survival function for the triangular /// distribution at `x` /// /// # Formula /// /// ```text /// if x == min { /// 1 /// } if min < x <= mode { /// 1 - (x - min)^2 / ((max - min) * (mode - min)) /// } else if mode < x < max { /// (max - min)^2 / ((max - min) * (max - mode)) /// } else { /// 0 /// } /// ``` fn sf(&self, x: f64) -> f64 { let a = self.min; let b = self.max; let c = self.mode; if x <= a { 1.0 } else if x <= c { 1.0 - ((x - a) * (x - a) / ((b - a) * (c - a))) } else if x < b { (b - x) * (b - x) / ((b - a) * (b - c)) } else { 0.0 } } /// Calculates the inverse cumulative distribution function for the triangular /// distribution /// at `x` /// /// # Formula /// /// ```text /// if x < (mode - min) / (max - min) { /// min + ((max - min) * (mode - min) * x)^(1 / 2) /// } else { /// max - ((max - min) * (max - mode) * (1 - x))^(1 / 2) /// } /// ``` fn inverse_cdf(&self, p: f64) -> f64 { let a = self.min; let b = self.max; let c = self.mode; if !(0.0..=1.0).contains(&p) { panic!("x must be in [0, 1]"); } if p < (c - a) / (b - a) { a + ((c - a) * (b - a) * p).sqrt() } else { b - ((b - a) * (b - c) * (1.0 - p)).sqrt() } } } impl Min for Triangular { /// Returns the minimum value in the domain of the /// triangular distribution representable by a double precision float /// /// # Remarks /// /// The return value is the same min used to construct the distribution fn min(&self) -> f64 { self.min } } impl Max for Triangular { /// Returns the maximum value in the domain of the /// triangular distribution representable by a double precision float /// /// # Remarks /// /// The return value is the same max used to construct the distribution fn max(&self) -> f64 { self.max } } impl Distribution for Triangular { /// Returns the mean of the triangular distribution /// /// # Formula /// /// ```text /// (min + max + mode) / 3 /// ``` fn mean(&self) -> Option { Some((self.min + self.max + self.mode) / 3.0) } /// Returns the variance of the triangular distribution /// /// # Formula /// /// ```text /// (min^2 + max^2 + mode^2 - min * max - min * mode - max * mode) / 18 /// ``` fn variance(&self) -> Option { let a = self.min; let b = self.max; let c = self.mode; Some((a * a + b * b + c * c - a * b - a * c - b * c) / 18.0) } /// Returns the entropy of the triangular distribution /// /// # Formula /// /// ```text /// 1 / 2 + ln((max - min) / 2) /// ``` fn entropy(&self) -> Option { Some(0.5 + ((self.max - self.min) / 2.0).ln()) } /// Returns the skewness of the triangular distribution /// /// # Formula /// /// ```text /// (sqrt(2) * (min + max - 2 * mode) * (2 * min - max - mode) * (min - 2 * /// max + mode)) / /// ( 5 * (min^2 + max^2 + mode^2 - min * max - min * mode - max * mode)^(3 /// / 2)) /// ``` fn skewness(&self) -> Option { let a = self.min; let b = self.max; let c = self.mode; let q = f64::consts::SQRT_2 * (a + b - 2.0 * c) * (2.0 * a - b - c) * (a - 2.0 * b + c); let d = 5.0 * (a * a + b * b + c * c - a * b - a * c - b * c).powf(3.0 / 2.0); Some(q / d) } } impl Median for Triangular { /// Returns the median of the triangular distribution /// /// # Formula /// /// ```text /// if mode >= (min + max) / 2 { /// min + sqrt((max - min) * (mode - min) / 2) /// } else { /// max - sqrt((max - min) * (max - mode) / 2) /// } /// ``` fn median(&self) -> f64 { let a = self.min; let b = self.max; let c = self.mode; if c >= (a + b) / 2.0 { a + ((b - a) * (c - a) / 2.0).sqrt() } else { b - ((b - a) * (b - c) / 2.0).sqrt() } } } impl Mode> for Triangular { /// Returns the mode of the triangular distribution /// /// # Formula /// /// ```text /// mode /// ``` fn mode(&self) -> Option { Some(self.mode) } } impl Continuous for Triangular { /// Calculates the probability density function for the triangular /// distribution /// at `x` /// /// # Formula /// /// ```text /// if x < min { /// 0 /// } else if min <= x <= mode { /// 2 * (x - min) / ((max - min) * (mode - min)) /// } else if mode < x <= max { /// 2 * (max - x) / ((max - min) * (max - mode)) /// } else { /// 0 /// } /// ``` fn pdf(&self, x: f64) -> f64 { let a = self.min; let b = self.max; let c = self.mode; if a <= x && x <= c { 2.0 * (x - a) / ((b - a) * (c - a)) } else if c < x && x <= b { 2.0 * (b - x) / ((b - a) * (b - c)) } else { 0.0 } } /// Calculates the log probability density function for the triangular /// distribution /// at `x` /// /// # Formula /// /// ```text /// ln( if x < min { /// 0 /// } else if min <= x <= mode { /// 2 * (x - min) / ((max - min) * (mode - min)) /// } else if mode < x <= max { /// 2 * (max - x) / ((max - min) * (max - mode)) /// } else { /// 0 /// } ) /// ``` fn ln_pdf(&self, x: f64) -> f64 { self.pdf(x).ln() } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] fn sample_unchecked(rng: &mut R, min: f64, max: f64, mode: f64) -> f64 { let f: f64 = rng.gen(); if f < (mode - min) / (max - min) { min + (f * (max - min) * (mode - min)).sqrt() } else { max - ((1.0 - f) * (max - min) * (max - mode)).sqrt() } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(min: f64, max: f64, mode: f64; Triangular; TriangularError); #[test] fn test_create() { create_ok(-1.0, 1.0, 0.0); create_ok(1.0, 2.0, 1.0); create_ok(5.0, 25.0, 25.0); create_ok(1.0e-5, 1.0e5, 1.0e-3); create_ok(0.0, 1.0, 0.9); create_ok(-4.0, -0.5, -2.0); create_ok(-13.039, 8.42, 1.17); } #[test] fn test_bad_create() { let invalid = [ (0.0, 0.0, 0.0, TriangularError::MinEqualsMax), (0.0, 1.0, -0.1, TriangularError::ModeOutOfRange), (0.0, 1.0, 1.1, TriangularError::ModeOutOfRange), (0.0, -1.0, 0.5, TriangularError::ModeOutOfRange), (2.0, 1.0, 1.5, TriangularError::ModeOutOfRange), (f64::NAN, 1.0, 0.5, TriangularError::MinInvalid), (0.2, f64::NAN, 0.5, TriangularError::MaxInvalid), (0.5, 1.0, f64::NAN, TriangularError::ModeInvalid), (f64::NAN, f64::NAN, f64::NAN, TriangularError::MinInvalid), (f64::NEG_INFINITY, 1.0, 0.5, TriangularError::MinInvalid), (0.0, f64::INFINITY, 0.5, TriangularError::MaxInvalid), ]; for (min, max, mode, err) in invalid { test_create_err(min, max, mode, err); } } #[test] fn test_variance() { let variance = |x: Triangular| x.variance().unwrap(); test_exact(0.0, 1.0, 0.5, 0.75 / 18.0, variance); test_exact(0.0, 1.0, 0.75, 0.8125 / 18.0, variance); test_exact(-5.0, 8.0, -3.5, 151.75 / 18.0, variance); test_exact(-5.0, 8.0, 5.0, 139.0 / 18.0, variance); test_exact(-5.0, -3.0, -4.0, 3.0 / 18.0, variance); test_exact(15.0, 134.0, 21.0, 13483.0 / 18.0, variance); } #[test] fn test_entropy() { let entropy = |x: Triangular| x.entropy().unwrap(); test_absolute(0.0, 1.0, 0.5, -0.1931471805599453094172, 1e-16, entropy); test_absolute(0.0, 1.0, 0.75, -0.1931471805599453094172, 1e-16, entropy); test_exact(-5.0, 8.0, -3.5, 2.371802176901591426636, entropy); test_exact(-5.0, 8.0, 5.0, 2.371802176901591426636, entropy); test_exact(-5.0, -3.0, -4.0, 0.5, entropy); test_exact(15.0, 134.0, 21.0, 4.585976312551584075938, entropy); } #[test] fn test_skewness() { let skewness = |x: Triangular| x.skewness().unwrap(); test_exact(0.0, 1.0, 0.5, 0.0, skewness); test_exact(0.0, 1.0, 0.75, -0.4224039833745502226059, skewness); test_exact(-5.0, 8.0, -3.5, 0.5375093589712976359809, skewness); test_exact(-5.0, 8.0, 5.0, -0.4445991743012595633537, skewness); test_exact(-5.0, -3.0, -4.0, 0.0, skewness); test_exact(15.0, 134.0, 21.0, 0.5605920922751860613217, skewness); } #[test] fn test_mode() { let mode = |x: Triangular| x.mode().unwrap(); test_exact(0.0, 1.0, 0.5, 0.5, mode); test_exact(0.0, 1.0, 0.75, 0.75, mode); test_exact(-5.0, 8.0, -3.5, -3.5, mode); test_exact(-5.0, 8.0, 5.0, 5.0, mode); test_exact(-5.0, -3.0, -4.0, -4.0, mode); test_exact(15.0, 134.0, 21.0, 21.0, mode); } #[test] fn test_median() { let median = |x: Triangular| x.median(); test_exact(0.0, 1.0, 0.5, 0.5, median); test_exact(0.0, 1.0, 0.75, 0.6123724356957945245493, median); test_absolute(-5.0, 8.0, -3.5, -0.6458082328952913226724, 1e-15, median); test_absolute(-5.0, 8.0, 5.0, 3.062257748298549652367, 1e-15, median); test_exact(-5.0, -3.0, -4.0, -4.0, median); test_absolute(15.0, 134.0, 21.0, 52.00304883716712238797, 1e-14, median); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Triangular| x.pdf(arg); test_exact(0.0, 1.0, 0.5, 0.0, pdf(-1.0)); test_exact(0.0, 1.0, 0.5, 0.0, pdf(1.1)); test_exact(0.0, 1.0, 0.5, 1.0, pdf(0.25)); test_exact(0.0, 1.0, 0.5, 2.0, pdf(0.5)); test_exact(0.0, 1.0, 0.5, 1.0, pdf(0.75)); test_exact(-5.0, 8.0, -3.5, 0.0, pdf(-5.1)); test_exact(-5.0, 8.0, -3.5, 0.0, pdf(8.1)); test_exact(-5.0, 8.0, -3.5, 0.1025641025641025641026, pdf(-4.0)); test_exact(-5.0, 8.0, -3.5, 0.1538461538461538461538, pdf(-3.5)); test_exact(-5.0, 8.0, -3.5, 0.05351170568561872909699, pdf(4.0)); test_exact(-5.0, -3.0, -4.0, 0.0, pdf(-5.1)); test_exact(-5.0, -3.0, -4.0, 0.0, pdf(-2.9)); test_exact(-5.0, -3.0, -4.0, 0.5, pdf(-4.5)); test_exact(-5.0, -3.0, -4.0, 1.0, pdf(-4.0)); test_exact(-5.0, -3.0, -4.0, 0.5, pdf(-3.5)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Triangular| x.ln_pdf(arg); test_exact(0.0, 1.0, 0.5, f64::NEG_INFINITY, ln_pdf(-1.0)); test_exact(0.0, 1.0, 0.5, f64::NEG_INFINITY, ln_pdf(1.1)); test_exact(0.0, 1.0, 0.5, 0.0, ln_pdf(0.25)); test_exact(0.0, 1.0, 0.5, 2f64.ln(), ln_pdf(0.5)); test_exact(0.0, 1.0, 0.5, 0.0, ln_pdf(0.75)); test_exact(-5.0, 8.0, -3.5, f64::NEG_INFINITY, ln_pdf(-5.1)); test_exact(-5.0, 8.0, -3.5, f64::NEG_INFINITY, ln_pdf(8.1)); test_exact(-5.0, 8.0, -3.5, 0.1025641025641025641026f64.ln(), ln_pdf(-4.0)); test_exact(-5.0, 8.0, -3.5, 0.1538461538461538461538f64.ln(), ln_pdf(-3.5)); test_exact(-5.0, 8.0, -3.5, 0.05351170568561872909699f64.ln(), ln_pdf(4.0)); test_exact(-5.0, -3.0, -4.0, f64::NEG_INFINITY, ln_pdf(-5.1)); test_exact(-5.0, -3.0, -4.0, f64::NEG_INFINITY, ln_pdf(-2.9)); test_exact(-5.0, -3.0, -4.0, 0.5f64.ln(), ln_pdf(-4.5)); test_exact(-5.0, -3.0, -4.0, 0.0, ln_pdf(-4.0)); test_exact(-5.0, -3.0, -4.0, 0.5f64.ln(), ln_pdf(-3.5)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Triangular| x.cdf(arg); test_exact(0.0, 1.0, 0.5, 0.125, cdf(0.25)); test_exact(0.0, 1.0, 0.5, 0.5, cdf(0.5)); test_exact(0.0, 1.0, 0.5, 0.875, cdf(0.75)); test_exact(-5.0, 8.0, -3.5, 0.05128205128205128205128, cdf(-4.0)); test_exact(-5.0, 8.0, -3.5, 0.1153846153846153846154, cdf(-3.5)); test_exact(-5.0, 8.0, -3.5, 0.892976588628762541806, cdf(4.0)); test_exact(-5.0, -3.0, -4.0, 0.125, cdf(-4.5)); test_exact(-5.0, -3.0, -4.0, 0.5, cdf(-4.0)); test_exact(-5.0, -3.0, -4.0, 0.875, cdf(-3.5)); } #[test] fn test_cdf_lower_bound() { let cdf = |arg: f64| move |x: Triangular| x.cdf(arg); test_exact(0.0, 3.0, 1.5, 0.0, cdf(-1.0)); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: f64| move |x: Triangular| x.cdf(arg); test_exact(0.0, 3.0, 1.5, 1.0, cdf(5.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Triangular| x.sf(arg); test_exact(0.0, 1.0, 0.5, 0.875, sf(0.25)); test_exact(0.0, 1.0, 0.5, 0.5, sf(0.5)); test_exact(0.0, 1.0, 0.5, 0.125, sf(0.75)); test_exact(-5.0, 8.0, -3.5, 0.9487179487179487, sf(-4.0)); test_exact(-5.0, 8.0, -3.5, 0.8846153846153846, sf(-3.5)); test_exact(-5.0, 8.0, -3.5, 0.10702341137123746, sf(4.0)); test_exact(-5.0, -3.0, -4.0, 0.875, sf(-4.5)); test_exact(-5.0, -3.0, -4.0, 0.5, sf(-4.0)); test_exact(-5.0, -3.0, -4.0, 0.125, sf(-3.5)); } #[test] fn test_sf_lower_bound() { let sf = |arg: f64| move |x: Triangular| x.sf(arg); test_exact(0.0, 3.0, 1.5, 1.0, sf(-1.0)); } #[test] fn test_sf_upper_bound() { let sf = |arg: f64| move |x: Triangular| x.sf(arg); test_exact(0.0, 3.0, 1.5, 0.0, sf(5.0)); } #[test] fn test_inverse_cdf() { let func = |arg: f64| move |x: Triangular| x.inverse_cdf(x.cdf(arg)); test_absolute(0.0, 1.0, 0.5, 0.25, 1e-15, func(0.25)); test_absolute(0.0, 1.0, 0.5, 0.5, 1e-15, func(0.5)); test_absolute(0.0, 1.0, 0.5, 0.75, 1e-15, func(0.75)); test_absolute(-5.0, 8.0, -3.5, -4.0, 1e-15, func(-4.0)); test_absolute(-5.0, 8.0, -3.5, -3.5, 1e-15, func(-3.5)); test_absolute(-5.0, 8.0, -3.5, 4.0, 1e-15, func(4.0)); test_absolute(-5.0, -3.0, -4.0, -4.5, 1e-15, func(-4.5)); test_absolute(-5.0, -3.0, -4.0, -4.0, 1e-15, func(-4.0)); test_absolute(-5.0, -3.0, -4.0, -3.5, 1e-15, func(-3.5)); } #[test] fn test_continuous() { test::check_continuous_distribution(&create_ok(-5.0, 5.0, 0.0), -5.0, 5.0); test::check_continuous_distribution(&create_ok(-15.0, -2.0, -3.0), -15.0, -2.0); } } statrs-0.18.0/src/distribution/uniform.rs000064400000000000000000000353501046102023000165730ustar 00000000000000use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; use std::f64; use std::fmt::Debug; /// Implements the [Continuous /// Uniform](https://en.wikipedia.org/wiki/Uniform_distribution_(continuous)) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{Uniform, Continuous}; /// use statrs::statistics::Distribution; /// /// let n = Uniform::new(0.0, 1.0).unwrap(); /// assert_eq!(n.mean().unwrap(), 0.5); /// assert_eq!(n.pdf(0.5), 1.0); /// ``` #[derive(Debug, Copy, Clone, PartialEq)] pub struct Uniform { min: f64, max: f64, } /// Represents the errors that can occur when creating a [`Uniform`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum UniformError { /// The minimum is NaN or infinite. MinInvalid, /// The maximum is NaN or infinite. MaxInvalid, /// The maximum is not greater than the minimum. MaxNotGreaterThanMin, } impl std::fmt::Display for UniformError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { UniformError::MinInvalid => write!(f, "Minimum is NaN or infinite"), UniformError::MaxInvalid => write!(f, "Maximum is NaN or infinite"), UniformError::MaxNotGreaterThanMin => { write!(f, "Maximum is not greater than the minimum") } } } } impl std::error::Error for UniformError {} impl Uniform { /// Constructs a new uniform distribution with a min of `min` and a max /// of `max`. /// /// # Errors /// /// Returns an error if `min` or `max` are `NaN` or infinite. /// Returns an error if `min >= max`. /// /// # Examples /// /// ``` /// use statrs::distribution::Uniform; /// use std::f64; /// /// let mut result = Uniform::new(0.0, 1.0); /// assert!(result.is_ok()); /// /// result = Uniform::new(f64::NAN, f64::NAN); /// assert!(result.is_err()); /// /// result = Uniform::new(f64::NEG_INFINITY, 1.0); /// assert!(result.is_err()); /// ``` pub fn new(min: f64, max: f64) -> Result { if !min.is_finite() { return Err(UniformError::MinInvalid); } if !max.is_finite() { return Err(UniformError::MaxInvalid); } if min < max { Ok(Uniform { min, max }) } else { Err(UniformError::MaxNotGreaterThanMin) } } /// Constructs a new standard uniform distribution with /// a lower bound 0 and an upper bound of 1. /// /// # Examples /// /// ``` /// use statrs::distribution::Uniform; /// /// let uniform = Uniform::standard(); /// ``` pub fn standard() -> Self { Self { min: 0.0, max: 1.0 } } } impl Default for Uniform { fn default() -> Self { Self::standard() } } impl std::fmt::Display for Uniform { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Uni([{},{}])", self.min, self.max) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Uniform { fn sample(&self, rng: &mut R) -> f64 { let d = rand::distributions::Uniform::new_inclusive(self.min, self.max); rng.sample(d) } } impl ContinuousCDF for Uniform { /// Calculates the cumulative distribution function for the uniform /// distribution /// at `x` /// /// # Formula /// /// ```text /// (x - min) / (max - min) /// ``` fn cdf(&self, x: f64) -> f64 { if x <= self.min { 0.0 } else if x >= self.max { 1.0 } else { (x - self.min) / (self.max - self.min) } } /// Calculates the survival function for the uniform /// distribution at `x` /// /// # Formula /// /// ```text /// (max - x) / (max - min) /// ``` fn sf(&self, x: f64) -> f64 { if x <= self.min { 1.0 } else if x >= self.max { 0.0 } else { (self.max - x) / (self.max - self.min) } } /// Finds the value of `x` where `F(p) = x` fn inverse_cdf(&self, p: f64) -> f64 { if !(0.0..=1.0).contains(&p) { panic!("p must be in [0, 1], was {p}"); } else if p == 0.0 { self.min } else if p == 1.0 { self.max } else { (self.max - self.min) * p + self.min } } } impl Min for Uniform { fn min(&self) -> f64 { self.min } } impl Max for Uniform { fn max(&self) -> f64 { self.max } } impl Distribution for Uniform { /// Returns the mean for the continuous uniform distribution /// /// # Formula /// /// ```text /// (min + max) / 2 /// ``` fn mean(&self) -> Option { Some((self.min + self.max) / 2.0) } /// Returns the variance for the continuous uniform distribution /// /// # Formula /// /// ```text /// (max - min)^2 / 12 /// ``` fn variance(&self) -> Option { Some((self.max - self.min) * (self.max - self.min) / 12.0) } /// Returns the entropy for the continuous uniform distribution /// /// # Formula /// /// ```text /// ln(max - min) /// ``` fn entropy(&self) -> Option { Some((self.max - self.min).ln()) } /// Returns the skewness for the continuous uniform distribution /// /// # Formula /// /// ```text /// 0 /// ``` fn skewness(&self) -> Option { Some(0.0) } } impl Median for Uniform { /// Returns the median for the continuous uniform distribution /// /// # Formula /// /// ```text /// (min + max) / 2 /// ``` fn median(&self) -> f64 { (self.min + self.max) / 2.0 } } impl Mode> for Uniform { /// Returns the mode for the continuous uniform distribution /// /// # Remarks /// /// Since every element has an equal probability, mode simply /// returns the middle element /// /// # Formula /// /// ```text /// N/A // (max + min) / 2 for the middle element /// ``` fn mode(&self) -> Option { Some((self.min + self.max) / 2.0) } } impl Continuous for Uniform { /// Calculates the probability density function for the continuous uniform /// distribution at `x` /// /// # Remarks /// /// Returns `0.0` if `x` is not in `[min, max]` /// /// # Formula /// /// ```text /// 1 / (max - min) /// ``` fn pdf(&self, x: f64) -> f64 { if x < self.min || x > self.max { 0.0 } else { 1.0 / (self.max - self.min) } } /// Calculates the log probability density function for the continuous /// uniform /// distribution at `x` /// /// # Remarks /// /// Returns `f64::NEG_INFINITY` if `x` is not in `[min, max]` /// /// # Formula /// /// ```text /// ln(1 / (max - min)) /// ``` fn ln_pdf(&self, x: f64) -> f64 { if x < self.min || x > self.max { f64::NEG_INFINITY } else { -(self.max - self.min).ln() } } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(min: f64, max: f64; Uniform; UniformError); #[test] fn test_create() { create_ok(0.0, 0.1); create_ok(0.0, 1.0); create_ok(-5.0, 11.0); create_ok(-5.0, 100.0); } #[test] fn test_bad_create() { let invalid = [ (0.0, 0.0, UniformError::MaxNotGreaterThanMin), (f64::NAN, 1.0, UniformError::MinInvalid), (1.0, f64::NAN, UniformError::MaxInvalid), (f64::NAN, f64::NAN, UniformError::MinInvalid), (0.0, f64::INFINITY, UniformError::MaxInvalid), (1.0, 0.0, UniformError::MaxNotGreaterThanMin), ]; for (min, max, err) in invalid { test_create_err(min, max, err); } } #[test] fn test_variance() { let variance = |x: Uniform| x.variance().unwrap(); test_exact(-0.0, 2.0, 1.0 / 3.0, variance); test_exact(0.0, 2.0, 1.0 / 3.0, variance); test_absolute(0.1, 4.0, 1.2675, 1e-15, variance); test_exact(10.0, 11.0, 1.0 / 12.0, variance); } #[test] fn test_entropy() { let entropy = |x: Uniform| x.entropy().unwrap(); test_exact(-0.0, 2.0, 0.6931471805599453094172, entropy); test_exact(0.0, 2.0, 0.6931471805599453094172, entropy); test_absolute(0.1, 4.0, 1.360976553135600743431, 1e-15, entropy); test_exact(1.0, 10.0, 2.19722457733621938279, entropy); test_exact(10.0, 11.0, 0.0, entropy); } #[test] fn test_skewness() { let skewness = |x: Uniform| x.skewness().unwrap(); test_exact(-0.0, 2.0, 0.0, skewness); test_exact(0.0, 2.0, 0.0, skewness); test_exact(0.1, 4.0, 0.0, skewness); test_exact(1.0, 10.0, 0.0, skewness); test_exact(10.0, 11.0, 0.0, skewness); } #[test] fn test_mode() { let mode = |x: Uniform| x.mode().unwrap(); test_exact(-0.0, 2.0, 1.0, mode); test_exact(0.0, 2.0, 1.0, mode); test_exact(0.1, 4.0, 2.05, mode); test_exact(1.0, 10.0, 5.5, mode); test_exact(10.0, 11.0, 10.5, mode); } #[test] fn test_median() { let median = |x: Uniform| x.median(); test_exact(-0.0, 2.0, 1.0, median); test_exact(0.0, 2.0, 1.0, median); test_exact(0.1, 4.0, 2.05, median); test_exact(1.0, 10.0, 5.5, median); test_exact(10.0, 11.0, 10.5, median); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Uniform| x.pdf(arg); test_exact(0.0, 0.1, 0.0, pdf(-5.0)); test_exact(0.0, 0.1, 10.0, pdf(0.05)); test_exact(0.0, 0.1, 0.0, pdf(5.0)); test_exact(0.0, 1.0, 0.0, pdf(-5.0)); test_exact(0.0, 1.0, 1.0, pdf(0.5)); test_exact(0.0, 0.1, 0.0, pdf(5.0)); test_exact(0.0, 10.0, 0.0, pdf(-5.0)); test_exact(0.0, 10.0, 0.1, pdf(1.0)); test_exact(0.0, 10.0, 0.1, pdf(5.0)); test_exact(0.0, 10.0, 0.0, pdf(11.0)); test_exact(-5.0, 100.0, 0.0, pdf(-10.0)); test_exact(-5.0, 100.0, 0.009523809523809523809524, pdf(-5.0)); test_exact(-5.0, 100.0, 0.009523809523809523809524, pdf(0.0)); test_exact(-5.0, 100.0, 0.0, pdf(101.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Uniform| x.ln_pdf(arg); test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(-5.0)); test_absolute(0.0, 0.1, 2.302585092994045684018, 1e-15, ln_pdf(0.05)); test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0)); test_exact(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(-5.0)); test_exact(0.0, 1.0, 0.0, ln_pdf(0.5)); test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0)); test_exact(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(-5.0)); test_exact(0.0, 10.0, -2.302585092994045684018, ln_pdf(1.0)); test_exact(0.0, 10.0, -2.302585092994045684018, ln_pdf(5.0)); test_exact(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(11.0)); test_exact(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(-10.0)); test_exact(-5.0, 100.0, -4.653960350157523371101, ln_pdf(-5.0)); test_exact(-5.0, 100.0, -4.653960350157523371101, ln_pdf(0.0)); test_exact(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(101.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Uniform| x.cdf(arg); test_exact(0.0, 0.1, 0.5, cdf(0.05)); test_exact(0.0, 1.0, 0.5, cdf(0.5)); test_exact(0.0, 10.0, 0.1, cdf(1.0)); test_exact(0.0, 10.0, 0.5, cdf(5.0)); test_exact(-5.0, 100.0, 0.0, cdf(-5.0)); test_exact(-5.0, 100.0, 0.04761904761904761904762, cdf(0.0)); } #[test] fn test_inverse_cdf() { let inverse_cdf = |arg: f64| move |x: Uniform| x.inverse_cdf(arg); test_exact(0.0, 0.1, 0.05, inverse_cdf(0.5)); test_exact(0.0, 10.0, 5.0, inverse_cdf(0.5)); test_exact(1.0, 10.0, 1.0, inverse_cdf(0.0)); test_exact(1.0, 10.0, 4.0, inverse_cdf(1.0 / 3.0)); test_exact(1.0, 10.0, 10.0, inverse_cdf(1.0)); } #[test] fn test_cdf_lower_bound() { let cdf = |arg: f64| move |x: Uniform| x.cdf(arg); test_exact(0.0, 3.0, 0.0, cdf(-1.0)); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: f64| move |x: Uniform| x.cdf(arg); test_exact(0.0, 3.0, 1.0, cdf(5.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Uniform| x.sf(arg); test_exact(0.0, 0.1, 0.5, sf(0.05)); test_exact(0.0, 1.0, 0.5, sf(0.5)); test_exact(0.0, 10.0, 0.9, sf(1.0)); test_exact(0.0, 10.0, 0.5, sf(5.0)); test_exact(-5.0, 100.0, 1.0, sf(-5.0)); test_exact(-5.0, 100.0, 0.9523809523809523, sf(0.0)); } #[test] fn test_sf_lower_bound() { let sf = |arg: f64| move |x: Uniform| x.sf(arg); test_exact(0.0, 3.0, 1.0, sf(-1.0)); } #[test] fn test_sf_upper_bound() { let sf = |arg: f64| move |x: Uniform| x.sf(arg); test_exact(0.0, 3.0, 0.0, sf(5.0)); } #[test] fn test_continuous() { test::check_continuous_distribution(&create_ok(0.0, 10.0), 0.0, 10.0); test::check_continuous_distribution(&create_ok(-2.0, 15.0), -2.0, 15.0); } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] #[test] fn test_samples_in_range() { use rand::rngs::StdRng; use rand::SeedableRng; use rand::distributions::Distribution; let seed = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 ]; let mut r: StdRng = SeedableRng::from_seed(seed); let min = -0.5; let max = 0.5; let num_trials = 10_000; let n = create_ok(min, max); assert!((0..num_trials) .map(|_| n.sample::(&mut r)) .all(|v| (min <= v) && (v < max)) ); } #[test] fn test_default() { let n = Uniform::default(); let n_mean = n.mean().unwrap(); let n_std = n.std_dev().unwrap(); // Check that the mean of the distribution is close to 1 / 2 assert_almost_eq!(n_mean, 0.5, 1e-15); // Check that the standard deviation of the distribution is close to 1 / sqrt(12) assert_almost_eq!(n_std, 0.288_675_134_594_812_9, 1e-15); } } statrs-0.18.0/src/distribution/weibull.rs000064400000000000000000000427661046102023000165700ustar 00000000000000use crate::consts; use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::statistics::*; use std::f64; /// Implements the [Weibull](https://en.wikipedia.org/wiki/Weibull_distribution) /// distribution /// /// # Examples /// /// ``` /// use statrs::distribution::{Weibull, Continuous}; /// use statrs::statistics::Distribution; /// use statrs::prec; /// /// let n = Weibull::new(10.0, 1.0).unwrap(); /// assert!(prec::almost_eq(n.mean().unwrap(), /// 0.95135076986687318362924871772654021925505786260884, 1e-15)); /// assert_eq!(n.pdf(1.0), 3.6787944117144232159552377016146086744581113103177); /// ``` #[derive(Copy, Clone, PartialEq, Debug)] pub struct Weibull { shape: f64, scale: f64, scale_pow_shape_inv: f64, } /// Represents the errors that can occur when creating a [`Weibull`]. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum WeibullError { /// The shape is NaN, zero or less than zero. ShapeInvalid, /// The scale is NaN, zero or less than zero. ScaleInvalid, } impl std::fmt::Display for WeibullError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { WeibullError::ShapeInvalid => write!(f, "Shape is NaN, zero or less than zero."), WeibullError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero."), } } } impl std::error::Error for WeibullError {} impl Weibull { /// Constructs a new weibull distribution with a shape (k) of `shape` /// and a scale (λ) of `scale` /// /// # Errors /// /// Returns an error if `shape` or `scale` are `NaN`. /// Returns an error if `shape <= 0.0` or `scale <= 0.0` /// /// # Examples /// /// ``` /// use statrs::distribution::Weibull; /// /// let mut result = Weibull::new(10.0, 1.0); /// assert!(result.is_ok()); /// /// result = Weibull::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` pub fn new(shape: f64, scale: f64) -> Result { if shape.is_nan() || shape <= 0.0 { return Err(WeibullError::ShapeInvalid); } if scale.is_nan() || scale <= 0.0 { return Err(WeibullError::ScaleInvalid); } Ok(Weibull { shape, scale, scale_pow_shape_inv: scale.powf(-shape), }) } /// Returns the shape of the weibull distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Weibull; /// /// let n = Weibull::new(10.0, 1.0).unwrap(); /// assert_eq!(n.shape(), 10.0); /// ``` pub fn shape(&self) -> f64 { self.shape } /// Returns the scale of the weibull distribution /// /// # Examples /// /// ``` /// use statrs::distribution::Weibull; /// /// let n = Weibull::new(10.0, 1.0).unwrap(); /// assert_eq!(n.scale(), 1.0); /// ``` pub fn scale(&self) -> f64 { self.scale } } impl std::fmt::Display for Weibull { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Weibull({},{})", self.scale, self.shape) } } #[cfg(feature = "rand")] #[cfg_attr(docsrs, doc(cfg(feature = "rand")))] impl ::rand::distributions::Distribution for Weibull { fn sample(&self, rng: &mut R) -> f64 { let x: f64 = rng.gen(); self.scale * (-x.ln()).powf(1.0 / self.shape) } } impl ContinuousCDF for Weibull { /// Calculates the cumulative distribution function for the weibull /// distribution at `x` /// /// # Formula /// /// ```text /// 1 - e^-((x/λ)^k) /// ``` /// /// where `k` is the shape and `λ` is the scale fn cdf(&self, x: f64) -> f64 { if x < 0.0 { 0.0 } else { -(-x.powf(self.shape) * self.scale_pow_shape_inv).exp_m1() } } /// Calculates the survival function for the weibull /// distribution at `x` /// /// # Formula /// /// ```text /// e^-((x/λ)^k) /// ``` /// /// where `k` is the shape and `λ` is the scale fn sf(&self, x: f64) -> f64 { if x < 0.0 { 1.0 } else { (-x.powf(self.shape) * self.scale_pow_shape_inv).exp() } } /// Calculates the inverse cumulative distribution function for the weibull /// distribution at `x` /// /// # Formula /// /// ```text /// λ (-ln(1 - x))^(1 / k) /// ``` /// /// where `k` is the shape and `λ` is the scale fn inverse_cdf(&self, p: f64) -> f64 { if !(0.0..=1.0).contains(&p) { panic!("x must be in [0, 1]"); } (-((-p).ln_1p() / self.scale_pow_shape_inv)).powf(1.0 / self.shape) } } impl Min for Weibull { /// Returns the minimum value in the domain of the weibull /// distribution representable by a double precision float /// /// # Formula /// /// ```text /// 0 /// ``` fn min(&self) -> f64 { 0.0 } } impl Max for Weibull { /// Returns the maximum value in the domain of the weibull /// distribution representable by a double precision float /// /// # Formula /// /// ```text /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY } } impl Distribution for Weibull { /// Returns the mean of the weibull distribution /// /// # Formula /// /// ```text /// λΓ(1 + 1 / k) /// ``` /// /// where `k` is the shape, `λ` is the scale, and `Γ` is /// the gamma function fn mean(&self) -> Option { Some(self.scale * gamma::gamma(1.0 + 1.0 / self.shape)) } /// Returns the variance of the weibull distribution /// /// # Formula /// /// ```text /// λ^2 * (Γ(1 + 2 / k) - Γ(1 + 1 / k)^2) /// ``` /// /// where `k` is the shape, `λ` is the scale, and `Γ` is /// the gamma function fn variance(&self) -> Option { let mean = self.mean()?; Some(self.scale * self.scale * gamma::gamma(1.0 + 2.0 / self.shape) - mean * mean) } /// Returns the entropy of the weibull distribution /// /// # Formula /// /// ```text /// γ(1 - 1 / k) + ln(λ / k) + 1 /// ``` /// /// where `k` is the shape, `λ` is the scale, and `γ` is /// the Euler-Mascheroni constant fn entropy(&self) -> Option { let entr = consts::EULER_MASCHERONI * (1.0 - 1.0 / self.shape) + (self.scale / self.shape).ln() + 1.0; Some(entr) } /// Returns the skewness of the weibull distribution /// /// # Formula /// /// ```text /// (Γ(1 + 3 / k) * λ^3 - 3μσ^2 - μ^3) / σ^3 /// ``` /// /// where `k` is the shape, `λ` is the scale, and `Γ` is /// the gamma function, `μ` is the mean of the distribution. /// and `σ` the standard deviation of the distribution fn skewness(&self) -> Option { let mu = self.mean()?; let sigma = self.std_dev()?; let sigma2 = sigma * sigma; let sigma3 = sigma2 * sigma; let skew = (self.scale * self.scale * self.scale * gamma::gamma(1.0 + 3.0 / self.shape) - 3.0 * sigma2 * mu - (mu * mu * mu)) / sigma3; Some(skew) } } impl Median for Weibull { /// Returns the median of the weibull distribution /// /// # Formula /// /// ```text /// λ(ln(2))^(1 / k) /// ``` /// /// where `k` is the shape and `λ` is the scale fn median(&self) -> f64 { self.scale * f64::consts::LN_2.powf(1.0 / self.shape) } } impl Mode> for Weibull { /// Returns the median of the weibull distribution /// /// # Formula /// /// ```text /// if k == 1 { /// 0 /// } else { /// λ((k - 1) / k)^(1 / k) /// } /// ``` /// /// where `k` is the shape and `λ` is the scale fn mode(&self) -> Option { let mode = if ulps_eq!(self.shape, 1.0) { 0.0 } else { self.scale * ((self.shape - 1.0) / self.shape).powf(1.0 / self.shape) }; Some(mode) } } impl Continuous for Weibull { /// Calculates the probability density function for the weibull /// distribution at `x` /// /// # Formula /// /// ```text /// (k / λ) * (x / λ)^(k - 1) * e^(-(x / λ)^k) /// ``` /// /// where `k` is the shape and `λ` is the scale fn pdf(&self, x: f64) -> f64 { if x < 0.0 { 0.0 } else if x == 0.0 && ulps_eq!(self.shape, 1.0) { 1.0 / self.scale } else if x.is_infinite() { 0.0 } else { self.shape * (x / self.scale).powf(self.shape - 1.0) * (-(x.powf(self.shape)) * self.scale_pow_shape_inv).exp() / self.scale } } /// Calculates the log probability density function for the weibull /// distribution at `x` /// /// # Formula /// /// ```text /// ln((k / λ) * (x / λ)^(k - 1) * e^(-(x / λ)^k)) /// ``` /// /// where `k` is the shape and `λ` is the scale fn ln_pdf(&self, x: f64) -> f64 { if x < 0.0 { f64::NEG_INFINITY } else if x == 0.0 && ulps_eq!(self.shape, 1.0) { 0.0 - self.scale.ln() } else if x.is_infinite() { f64::NEG_INFINITY } else { self.shape.ln() + (self.shape - 1.0) * (x / self.scale).ln() - x.powf(self.shape) * self.scale_pow_shape_inv - self.scale.ln() } } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use crate::distribution::internal::*; use crate::testing_boiler; testing_boiler!(shape: f64, scale: f64; Weibull; WeibullError); #[test] fn test_create() { create_ok(1.0, 0.1); create_ok(10.0, 1.0); create_ok(11.0, 10.0); create_ok(12.0, f64::INFINITY); } #[test] fn test_bad_create() { test_create_err(f64::NAN, 1.0, WeibullError::ShapeInvalid); test_create_err(1.0, f64::NAN, WeibullError::ScaleInvalid); create_err(f64::NAN, f64::NAN); create_err(1.0, -1.0); create_err(-1.0, 1.0); create_err(-1.0, -1.0); create_err(0.0, 0.0); create_err(0.0, 1.0); create_err(1.0, 0.0); } #[test] fn test_mean() { let mean = |x: Weibull| x.mean().unwrap(); test_exact(1.0, 0.1, 0.1, mean); test_exact(1.0, 1.0, 1.0, mean); test_absolute(10.0, 10.0, 9.5135076986687318362924871772654021925505786260884, 1e-14, mean); test_absolute(10.0, 1.0, 0.95135076986687318362924871772654021925505786260884, 1e-15, mean); } #[test] fn test_variance() { let variance = |x: Weibull| x.variance().unwrap(); test_absolute(1.0, 0.1, 0.01, 1e-16, variance); test_absolute(1.0, 1.0, 1.0, 1e-14, variance); test_absolute(10.0, 10.0, 1.3100455073468309147154581687505295026863354547057, 1e-12, variance); test_absolute(10.0, 1.0, 0.013100455073468309147154581687505295026863354547057, 1e-14, variance); } #[test] fn test_entropy() { let entropy = |x: Weibull| x.entropy().unwrap(); test_absolute(1.0, 0.1, -1.302585092994045684018, 1e-15, entropy); test_exact(1.0, 1.0, 1.0, entropy); test_exact(10.0, 10.0, 1.519494098411379574546, entropy); test_absolute(10.0, 1.0, -0.783090994582666109472, 1e-15, entropy); } #[test] fn test_skewnewss() { let skewness = |x: Weibull| x.skewness().unwrap(); test_absolute(1.0, 0.1, 2.0, 1e-13, skewness); test_absolute(1.0, 1.0, 2.0, 1e-13, skewness); test_absolute(10.0, 10.0, -0.63763713390314440916597757156663888653981696212127, 1e-11, skewness); test_absolute(10.0, 1.0, -0.63763713390314440916597757156663888653981696212127, 1e-11, skewness); } #[test] fn test_median() { let median = |x: Weibull| x.median(); test_exact(1.0, 0.1, 0.069314718055994530941723212145817656807550013436026, median); test_exact(1.0, 1.0, 0.69314718055994530941723212145817656807550013436026, median); test_exact(10.0, 10.0, 9.6401223546778973665856033763604752124634905617583, median); test_exact(10.0, 1.0, 0.96401223546778973665856033763604752124634905617583, median); } #[test] fn test_mode() { let mode = |x: Weibull| x.mode().unwrap(); test_exact(1.0, 0.1, 0.0, mode); test_exact(1.0, 1.0, 0.0, mode); test_exact(10.0, 10.0, 9.8951925820621439264623017041980483215553841533709, mode); test_exact(10.0, 1.0, 0.98951925820621439264623017041980483215553841533709, mode); } #[test] fn test_min_max() { let min = |x: Weibull| x.min(); let max = |x: Weibull| x.max(); test_exact(1.0, 1.0, 0.0, min); test_exact(1.0, 1.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Weibull| x.pdf(arg); test_exact(1.0, 0.1, 10.0, pdf(0.0)); test_exact(1.0, 0.1, 0.00045399929762484851535591515560550610237918088866565, pdf(1.0)); test_exact(1.0, 0.1, 3.7200759760208359629596958038631183373588922923768e-43, pdf(10.0)); test_exact(1.0, 1.0, 1.0, pdf(0.0)); test_exact(1.0, 1.0, 0.36787944117144232159552377016146086744581113103177, pdf(1.0)); test_exact(1.0, 1.0, 0.000045399929762484851535591515560550610237918088866565, pdf(10.0)); test_exact(10.0, 10.0, 0.0, pdf(0.0)); test_absolute(10.0, 10.0, 9.9999999990000000000499999999983333333333750000000e-10, 1e-24, pdf(1.0)); test_exact(10.0, 10.0, 0.36787944117144232159552377016146086744581113103177, pdf(10.0)); test_exact(10.0, 1.0, 0.0, pdf(0.0)); test_exact(10.0, 1.0, 3.6787944117144232159552377016146086744581113103177, pdf(1.0)); test_exact(10.0, 1.0, 0.0, pdf(10.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Weibull| x.ln_pdf(arg); test_absolute(1.0, 0.1, 2.3025850929940456840179914546843642076011014886288, 1e-15, ln_pdf(0.0)); test_absolute(1.0, 0.1, -7.6974149070059543159820085453156357923988985113712, 1e-15, ln_pdf(1.0)); test_exact(1.0, 0.1, -97.697414907005954315982008545315635792398898511371, ln_pdf(10.0)); test_exact(1.0, 1.0, 0.0, ln_pdf(0.0)); test_exact(1.0, 1.0, -1.0, ln_pdf(1.0)); test_exact(1.0, 1.0, -10.0, ln_pdf(10.0)); test_exact(10.0, 10.0, f64::NEG_INFINITY, ln_pdf(0.0)); test_absolute(10.0, 10.0, -20.723265837046411156161923092159277868409913397659, 1e-14, ln_pdf(1.0)); test_exact(10.0, 10.0, -1.0, ln_pdf(10.0)); test_exact(10.0, 1.0, f64::NEG_INFINITY, ln_pdf(0.0)); test_absolute(10.0, 1.0, 1.3025850929940456840179914546843642076011014886288, 1e-15, ln_pdf(1.0)); test_exact(10.0, 1.0, -9.999999976974149070059543159820085453156357923988985113712e9, ln_pdf(10.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Weibull| x.cdf(arg); test_exact(1.0, 0.1, 0.0, cdf(0.0)); test_exact(1.0, 0.1, 0.99995460007023751514846440848443944938976208191113, cdf(1.0)); test_exact(1.0, 0.1, 0.99999999999999999999999999999999999999999996279924, cdf(10.0)); test_exact(1.0, 1.0, 0.0, cdf(0.0)); test_exact(1.0, 1.0, 0.63212055882855767840447622983853913255418886896823, cdf(1.0)); test_exact(1.0, 1.0, 0.99995460007023751514846440848443944938976208191113, cdf(10.0)); test_exact(10.0, 10.0, 0.0, cdf(0.0)); test_absolute(10.0, 10.0, 9.9999999995000000000166666666662500000000083333333e-11, 1e-25, cdf(1.0)); test_exact(10.0, 10.0, 0.63212055882855767840447622983853913255418886896823, cdf(10.0)); test_exact(10.0, 1.0, 0.0, cdf(0.0)); test_exact(10.0, 1.0, 0.63212055882855767840447622983853913255418886896823, cdf(1.0)); test_exact(10.0, 1.0, 1.0, cdf(10.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Weibull| x.sf(arg); test_exact(1.0, 0.1, 1.0, sf(0.0)); test_exact(1.0, 0.1, 4.5399929762484854e-5, sf(1.0)); test_exact(1.0, 0.1, 3.720075976020836e-44, sf(10.0)); test_exact(1.0, 1.0, 1.0, sf(0.0)); test_exact(1.0, 1.0, 0.36787944117144233, sf(1.0)); test_exact(1.0, 1.0, 4.5399929762484854e-5, sf(10.0)); test_exact(10.0, 10.0, 1.0, sf(0.0)); test_absolute(10.0, 10.0, 0.9999999999, 1e-25, sf(1.0)); test_exact(10.0, 10.0, 0.36787944117144233, sf(10.0)); test_exact(10.0, 1.0, 1.0, sf(0.0)); test_exact(10.0, 1.0, 0.36787944117144233, sf(1.0)); test_exact(10.0, 1.0, 0.0, sf(10.0)); } #[test] fn test_inverse_cdf() { let func = |arg: f64| move |x: Weibull| x.inverse_cdf(x.cdf(arg)); test_exact(1.0, 0.1, 0.0, func(0.0)); test_absolute(1.0, 0.1, 1.0, 1e-13, func(1.0)); test_exact(1.0, 1.0, 0.0, func(0.0)); test_exact(1.0, 1.0, 1.0, func(1.0)); test_absolute(1.0, 1.0, 10.0, 1e-10, func(10.0)); test_exact(10.0, 10.0, 0.0, func(0.0)); test_absolute(10.0, 10.0, 1.0, 1e-5, func(1.0)); test_absolute(10.0, 10.0, 10.0, 1e-10, func(10.0)); test_exact(10.0, 1.0, 0.0, func(0.0)); test_exact(10.0, 1.0, 1.0, func(1.0)); } #[test] fn test_continuous() { test::check_continuous_distribution(&create_ok(1.0, 0.2), 0.0, 10.0); } } statrs-0.18.0/src/distribution/ziggurat.rs000064400000000000000000000051361046102023000167470ustar 00000000000000use super::ziggurat_tables; use rand::distributions::Open01; use rand::Rng; pub fn sample_std_normal(rng: &mut R) -> f64 { #[inline] fn pdf(x: f64) -> f64 { (-x * x / 2.0).exp() } #[inline] fn zero_case(rng: &mut R, u: f64) -> f64 { let mut x = 1.0f64; let mut y = 0.0f64; while -2.0 * y < x * x { let x_: f64 = rng.sample(Open01); let y_: f64 = rng.sample(Open01); x = x_.ln() / ziggurat_tables::ZIG_NORM_R; y = y_.ln(); } if u < 0.0 { x - ziggurat_tables::ZIG_NORM_R } else { ziggurat_tables::ZIG_NORM_R - x } } ziggurat( rng, true, &ziggurat_tables::ZIG_NORM_X, &ziggurat_tables::ZIG_NORM_F, pdf, zero_case, ) } pub fn sample_exp_1(rng: &mut R) -> f64 { #[inline] fn pdf(x: f64) -> f64 { (-x).exp() } #[inline] fn zero_case(rng: &mut R, _u: f64) -> f64 { ziggurat_tables::ZIG_EXP_R - rng.gen::().ln() } ziggurat( rng, false, &ziggurat_tables::ZIG_EXP_X, &ziggurat_tables::ZIG_EXP_F, pdf, zero_case, ) } // Ziggurat method for sampling a random number based on the ZIGNOR // variant from Doornik 2005. Code borrowed from // https://github.com/rust-lang-nursery/rand/blob/master/src/distributions/mod. // rs#L223 #[inline(always)] fn ziggurat( rng: &mut R, symmetric: bool, x_tab: ziggurat_tables::ZigTable, f_tab: ziggurat_tables::ZigTable, mut pdf: P, mut zero_case: Z, ) -> f64 where P: FnMut(f64) -> f64, Z: FnMut(&mut R, f64) -> f64, { const SCALE: f64 = (1u64 << 53) as f64; loop { let bits: u64 = rng.gen(); let i = (bits & 0xff) as usize; let f = (bits >> 11) as f64 / SCALE; // u is either U(-1, 1) or U(0, 1) depending on if this is a // symmetric distribution or not. let u = if symmetric { 2.0 * f - 1.0 } else { f }; let x = u * x_tab[i]; let test_x = if symmetric { x.abs() } else { x }; // algebraically equivalent to |u| < x_tab[i+1]/x_tab[i] (or u < // x_tab[i+1]/x_tab[i]) if test_x < x_tab[i + 1] { return x; } if i == 0 { return zero_case(rng, u); } // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1 if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen::() < pdf(x) { return x; } } } statrs-0.18.0/src/distribution/ziggurat_tables.rs000064400000000000000000000575341046102023000203120ustar 00000000000000//! Generated ziggurat tables borrowed from //! `https://github.com/rust-lang-nursery/rand/blob/master/src/distributions/ziggurat_tables.rs` pub type ZigTable = &'static [f64; 257]; pub const ZIG_NORM_R: f64 = 3.654152885361008796; #[rustfmt::skip] pub static ZIG_NORM_X: [f64; 257] = [3.910757959537090045, 3.654152885361008796, 3.449278298560964462, 3.320244733839166074, 3.224575052047029100, 3.147889289517149969, 3.083526132001233044, 3.027837791768635434, 2.978603279880844834, 2.934366867207854224, 2.894121053612348060, 2.857138730872132548, 2.822877396825325125, 2.790921174000785765, 2.760944005278822555, 2.732685359042827056, 2.705933656121858100, 2.680514643284522158, 2.656283037575502437, 2.633116393630324570, 2.610910518487548515, 2.589575986706995181, 2.569035452680536569, 2.549221550323460761, 2.530075232158516929, 2.511544441625342294, 2.493583041269680667, 2.476149939669143318, 2.459208374333311298, 2.442725318198956774, 2.426670984935725972, 2.411018413899685520, 2.395743119780480601, 2.380822795170626005, 2.366237056715818632, 2.351967227377659952, 2.337996148795031370, 2.324308018869623016, 2.310888250599850036, 2.297723348901329565, 2.284800802722946056, 2.272108990226823888, 2.259637095172217780, 2.247375032945807760, 2.235313384928327984, 2.223443340090905718, 2.211756642882544366, 2.200245546609647995, 2.188902771624720689, 2.177721467738641614, 2.166695180352645966, 2.155817819875063268, 2.145083634046203613, 2.134487182844320152, 2.124023315687815661, 2.113687150684933957, 2.103474055713146829, 2.093379631137050279, 2.083399693996551783, 2.073530263516978778, 2.063767547809956415, 2.054107931648864849, 2.044547965215732788, 2.035084353727808715, 2.025713947862032960, 2.016433734904371722, 2.007240830558684852, 1.998132471356564244, 1.989106007615571325, 1.980158896898598364, 1.971288697931769640, 1.962493064942461896, 1.953769742382734043, 1.945116560006753925, 1.936531428273758904, 1.928012334050718257, 1.919557336591228847, 1.911164563769282232, 1.902832208548446369, 1.894558525668710081, 1.886341828534776388, 1.878180486290977669, 1.870072921069236838, 1.862017605397632281, 1.854013059758148119, 1.846057850283119750, 1.838150586580728607, 1.830289919680666566, 1.822474540091783224, 1.814703175964167636, 1.806974591348693426, 1.799287584547580199, 1.791640986550010028, 1.784033659547276329, 1.776464495522344977, 1.768932414909077933, 1.761436365316706665, 1.753975320315455111, 1.746548278279492994, 1.739154261283669012, 1.731792314050707216, 1.724461502945775715, 1.717160915015540690, 1.709889657069006086, 1.702646854797613907, 1.695431651932238548, 1.688243209434858727, 1.681080704722823338, 1.673943330923760353, 1.666830296159286684, 1.659740822855789499, 1.652674147080648526, 1.645629517902360339, 1.638606196773111146, 1.631603456932422036, 1.624620582830568427, 1.617656869570534228, 1.610711622367333673, 1.603784156023583041, 1.596873794420261339, 1.589979870021648534, 1.583101723393471438, 1.576238702733332886, 1.569390163412534456, 1.562555467528439657, 1.555733983466554893, 1.548925085471535512, 1.542128153226347553, 1.535342571438843118, 1.528567729435024614, 1.521803020758293101, 1.515047842773992404, 1.508301596278571965, 1.501563685112706548, 1.494833515777718391, 1.488110497054654369, 1.481394039625375747, 1.474683555695025516, 1.467978458615230908, 1.461278162507407830, 1.454582081885523293, 1.447889631277669675, 1.441200224845798017, 1.434513276002946425, 1.427828197027290358, 1.421144398672323117, 1.414461289772464658, 1.407778276843371534, 1.401094763676202559, 1.394410150925071257, 1.387723835686884621, 1.381035211072741964, 1.374343665770030531, 1.367648583594317957, 1.360949343030101844, 1.354245316759430606, 1.347535871177359290, 1.340820365893152122, 1.334098153216083604, 1.327368577624624679, 1.320630975217730096, 1.313884673146868964, 1.307128989027353860, 1.300363230327433728, 1.293586693733517645, 1.286798664489786415, 1.279998415710333237, 1.273185207661843732, 1.266358287014688333, 1.259516886060144225, 1.252660221891297887, 1.245787495544997903, 1.238897891102027415, 1.231990574742445110, 1.225064693752808020, 1.218119375481726552, 1.211153726239911244, 1.204166830140560140, 1.197157747875585931, 1.190125515422801650, 1.183069142678760732, 1.175987612011489825, 1.168879876726833800, 1.161744859441574240, 1.154581450355851802, 1.147388505416733873, 1.140164844363995789, 1.132909248648336975, 1.125620459211294389, 1.118297174115062909, 1.110938046009249502, 1.103541679420268151, 1.096106627847603487, 1.088631390649514197, 1.081114409698889389, 1.073554065787871714, 1.065948674757506653, 1.058296483326006454, 1.050595664586207123, 1.042844313139370538, 1.035040439828605274, 1.027181966030751292, 1.019266717460529215, 1.011292417434978441, 1.003256679539591412, 0.995156999629943084, 0.986990747093846266, 0.978755155288937750, 0.970447311058864615, 0.962064143217605250, 0.953602409875572654, 0.945058684462571130, 0.936429340280896860, 0.927710533396234771, 0.918898183643734989, 0.909987953490768997, 0.900975224455174528, 0.891855070726792376, 0.882622229578910122, 0.873271068082494550, 0.863795545546826915, 0.854189171001560554, 0.844444954902423661, 0.834555354079518752, 0.824512208745288633, 0.814306670128064347, 0.803929116982664893, 0.793369058833152785, 0.782615023299588763, 0.771654424216739354, 0.760473406422083165, 0.749056662009581653, 0.737387211425838629, 0.725446140901303549, 0.713212285182022732, 0.700661841097584448, 0.687767892786257717, 0.674499822827436479, 0.660822574234205984, 0.646695714884388928, 0.632072236375024632, 0.616896989996235545, 0.601104617743940417, 0.584616766093722262, 0.567338257040473026, 0.549151702313026790, 0.529909720646495108, 0.509423329585933393, 0.487443966121754335, 0.463634336771763245, 0.437518402186662658, 0.408389134588000746, 0.375121332850465727, 0.335737519180459465, 0.286174591747260509, 0.215241895913273806, 0.000000000000000000]; #[rustfmt::skip] pub static ZIG_NORM_F: [f64; 257] = [0.000477467764586655, 0.001260285930498598, 0.002609072746106363, 0.004037972593371872, 0.005522403299264754, 0.007050875471392110, 0.008616582769422917, 0.010214971439731100, 0.011842757857943104, 0.013497450601780807, 0.015177088307982072, 0.016880083152595839, 0.018605121275783350, 0.020351096230109354, 0.022117062707379922, 0.023902203305873237, 0.025705804008632656, 0.027527235669693315, 0.029365939758230111, 0.031221417192023690, 0.033093219458688698, 0.034980941461833073, 0.036884215688691151, 0.038802707404656918, 0.040736110656078753, 0.042684144916619378, 0.044646552251446536, 0.046623094902089664, 0.048613553216035145, 0.050617723861121788, 0.052635418276973649, 0.054666461325077916, 0.056710690106399467, 0.058767952921137984, 0.060838108349751806, 0.062921024437977854, 0.065016577971470438, 0.067124653828023989, 0.069245144397250269, 0.071377949059141965, 0.073522973714240991, 0.075680130359194964, 0.077849336702372207, 0.080030515814947509, 0.082223595813495684, 0.084428509570654661, 0.086645194450867782, 0.088873592068594229, 0.091113648066700734, 0.093365311913026619, 0.095628536713353335, 0.097903279039215627, 0.100189498769172020, 0.102487158942306270, 0.104796225622867056, 0.107116667775072880, 0.109448457147210021, 0.111791568164245583, 0.114145977828255210, 0.116511665626037014, 0.118888613443345698, 0.121276805485235437, 0.123676228202051403, 0.126086870220650349, 0.128508722280473636, 0.130941777174128166, 0.133386029692162844, 0.135841476571757352, 0.138308116449064322, 0.140785949814968309, 0.143274978974047118, 0.145775208006537926, 0.148286642733128721, 0.150809290682410169, 0.153343161060837674, 0.155888264725064563, 0.158444614156520225, 0.161012223438117663, 0.163591108232982951, 0.166181285765110071, 0.168782774801850333, 0.171395595638155623, 0.174019770082499359, 0.176655321444406654, 0.179302274523530397, 0.181960655600216487, 0.184630492427504539, 0.187311814224516926, 0.190004651671193070, 0.192709036904328807, 0.195425003514885592, 0.198152586546538112, 0.200891822495431333, 0.203642749311121501, 0.206405406398679298, 0.209179834621935651, 0.211966076307852941, 0.214764175252008499, 0.217574176725178370, 0.220396127481011589, 0.223230075764789593, 0.226076071323264877, 0.228934165415577484, 0.231804410825248525, 0.234686861873252689, 0.237581574432173676, 0.240488605941449107, 0.243408015423711988, 0.246339863502238771, 0.249284212419516704, 0.252241126056943765, 0.255210669955677150, 0.258192911338648023, 0.261187919133763713, 0.264195763998317568, 0.267216518344631837, 0.270250256366959984, 0.273297054069675804, 0.276356989296781264, 0.279430141762765316, 0.282516593084849388, 0.285616426816658109, 0.288729728483353931, 0.291856585618280984, 0.294997087801162572, 0.298151326697901342, 0.301319396102034120, 0.304501391977896274, 0.307697412505553769, 0.310907558127563710, 0.314131931597630143, 0.317370638031222396, 0.320623784958230129, 0.323891482377732021, 0.327173842814958593, 0.330470981380537099, 0.333783015832108509, 0.337110066638412809, 0.340452257045945450, 0.343809713148291340, 0.347182563958251478, 0.350570941482881204, 0.353974980801569250, 0.357394820147290515, 0.360830600991175754, 0.364282468130549597, 0.367750569780596226, 0.371235057669821344, 0.374736087139491414, 0.378253817247238111, 0.381788410875031348, 0.385340034841733958, 0.388908860020464597, 0.392495061461010764, 0.396098818517547080, 0.399720314981931668, 0.403359739222868885, 0.407017284331247953, 0.410693148271983222, 0.414387534042706784, 0.418100649839684591, 0.421832709231353298, 0.425583931339900579, 0.429354541031341519, 0.433144769114574058, 0.436954852549929273, 0.440785034667769915, 0.444635565397727750, 0.448506701509214067, 0.452398706863882505, 0.456311852680773566, 0.460246417814923481, 0.464202689050278838, 0.468180961407822172, 0.472181538469883255, 0.476204732721683788, 0.480250865911249714, 0.484320269428911598, 0.488413284707712059, 0.492530263646148658, 0.496671569054796314, 0.500837575128482149, 0.505028667945828791, 0.509245245998136142, 0.513487720749743026, 0.517756517232200619, 0.522052074674794864, 0.526374847174186700, 0.530725304406193921, 0.535103932383019565, 0.539511234259544614, 0.543947731192649941, 0.548413963257921133, 0.552910490428519918, 0.557437893621486324, 0.561996775817277916, 0.566587763258951771, 0.571211506738074970, 0.575868682975210544, 0.580559996103683473, 0.585286179266300333, 0.590047996335791969, 0.594846243770991268, 0.599681752622167719, 0.604555390700549533, 0.609468064928895381, 0.614420723892076803, 0.619414360609039205, 0.624450015550274240, 0.629528779928128279, 0.634651799290960050, 0.639820277456438991, 0.645035480824251883, 0.650298743114294586, 0.655611470583224665, 0.660975147780241357, 0.666391343912380640, 0.671861719900766374, 0.677388036222513090, 0.682972161648791376, 0.688616083008527058, 0.694321916130032579, 0.700091918140490099, 0.705928501336797409, 0.711834248882358467, 0.717811932634901395, 0.723864533472881599, 0.729995264565802437, 0.736207598131266683, 0.742505296344636245, 0.748892447223726720, 0.755373506511754500, 0.761953346841546475, 0.768637315803334831, 0.775431304986138326, 0.782341832659861902, 0.789376143571198563, 0.796542330428254619, 0.803849483176389490, 0.811307874318219935, 0.818929191609414797, 0.826726833952094231, 0.834716292992930375, 0.842915653118441077, 0.851346258465123684, 0.860033621203008636, 0.869008688043793165, 0.878309655816146839, 0.887984660763399880, 0.898095921906304051, 0.908726440060562912, 0.919991505048360247, 0.932060075968990209, 0.945198953453078028, 0.959879091812415930, 0.977101701282731328, 1.000000000000000000]; pub const ZIG_EXP_R: f64 = 7.697117470131050077; #[rustfmt::skip] pub static ZIG_EXP_X: [f64; 257] = [8.697117470131052741, 7.697117470131050077, 6.941033629377212577, 6.478378493832569696, 6.144164665772472667, 5.882144315795399869, 5.666410167454033697, 5.482890627526062488, 5.323090505754398016, 5.181487281301500047, 5.054288489981304089, 4.938777085901250530, 4.832939741025112035, 4.735242996601741083, 4.644491885420085175, 4.559737061707351380, 4.480211746528421912, 4.405287693473573185, 4.334443680317273007, 4.267242480277365857, 4.203313713735184365, 4.142340865664051464, 4.084051310408297830, 4.028208544647936762, 3.974606066673788796, 3.923062500135489739, 3.873417670399509127, 3.825529418522336744, 3.779270992411667862, 3.734528894039797375, 3.691201090237418825, 3.649195515760853770, 3.608428813128909507, 3.568825265648337020, 3.530315889129343354, 3.492837654774059608, 3.456332821132760191, 3.420748357251119920, 3.386035442460300970, 3.352149030900109405, 3.319047470970748037, 3.286692171599068679, 3.255047308570449882, 3.224079565286264160, 3.193757903212240290, 3.164053358025972873, 3.134938858084440394, 3.106389062339824481, 3.078380215254090224, 3.050890016615455114, 3.023897504455676621, 2.997382949516130601, 2.971327759921089662, 2.945714394895045718, 2.920526286512740821, 2.895747768600141825, 2.871364012015536371, 2.847360965635188812, 2.823725302450035279, 2.800444370250737780, 2.777506146439756574, 2.754899196562344610, 2.732612636194700073, 2.710636095867928752, 2.688959688741803689, 2.667573980773266573, 2.646469963151809157, 2.625639026797788489, 2.605072938740835564, 2.584763820214140750, 2.564704126316905253, 2.544886627111869970, 2.525304390037828028, 2.505950763528594027, 2.486819361740209455, 2.467904050297364815, 2.449198932978249754, 2.430698339264419694, 2.412396812688870629, 2.394289099921457886, 2.376370140536140596, 2.358635057409337321, 2.341079147703034380, 2.323697874390196372, 2.306486858283579799, 2.289441870532269441, 2.272558825553154804, 2.255833774367219213, 2.239262898312909034, 2.222842503111036816, 2.206569013257663858, 2.190438966723220027, 2.174449009937774679, 2.158595893043885994, 2.142876465399842001, 2.127287671317368289, 2.111826546019042183, 2.096490211801715020, 2.081275874393225145, 2.066180819490575526, 2.051202409468584786, 2.036338080248769611, 2.021585338318926173, 2.006941757894518563, 1.992404978213576650, 1.977972700957360441, 1.963642687789548313, 1.949412758007184943, 1.935280786297051359, 1.921244700591528076, 1.907302480018387536, 1.893452152939308242, 1.879691795072211180, 1.866019527692827973, 1.852433515911175554, 1.838931967018879954, 1.825513128903519799, 1.812175288526390649, 1.798916770460290859, 1.785735935484126014, 1.772631179231305643, 1.759600930889074766, 1.746643651946074405, 1.733757834985571566, 1.720942002521935299, 1.708194705878057773, 1.695514524101537912, 1.682900062917553896, 1.670349953716452118, 1.657862852574172763, 1.645437439303723659, 1.633072416535991334, 1.620766508828257901, 1.608518461798858379, 1.596327041286483395, 1.584191032532688892, 1.572109239386229707, 1.560080483527888084, 1.548103603714513499, 1.536177455041032092, 1.524300908219226258, 1.512472848872117082, 1.500692176842816750, 1.488957805516746058, 1.477268661156133867, 1.465623682245745352, 1.454021818848793446, 1.442462031972012504, 1.430943292938879674, 1.419464582769983219, 1.408024891569535697, 1.396623217917042137, 1.385258568263121992, 1.373929956328490576, 1.362636402505086775, 1.351376933258335189, 1.340150580529504643, 1.328956381137116560, 1.317793376176324749, 1.306660610415174117, 1.295557131686601027, 1.284481990275012642, 1.273434238296241139, 1.262412929069615330, 1.251417116480852521, 1.240445854334406572, 1.229498195693849105, 1.218573192208790124, 1.207669893426761121, 1.196787346088403092, 1.185924593404202199, 1.175080674310911677, 1.164254622705678921, 1.153445466655774743, 1.142652227581672841, 1.131873919411078511, 1.121109547701330200, 1.110358108727411031, 1.099618588532597308, 1.088889961938546813, 1.078171191511372307, 1.067461226479967662, 1.056759001602551429, 1.046063435977044209, 1.035373431790528542, 1.024687873002617211, 1.014005623957096480, 1.003325527915696735, 0.992646405507275897, 0.981967053085062602, 0.971286240983903260, 0.960602711668666509, 0.949915177764075969, 0.939222319955262286, 0.928522784747210395, 0.917815182070044311, 0.907098082715690257, 0.896370015589889935, 0.885629464761751528, 0.874874866291025066, 0.864104604811004484, 0.853317009842373353, 0.842510351810368485, 0.831682837734273206, 0.820832606554411814, 0.809957724057418282, 0.799056177355487174, 0.788125868869492430, 0.777164609759129710, 0.766170112735434672, 0.755139984181982249, 0.744071715500508102, 0.732962673584365398, 0.721810090308756203, 0.710611050909655040, 0.699362481103231959, 0.688061132773747808, 0.676703568029522584, 0.665286141392677943, 0.653804979847664947, 0.642255960424536365, 0.630634684933490286, 0.618936451394876075, 0.607156221620300030, 0.595288584291502887, 0.583327712748769489, 0.571267316532588332, 0.559100585511540626, 0.546820125163310577, 0.534417881237165604, 0.521885051592135052, 0.509211982443654398, 0.496388045518671162, 0.483401491653461857, 0.470239275082169006, 0.456886840931420235, 0.443327866073552401, 0.429543940225410703, 0.415514169600356364, 0.401214678896277765, 0.386617977941119573, 0.371692145329917234, 0.356399760258393816, 0.340696481064849122, 0.324529117016909452, 0.307832954674932158, 0.290527955491230394, 0.272513185478464703, 0.253658363385912022, 0.233790483059674731, 0.212671510630966620, 0.189958689622431842, 0.165127622564187282, 0.137304980940012589, 0.104838507565818778, 0.063852163815001570, 0.000000000000000000]; #[rustfmt::skip] pub static ZIG_EXP_F: [f64; 257] = [0.000167066692307963, 0.000454134353841497, 0.000967269282327174, 0.001536299780301573, 0.002145967743718907, 0.002788798793574076, 0.003460264777836904, 0.004157295120833797, 0.004877655983542396, 0.005619642207205489, 0.006381905937319183, 0.007163353183634991, 0.007963077438017043, 0.008780314985808977, 0.009614413642502212, 0.010464810181029981, 0.011331013597834600, 0.012212592426255378, 0.013109164931254991, 0.014020391403181943, 0.014945968011691148, 0.015885621839973156, 0.016839106826039941, 0.017806200410911355, 0.018786700744696024, 0.019780424338009740, 0.020787204072578114, 0.021806887504283581, 0.022839335406385240, 0.023884420511558174, 0.024942026419731787, 0.026012046645134221, 0.027094383780955803, 0.028188948763978646, 0.029295660224637411, 0.030414443910466622, 0.031545232172893622, 0.032687963508959555, 0.033842582150874358, 0.035009037697397431, 0.036187284781931443, 0.037377282772959382, 0.038578995503074871, 0.039792391023374139, 0.041017441380414840, 0.042254122413316254, 0.043502413568888197, 0.044762297732943289, 0.046033761076175184, 0.047316792913181561, 0.048611385573379504, 0.049917534282706379, 0.051235237055126281, 0.052564494593071685, 0.053905310196046080, 0.055257689676697030, 0.056621641283742870, 0.057997175631200659, 0.059384305633420280, 0.060783046445479660, 0.062193415408541036, 0.063615431999807376, 0.065049117786753805, 0.066494496385339816, 0.067951593421936643, 0.069420436498728783, 0.070901055162371843, 0.072393480875708752, 0.073897746992364746, 0.075413888734058410, 0.076941943170480517, 0.078481949201606435, 0.080033947542319905, 0.081597980709237419, 0.083174093009632397, 0.084762330532368146, 0.086362741140756927, 0.087975374467270231, 0.089600281910032886, 0.091237516631040197, 0.092887133556043569, 0.094549189376055873, 0.096223742550432825, 0.097910853311492213, 0.099610583670637132, 0.101322997425953631, 0.103048160171257702, 0.104786139306570145, 0.106537004050001632, 0.108300825451033755, 0.110077676405185357, 0.111867631670056283, 0.113670767882744286, 0.115487163578633506, 0.117316899211555525, 0.119160057175327641, 0.121016721826674792, 0.122886979509545108, 0.124770918580830933, 0.126668629437510671, 0.128580204545228199, 0.130505738468330773, 0.132445327901387494, 0.134399071702213602, 0.136367070926428829, 0.138349428863580176, 0.140346251074862399, 0.142357645432472146, 0.144383722160634720, 0.146424593878344889, 0.148480375643866735, 0.150551185001039839, 0.152637142027442801, 0.154738369384468027, 0.156854992369365148, 0.158987138969314129, 0.161134939917591952, 0.163298528751901734, 0.165478041874935922, 0.167673618617250081, 0.169885401302527550, 0.172113535315319977, 0.174358169171353411, 0.176619454590494829, 0.178897546572478278, 0.181192603475496261, 0.183504787097767436, 0.185834262762197083, 0.188181199404254262, 0.190545769663195363, 0.192928149976771296, 0.195328520679563189, 0.197747066105098818, 0.200183974691911210, 0.202639439093708962, 0.205113656293837654, 0.207606827724221982, 0.210119159388988230, 0.212650861992978224, 0.215202151075378628, 0.217773247148700472, 0.220364375843359439, 0.222975768058120111, 0.225607660116683956, 0.228260293930716618, 0.230933917169627356, 0.233628783437433291, 0.236345152457059560, 0.239083290262449094, 0.241843469398877131, 0.244625969131892024, 0.247431075665327543, 0.250259082368862240, 0.253110290015629402, 0.255985007030415324, 0.258883549749016173, 0.261806242689362922, 0.264753418835062149, 0.267725419932044739, 0.270722596799059967, 0.273745309652802915, 0.276793928448517301, 0.279868833236972869, 0.282970414538780746, 0.286099073737076826, 0.289255223489677693, 0.292439288161892630, 0.295651704281261252, 0.298892921015581847, 0.302163400675693528, 0.305463619244590256, 0.308794066934560185, 0.312155248774179606, 0.315547685227128949, 0.318971912844957239, 0.322428484956089223, 0.325917972393556354, 0.329440964264136438, 0.332998068761809096, 0.336589914028677717, 0.340217149066780189, 0.343880444704502575, 0.347580494621637148, 0.351318016437483449, 0.355093752866787626, 0.358908472948750001, 0.362762973354817997, 0.366658079781514379, 0.370594648435146223, 0.374573567615902381, 0.378595759409581067, 0.382662181496010056, 0.386773829084137932, 0.390931736984797384, 0.395136981833290435, 0.399390684475231350, 0.403694012530530555, 0.408048183152032673, 0.412454465997161457, 0.416914186433003209, 0.421428728997616908, 0.425999541143034677, 0.430628137288459167, 0.435316103215636907, 0.440065100842354173, 0.444876873414548846, 0.449753251162755330, 0.454696157474615836, 0.459707615642138023, 0.464789756250426511, 0.469944825283960310, 0.475175193037377708, 0.480483363930454543, 0.485871987341885248, 0.491343869594032867, 0.496901987241549881, 0.502549501841348056, 0.508289776410643213, 0.514126393814748894, 0.520063177368233931, 0.526104213983620062, 0.532253880263043655, 0.538516872002862246, 0.544898237672440056, 0.551403416540641733, 0.558038282262587892, 0.564809192912400615, 0.571723048664826150, 0.578787358602845359, 0.586010318477268366, 0.593400901691733762, 0.600968966365232560, 0.608725382079622346, 0.616682180915207878, 0.624852738703666200, 0.633251994214366398, 0.641896716427266423, 0.650805833414571433, 0.660000841079000145, 0.669506316731925177, 0.679350572264765806, 0.689566496117078431, 0.700192655082788606, 0.711274760805076456, 0.722867659593572465, 0.735038092431424039, 0.747868621985195658, 0.761463388849896838, 0.775956852040116218, 0.791527636972496285, 0.808421651523009044, 0.826993296643051101, 0.847785500623990496, 0.871704332381204705, 0.900469929925747703, 0.938143680862176477, 1.000000000000000000]; statrs-0.18.0/src/euclid.rs000064400000000000000000000023631046102023000136400ustar 00000000000000//! Provides number theory utility functions /// Provides a trait for the canonical modulus operation since % is technically /// the remainder operation pub trait Modulus { /// Performs a canonical modulus operation between `self` and `divisor`. /// /// # Examples /// /// ``` /// use statrs::euclid::Modulus; /// /// let x = 4i64.modulus(5); /// assert_eq!(x, 4); /// /// let y = -4i64.modulus(5); /// assert_eq!(x, 4); /// ``` fn modulus(self, divisor: Self) -> Self; } impl Modulus for f64 { fn modulus(self, divisor: f64) -> f64 { ((self % divisor) + divisor) % divisor } } impl Modulus for f32 { fn modulus(self, divisor: f32) -> f32 { ((self % divisor) + divisor) % divisor } } impl Modulus for i64 { fn modulus(self, divisor: i64) -> i64 { ((self % divisor) + divisor) % divisor } } impl Modulus for i32 { fn modulus(self, divisor: i32) -> i32 { ((self % divisor) + divisor) % divisor } } impl Modulus for u64 { fn modulus(self, divisor: u64) -> u64 { ((self % divisor) + divisor) % divisor } } impl Modulus for u32 { fn modulus(self, divisor: u32) -> u32 { ((self % divisor) + divisor) % divisor } } statrs-0.18.0/src/function/beta.rs000064400000000000000000000463351046102023000151420ustar 00000000000000//! Provides the [beta](https://en.wikipedia.org/wiki/Beta_function) and related //! function use crate::function::gamma; use crate::prec; use std::f64; /// Represents the errors that can occur when computing the natural logarithm /// of the beta function or the regularized lower incomplete beta function. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum BetaFuncError { /// `a` is zero or less than zero. ANotGreaterThanZero, /// `b` is zero or less than zero. BNotGreaterThanZero, /// `x` is not in `[0, 1]`. XOutOfRange, } impl std::fmt::Display for BetaFuncError { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { BetaFuncError::ANotGreaterThanZero => write!(f, "a is zero or less than zero"), BetaFuncError::BNotGreaterThanZero => write!(f, "b is zero or less than zero"), BetaFuncError::XOutOfRange => write!(f, "x is not in [0, 1]"), } } } impl std::error::Error for BetaFuncError {} /// Computes the natural logarithm /// of the beta function /// where `a` is the first beta parameter /// and `b` is the second beta parameter /// and `a > 0`, `b > 0`. /// /// # Panics /// /// if `a <= 0.0` or `b <= 0.0` pub fn ln_beta(a: f64, b: f64) -> f64 { checked_ln_beta(a, b).unwrap() } /// Computes the natural logarithm /// of the beta function /// where `a` is the first beta parameter /// and `b` is the second beta parameter /// and `a > 0`, `b > 0`. /// /// # Errors /// /// if `a <= 0.0` or `b <= 0.0` pub fn checked_ln_beta(a: f64, b: f64) -> Result { if a <= 0.0 { Err(BetaFuncError::ANotGreaterThanZero) } else if b <= 0.0 { Err(BetaFuncError::BNotGreaterThanZero) } else { Ok(gamma::ln_gamma(a) + gamma::ln_gamma(b) - gamma::ln_gamma(a + b)) } } /// Computes the beta function /// where `a` is the first beta parameter /// and `b` is the second beta parameter. /// /// /// # Panics /// /// if `a <= 0.0` or `b <= 0.0` pub fn beta(a: f64, b: f64) -> f64 { checked_beta(a, b).unwrap() } /// Computes the beta function /// where `a` is the first beta parameter /// and `b` is the second beta parameter. /// /// /// # Errors /// /// if `a <= 0.0` or `b <= 0.0` pub fn checked_beta(a: f64, b: f64) -> Result { checked_ln_beta(a, b).map(|x| x.exp()) } /// Computes the lower incomplete (unregularized) beta function /// `B(a,b,x) = int(t^(a-1)*(1-t)^(b-1),t=0..x)` for `a > 0, b > 0, 1 >= x >= 0` /// where `a` is the first beta parameter, `b` is the second beta parameter, and /// `x` is the upper limit of the integral /// /// # Panics /// /// If `a <= 0.0`, `b <= 0.0`, `x < 0.0`, or `x > 1.0` pub fn beta_inc(a: f64, b: f64, x: f64) -> f64 { checked_beta_inc(a, b, x).unwrap() } /// Computes the lower incomplete (unregularized) beta function /// `B(a,b,x) = int(t^(a-1)*(1-t)^(b-1),t=0..x)` for `a > 0, b > 0, 1 >= x >= 0` /// where `a` is the first beta parameter, `b` is the second beta parameter, and /// `x` is the upper limit of the integral /// /// # Errors /// /// If `a <= 0.0`, `b <= 0.0`, `x < 0.0`, or `x > 1.0` pub fn checked_beta_inc(a: f64, b: f64, x: f64) -> Result { checked_beta_reg(a, b, x).and_then(|x| checked_beta(a, b).map(|y| x * y)) } /// Computes the regularized lower incomplete beta function /// `I_x(a,b) = 1/Beta(a,b) * int(t^(a-1)*(1-t)^(b-1), t=0..x)` /// `a > 0`, `b > 0`, `1 >= x >= 0` where `a` is the first beta parameter, /// `b` is the second beta parameter, and `x` is the upper limit of the /// integral. /// /// # Panics /// /// if `a <= 0.0`, `b <= 0.0`, `x < 0.0`, or `x > 1.0` pub fn beta_reg(a: f64, b: f64, x: f64) -> f64 { checked_beta_reg(a, b, x).unwrap() } /// Computes the regularized lower incomplete beta function /// `I_x(a,b) = 1/Beta(a,b) * int(t^(a-1)*(1-t)^(b-1), t=0..x)` /// `a > 0`, `b > 0`, `1 >= x >= 0` where `a` is the first beta parameter, /// `b` is the second beta parameter, and `x` is the upper limit of the /// integral. /// /// # Errors /// /// if `a <= 0.0`, `b <= 0.0`, `x < 0.0`, or `x > 1.0` pub fn checked_beta_reg(a: f64, b: f64, x: f64) -> Result { if a <= 0.0 { return Err(BetaFuncError::ANotGreaterThanZero); } if b <= 0.0 { return Err(BetaFuncError::BNotGreaterThanZero); } if !(0.0..=1.0).contains(&x) { return Err(BetaFuncError::XOutOfRange); } let bt = if x == 0.0 || ulps_eq!(x, 1.0) { 0.0 } else { (gamma::ln_gamma(a + b) - gamma::ln_gamma(a) - gamma::ln_gamma(b) + a * x.ln() + b * (1.0 - x).ln()) .exp() }; let symm_transform = x >= (a + 1.0) / (a + b + 2.0); let eps = prec::F64_PREC; let fpmin = f64::MIN_POSITIVE / eps; let mut a = a; let mut b = b; let mut x = x; if symm_transform { let swap = a; x = 1.0 - x; a = b; b = swap; } let qab = a + b; let qap = a + 1.0; let qam = a - 1.0; let mut c = 1.0; let mut d = 1.0 - qab * x / qap; if d.abs() < fpmin { d = fpmin; } d = 1.0 / d; let mut h = d; for m in 1..141 { let m = f64::from(m); let m2 = m * 2.0; let mut aa = m * (b - m) * x / ((qam + m2) * (a + m2)); d = 1.0 + aa * d; if d.abs() < fpmin { d = fpmin; } c = 1.0 + aa / c; if c.abs() < fpmin { c = fpmin; } d = 1.0 / d; h = h * d * c; aa = -(a + m) * (qab + m) * x / ((a + m2) * (qap + m2)); d = 1.0 + aa * d; if d.abs() < fpmin { d = fpmin; } c = 1.0 + aa / c; if c.abs() < fpmin { c = fpmin; } d = 1.0 / d; let del = d * c; h *= del; if (del - 1.0).abs() <= eps { return if symm_transform { Ok(1.0 - bt * h / a) } else { Ok(bt * h / a) }; } } if symm_transform { Ok(1.0 - bt * h / a) } else { Ok(bt * h / a) } } /// Computes the inverse of the regularized incomplete beta function // This code is based on the implementation in the ["special"][1] crate, // which in turn is based on a [C implementation][2] by John Burkardt. The // original algorithm was published in Applied Statistics and is known as // [Algorithm AS 64][3] and [Algorithm AS 109][4]. // // [1]: https://docs.rs/special/0.8.1/ // [2]: http://people.sc.fsu.edu/~jburkardt/c_src/asa109/asa109.html // [3]: http://www.jstor.org/stable/2346798 // [4]: http://www.jstor.org/stable/2346887 // // > Copyright 2014–2019 The special Developers // > // > Permission is hereby granted, free of charge, to any person obtaining a copy of // > this software and associated documentation files (the “Software”), to deal in // > the Software without restriction, including without limitation the rights to // > use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of // > the Software, and to permit persons to whom the Software is furnished to do so, // > subject to the following conditions: // > // > The above copyright notice and this permission notice shall be included in all // > copies or substantial portions of the Software. // > // > THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // > IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS // > FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR // > COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER // > IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN // > CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. pub fn inv_beta_reg(mut a: f64, mut b: f64, mut x: f64) -> f64 { // Algorithm AS 64 // http://www.jstor.org/stable/2346798 // // An approximation x₀ to x if found from (cf. Scheffé and Tukey, 1944) // // 1 + x₀ 4p + 2q - 2 // ------ = ----------- // 1 - x₀ χ²(α) // // where χ²(α) is the upper α point of the χ² distribution with 2q // degrees of freedom and is obtained from Wilson and Hilferty’s // approximation (cf. Wilson and Hilferty, 1931) // // χ²(α) = 2q (1 - 1 / (9q) + y(α) sqrt(1 / (9q)))^3, // // y(α) being Hastings’ approximation (cf. Hastings, 1955) for the upper // α point of the standard normal distribution. If χ²(α) < 0, then // // x₀ = 1 - ((1 - α)q B(p, q))^(1 / q). // // Again if (4p + 2q - 2) / χ²(α) does not exceed 1, x₀ is obtained from // // x₀ = (αp B(p, q))^(1 / p). // // The final solution is obtained by the Newton–Raphson method from the // relation // // f(x[i - 1]) // x[i] = x[i - 1] - ------------ // f'(x[i - 1]) // // where // // f(x) = I(x, p, q) - α. let ln_beta = ln_beta(a, b); // Remark AS R83 // http://www.jstor.org/stable/2347779 const SAE: i32 = -30; const FPU: f64 = 1e-30; // 10^SAE debug_assert!((0.0..=1.0).contains(&x) && a > 0.0 && b > 0.0); if x == 0.0 { return 0.0; } if x == 1.0 { return 1.0; } let mut p; let mut q; let flip = 0.5 < x; if flip { p = a; a = b; b = p; x = 1.0 - x; } p = (-(x * x).ln()).sqrt(); q = p - (2.30753 + 0.27061 * p) / (1.0 + (0.99229 + 0.04481 * p) * p); if 1.0 < a && 1.0 < b { // Remark AS R19 and Algorithm AS 109 // http://www.jstor.org/stable/2346887 // // For a and b > 1, the approximation given by Carter (1947), which // improves the Fisher–Cochran formula, is generally better. For // other values of a and b en empirical investigation has shown that // the approximation given in AS 64 is adequate. let r = (q * q - 3.0) / 6.0; let s = 1.0 / (2.0 * a - 1.0); let t = 1.0 / (2.0 * b - 1.0); let h = 2.0 / (s + t); let w = q * (h + r).sqrt() / h - (t - s) * (r + 5.0 / 6.0 - 2.0 / (3.0 * h)); p = a / (a + b * (2.0 * w).exp()); } else { let mut t = 1.0 / (9.0 * b); t = 2.0 * b * (1.0 - t + q * t.sqrt()).powf(3.0); if t <= 0.0 { p = 1.0 - ((((1.0 - x) * b).ln() + ln_beta) / b).exp(); } else { t = 2.0 * (2.0 * a + b - 1.0) / t; if t <= 1.0 { p = (((x * a).ln() + ln_beta) / a).exp(); } else { p = 1.0 - 2.0 / (t + 1.0); } } } p = p.clamp(0.0001, 0.9999); // Remark AS R83 // http://www.jstor.org/stable/2347779 let e = (-5.0 / a / a - 1.0 / x.powf(0.2) - 13.0) as i32; let acu = if e > SAE { f64::powi(10.0, e) } else { FPU }; let mut pnext; let mut qprev = 0.0; let mut sq = 1.0; let mut prev = 1.0; 'outer: loop { // Remark AS R19 and Algorithm AS 109 // http://www.jstor.org/stable/2346887 q = beta_reg(a, b, p); q = (q - x) * (ln_beta + (1.0 - a) * p.ln() + (1.0 - b) * (1.0 - p).ln()).exp(); // Remark AS R83 // http://www.jstor.org/stable/2347779 if q * qprev <= 0.0 { prev = if sq > FPU { sq } else { FPU }; } // Remark AS R19 and Algorithm AS 109 // http://www.jstor.org/stable/2346887 let mut g = 1.0; loop { loop { let adj = g * q; sq = adj * adj; if sq < prev { pnext = p - adj; if (0.0..=1.0).contains(&pnext) { break; } } g /= 3.0; } if prev <= acu || q * q <= acu { p = pnext; break 'outer; } if pnext != 0.0 && pnext != 1.0 { break; } g /= 3.0; } if pnext == p { break; } p = pnext; qprev = q; } if flip { 1.0 - p } else { p } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; #[test] fn test_ln_beta() { assert_almost_eq!(super::ln_beta(0.5, 0.5), 1.144729885849400174144, 1e-15); assert_almost_eq!(super::ln_beta(1.0, 0.5), 0.6931471805599453094172, 1e-14); assert_almost_eq!(super::ln_beta(2.5, 0.5), 0.163900632837673937284, 1e-15); assert_almost_eq!(super::ln_beta(0.5, 1.0), 0.6931471805599453094172, 1e-14); assert_almost_eq!(super::ln_beta(1.0, 1.0), 0.0, 1e-15); assert_almost_eq!(super::ln_beta(2.5, 1.0), -0.9162907318741550651835, 1e-14); assert_almost_eq!(super::ln_beta(0.5, 2.5), 0.163900632837673937284, 1e-15); assert_almost_eq!(super::ln_beta(1.0, 2.5), -0.9162907318741550651835, 1e-14); assert_almost_eq!(super::ln_beta(2.5, 2.5), -2.608688089402107300388, 1e-14); } #[test] #[should_panic] fn test_ln_beta_a_lte_0() { super::ln_beta(0.0, 0.5); } #[test] #[should_panic] fn test_ln_beta_b_lte_0() { super::ln_beta(0.5, 0.0); } #[test] fn test_checked_ln_beta_a_lte_0() { assert!(super::checked_ln_beta(0.0, 0.5).is_err()); } #[test] fn test_checked_ln_beta_b_lte_0() { assert!(super::checked_ln_beta(0.5, 0.0).is_err()); } #[test] #[should_panic] fn test_beta_a_lte_0() { super::beta(0.0, 0.5); } #[test] #[should_panic] fn test_beta_b_lte_0() { super::beta(0.5, 0.0); } #[test] fn test_checked_beta_a_lte_0() { assert!(super::checked_beta(0.0, 0.5).is_err()); } #[test] fn test_checked_beta_b_lte_0() { assert!(super::checked_beta(0.5, 0.0).is_err()); } #[test] fn test_beta() { assert_almost_eq!(super::beta(0.5, 0.5), 3.141592653589793238463, 1e-15); assert_almost_eq!(super::beta(1.0, 0.5), 2.0, 1e-14); assert_almost_eq!(super::beta(2.5, 0.5), 1.17809724509617246442, 1e-15); assert_almost_eq!(super::beta(0.5, 1.0), 2.0, 1e-14); assert_almost_eq!(super::beta(1.0, 1.0), 1.0, 1e-15); assert_almost_eq!(super::beta(2.5, 1.0), 0.4, 1e-14); assert_almost_eq!(super::beta(0.5, 2.5), 1.17809724509617246442, 1e-15); assert_almost_eq!(super::beta(1.0, 2.5), 0.4, 1e-14); assert_almost_eq!(super::beta(2.5, 2.5), 0.073631077818510779026, 1e-15); } #[test] fn test_beta_inc() { assert_almost_eq!(super::beta_inc(0.5, 0.5, 0.5), 1.570796326794896619231, 1e-14); assert_almost_eq!(super::beta_inc(0.5, 0.5, 1.0), 3.141592653589793238463, 1e-15); assert_almost_eq!(super::beta_inc(1.0, 0.5, 0.5), 0.5857864376269049511983, 1e-15); assert_almost_eq!(super::beta_inc(1.0, 0.5, 1.0), 2.0, 1e-14); assert_almost_eq!(super::beta_inc(2.5, 0.5, 0.5), 0.0890486225480862322117, 1e-16); assert_almost_eq!(super::beta_inc(2.5, 0.5, 1.0), 1.17809724509617246442, 1e-15); assert_almost_eq!(super::beta_inc(0.5, 1.0, 0.5), 1.414213562373095048802, 1e-14); assert_almost_eq!(super::beta_inc(0.5, 1.0, 1.0), 2.0, 1e-14); assert_almost_eq!(super::beta_inc(1.0, 1.0, 0.5), 0.5, 1e-15); assert_almost_eq!(super::beta_inc(1.0, 1.0, 1.0), 1.0, 1e-15); assert_eq!(super::beta_inc(2.5, 1.0, 0.5), 0.0707106781186547524401); assert_almost_eq!(super::beta_inc(2.5, 1.0, 1.0), 0.4, 1e-14); assert_almost_eq!(super::beta_inc(0.5, 2.5, 0.5), 1.08904862254808623221, 1e-15); assert_almost_eq!(super::beta_inc(0.5, 2.5, 1.0), 1.17809724509617246442, 1e-15); assert_almost_eq!(super::beta_inc(1.0, 2.5, 0.5), 0.32928932188134524756, 1e-14); assert_almost_eq!(super::beta_inc(1.0, 2.5, 1.0), 0.4, 1e-14); assert_almost_eq!(super::beta_inc(2.5, 2.5, 0.5), 0.03681553890925538951323, 1e-15); assert_almost_eq!(super::beta_inc(2.5, 2.5, 1.0), 0.073631077818510779026, 1e-15); } #[test] #[should_panic] fn test_beta_inc_a_lte_0() { super::beta_inc(0.0, 1.0, 1.0); } #[test] #[should_panic] fn test_beta_inc_b_lte_0() { super::beta_inc(1.0, 0.0, 1.0); } #[test] #[should_panic] fn test_beta_inc_x_lt_0() { super::beta_inc(1.0, 1.0, -1.0); } #[test] #[should_panic] fn test_beta_inc_x_gt_1() { super::beta_inc(1.0, 1.0, 2.0); } #[test] fn test_checked_beta_inc_a_lte_0() { assert!(super::checked_beta_inc(0.0, 1.0, 1.0).is_err()); } #[test] fn test_checked_beta_inc_b_lte_0() { assert!(super::checked_beta_inc(1.0, 0.0, 1.0).is_err()); } #[test] fn test_checked_beta_inc_x_lt_0() { assert!(super::checked_beta_inc(1.0, 1.0, -1.0).is_err()); } #[test] fn test_checked_beta_inc_x_gt_1() { assert!(super::checked_beta_inc(1.0, 1.0, 2.0).is_err()); } #[test] fn test_beta_reg() { assert_almost_eq!(super::beta_reg(0.5, 0.5, 0.5), 0.5, 1e-15); assert_eq!(super::beta_reg(0.5, 0.5, 1.0), 1.0); assert_almost_eq!(super::beta_reg(1.0, 0.5, 0.5), 0.292893218813452475599, 1e-15); assert_eq!(super::beta_reg(1.0, 0.5, 1.0), 1.0); assert_almost_eq!(super::beta_reg(2.5, 0.5, 0.5), 0.07558681842161243795, 1e-16); assert_eq!(super::beta_reg(2.5, 0.5, 1.0), 1.0); assert_almost_eq!(super::beta_reg(0.5, 1.0, 0.5), 0.7071067811865475244, 1e-15); assert_eq!(super::beta_reg(0.5, 1.0, 1.0), 1.0); assert_almost_eq!(super::beta_reg(1.0, 1.0, 0.5), 0.5, 1e-15); assert_eq!(super::beta_reg(1.0, 1.0, 1.0), 1.0); assert_almost_eq!(super::beta_reg(2.5, 1.0, 0.5), 0.1767766952966368811, 1e-15); assert_eq!(super::beta_reg(2.5, 1.0, 1.0), 1.0); assert_eq!(super::beta_reg(0.5, 2.5, 0.5), 0.92441318157838756205); assert_eq!(super::beta_reg(0.5, 2.5, 1.0), 1.0); assert_almost_eq!(super::beta_reg(1.0, 2.5, 0.5), 0.8232233047033631189, 1e-15); assert_eq!(super::beta_reg(1.0, 2.5, 1.0), 1.0); assert_almost_eq!(super::beta_reg(2.5, 2.5, 0.5), 0.5, 1e-15); assert_eq!(super::beta_reg(2.5, 2.5, 1.0), 1.0); } #[test] #[should_panic] fn test_beta_reg_a_lte_0() { super::beta_reg(0.0, 1.0, 1.0); } #[test] #[should_panic] fn test_beta_reg_b_lte_0() { super::beta_reg(1.0, 0.0, 1.0); } #[test] #[should_panic] fn test_beta_reg_x_lt_0() { super::beta_reg(1.0, 1.0, -1.0); } #[test] #[should_panic] fn test_beta_reg_x_gt_1() { super::beta_reg(1.0, 1.0, 2.0); } #[test] fn test_checked_beta_reg_a_lte_0() { assert!(super::checked_beta_reg(0.0, 1.0, 1.0).is_err()); } #[test] fn test_checked_beta_reg_b_lte_0() { assert!(super::checked_beta_reg(1.0, 0.0, 1.0).is_err()); } #[test] fn test_checked_beta_reg_x_lt_0() { assert!(super::checked_beta_reg(1.0, 1.0, -1.0).is_err()); } #[test] fn test_checked_beta_reg_x_gt_1() { assert!(super::checked_beta_reg(1.0, 1.0, 2.0).is_err()); } #[test] fn test_error_is_sync_send() { fn assert_sync_send() {} assert_sync_send::(); } } statrs-0.18.0/src/function/erf.rs000064400000000000000000000666261046102023000150100ustar 00000000000000//! Provides the [error](https://en.wikipedia.org/wiki/Error_function) and //! related functions use crate::function::evaluate; use std::f64; /// `erf` calculates the error function at `x`. pub fn erf(x: f64) -> f64 { if x.is_nan() { f64::NAN } else if x >= 0.0 && x.is_infinite() { 1.0 } else if x <= 0.0 && x.is_infinite() { -1.0 } else if x == 0.0 { 0.0 } else { erf_impl(x, false) } } /// `erf_inv` calculates the inverse error function /// at `x`. pub fn erf_inv(x: f64) -> f64 { if x == 0.0 { 0.0 } else if x >= 1.0 { f64::INFINITY } else if x <= -1.0 { f64::NEG_INFINITY } else if x < 0.0 { erf_inv_impl(-x, 1.0 + x, -1.0) } else { erf_inv_impl(x, 1.0 - x, 1.0) } } /// `erfc` calculates the complementary error function /// at `x`. pub fn erfc(x: f64) -> f64 { if x.is_nan() { f64::NAN } else if x == f64::INFINITY { 0.0 } else if x == f64::NEG_INFINITY { 2.0 } else { erf_impl(x, true) } } /// `erfc_inv` calculates the complementary inverse /// error function at `x`. pub fn erfc_inv(x: f64) -> f64 { if x <= 0.0 { f64::INFINITY } else if x >= 2.0 { f64::NEG_INFINITY } else if x > 1.0 { erf_inv_impl(-1.0 + x, 2.0 - x, -1.0) } else { erf_inv_impl(1.0 - x, x, 1.0) } } // ********************************************************** // ********** Coefficients for erf_impl polynomial ********** // ********************************************************** /// Polynomial coefficients for a numerator of `erf_impl` /// in the interval [1e-10, 0.5]. const ERF_IMPL_AN: &[f64] = &[ 0.00337916709551257388990745, -0.00073695653048167948530905, -0.374732337392919607868241, 0.0817442448733587196071743, -0.0421089319936548595203468, 0.0070165709512095756344528, -0.00495091255982435110337458, 0.000871646599037922480317225, ]; /// Polynomial coefficients for a denominator of `erf_impl` /// in the interval [1e-10, 0.5] const ERF_IMPL_AD: &[f64] = &[ 1.0, -0.218088218087924645390535, 0.412542972725442099083918, -0.0841891147873106755410271, 0.0655338856400241519690695, -0.0120019604454941768171266, 0.00408165558926174048329689, -0.000615900721557769691924509, ]; /// Polynomial coefficients for a numerator in `erf_impl` /// in the interval [0.5, 0.75]. const ERF_IMPL_BN: &[f64] = &[ -0.0361790390718262471360258, 0.292251883444882683221149, 0.281447041797604512774415, 0.125610208862766947294894, 0.0274135028268930549240776, 0.00250839672168065762786937, ]; /// Polynomial coefficients for a denominator in `erf_impl` /// in the interval [0.5, 0.75]. const ERF_IMPL_BD: &[f64] = &[ 1.0, 1.8545005897903486499845, 1.43575803037831418074962, 0.582827658753036572454135, 0.124810476932949746447682, 0.0113724176546353285778481, ]; /// Polynomial coefficients for a numerator in `erf_impl` /// in the interval [0.75, 1.25]. const ERF_IMPL_CN: &[f64] = &[ -0.0397876892611136856954425, 0.153165212467878293257683, 0.191260295600936245503129, 0.10276327061989304213645, 0.029637090615738836726027, 0.0046093486780275489468812, 0.000307607820348680180548455, ]; /// Polynomial coefficients for a denominator in `erf_impl` /// in the interval [0.75, 1.25]. const ERF_IMPL_CD: &[f64] = &[ 1.0, 1.95520072987627704987886, 1.64762317199384860109595, 0.768238607022126250082483, 0.209793185936509782784315, 0.0319569316899913392596356, 0.00213363160895785378615014, ]; /// Polynomial coefficients for a numerator in `erf_impl` /// in the interval [1.25, 2.25]. const ERF_IMPL_DN: &[f64] = &[ -0.0300838560557949717328341, 0.0538578829844454508530552, 0.0726211541651914182692959, 0.0367628469888049348429018, 0.00964629015572527529605267, 0.00133453480075291076745275, 0.778087599782504251917881e-4, ]; /// Polynomial coefficients for a denominator in `erf_impl` /// in the interval [1.25, 2.25]. const ERF_IMPL_DD: &[f64] = &[ 1.0, 1.75967098147167528287343, 1.32883571437961120556307, 0.552528596508757581287907, 0.133793056941332861912279, 0.0179509645176280768640766, 0.00104712440019937356634038, -0.106640381820357337177643e-7, ]; /// Polynomial coefficients for a numerator in `erf_impl` /// in the interval [2.25, 3.5]. const ERF_IMPL_EN: &[f64] = &[ -0.0117907570137227847827732, 0.014262132090538809896674, 0.0202234435902960820020765, 0.00930668299990432009042239, 0.00213357802422065994322516, 0.00025022987386460102395382, 0.120534912219588189822126e-4, ]; /// Polynomial coefficients for a denominator in `erf_impl` /// in the interval [2.25, 3.5]. const ERF_IMPL_ED: &[f64] = &[ 1.0, 1.50376225203620482047419, 0.965397786204462896346934, 0.339265230476796681555511, 0.0689740649541569716897427, 0.00771060262491768307365526, 0.000371421101531069302990367, ]; /// Polynomial coefficients for a numerator in `erf_impl` /// in the interval [3.5, 5.25]. const ERF_IMPL_FN: &[f64] = &[ -0.00546954795538729307482955, 0.00404190278731707110245394, 0.0054963369553161170521356, 0.00212616472603945399437862, 0.000394984014495083900689956, 0.365565477064442377259271e-4, 0.135485897109932323253786e-5, ]; /// Polynomial coefficients for a denominator in `erf_impl` /// in the interval [3.5, 5.25]. const ERF_IMPL_FD: &[f64] = &[ 1.0, 1.21019697773630784832251, 0.620914668221143886601045, 0.173038430661142762569515, 0.0276550813773432047594539, 0.00240625974424309709745382, 0.891811817251336577241006e-4, -0.465528836283382684461025e-11, ]; /// Polynomial coefficients for a numerator in `erf_impl` /// in the interval [5.25, 8]. const ERF_IMPL_GN: &[f64] = &[ -0.00270722535905778347999196, 0.0013187563425029400461378, 0.00119925933261002333923989, 0.00027849619811344664248235, 0.267822988218331849989363e-4, 0.923043672315028197865066e-6, ]; /// Polynomial coefficients for a denominator in `erf_impl` /// in the interval [5.25, 8]. const ERF_IMPL_GD: &[f64] = &[ 1.0, 0.814632808543141591118279, 0.268901665856299542168425, 0.0449877216103041118694989, 0.00381759663320248459168994, 0.000131571897888596914350697, 0.404815359675764138445257e-11, ]; /// Polynomial coefficients for a numerator in `erf_impl` /// in the interval [8, 11.5]. const ERF_IMPL_HN: &[f64] = &[ -0.00109946720691742196814323, 0.000406425442750422675169153, 0.000274499489416900707787024, 0.465293770646659383436343e-4, 0.320955425395767463401993e-5, 0.778286018145020892261936e-7, ]; /// Polynomial coefficients for a denominator in `erf_impl` /// in the interval [8, 11.5]. const ERF_IMPL_HD: &[f64] = &[ 1.0, 0.588173710611846046373373, 0.139363331289409746077541, 0.0166329340417083678763028, 0.00100023921310234908642639, 0.24254837521587225125068e-4, ]; /// Polynomial coefficients for a numerator in `erf_impl` /// in the interval [11.5, 17]. const ERF_IMPL_IN: &[f64] = &[ -0.00056907993601094962855594, 0.000169498540373762264416984, 0.518472354581100890120501e-4, 0.382819312231928859704678e-5, 0.824989931281894431781794e-7, ]; /// Polynomial coefficients for a denominator in `erf_impl` /// in the interval [11.5, 17]. const ERF_IMPL_ID: &[f64] = &[ 1.0, 0.339637250051139347430323, 0.043472647870310663055044, 0.00248549335224637114641629, 0.535633305337152900549536e-4, -0.117490944405459578783846e-12, ]; /// Polynomial coefficients for a numerator in `erf_impl` /// in the interval [17, 24]. const ERF_IMPL_JN: &[f64] = &[ -0.000241313599483991337479091, 0.574224975202501512365975e-4, 0.115998962927383778460557e-4, 0.581762134402593739370875e-6, 0.853971555085673614607418e-8, ]; /// Polynomial coefficients for a denominator in `erf_impl` /// in the interval [17, 24]. const ERF_IMPL_JD: &[f64] = &[ 1.0, 0.233044138299687841018015, 0.0204186940546440312625597, 0.000797185647564398289151125, 0.117019281670172327758019e-4, ]; /// Polynomial coefficients for a numerator in `erf_impl` /// in the interval [24, 38]. const ERF_IMPL_KN: &[f64] = &[ -0.000146674699277760365803642, 0.162666552112280519955647e-4, 0.269116248509165239294897e-5, 0.979584479468091935086972e-7, 0.101994647625723465722285e-8, ]; /// Polynomial coefficients for a denominator in `erf_impl` /// in the interval [24, 38]. const ERF_IMPL_KD: &[f64] = &[ 1.0, 0.165907812944847226546036, 0.0103361716191505884359634, 0.000286593026373868366935721, 0.298401570840900340874568e-5, ]; /// Polynomial coefficients for a numerator in `erf_impl` /// in the interval [38, 60]. const ERF_IMPL_LN: &[f64] = &[ -0.583905797629771786720406e-4, 0.412510325105496173512992e-5, 0.431790922420250949096906e-6, 0.993365155590013193345569e-8, 0.653480510020104699270084e-10, ]; /// Polynomial coefficients for a denominator in `erf_impl` /// in the interval [38, 60]. const ERF_IMPL_LD: &[f64] = &[ 1.0, 0.105077086072039915406159, 0.00414278428675475620830226, 0.726338754644523769144108e-4, 0.477818471047398785369849e-6, ]; /// Polynomial coefficients for a numerator in `erf_impl` /// in the interval [60, 85]. const ERF_IMPL_MN: &[f64] = &[ -0.196457797609229579459841e-4, 0.157243887666800692441195e-5, 0.543902511192700878690335e-7, 0.317472492369117710852685e-9, ]; /// Polynomial coefficients for a denominator in `erf_impl` /// in the interval [60, 85]. const ERF_IMPL_MD: &[f64] = &[ 1.0, 0.052803989240957632204885, 0.000926876069151753290378112, 0.541011723226630257077328e-5, 0.535093845803642394908747e-15, ]; /// Polynomial coefficients for a numerator in `erf_impl` /// in the interval [85, 110]. const ERF_IMPL_NN: &[f64] = &[ -0.789224703978722689089794e-5, 0.622088451660986955124162e-6, 0.145728445676882396797184e-7, 0.603715505542715364529243e-10, ]; /// Polynomial coefficients for a denominator in `erf_impl` /// in the interval [85, 110]. const ERF_IMPL_ND: &[f64] = &[ 1.0, 0.0375328846356293715248719, 0.000467919535974625308126054, 0.193847039275845656900547e-5, ]; // ********************************************************** // ********** Coefficients for erf_inv_impl polynomial ****** // ********************************************************** /// Polynomial coefficients for a numerator of `erf_inv_impl` /// in the interval [0, 0.5]. const ERF_INV_IMPL_AN: &[f64] = &[ -0.000508781949658280665617, -0.00836874819741736770379, 0.0334806625409744615033, -0.0126926147662974029034, -0.0365637971411762664006, 0.0219878681111168899165, 0.00822687874676915743155, -0.00538772965071242932965, ]; /// Polynomial coefficients for a denominator of `erf_inv_impl` /// in the interval [0, 0.5]. const ERF_INV_IMPL_AD: &[f64] = &[ 1.0, -0.970005043303290640362, -1.56574558234175846809, 1.56221558398423026363, 0.662328840472002992063, -0.71228902341542847553, -0.0527396382340099713954, 0.0795283687341571680018, -0.00233393759374190016776, 0.000886216390456424707504, ]; /// Polynomial coefficients for a numerator of `erf_inv_impl` /// in the interval [0.5, 0.75]. const ERF_INV_IMPL_BN: &[f64] = &[ -0.202433508355938759655, 0.105264680699391713268, 8.37050328343119927838, 17.6447298408374015486, -18.8510648058714251895, -44.6382324441786960818, 17.445385985570866523, 21.1294655448340526258, -3.67192254707729348546, ]; /// Polynomial coefficients for a denominator of `erf_inv_impl` /// in the interval [0.5, 0.75]. const ERF_INV_IMPL_BD: &[f64] = &[ 1.0, 6.24264124854247537712, 3.9713437953343869095, -28.6608180499800029974, -20.1432634680485188801, 48.5609213108739935468, 10.8268667355460159008, -22.6436933413139721736, 1.72114765761200282724, ]; /// Polynomial coefficients for a numerator of `erf_inv_impl` /// in the interval [0.75, 1] with x less than 3. const ERF_INV_IMPL_CN: &[f64] = &[ -0.131102781679951906451, -0.163794047193317060787, 0.117030156341995252019, 0.387079738972604337464, 0.337785538912035898924, 0.142869534408157156766, 0.0290157910005329060432, 0.00214558995388805277169, -0.679465575181126350155e-6, 0.285225331782217055858e-7, -0.681149956853776992068e-9, ]; /// Polynomial coefficients for a denominator of `erf_inv_impl` /// in the interval [0.75, 1] with x less than 3. const ERF_INV_IMPL_CD: &[f64] = &[ 1.0, 3.46625407242567245975, 5.38168345707006855425, 4.77846592945843778382, 2.59301921623620271374, 0.848854343457902036425, 0.152264338295331783612, 0.01105924229346489121, ]; /// Polynomial coefficients for a numerator of `erf_inv_impl` /// in the interval [0.75, 1] with x between 3 and 6. const ERF_INV_IMPL_DN: &[f64] = &[ -0.0350353787183177984712, -0.00222426529213447927281, 0.0185573306514231072324, 0.00950804701325919603619, 0.00187123492819559223345, 0.000157544617424960554631, 0.460469890584317994083e-5, -0.230404776911882601748e-9, 0.266339227425782031962e-11, ]; /// Polynomial coefficients for a denominator of `erf_inv_impl` /// in the interval [0.75, 1] with x between 3 and 6. const ERF_INV_IMPL_DD: &[f64] = &[ 1.0, 1.3653349817554063097, 0.762059164553623404043, 0.220091105764131249824, 0.0341589143670947727934, 0.00263861676657015992959, 0.764675292302794483503e-4, ]; /// Polynomial coefficients for a numerator of `erf_inv_impl` /// in the interval [0.75, 1] with x between 6 and 18. const ERF_INV_IMPL_EN: &[f64] = &[ -0.0167431005076633737133, -0.00112951438745580278863, 0.00105628862152492910091, 0.000209386317487588078668, 0.149624783758342370182e-4, 0.449696789927706453732e-6, 0.462596163522878599135e-8, -0.281128735628831791805e-13, 0.99055709973310326855e-16, ]; /// Polynomial coefficients for a denominator of `erf_inv_impl` /// in the interval [0.75, 1] with x between 6 and 18. const ERF_INV_IMPL_ED: &[f64] = &[ 1.0, 0.591429344886417493481, 0.138151865749083321638, 0.0160746087093676504695, 0.000964011807005165528527, 0.275335474764726041141e-4, 0.282243172016108031869e-6, ]; /// Polynomial coefficients for a numerator of `erf_inv_impl` /// in the interval [0.75, 1] with x between 18 and 44. const ERF_INV_IMPL_FN: &[f64] = &[ -0.0024978212791898131227, -0.779190719229053954292e-5, 0.254723037413027451751e-4, 0.162397777342510920873e-5, 0.396341011304801168516e-7, 0.411632831190944208473e-9, 0.145596286718675035587e-11, -0.116765012397184275695e-17, ]; /// Polynomial coefficients for a denominator of `erf_inv_impl` /// in the interval [0.75, 1] with x between 18 and 44. const ERF_INV_IMPL_FD: &[f64] = &[ 1.0, 0.207123112214422517181, 0.0169410838120975906478, 0.000690538265622684595676, 0.145007359818232637924e-4, 0.144437756628144157666e-6, 0.509761276599778486139e-9, ]; /// Polynomial coefficients for a numerator of `erf_inv_impl` /// in the interval [0.75, 1] with x greater than 44. const ERF_INV_IMPL_GN: &[f64] = &[ -0.000539042911019078575891, -0.28398759004727721098e-6, 0.899465114892291446442e-6, 0.229345859265920864296e-7, 0.225561444863500149219e-9, 0.947846627503022684216e-12, 0.135880130108924861008e-14, -0.348890393399948882918e-21, ]; /// Polynomial coefficients for a denominator of `erf_inv_impl` /// in the interval [0.75, 1] with x greater than 44. const ERF_INV_IMPL_GD: &[f64] = &[ 1.0, 0.0845746234001899436914, 0.00282092984726264681981, 0.468292921940894236786e-4, 0.399968812193862100054e-6, 0.161809290887904476097e-8, 0.231558608310259605225e-11, ]; /// `erf_impl` computes the error function at `z`. /// If `inv` is true, `1 - erf` is calculated as opposed to `erf` fn erf_impl(z: f64, inv: bool) -> f64 { if z < 0.0 { if !inv { return -erf_impl(-z, false); } if z < -0.5 { return 2.0 - erf_impl(-z, true); } return 1.0 + erf_impl(-z, false); } let result = if z < 0.5 { if z < 1e-10 { z * 1.125 + z * 0.003379167095512573896158903121545171688 } else { z * 1.125 + z * evaluate::polynomial(z, ERF_IMPL_AN) / evaluate::polynomial(z, ERF_IMPL_AD) } } else if z < 110.0 { let (r, b) = if z < 0.75 { ( evaluate::polynomial(z - 0.5, ERF_IMPL_BN) / evaluate::polynomial(z - 0.5, ERF_IMPL_BD), 0.3440242112, ) } else if z < 1.25 { ( evaluate::polynomial(z - 0.75, ERF_IMPL_CN) / evaluate::polynomial(z - 0.75, ERF_IMPL_CD), 0.419990927, ) } else if z < 2.25 { ( evaluate::polynomial(z - 1.25, ERF_IMPL_DN) / evaluate::polynomial(z - 1.25, ERF_IMPL_DD), 0.4898625016, ) } else if z < 3.5 { ( evaluate::polynomial(z - 2.25, ERF_IMPL_EN) / evaluate::polynomial(z - 2.25, ERF_IMPL_ED), 0.5317370892, ) } else if z < 5.25 { ( evaluate::polynomial(z - 3.5, ERF_IMPL_FN) / evaluate::polynomial(z - 3.5, ERF_IMPL_FD), 0.5489973426, ) } else if z < 8.0 { ( evaluate::polynomial(z - 5.25, ERF_IMPL_GN) / evaluate::polynomial(z - 5.25, ERF_IMPL_GD), 0.5571740866, ) } else if z < 11.5 { ( evaluate::polynomial(z - 8.0, ERF_IMPL_HN) / evaluate::polynomial(z - 8.0, ERF_IMPL_HD), 0.5609807968, ) } else if z < 17.0 { ( evaluate::polynomial(z - 11.5, ERF_IMPL_IN) / evaluate::polynomial(z - 11.5, ERF_IMPL_ID), 0.5626493692, ) } else if z < 24.0 { ( evaluate::polynomial(z - 17.0, ERF_IMPL_JN) / evaluate::polynomial(z - 17.0, ERF_IMPL_JD), 0.5634598136, ) } else if z < 38.0 { ( evaluate::polynomial(z - 24.0, ERF_IMPL_KN) / evaluate::polynomial(z - 24.0, ERF_IMPL_KD), 0.5638477802, ) } else if z < 60.0 { ( evaluate::polynomial(z - 38.0, ERF_IMPL_LN) / evaluate::polynomial(z - 38.0, ERF_IMPL_LD), 0.5640528202, ) } else if z < 85.0 { ( evaluate::polynomial(z - 60.0, ERF_IMPL_MN) / evaluate::polynomial(z - 60.0, ERF_IMPL_MD), 0.5641309023, ) } else { ( evaluate::polynomial(z - 85.0, ERF_IMPL_NN) / evaluate::polynomial(z - 85.0, ERF_IMPL_ND), 0.5641584396, ) }; let g = (-z * z).exp() / z; g * b + g * r } else { 0.0 }; if inv && z >= 0.5 { result } else if z >= 0.5 || inv { 1.0 - result } else { result } } // `erf_inv_impl` computes the inverse error function where // `p`,`q`, and `s` are the first, second, and third intermediate // parameters respectively fn erf_inv_impl(p: f64, q: f64, s: f64) -> f64 { let result = if p <= 0.5 { let y = 0.0891314744949340820313; let g = p * (p + 10.0); let r = evaluate::polynomial(p, ERF_INV_IMPL_AN) / evaluate::polynomial(p, ERF_INV_IMPL_AD); g * y + g * r } else if q >= 0.25 { let y = 2.249481201171875; let g = (-2.0 * q.ln()).sqrt(); let xs = q - 0.25; let r = evaluate::polynomial(xs, ERF_INV_IMPL_BN) / evaluate::polynomial(xs, ERF_INV_IMPL_BD); g / (y + r) } else { let x = (-q.ln()).sqrt(); if x < 3.0 { let y = 0.807220458984375; let xs = x - 1.125; let r = evaluate::polynomial(xs, ERF_INV_IMPL_CN) / evaluate::polynomial(xs, ERF_INV_IMPL_CD); y * x + r * x } else if x < 6.0 { let y = 0.93995571136474609375; let xs = x - 3.0; let r = evaluate::polynomial(xs, ERF_INV_IMPL_DN) / evaluate::polynomial(xs, ERF_INV_IMPL_DD); y * x + r * x } else if x < 18.0 { let y = 0.98362827301025390625; let xs = x - 6.0; let r = evaluate::polynomial(xs, ERF_INV_IMPL_EN) / evaluate::polynomial(xs, ERF_INV_IMPL_ED); y * x + r * x } else if x < 44.0 { let y = 0.99714565277099609375; let xs = x - 18.0; let r = evaluate::polynomial(xs, ERF_INV_IMPL_FN) / evaluate::polynomial(xs, ERF_INV_IMPL_FD); y * x + r * x } else { let y = 0.99941349029541015625; let xs = x - 44.0; let r = evaluate::polynomial(xs, ERF_INV_IMPL_GN) / evaluate::polynomial(xs, ERF_INV_IMPL_GD); y * x + r * x } }; s * result } #[rustfmt::skip] #[cfg(test)] mod tests { use std::f64; #[test] fn test_erf() { assert!(super::erf(f64::NAN).is_nan()); assert_almost_eq!(super::erf(-1.0), -0.84270079294971486934122063508260925929606699796630291, 1e-11); assert_eq!(super::erf(0.0), 0.0); assert_eq!(super::erf(1e-15), 0.0000000000000011283791670955126615773132947717431253912942469337536); assert_eq!(super::erf(0.1), 0.1124629160182848984047122510143040617233925185058162); assert_almost_eq!(super::erf(0.2), 0.22270258921047846617645303120925671669511570710081967, 1e-16); assert_eq!(super::erf(0.3), 0.32862675945912741618961798531820303325847175931290341); assert_eq!(super::erf(0.4), 0.42839235504666847645410962730772853743532927705981257); assert_almost_eq!(super::erf(0.5), 0.5204998778130465376827466538919645287364515757579637, 1e-9); assert_almost_eq!(super::erf(1.0), 0.84270079294971486934122063508260925929606699796630291, 1e-11); assert_almost_eq!(super::erf(1.5), 0.96610514647531072706697626164594785868141047925763678, 1e-11); assert_almost_eq!(super::erf(2.0), 0.99532226501895273416206925636725292861089179704006008, 1e-11); assert_almost_eq!(super::erf(2.5), 0.99959304798255504106043578426002508727965132259628658, 1e-13); assert_almost_eq!(super::erf(3.0), 0.99997790950300141455862722387041767962015229291260075, 1e-11); assert_eq!(super::erf(4.0), 0.99999998458274209971998114784032651311595142785474641); assert_eq!(super::erf(5.0), 0.99999999999846254020557196514981165651461662110988195); assert_eq!(super::erf(6.0), 0.99999999999999997848026328750108688340664960081261537); assert_eq!(super::erf(f64::INFINITY), 1.0); assert_eq!(super::erf(f64::NEG_INFINITY), -1.0); } #[test] fn test_erfc() { assert!(super::erfc(f64::NAN).is_nan()); assert_almost_eq!(super::erfc(-1.0), 1.8427007929497148693412206350826092592960669979663028, 1e-11); assert_eq!(super::erfc(0.0), 1.0); assert_almost_eq!(super::erfc(0.1), 0.88753708398171510159528774898569593827660748149418343, 1e-15); assert_eq!(super::erfc(0.2), 0.77729741078952153382354696879074328330488429289918085); assert_eq!(super::erfc(0.3), 0.67137324054087258381038201468179696674152824068709621); assert_almost_eq!(super::erfc(0.4), 0.57160764495333152354589037269227146256467072294018715, 1e-15); assert_almost_eq!(super::erfc(0.5), 0.47950012218695346231725334610803547126354842424203654, 1e-9); assert_almost_eq!(super::erfc(1.0), 0.15729920705028513065877936491739074070393300203369719, 1e-11); assert_almost_eq!(super::erfc(1.5), 0.033894853524689272933023738354052141318589520742363247, 1e-11); assert_almost_eq!(super::erfc(2.0), 0.0046777349810472658379307436327470713891082029599399245, 1e-11); assert_almost_eq!(super::erfc(2.5), 0.00040695201744495893956421573997491272034867740371342016, 1e-13); assert_almost_eq!(super::erfc(3.0), 0.00002209049699858544137277612958232037984770708739924966, 1e-11); assert_almost_eq!(super::erfc(4.0), 0.000000015417257900280018852159673486884048572145253589191167, 1e-18); assert_almost_eq!(super::erfc(5.0), 0.0000000000015374597944280348501883434853833788901180503147233804, 1e-22); assert_almost_eq!(super::erfc(6.0), 2.1519736712498913116593350399187384630477514061688559e-17, 1e-26); assert_almost_eq!(super::erfc(10.0), 2.0884875837625447570007862949577886115608181193211634e-45, 1e-55); assert_almost_eq!(super::erfc(15.0), 7.2129941724512066665650665586929271099340909298253858e-100, 1e-109); assert_almost_eq!(super::erfc(20.0), 5.3958656116079009289349991679053456040882726709236071e-176, 1e-186); assert_eq!(super::erfc(30.0), 2.5646562037561116000333972775014471465488897227786155e-393); assert_eq!(super::erfc(50.0), 2.0709207788416560484484478751657887929322509209953988e-1088); assert_eq!(super::erfc(80.0), 2.3100265595063985852034904366341042118385080919280966e-2782); assert_eq!(super::erfc(f64::INFINITY), 0.0); assert_eq!(super::erfc(f64::NEG_INFINITY), 2.0); } #[test] fn test_erf_inv() { assert!(super::erf_inv(f64::NAN).is_nan()); assert_eq!(super::erf_inv(-1.0), f64::NEG_INFINITY); assert_eq!(super::erf_inv(0.0), 0.0); assert_almost_eq!(super::erf_inv(1e-15), 8.86226925452758013649e-16, 1e-30); assert_eq!(super::erf_inv(0.1), 0.08885599049425768701574); assert_almost_eq!(super::erf_inv(0.2), 0.1791434546212916764927, 1e-15); assert_eq!(super::erf_inv(0.3), 0.272462714726754355622); assert_eq!(super::erf_inv(0.4), 0.3708071585935579290582); assert_eq!(super::erf_inv(0.5), 0.4769362762044698733814); assert_eq!(super::erf_inv(1.0), f64::INFINITY); assert_eq!(super::erf_inv(f64::INFINITY), f64::INFINITY); assert_eq!(super::erf_inv(f64::NEG_INFINITY), f64::NEG_INFINITY); } #[test] fn test_erfc_inv() { assert_eq!(super::erfc_inv(0.0), f64::INFINITY); assert_almost_eq!(super::erfc_inv(1e-100), 15.065574702593, 1e-11); assert_almost_eq!(super::erfc_inv(1e-30), 8.1486162231699, 1e-12); assert_almost_eq!(super::erfc_inv(1e-20), 6.6015806223551, 1e-13); assert_almost_eq!(super::erfc_inv(1e-10), 4.5728249585449249378479309946884581365517663258840893, 1e-7); assert_almost_eq!(super::erfc_inv(1e-5), 3.1234132743415708640270717579666062107939039971365252, 1e-11); assert_almost_eq!(super::erfc_inv(0.1), 1.1630871536766741628440954340547000483801487126688552, 1e-14); assert_almost_eq!(super::erfc_inv(0.2), 0.90619380243682330953597079527631536107443494091638384, 1e-15); assert_eq!(super::erfc_inv(0.5), 0.47693627620446987338141835364313055980896974905947083); assert_eq!(super::erfc_inv(1.0), 0.0); assert_eq!(super::erfc_inv(1.5), -0.47693627620446987338141835364313055980896974905947083); assert_eq!(super::erfc_inv(2.0), f64::NEG_INFINITY); } } statrs-0.18.0/src/function/evaluate.rs000064400000000000000000000037501046102023000160270ustar 00000000000000//! Provides functions that don't have a numerical solution and must //! be solved computationally (e.g. evaluation of a polynomial) /// evaluates a polynomial at `z` where `coeff` are the coeffecients /// to a polynomial of order `k` where `k` is the length of `coeff` and the /// coeffecient /// to the `k`th power is the `k`th element in coeff. E.g. [3,-1,2] equates to /// `2z^2 - z + 3` /// /// # Remarks /// /// Returns 0 for a 0 length coefficient slice pub fn polynomial(z: f64, coeff: &[f64]) -> f64 { let n = coeff.len(); if n == 0 { return 0.0; } let mut sum = *coeff.last().unwrap(); for c in coeff[0..n - 1].iter().rev() { sum = *c + z * sum; } sum } #[rustfmt::skip] #[cfg(test)] mod tests { use std::f64; // these tests probably could be more robust #[test] fn test_polynomial() { let empty: [f64; 0] = []; assert_eq!(super::polynomial(2.0, &empty), 0.0); let zero = [0.0]; assert_eq!(super::polynomial(2.0, &zero), 0.0); let mut coeff = [1.0, 0.0, 5.0]; assert_eq!(super::polynomial(2.0, &coeff), 21.0); coeff = [-5.0, -2.0, 3.0]; assert_eq!(super::polynomial(2.0, &coeff), 3.0); assert_eq!(super::polynomial(-2.0, &coeff), 11.0); let large_coeff = [-1.35e3, 2.5e2, 8.0, -4.0, 1e2, 3.0]; assert_eq!(super::polynomial(5.0, &large_coeff), 71475.0); assert_eq!(super::polynomial(-5.0, &large_coeff), 51225.0); coeff = [f64::INFINITY, -2.0, 3.0]; assert_eq!(super::polynomial(2.0, &coeff), f64::INFINITY); assert_eq!(super::polynomial(-2.0, &coeff), f64::INFINITY); coeff = [f64::NEG_INFINITY, -2.0, 3.0]; assert_eq!(super::polynomial(2.0, &coeff), f64::NEG_INFINITY); assert_eq!(super::polynomial(-2.0, &coeff), f64::NEG_INFINITY); coeff = [f64::NAN, -2.0, 3.0]; assert!(super::polynomial(2.0, &coeff).is_nan()); assert!(super::polynomial(-2.0, &coeff).is_nan()); } } statrs-0.18.0/src/function/exponential.rs000064400000000000000000000074731046102023000165550ustar 00000000000000//! Provides functions related to exponential calculations use crate::consts; /// Computes the generalized Exponential Integral function /// where `x` is the argument and `n` is the integer power of the /// denominator term. /// /// Returns `None` if `x < 0.0` or the computation could not /// converge after 100 iterations /// /// # Remarks /// /// This implementation follows the derivation in /// /// _"Handbook of Mathematical Functions, Applied Mathematics Series, Volume /// 55"_ - Abramowitz, M., and Stegun, I.A 1964 /// /// AND /// /// _"Advanced mathematical methods for scientists and engineers"_ - Bender, /// Carl M.; Steven A. Orszag (1978). page 253 /// /// The continued fraction approach is used for `x > 1.0` while the taylor /// series expansions is used for `0.0 < x <= 1`. // TODO: Add examples pub fn integral(x: f64, n: u64) -> Option { let eps = 0.00000000000000001; let max_iter = 100; let nf64 = n as f64; let near_f64min = 1e-100; // needs very small value that is not quite as small as f64 min // special cases if n == 0 { return Some((-1.0 * x).exp() / x); } if x == 0.0 { return Some(1.0 / (nf64 - 1.0)); } if x > 1.0 { let mut b = x + nf64; let mut c = 1.0 / near_f64min; let mut d = 1.0 / b; let mut h = d; for i in 1..max_iter + 1 { let a = -1.0 * i as f64 * (nf64 - 1.0 + i as f64); b += 2.0; d = 1.0 / (a * d + b); c = b + a / c; let del = c * d; h *= del; if (del - 1.0).abs() < eps { return Some(h * (-x).exp()); } } None } else { let mut factorial = 1.0; let mut result = if n - 1 != 0 { 1.0 / (nf64 - 1.0) } else { -1.0 * x.ln() - consts::EULER_MASCHERONI }; for i in 1..max_iter + 1 { factorial *= -1.0 * x / i as f64; let del = if i != n - 1 { -factorial / (i as f64 - nf64 + 1.0) } else { let mut psi = -1.0 * consts::EULER_MASCHERONI; for ii in 1..n { psi += 1.0 / ii as f64; } factorial * (-1.0 * x.ln() + psi) }; result += del; if del.abs() < result.abs() * eps { return Some(result); } } None } } #[rustfmt::skip] #[cfg(test)] mod tests { #[test] fn test_integral() { assert_eq!(super::integral(0.001, 1).unwrap(), 6.33153936413614904); assert_almost_eq!(super::integral(0.1, 1).unwrap(), 1.82292395841939059, 1e-15); assert_eq!(super::integral(1.0, 1).unwrap(), 0.219383934395520286); assert_almost_eq!(super::integral(2.0, 1).unwrap(), 0.0489005107080611248, 1e-15); assert_almost_eq!(super::integral(2.5, 1).unwrap(), 0.0249149178702697399, 1e-15); assert_almost_eq!(super::integral(10.0, 1).unwrap(), 4.15696892968532464e-06, 1e-20); assert_eq!(super::integral(0.001, 2).unwrap(), 0.992668960469238915); assert_almost_eq!(super::integral(0.1, 2).unwrap(), 0.722545022194020392, 1e-15); assert_almost_eq!(super::integral(1.0, 2).unwrap(), 0.148495506775922048, 1e-16); assert_almost_eq!(super::integral(2.0, 2).unwrap(), 0.0375342618204904527, 1e-16); assert_almost_eq!(super::integral(10.0, 2).unwrap(), 3.830240465631608e-06, 1e-20); assert_eq!(super::integral(0.001, 0).unwrap(), 999.000499833375); assert_eq!(super::integral(0.1, 0).unwrap(), 9.048374180359595); assert_almost_eq!(super::integral(1.0, 0).unwrap(), 0.3678794411714423, 1e-16); assert_eq!(super::integral(2.0, 0).unwrap(), 0.06766764161830635); assert_eq!(super::integral(10.0, 0).unwrap(), 4.539992976248485e-06); } } statrs-0.18.0/src/function/factorial.rs000064400000000000000000000116361046102023000161670ustar 00000000000000//! Provides functions related to factorial calculations (e.g. binomial //! coefficient, factorial, multinomial) use crate::function::gamma; /// The maximum factorial representable /// by a 64-bit floating point without /// overflowing pub const MAX_FACTORIAL: usize = 170; /// Computes the factorial function `x -> x!` for /// `170 >= x >= 0`. All factorials larger than `170!` /// will overflow an `f64`. /// /// # Remarks /// /// Returns `f64::INFINITY` if `x > 170` pub fn factorial(x: u64) -> f64 { let x = x as usize; FCACHE.get(x).map_or(f64::INFINITY, |&fac| fac) } /// Computes the logarithmic factorial function `x -> ln(x!)` /// for `x >= 0`. /// /// # Remarks /// /// Returns `0.0` if `x <= 1` pub fn ln_factorial(x: u64) -> f64 { let x = x as usize; FCACHE .get(x) .map_or_else(|| gamma::ln_gamma(x as f64 + 1.0), |&fac| fac.ln()) } /// Computes the binomial coefficient `n choose k` /// where `k` and `n` are non-negative values. /// /// # Remarks /// /// Returns `0.0` if `k > n` pub fn binomial(n: u64, k: u64) -> f64 { if k > n { 0.0 } else { (0.5 + (ln_factorial(n) - ln_factorial(k) - ln_factorial(n - k)).exp()).floor() } } /// Computes the natural logarithm of the binomial coefficient /// `ln(n choose k)` where `k` and `n` are non-negative values /// /// # Remarks /// /// Returns `f64::NEG_INFINITY` if `k > n` pub fn ln_binomial(n: u64, k: u64) -> f64 { if k > n { f64::NEG_INFINITY } else { ln_factorial(n) - ln_factorial(k) - ln_factorial(n - k) } } /// Computes the multinomial coefficient: `n choose n1, n2, n3, ...` /// /// # Panics /// /// If the elements in `ni` do not sum to `n` pub fn multinomial(n: u64, ni: &[u64]) -> f64 { checked_multinomial(n, ni).unwrap() } /// Computes the multinomial coefficient: `n choose n1, n2, n3, ...` /// /// Returns `None` if the elements in `ni` do not sum to `n`. pub fn checked_multinomial(n: u64, ni: &[u64]) -> Option { let (sum, ret) = ni.iter().fold((0, ln_factorial(n)), |acc, &x| { (acc.0 + x, acc.1 - ln_factorial(x)) }); if sum == n { Some((0.5 + ret.exp()).floor()) } else { None } } // Initialization for pre-computed cache of 171 factorial // values 0!...170! const FCACHE: [f64; MAX_FACTORIAL + 1] = { let mut fcache = [1.0; MAX_FACTORIAL + 1]; // `const` only allow while loops let mut i = 1; while i < MAX_FACTORIAL + 1 { fcache[i] = fcache[i - 1] * i as f64; i += 1; } fcache }; #[cfg(test)] mod tests { use super::*; #[test] fn test_fcache() { assert!((FCACHE[0] - 1.0).abs() < f64::EPSILON); assert!((FCACHE[1] - 1.0).abs() < f64::EPSILON); assert!((FCACHE[2] - 2.0).abs() < f64::EPSILON); assert!((FCACHE[3] - 6.0).abs() < f64::EPSILON); assert!((FCACHE[4] - 24.0).abs() < f64::EPSILON); assert!((FCACHE[70] - 1197857166996989e85).abs() < f64::EPSILON); assert!((FCACHE[170] - 7257415615307994e291).abs() < f64::EPSILON); } #[test] fn test_factorial_and_ln_factorial() { let mut fac = 1.0; assert_eq!(factorial(0), fac); for i in 1..171 { fac *= i as f64; assert_eq!(factorial(i), fac); assert_eq!(ln_factorial(i), fac.ln()); } } #[test] fn test_factorial_overflow() { assert_eq!(factorial(172), f64::INFINITY); assert_eq!(factorial(u64::MAX), f64::INFINITY); } #[test] fn test_ln_factorial_does_not_overflow() { assert_eq!(ln_factorial(1 << 10), 6078.2118847500501140); assert_almost_eq!(ln_factorial(1 << 12), 29978.648060844048236, 1e-11); assert_eq!(ln_factorial(1 << 15), 307933.81973375485425); assert_eq!(ln_factorial(1 << 17), 1413421.9939462073242); } #[test] fn test_binomial() { assert_eq!(binomial(1, 1), 1.0); assert_eq!(binomial(5, 2), 10.0); assert_eq!(binomial(7, 3), 35.0); assert_eq!(binomial(1, 0), 1.0); assert_eq!(binomial(0, 1), 0.0); assert_eq!(binomial(5, 7), 0.0); } #[test] fn test_ln_binomial() { assert_eq!(ln_binomial(1, 1), 1f64.ln()); assert_almost_eq!(ln_binomial(5, 2), 10f64.ln(), 1e-14); assert_almost_eq!(ln_binomial(7, 3), 35f64.ln(), 1e-14); assert_eq!(ln_binomial(1, 0), 1f64.ln()); assert_eq!(ln_binomial(0, 1), 0f64.ln()); assert_eq!(ln_binomial(5, 7), 0f64.ln()); } #[test] fn test_multinomial() { assert_eq!(1.0, multinomial(1, &[1, 0])); assert_eq!(10.0, multinomial(5, &[3, 2])); assert_eq!(10.0, multinomial(5, &[2, 3])); assert_eq!(35.0, multinomial(7, &[3, 4])); } #[test] #[should_panic] fn test_multinomial_bad_ni() { multinomial(1, &[1, 1]); } #[test] fn test_checked_multinomial_bad_ni() { assert!(checked_multinomial(1, &[1, 1]).is_none()); } } statrs-0.18.0/src/function/gamma.rs000064400000000000000000001030641046102023000153020ustar 00000000000000//! Provides the [gamma](https://en.wikipedia.org/wiki/Gamma_function) and //! related functions use crate::consts; use crate::prec; use std::f64; /// Represents the errors that can occur when computing any of the incomplete /// gamma functions. #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum GammaFuncError { /// `a` is infinite, zero or less than zero. AInvalid, /// `x` is infinite, zero or less than zero. XInvalid, } impl std::fmt::Display for GammaFuncError { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { GammaFuncError::AInvalid => write!(f, "a is infinite, zero or less than zero"), GammaFuncError::XInvalid => write!(f, "x is infinite, zero or less than zero"), } } } impl std::error::Error for GammaFuncError {} /// Auxiliary variable when evaluating the `gamma_ln` function const GAMMA_R: f64 = 10.900511; /// Polynomial coefficients for approximating the `gamma_ln` function const GAMMA_DK: &[f64] = &[ 2.48574089138753565546e-5, 1.05142378581721974210, -3.45687097222016235469, 4.51227709466894823700, -2.98285225323576655721, 1.05639711577126713077, -1.95428773191645869583e-1, 1.70970543404441224307e-2, -5.71926117404305781283e-4, 4.63399473359905636708e-6, -2.71994908488607703910e-9, ]; /// Computes the logarithm of the gamma function /// with an accuracy of 16 floating point digits. /// The implementation is derived from /// "An Analysis of the Lanczos Gamma Approximation", /// Glendon Ralph Pugh, 2004 p. 116 pub fn ln_gamma(x: f64) -> f64 { if x < 0.5 { let s = GAMMA_DK .iter() .enumerate() .skip(1) .fold(GAMMA_DK[0], |s, t| s + t.1 / (t.0 as f64 - x)); consts::LN_PI - (f64::consts::PI * x).sin().ln() - s.ln() - consts::LN_2_SQRT_E_OVER_PI - (0.5 - x) * ((0.5 - x + GAMMA_R) / f64::consts::E).ln() } else { let s = GAMMA_DK .iter() .enumerate() .skip(1) .fold(GAMMA_DK[0], |s, t| s + t.1 / (x + t.0 as f64 - 1.0)); s.ln() + consts::LN_2_SQRT_E_OVER_PI + (x - 0.5) * ((x - 0.5 + GAMMA_R) / f64::consts::E).ln() } } /// Computes the gamma function with an accuracy /// of 16 floating point digits. The implementation /// is derived from "An Analysis of the Lanczos Gamma Approximation", /// Glendon Ralph Pugh, 2004 p. 116 pub fn gamma(x: f64) -> f64 { if x < 0.5 { let s = GAMMA_DK .iter() .enumerate() .skip(1) .fold(GAMMA_DK[0], |s, t| s + t.1 / (t.0 as f64 - x)); f64::consts::PI / ((f64::consts::PI * x).sin() * s * consts::TWO_SQRT_E_OVER_PI * ((0.5 - x + GAMMA_R) / f64::consts::E).powf(0.5 - x)) } else { let s = GAMMA_DK .iter() .enumerate() .skip(1) .fold(GAMMA_DK[0], |s, t| s + t.1 / (x + t.0 as f64 - 1.0)); s * consts::TWO_SQRT_E_OVER_PI * ((x - 0.5 + GAMMA_R) / f64::consts::E).powf(x - 0.5) } } /// Computes the upper incomplete gamma function /// `Gamma(a,x) = int(exp(-t)t^(a-1), t=0..x) for a > 0, x > 0` /// where `a` is the argument for the gamma function and /// `x` is the lower intergral limit. /// /// # Panics /// /// if `a` or `x` are not in `(0, +inf)` pub fn gamma_ui(a: f64, x: f64) -> f64 { checked_gamma_ui(a, x).unwrap() } /// Computes the upper incomplete gamma function /// `Gamma(a,x) = int(exp(-t)t^(a-1), t=0..x) for a > 0, x > 0` /// where `a` is the argument for the gamma function and /// `x` is the lower intergral limit. /// /// # Errors /// /// if `a` or `x` are not in `(0, +inf)` pub fn checked_gamma_ui(a: f64, x: f64) -> Result { checked_gamma_ur(a, x).map(|x| x * gamma(a)) } /// Computes the lower incomplete gamma function /// `gamma(a,x) = int(exp(-t)t^(a-1), t=0..x) for a > 0, x > 0` /// where `a` is the argument for the gamma function and `x` /// is the upper integral limit. /// /// /// # Panics /// /// if `a` or `x` are not in `(0, +inf)` pub fn gamma_li(a: f64, x: f64) -> f64 { checked_gamma_li(a, x).unwrap() } /// Computes the lower incomplete gamma function /// `gamma(a,x) = int(exp(-t)t^(a-1), t=0..x) for a > 0, x > 0` /// where `a` is the argument for the gamma function and `x` /// is the upper integral limit. /// /// /// # Errors /// /// if `a` or `x` are not in `(0, +inf)` pub fn checked_gamma_li(a: f64, x: f64) -> Result { checked_gamma_lr(a, x).map(|x| x * gamma(a)) } /// Computes the upper incomplete regularized gamma function /// `Q(a,x) = 1 / Gamma(a) * int(exp(-t)t^(a-1), t=0..x) for a > 0, x > 0` /// where `a` is the argument for the gamma function and /// `x` is the lower integral limit. /// /// # Remarks /// /// Returns `f64::NAN` if either argument is `f64::NAN` /// /// # Panics /// /// if `a` or `x` are not in `(0, +inf)` pub fn gamma_ur(a: f64, x: f64) -> f64 { checked_gamma_ur(a, x).unwrap() } /// Computes the upper incomplete regularized gamma function /// `Q(a,x) = 1 / Gamma(a) * int(exp(-t)t^(a-1), t=0..x) for a > 0, x > 0` /// where `a` is the argument for the gamma function and /// `x` is the lower integral limit. /// /// # Remarks /// /// Returns `f64::NAN` if either argument is `f64::NAN` /// /// # Errors /// /// if `a` or `x` are not in `(0, +inf)` pub fn checked_gamma_ur(a: f64, x: f64) -> Result { if a.is_nan() || x.is_nan() { return Ok(f64::NAN); } if a <= 0.0 || a == f64::INFINITY { return Err(GammaFuncError::AInvalid); } if x <= 0.0 || x == f64::INFINITY { return Err(GammaFuncError::XInvalid); } let eps = 0.000000000000001; let big = 4503599627370496.0; let big_inv = 2.22044604925031308085e-16; if x < 1.0 || x <= a { return Ok(1.0 - gamma_lr(a, x)); } let mut ax = a * x.ln() - x - ln_gamma(a); if ax < -709.78271289338399 { return if a < x { Ok(0.0) } else { Ok(1.0) }; } ax = ax.exp(); let mut y = 1.0 - a; let mut z = x + y + 1.0; let mut c = 0.0; let mut pkm2 = 1.0; let mut qkm2 = x; let mut pkm1 = x + 1.0; let mut qkm1 = z * x; let mut ans = pkm1 / qkm1; loop { y += 1.0; z += 2.0; c += 1.0; let yc = y * c; let pk = pkm1 * z - pkm2 * yc; let qk = qkm1 * z - qkm2 * yc; pkm2 = pkm1; pkm1 = pk; qkm2 = qkm1; qkm1 = qk; if pk.abs() > big { pkm2 *= big_inv; pkm1 *= big_inv; qkm2 *= big_inv; qkm1 *= big_inv; } if qk != 0.0 { let r = pk / qk; let t = ((ans - r) / r).abs(); ans = r; if t <= eps { break; } } } Ok(ans * ax) } /// Computes the lower incomplete regularized gamma function /// `P(a,x) = 1 / Gamma(a) * int(exp(-t)t^(a-1), t=0..x) for real a > 0, x > 0` /// where `a` is the argument for the gamma function and `x` is the upper /// integral limit. /// /// # Remarks /// /// Returns `f64::NAN` if either argument is `f64::NAN` /// /// # Panics /// /// if `a` or `x` are not in `(0, +inf)` pub fn gamma_lr(a: f64, x: f64) -> f64 { checked_gamma_lr(a, x).unwrap() } /// Computes the lower incomplete regularized gamma function /// `P(a,x) = 1 / Gamma(a) * int(exp(-t)t^(a-1), t=0..x) for real a > 0, x > 0` /// where `a` is the argument for the gamma function and `x` is the upper /// integral limit. /// /// # Remarks /// /// Returns `f64::NAN` if either argument is `f64::NAN` /// /// # Errors /// /// if `a` or `x` are not in `(0, +inf)` pub fn checked_gamma_lr(a: f64, x: f64) -> Result { if a.is_nan() || x.is_nan() { return Ok(f64::NAN); } if a <= 0.0 || a == f64::INFINITY { return Err(GammaFuncError::AInvalid); } if x <= 0.0 || x == f64::INFINITY { return Err(GammaFuncError::XInvalid); } let eps = 0.000000000000001; let big = 4503599627370496.0; let big_inv = 2.22044604925031308085e-16; if prec::almost_eq(a, 0.0, prec::DEFAULT_F64_ACC) { return Ok(1.0); } if prec::almost_eq(x, 0.0, prec::DEFAULT_F64_ACC) { return Ok(0.0); } let ax = a * x.ln() - x - ln_gamma(a); if ax < -709.78271289338399 { if a < x { return Ok(1.0); } return Ok(0.0); } if x <= 1.0 || x <= a { let mut r2 = a; let mut c2 = 1.0; let mut ans2 = 1.0; loop { r2 += 1.0; c2 *= x / r2; ans2 += c2; if c2 / ans2 <= eps { break; } } return Ok(ax.exp() * ans2 / a); } let mut y = 1.0 - a; let mut z = x + y + 1.0; let mut c = 0; let mut p3 = 1.0; let mut q3 = x; let mut p2 = x + 1.0; let mut q2 = z * x; let mut ans = p2 / q2; loop { y += 1.0; z += 2.0; c += 1; let yc = y * f64::from(c); let p = p2 * z - p3 * yc; let q = q2 * z - q3 * yc; p3 = p2; p2 = p; q3 = q2; q2 = q; if p.abs() > big { p3 *= big_inv; p2 *= big_inv; q3 *= big_inv; q2 *= big_inv; } if q != 0.0 { let nextans = p / q; let error = ((ans - nextans) / nextans).abs(); ans = nextans; if error <= eps { break; } } } Ok(1.0 - ax.exp() * ans) } /// Computes the Digamma function which is defined as the derivative of /// the log of the gamma function. The implementation is based on /// "Algorithm AS 103", Jose Bernardo, Applied Statistics, Volume 25, Number 3 /// 1976, pages 315 - 317 pub fn digamma(x: f64) -> f64 { let c = 12.0; let d1 = -0.57721566490153286; let d2 = 1.6449340668482264365; let s = 1e-6; let s3 = 1.0 / 12.0; let s4 = 1.0 / 120.0; let s5 = 1.0 / 252.0; let s6 = 1.0 / 240.0; let s7 = 1.0 / 132.0; if x == f64::NEG_INFINITY || x.is_nan() { return f64::NAN; } if x <= 0.0 && ulps_eq!(x.floor(), x) { return f64::NEG_INFINITY; } if x < 0.0 { return digamma(1.0 - x) + f64::consts::PI / (-f64::consts::PI * x).tan(); } if x <= s { return d1 - 1.0 / x + d2 * x; } let mut result = 0.0; let mut z = x; while z < c { result -= 1.0 / z; z += 1.0; } if z >= c { let mut r = 1.0 / z; result += z.ln() - 0.5 * r; r *= r; result -= r * (s3 - r * (s4 - r * (s5 - r * (s6 - r * s7)))); } result } pub fn inv_digamma(x: f64) -> f64 { if x.is_nan() { return f64::NAN; } if x == f64::NEG_INFINITY { return 0.0; } if x == f64::INFINITY { return f64::INFINITY; } let mut y = x.exp(); let mut i = 1.0; while i > 1e-15 { y += i * signum(x - digamma(y)); i /= 2.0; } y } // modified signum that returns 0.0 if x == 0.0. Used // by inv_digamma, may consider extracting into a public // method fn signum(x: f64) -> f64 { if x == 0.0 { 0.0 } else { x.signum() } } #[rustfmt::skip] #[cfg(test)] mod tests { use super::*; use std::f64::consts; #[test] fn test_gamma() { assert!(super::gamma(f64::NAN).is_nan()); assert_almost_eq!(super::gamma(1.000001e-35), 9.9999900000099999900000099999899999522784235098567139293e+34, 1e20); assert_almost_eq!(super::gamma(1.000001e-10), 9.99998999943278432519738283781280989934496494539074049002e+9, 1e-5); assert_almost_eq!(super::gamma(1.000001e-5), 99999.32279432557746387178953902739303931424932435387031653234, 1e-10); assert_almost_eq!(super::gamma(1.000001e-2), 99.43248512896257405886134437203369035261893114349805309870831, 1e-13); assert_almost_eq!(super::gamma(-4.8), -0.06242336135475955314181664931547009890495158793105543559676, 1e-13); assert_almost_eq!(super::gamma(-1.5), 2.363271801207354703064223311121526910396732608163182837618410, 1e-13); assert_almost_eq!(super::gamma(-0.5), -3.54490770181103205459633496668229036559509891224477425642761, 1e-13); assert_almost_eq!(super::gamma(1.0e-5 + 1.0e-16), 99999.42279322556767360213300482199406241771308740302819426480, 1e-9); assert_almost_eq!(super::gamma(0.1), 9.513507698668731836292487177265402192550578626088377343050000, 1e-14); assert_eq!(super::gamma(1.0 - 1.0e-14), 1.000000000000005772156649015427511664653698987042926067639529); assert_almost_eq!(super::gamma(1.0), 1.0, 1e-15); assert_almost_eq!(super::gamma(1.0 + 1.0e-14), 0.99999999999999422784335098477029953441189552403615306268023, 1e-15); assert_almost_eq!(super::gamma(1.5), 0.886226925452758013649083741670572591398774728061193564106903, 1e-14); assert_almost_eq!(super::gamma(consts::PI/2.0), 0.890560890381539328010659635359121005933541962884758999762766, 1e-15); assert_eq!(super::gamma(2.0), 1.0); assert_almost_eq!(super::gamma(2.5), 1.329340388179137020473625612505858887098162092091790346160355, 1e-13); assert_almost_eq!(super::gamma(3.0), 2.0, 1e-14); assert_almost_eq!(super::gamma(consts::PI), 2.288037795340032417959588909060233922889688153356222441199380, 1e-13); assert_almost_eq!(super::gamma(3.5), 3.323350970447842551184064031264647217745405230229475865400889, 1e-14); assert_almost_eq!(super::gamma(4.0), 6.0, 1e-13); assert_almost_eq!(super::gamma(4.5), 11.63172839656744892914422410942626526210891830580316552890311, 1e-12); assert_almost_eq!(super::gamma(5.0 - 1.0e-14), 23.99999999999963853175957637087420162718107213574617032780374, 1e-13); assert_almost_eq!(super::gamma(5.0), 24.0, 1e-12); assert_almost_eq!(super::gamma(5.0 + 1.0e-14), 24.00000000000036146824042363510111050137786752408660789873592, 1e-12); assert_almost_eq!(super::gamma(5.5), 52.34277778455352018114900849241819367949013237611424488006401, 1e-12); assert_almost_eq!(super::gamma(10.1), 454760.7514415859508673358368319076190405047458218916492282448, 1e-7); assert_almost_eq!(super::gamma(150.0 + 1.0e-12), 3.8089226376496421386707466577615064443807882167327097140e+260, 1e248); } #[test] fn test_ln_gamma() { assert!(super::ln_gamma(f64::NAN).is_nan()); assert_eq!(super::ln_gamma(1.000001e-35), 80.59047725479209894029636783061921392709972287131139201585211); assert_almost_eq!(super::ln_gamma(1.000001e-10), 23.02584992988323521564308637407936081168344192865285883337793, 1e-14); assert_almost_eq!(super::ln_gamma(1.000001e-5), 11.51291869289055371493077240324332039045238086972508869965363, 1e-14); assert_eq!(super::ln_gamma(1.000001e-2), 4.599478872433667224554543378460164306444416156144779542513592); assert_almost_eq!(super::ln_gamma(0.1), 2.252712651734205959869701646368495118615627222294953765041739, 1e-14); assert_almost_eq!(super::ln_gamma(1.0 - 1.0e-14), 5.772156649015410852768463312546533565566459794933360600e-15, 1e-15); assert_almost_eq!(super::ln_gamma(1.0), 0.0, 1e-15); assert_almost_eq!(super::ln_gamma(1.0 + 1.0e-14), -5.77215664901524635936177848990288632404978978079827014e-15, 1e-15); assert_almost_eq!(super::ln_gamma(1.5), -0.12078223763524522234551844578164721225185272790259946836386, 1e-14); assert_almost_eq!(super::ln_gamma(consts::PI/2.0), -0.11590380084550241329912089415904874214542604767006895, 1e-14); assert_eq!(super::ln_gamma(2.0), 0.0); assert_almost_eq!(super::ln_gamma(2.5), 0.284682870472919159632494669682701924320137695559894729250145, 1e-13); assert_almost_eq!(super::ln_gamma(3.0), 0.693147180559945309417232121458176568075500134360255254120680, 1e-14); assert_almost_eq!(super::ln_gamma(consts::PI), 0.82769459232343710152957855845235995115350173412073715, 1e-13); assert_almost_eq!(super::ln_gamma(3.5), 1.200973602347074224816021881450712995770238915468157197042113, 1e-14); assert_almost_eq!(super::ln_gamma(4.0), 1.791759469228055000812477358380702272722990692183004705855374, 1e-14); assert_almost_eq!(super::ln_gamma(4.5), 2.453736570842442220504142503435716157331823510689763131380823, 1e-13); assert_almost_eq!(super::ln_gamma(5.0 - 1.0e-14), 3.178053830347930558470257283303394288448414225994179545985931, 1e-14); assert_almost_eq!(super::ln_gamma(5.0), 3.178053830347945619646941601297055408873990960903515214096734, 1e-14); assert_almost_eq!(super::ln_gamma(5.0 + 1.0e-14), 3.178053830347960680823625919312848824873279228348981287761046, 1e-13); assert_almost_eq!(super::ln_gamma(5.5), 3.957813967618716293877400855822590998551304491975006780729532, 1e-14); assert_almost_eq!(super::ln_gamma(10.1), 13.02752673863323795851370097886835481188051062306253294740504, 1e-14); assert_almost_eq!(super::ln_gamma(150.0 + 1.0e-12), 600.0094705553324354062157737572509902987070089159051628001813, 1e-12); assert_almost_eq!(super::ln_gamma(1.001e+7), 1.51342135323817913130119829455205139905331697084416059779e+8, 1e-13); } #[test] fn test_gamma_lr() { assert!(super::gamma_lr(f64::NAN, f64::NAN).is_nan()); assert_almost_eq!(super::gamma_lr(0.1, 1.0), 0.97587265627367222115949155252812057714751052498477013, 1e-14); assert_eq!(super::gamma_lr(0.1, 2.0), 0.99432617602018847196075251078067514034772764693462125); assert_eq!(super::gamma_lr(0.1, 8.0), 0.99999507519205198048686442150578226823401842046310854); assert_almost_eq!(super::gamma_lr(1.5, 1.0), 0.42759329552912016600095238564127189392715996802703368, 1e-13); assert_almost_eq!(super::gamma_lr(1.5, 2.0), 0.73853587005088937779717792402407879809718939080920993, 1e-15); assert_eq!(super::gamma_lr(1.5, 8.0), 0.99886601571021467734329986257903021041757398191304284); assert_almost_eq!(super::gamma_lr(2.5, 1.0), 0.15085496391539036377410688601371365034788861473418704, 1e-13); assert_almost_eq!(super::gamma_lr(2.5, 2.0), 0.45058404864721976739416885516693969548484517509263197, 1e-14); assert_almost_eq!(super::gamma_lr(2.5, 8.0), 0.99315592607757956900093935107222761316136944145439676, 1e-15); assert_almost_eq!(super::gamma_lr(5.5, 1.0), 0.0015041182825838038421585211353488839717739161316985392, 1e-15); assert_almost_eq!(super::gamma_lr(5.5, 2.0), 0.030082976121226050615171484772387355162056796585883967, 1e-14); assert_almost_eq!(super::gamma_lr(5.5, 8.0), 0.85886911973294184646060071855669224657735916933487681, 1e-14); assert_almost_eq!(super::gamma_lr(100.0, 0.5), 0.0, 1e-188); assert_almost_eq!(super::gamma_lr(100.0, 1.5), 0.0, 1e-141); assert_almost_eq!(super::gamma_lr(100.0, 90.0), 0.1582209891864301681049696996709105316998233457433473, 1e-13); assert_almost_eq!(super::gamma_lr(100.0, 100.0), 0.5132987982791486648573142565640291634709251499279450, 1e-13); assert_almost_eq!(super::gamma_lr(100.0, 110.0), 0.8417213299399129061982996209829688531933500308658222, 1e-13); assert_almost_eq!(super::gamma_lr(100.0, 200.0), 1.0, 1e-14); assert_eq!(super::gamma_lr(500.0, 0.5), 0.0); assert_eq!(super::gamma_lr(500.0, 1.5), 0.0); assert_almost_eq!(super::gamma_lr(500.0, 200.0), 0.0, 1e-70); assert_almost_eq!(super::gamma_lr(500.0, 450.0), 0.0107172380912897415573958770655204965434869949241480, 1e-14); assert_almost_eq!(super::gamma_lr(500.0, 500.0), 0.5059471461707603580470479574412058032802735425634263, 1e-13); assert_almost_eq!(super::gamma_lr(500.0, 550.0), 0.9853855918737048059548470006900844665580616318702748, 1e-14); assert_almost_eq!(super::gamma_lr(500.0, 700.0), 1.0, 1e-15); assert_eq!(super::gamma_lr(1000.0, 10000.0), 1.0); assert_eq!(super::gamma_lr(1e+50, 1e+48), 0.0); assert_eq!(super::gamma_lr(1e+50, 1e+52), 1.0); } #[test] #[should_panic] fn test_gamma_lr_a_lower_bound() { super::gamma_lr(-1.0, 1.0); } #[test] #[should_panic] fn test_gamma_lr_a_upper_bound() { super::gamma_lr(f64::INFINITY, 1.0); } #[test] #[should_panic] fn test_gamma_lr_x_lower_bound() { super::gamma_lr(1.0, -1.0); } #[test] #[should_panic] fn test_gamma_lr_x_upper_bound() { super::gamma_lr(1.0, f64::INFINITY); } #[test] fn test_checked_gamma_lr_a_lower_bound() { assert!(super::checked_gamma_lr(-1.0, 1.0).is_err()); } #[test] fn test_checked_gamma_lr_a_upper_bound() { assert!(super::checked_gamma_lr(f64::INFINITY, 1.0).is_err()); } #[test] fn test_checked_gamma_lr_x_lower_bound() { assert!(super::checked_gamma_lr(1.0, -1.0).is_err()); } #[test] fn test_checked_gamma_lr_x_upper_bound() { assert!(super::checked_gamma_lr(1.0, f64::INFINITY).is_err()); } #[test] fn test_gamma_li() { assert!(super::gamma_li(f64::NAN, f64::NAN).is_nan()); assert_almost_eq!(super::gamma_li(0.1, 1.0), 9.2839720283798852469443229940217320532607158711056334, 1e-14); assert_almost_eq!(super::gamma_li(0.1, 2.0), 9.4595297305559030536119885480983751098528458886962883, 1e-14); assert_almost_eq!(super::gamma_li(0.1, 8.0), 9.5134608464704033372127589212547718314010339263844976, 1e-13); assert_almost_eq!(super::gamma_li(1.5, 1.0), 0.37894469164098470380394366597039213790868855578083847, 1e-15); assert_almost_eq!(super::gamma_li(1.5, 2.0), 0.65451037345177732033319477475056262302270310457635612, 1e-14); assert_almost_eq!(super::gamma_li(1.5, 8.0), 0.88522195804210983776635107858848816480298923071075222, 1e-13); assert_almost_eq!(super::gamma_li(2.5, 1.0), 0.20053759629003473411039172879412733941722170263949, 1e-16); assert_almost_eq!(super::gamma_li(2.5, 2.0), 0.59897957413602228465664030130712917348327070206302442, 1e-15); assert_almost_eq!(super::gamma_li(2.5, 8.0), 1.3202422842943799358198434659248530581833764879301293, 1e-14); assert_almost_eq!(super::gamma_li(5.5, 1.0), 0.078729729026968321691794205337720556329618007004848672, 1e-16); assert_almost_eq!(super::gamma_li(5.5, 2.0), 1.5746265342113649473739798668921124454837064926448459, 1e-15); assert_almost_eq!(super::gamma_li(5.5, 8.0), 44.955595480196465884619737757794960132425035578313584, 1e-12); } #[test] #[should_panic] fn test_gamma_li_a_lower_bound() { super::gamma_li(-1.0, 1.0); } #[test] #[should_panic] fn test_gamma_li_a_upper_bound() { super::gamma_li(f64::INFINITY, 1.0); } #[test] #[should_panic] fn test_gamma_li_x_lower_bound() { super::gamma_li(1.0, -1.0); } #[test] #[should_panic] fn test_gamma_li_x_upper_bound() { super::gamma_li(1.0, f64::INFINITY); } #[test] fn test_checked_gamma_li_a_lower_bound() { assert!(super::checked_gamma_li(-1.0, 1.0).is_err()); } #[test] fn test_checked_gamma_li_a_upper_bound() { assert!(super::checked_gamma_li(f64::INFINITY, 1.0).is_err()); } #[test] fn test_checked_gamma_li_x_lower_bound() { assert!(super::checked_gamma_li(1.0, -1.0).is_err()); } #[test] fn test_checked_gamma_li_x_upper_bound() { assert!(super::checked_gamma_li(1.0, f64::INFINITY).is_err()); } // TODO: precision testing could be more accurate, borrowed wholesale from Math.NET #[test] fn test_gamma_ur() { assert!(super::gamma_ur(f64::NAN, f64::NAN).is_nan()); assert_almost_eq!(super::gamma_ur(0.1, 1.0), 0.0241273437263277773829694356333550393309597428392044, 1e-13); assert_almost_eq!(super::gamma_ur(0.1, 2.0), 0.0056738239798115280392474892193248596522723530653781, 1e-13); assert_almost_eq!(super::gamma_ur(0.1, 8.0), 0.0000049248079480195131355784942177317659815795368919702, 1e-13); assert_almost_eq!(super::gamma_ur(1.5, 1.0), 0.57240670447087983399904761435872810607284003197297, 1e-13); assert_almost_eq!(super::gamma_ur(1.5, 2.0), 0.26146412994911062220282207597592120190281060919079, 1e-13); assert_almost_eq!(super::gamma_ur(1.5, 8.0), 0.0011339842897853226567001374209697895824260180869567, 1e-13); assert_almost_eq!(super::gamma_ur(2.5, 1.0), 0.84914503608460963622589311398628634965211138526581, 1e-13); assert_almost_eq!(super::gamma_ur(2.5, 2.0), 0.54941595135278023260583114483306030451515482490737, 1e-13); assert_almost_eq!(super::gamma_ur(2.5, 8.0), 0.0068440739224204309990606489277723868386305585456026, 1e-13); assert_almost_eq!(super::gamma_ur(5.5, 1.0), 0.9984958817174161961578414788646511160282260838683, 1e-13); assert_almost_eq!(super::gamma_ur(5.5, 2.0), 0.96991702387877394938482851522761264483794320341412, 1e-13); assert_almost_eq!(super::gamma_ur(5.5, 8.0), 0.14113088026705815353939928144330775342264083066512, 1e-13); assert_almost_eq!(super::gamma_ur(100.0, 0.5), 1.0, 1e-14); assert_almost_eq!(super::gamma_ur(100.0, 1.5), 1.0, 1e-14); assert_almost_eq!(super::gamma_ur(100.0, 90.0), 0.8417790108135698318950303003290894683001766542566526, 1e-12); assert_almost_eq!(super::gamma_ur(100.0, 100.0), 0.4867012017208513351426857434359708365290748500720549, 1e-12); assert_almost_eq!(super::gamma_ur(100.0, 110.0), 0.1582786700600870938017003790170311468066499691341777, 1e-12); assert_almost_eq!(super::gamma_ur(100.0, 200.0), 0.0, 1e-14); assert_almost_eq!(super::gamma_ur(500.0, 0.5), 1.0, 1e-14); assert_almost_eq!(super::gamma_ur(500.0, 1.5), 1.0, 1e-14); assert_almost_eq!(super::gamma_ur(500.0, 200.0), 1.0, 1e-14); assert_almost_eq!(super::gamma_ur(500.0, 450.0), 0.9892827619087102584426041229344795034565130050758519, 1e-12); assert_almost_eq!(super::gamma_ur(500.0, 500.0), 0.4940528538292396419529520425587941967197264574365736, 1e-12); assert_almost_eq!(super::gamma_ur(500.0, 550.0), 0.0146144081262951940451529993099155334419383681297251, 1e-12); assert_almost_eq!(super::gamma_ur(500.0, 700.0), 0.0, 1e-14); assert_almost_eq!(super::gamma_ur(1000.0, 10000.0), 0.0, 1e-14); assert_almost_eq!(super::gamma_ur(1e+50, 1e+48), 1.0, 1e-14); assert_almost_eq!(super::gamma_ur(1e+50, 1e+52), 0.0, 1e-14); } #[test] #[should_panic] fn test_gamma_ur_a_lower_bound() { super::gamma_ur(-1.0, 1.0); } #[test] #[should_panic] fn test_gamma_ur_a_upper_bound() { super::gamma_ur(f64::INFINITY, 1.0); } #[test] #[should_panic] fn test_gamma_ur_x_lower_bound() { super::gamma_ur(1.0, -1.0); } #[test] #[should_panic] fn test_gamma_ur_x_upper_bound() { super::gamma_ur(1.0, f64::INFINITY); } #[test] fn test_checked_gamma_ur_a_lower_bound() { assert!(super::checked_gamma_ur(-1.0, 1.0).is_err()); } #[test] fn test_checked_gamma_ur_a_upper_bound() { assert!(super::checked_gamma_ur(f64::INFINITY, 1.0).is_err()); } #[test] fn test_checked_gamma_ur_x_lower_bound() { assert!(super::checked_gamma_ur(1.0, -1.0).is_err()); } #[test] fn test_checked_gamma_ur_x_upper_bound() { assert!(super::checked_gamma_ur(1.0, f64::INFINITY).is_err()); } #[test] fn test_gamma_ui() { assert!(super::gamma_ui(f64::NAN, f64::NAN).is_nan()); assert_almost_eq!(super::gamma_ui(0.1, 1.0), 0.2295356702888460382790772147651768201739736396141314, 1e-14); assert_almost_eq!(super::gamma_ui(0.1, 2.0), 0.053977968112828232195991347726857391060870217694027, 1e-15); assert_almost_eq!(super::gamma_ui(0.1, 8.0), 0.000046852198327948595220974570460669512682180005810156, 1e-19); assert_almost_eq!(super::gamma_ui(1.5, 1.0), 0.50728223381177330984514007570018045349008617228036, 1e-14); assert_almost_eq!(super::gamma_ui(1.5, 2.0), 0.23171655200098069331588896692000996837607162348484, 1e-15); assert_almost_eq!(super::gamma_ui(1.5, 8.0), 0.0010049674106481758827326630820844265957854973504417, 1e-17); assert_almost_eq!(super::gamma_ui(2.5, 1.0), 1.1288027918891022863632338837117315476809403894523, 1e-14); assert_almost_eq!(super::gamma_ui(2.5, 2.0), 0.73036081404311473581698531119872971361489139002877, 1e-14); assert_almost_eq!(super::gamma_ui(2.5, 8.0), 0.0090981038847570846537821465810058289147856041616617, 1e-17); assert_almost_eq!(super::gamma_ui(5.5, 1.0), 52.264048055526551859457214287080473123160514369109, 1e-12); assert_almost_eq!(super::gamma_ui(5.5, 2.0), 50.768151250342155233775028625526081234006425883469, 1e-12); assert_almost_eq!(super::gamma_ui(5.5, 8.0), 7.3871823043570542965292707346232335470650967978006, 1e-13); } #[test] #[should_panic] fn test_gamma_ui_a_lower_bound() { super::gamma_ui(-1.0, 1.0); } #[test] #[should_panic] fn test_gamma_ui_a_upper_bound() { super::gamma_ui(f64::INFINITY, 1.0); } #[test] #[should_panic] fn test_gamma_ui_x_lower_bound() { super::gamma_ui(1.0, -1.0); } #[test] #[should_panic] fn test_gamma_ui_x_upper_bound() { super::gamma_ui(1.0, f64::INFINITY); } #[test] fn test_checked_gamma_ui_a_lower_bound() { assert!(super::checked_gamma_ui(-1.0, 1.0).is_err()); } #[test] fn test_checked_gamma_ui_a_upper_bound() { assert!(super::checked_gamma_ui(f64::INFINITY, 1.0).is_err()); } #[test] fn test_checked_gamma_ui_x_lower_bound() { assert!(super::checked_gamma_ui(1.0, -1.0).is_err()); } #[test] fn test_checked_gamma_ui_x_upper_bound() { assert!(super::checked_gamma_ui(1.0, f64::INFINITY).is_err()); } // TODO: precision testing could be more accurate #[test] fn test_digamma() { assert!(super::digamma(f64::NAN).is_nan()); assert_almost_eq!(super::digamma(-1.5), 0.70315664064524318722569033366791109947350706200623256, 1e-14); assert_almost_eq!(super::digamma(-0.5), 0.036489973978576520559023667001244432806840395339565891, 1e-14); assert_almost_eq!(super::digamma(0.1), -10.423754940411076232100295314502760886768558023951363, 1e-14); assert_almost_eq!(super::digamma(1.0), -0.57721566490153286060651209008240243104215933593992359, 1e-14); assert_almost_eq!(super::digamma(1.5), 0.036489973978576520559023667001244432806840395339565888, 1e-14); assert_almost_eq!(super::digamma(consts::PI / 2.0), 0.10067337642740238636795561404029690452798358068944001, 1e-14); assert_almost_eq!(super::digamma(2.0), 0.42278433509846713939348790991759756895784066406007641, 1e-14); assert_almost_eq!(super::digamma(2.5), 0.70315664064524318722569033366791109947350706200623255, 1e-14); assert_almost_eq!(super::digamma(3.0), 0.92278433509846713939348790991759756895784066406007641, 1e-14); assert_almost_eq!(super::digamma(consts::PI), 0.97721330794200673329206948640618234364083460999432603, 1e-14); assert_almost_eq!(super::digamma(3.5), 1.1031566406452431872256903336679110994735070620062326, 1e-14); assert_almost_eq!(super::digamma(4.0), 1.2561176684318004727268212432509309022911739973934097, 1e-14); assert_almost_eq!(super::digamma(4.5), 1.3888709263595289015114046193821968137592213477205183, 1e-14); assert_almost_eq!(super::digamma(5.0), 1.5061176684318004727268212432509309022911739973934097, 1e-14); assert_almost_eq!(super::digamma(5.5), 1.6110931485817511237336268416044190359814435699427405, 1e-14); assert_almost_eq!(super::digamma(10.1), 2.2622143570941481235561593642219403924532310597356171, 1e-14); } #[test] fn test_inv_digamma() { assert!(super::inv_digamma(f64::NAN).is_nan()); assert_eq!(super::inv_digamma(f64::NEG_INFINITY), 0.0); assert_almost_eq!(super::inv_digamma(-10.423754940411076232100295314502760886768558023951363), 0.1, 1e-15); assert_almost_eq!(super::inv_digamma(-0.57721566490153286060651209008240243104215933593992359), 1.0, 1e-14); assert_almost_eq!(super::inv_digamma(0.036489973978576520559023667001244432806840395339565888), 1.5, 1e-14); assert_almost_eq!(super::inv_digamma(0.10067337642740238636795561404029690452798358068944001), consts::PI / 2.0, 1e-14); assert_almost_eq!(super::inv_digamma(0.42278433509846713939348790991759756895784066406007641), 2.0, 1e-14); assert_almost_eq!(super::inv_digamma(0.70315664064524318722569033366791109947350706200623255), 2.5, 1e-14); assert_almost_eq!(super::inv_digamma(0.92278433509846713939348790991759756895784066406007641), 3.0, 1e-14); assert_almost_eq!(super::inv_digamma(0.97721330794200673329206948640618234364083460999432603), consts::PI, 1e-14); assert_almost_eq!(super::inv_digamma(1.1031566406452431872256903336679110994735070620062326), 3.5, 1e-14); assert_almost_eq!(super::inv_digamma(1.2561176684318004727268212432509309022911739973934097), 4.0, 1e-14); assert_almost_eq!(super::inv_digamma(1.3888709263595289015114046193821968137592213477205183), 4.5, 1e-14); assert_almost_eq!(super::inv_digamma(1.5061176684318004727268212432509309022911739973934097), 5.0, 1e-14); assert_almost_eq!(super::inv_digamma(1.6110931485817511237336268416044190359814435699427405), 5.5, 1e-14); assert_almost_eq!(super::inv_digamma(2.2622143570941481235561593642219403924532310597356171), 10.1, 1e-13); } #[test] fn test_error_is_sync_send() { fn assert_sync_send() {} assert_sync_send::(); } } statrs-0.18.0/src/function/harmonic.rs000064400000000000000000000043011046102023000160120ustar 00000000000000//! Provides functions for calculating //! [harmonic](https://en.wikipedia.org/wiki/Harmonic_number) //! numbers use crate::consts; use crate::function::gamma; /// Computes the `t`-th harmonic number /// /// # Remarks /// /// Returns `1` as a special case when `t == 0` pub fn harmonic(t: u64) -> f64 { match t { 0 => 1.0, _ => consts::EULER_MASCHERONI + gamma::digamma(t as f64 + 1.0), } } /// Computes the generalized harmonic number of order `n` of `m` /// e.g. `(1 + 1/2^m + 1/3^m + ... + 1/n^m)` /// /// # Remarks /// /// Returns `1` as a special case when `n == 0` pub fn gen_harmonic(n: u64, m: f64) -> f64 { match n { 0 => 1.0, _ => (0..n).fold(0.0, |acc, x| acc + (x as f64 + 1.0).powf(-m)), } } #[rustfmt::skip] #[cfg(test)] mod tests { use std::f64; #[test] fn test_harmonic() { assert_eq!(super::harmonic(0), 1.0); assert_almost_eq!(super::harmonic(1), 1.0, 1e-14); assert_almost_eq!(super::harmonic(2), 1.5, 1e-14); assert_almost_eq!(super::harmonic(4), 2.083333333333333333333, 1e-14); assert_almost_eq!(super::harmonic(8), 2.717857142857142857143, 1e-14); assert_almost_eq!(super::harmonic(16), 3.380728993228993228993, 1e-14); } #[test] fn test_gen_harmonic() { assert_eq!(super::gen_harmonic(0, 0.0), 1.0); assert_eq!(super::gen_harmonic(0, f64::INFINITY), 1.0); assert_eq!(super::gen_harmonic(0, f64::NEG_INFINITY), 1.0); assert_eq!(super::gen_harmonic(1, 0.0), 1.0); assert_eq!(super::gen_harmonic(1, f64::INFINITY), 1.0); assert_eq!(super::gen_harmonic(1, f64::NEG_INFINITY), 1.0); assert_eq!(super::gen_harmonic(2, 1.0), 1.5); assert_eq!(super::gen_harmonic(2, 3.0), 1.125); assert_eq!(super::gen_harmonic(2, f64::INFINITY), 1.0); assert_eq!(super::gen_harmonic(2, f64::NEG_INFINITY), f64::INFINITY); assert_almost_eq!(super::gen_harmonic(4, 1.0), 2.083333333333333333333, 1e-14); assert_eq!(super::gen_harmonic(4, 3.0), 1.177662037037037037037); assert_eq!(super::gen_harmonic(4, f64::INFINITY), 1.0); assert_eq!(super::gen_harmonic(4, f64::NEG_INFINITY), f64::INFINITY); } } statrs-0.18.0/src/function/logistic.rs000064400000000000000000000051501046102023000160320ustar 00000000000000//! Provides the [logistic](http://en.wikipedia.org/wiki/Logistic_function) and //! related functions /// Computes the logistic function pub fn logistic(p: f64) -> f64 { 1.0 / ((-p).exp() + 1.0) } /// Computes the logit function /// /// # Panics /// /// If `p < 0.0` or `p > 1.0` pub fn logit(p: f64) -> f64 { checked_logit(p).unwrap() } /// Computes the logit function, returning `None` if `p < 0.0` or `p > 1.0`. pub fn checked_logit(p: f64) -> Option { if (0.0..=1.0).contains(&p) { Some((p / (1.0 - p)).ln()) } else { None } } #[rustfmt::skip] #[cfg(test)] mod tests { use std::f64; #[test] fn test_logistic() { assert_eq!(super::logistic(f64::NEG_INFINITY), 0.0); assert_eq!(super::logistic(-11.512915464920228103874353849992239636376994324587), 0.00001); assert_almost_eq!(super::logistic(-6.9067547786485535272274487616830597875179908939086), 0.001, 1e-18); assert_almost_eq!(super::logistic(-2.1972245773362193134015514347727700402304323440139), 0.1, 1e-16); assert_eq!(super::logistic(0.0), 0.5); assert_almost_eq!(super::logistic(2.1972245773362195801634726294284168954491240598975), 0.9, 1e-15); assert_almost_eq!(super::logistic(6.9067547786485526081487245019905638981131702804661), 0.999, 1e-15); assert_eq!(super::logistic(11.512915464924779098232747799811946290419057060965), 0.99999); assert_eq!(super::logistic(f64::INFINITY), 1.0); } #[test] fn test_logit() { assert_eq!(super::logit(0.0), f64::NEG_INFINITY); assert_eq!(super::logit(0.00001), -11.512915464920228103874353849992239636376994324587); assert_eq!(super::logit(0.001), -6.9067547786485535272274487616830597875179908939086); assert_eq!(super::logit(0.1), -2.1972245773362193134015514347727700402304323440139); assert_eq!(super::logit(0.5), 0.0); assert_eq!(super::logit(0.9), 2.1972245773362195801634726294284168954491240598975); assert_eq!(super::logit(0.999), 6.9067547786485526081487245019905638981131702804661); assert_eq!(super::logit(0.99999), 11.512915464924779098232747799811946290419057060965); assert_eq!(super::logit(1.0), f64::INFINITY); } #[test] #[should_panic] fn test_logit_p_lt_0() { super::logit(-1.0); } #[test] #[should_panic] fn test_logit_p_gt_1() { super::logit(2.0); } #[test] fn test_checked_logit_p_lt_0() { assert!(super::checked_logit(-1.0).is_none()); } #[test] fn test_checked_logit_p_gt_1() { assert!(super::checked_logit(2.0).is_none()); } } statrs-0.18.0/src/function/mod.rs000064400000000000000000000003611046102023000147730ustar 00000000000000//! Provides a host of special statistical functions (e.g. the beta function or //! the error function) pub mod beta; pub mod erf; pub mod evaluate; pub mod exponential; pub mod factorial; pub mod gamma; pub mod harmonic; pub mod logistic; statrs-0.18.0/src/generate.rs000064400000000000000000000242261046102023000141670ustar 00000000000000//! Provides utility functions for generating data sequences use crate::euclid::Modulus; use std::f64::consts; /// Generates a base 10 log spaced vector of the given length between the /// specified decade exponents (inclusive). Equivalent to MATLAB logspace /// /// # Examples /// /// ``` /// use statrs::generate; /// /// let x = generate::log_spaced(5, 0.0, 4.0); /// assert_eq!(x, [1.0, 10.0, 100.0, 1000.0, 10000.0]); /// ``` pub fn log_spaced(length: usize, start_exp: f64, stop_exp: f64) -> Vec { match length { 0 => Vec::new(), 1 => vec![10f64.powf(stop_exp)], _ => { let step = (stop_exp - start_exp) / (length - 1) as f64; let mut vec = (0..length) .map(|x| 10f64.powf(start_exp + (x as f64) * step)) .collect::>(); vec[length - 1] = 10f64.powf(stop_exp); vec } } } /// Infinite iterator returning floats that form a periodic wave #[derive(Clone, Copy, PartialEq, Debug)] pub struct InfinitePeriodic { amplitude: f64, step: f64, phase: f64, k: f64, } impl InfinitePeriodic { /// Constructs a new infinite periodic wave generator /// /// # Examples /// /// ``` /// use statrs::generate::InfinitePeriodic; /// /// let x = InfinitePeriodic::new(8.0, 2.0, 10.0, 1.0, /// 2).take(10).collect::>(); /// assert_eq!(x, [6.0, 8.5, 1.0, 3.5, 6.0, 8.5, 1.0, 3.5, 6.0, 8.5]); /// ``` pub fn new( sampling_rate: f64, frequency: f64, amplitude: f64, phase: f64, delay: i64, ) -> InfinitePeriodic { let step = frequency / sampling_rate * amplitude; InfinitePeriodic { amplitude, step, phase: (phase - delay as f64 * step).modulus(amplitude), k: 0.0, } } /// Constructs a default infinite periodic wave generator /// /// # Examples /// /// ``` /// use statrs::generate::InfinitePeriodic; /// /// let x = InfinitePeriodic::default(8.0, /// 2.0).take(10).collect::>(); /// assert_eq!(x, [0.0, 0.25, 0.5, 0.75, 0.0, 0.25, 0.5, 0.75, 0.0, 0.25]); /// ``` pub fn default(sampling_rate: f64, frequency: f64) -> InfinitePeriodic { Self::new(sampling_rate, frequency, 1.0, 0.0, 0) } } impl std::fmt::Display for InfinitePeriodic { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{self:#?}") } } impl Iterator for InfinitePeriodic { type Item = f64; fn next(&mut self) -> Option { let mut x = self.phase + self.k * self.step; if x >= self.amplitude { x %= self.amplitude; self.phase = x; self.k = 0.0; } self.k += 1.0; Some(x) } } /// Infinite iterator returning floats that form a sinusoidal wave #[derive(Debug, Clone, Copy, PartialEq)] pub struct InfiniteSinusoidal { amplitude: f64, mean: f64, step: f64, phase: f64, i: usize, } impl InfiniteSinusoidal { /// Constructs a new infinite sinusoidal wave generator /// /// # Examples /// /// ``` /// use statrs::generate::InfiniteSinusoidal; /// /// let x = InfiniteSinusoidal::new(8.0, 2.0, 1.0, 5.0, 2.0, /// 1).take(10).collect::>(); /// assert_eq!(x, /// [5.416146836547142, 5.909297426825682, 4.583853163452858, /// 4.090702573174318, 5.416146836547142, 5.909297426825682, /// 4.583853163452858, 4.090702573174318, 5.416146836547142, /// 5.909297426825682]); /// ``` pub fn new( sampling_rate: f64, frequency: f64, amplitude: f64, mean: f64, phase: f64, delay: i64, ) -> InfiniteSinusoidal { let pi2 = consts::PI * 2.0; let step = frequency / sampling_rate * pi2; InfiniteSinusoidal { amplitude, mean, step, phase: (phase - delay as f64 * step) % pi2, i: 0, } } /// Constructs a default infinite sinusoidal wave generator /// /// # Examples /// /// ``` /// use statrs::generate::InfiniteSinusoidal; /// /// let x = InfiniteSinusoidal::default(8.0, 2.0, /// 1.0).take(10).collect::>(); /// assert_eq!(x, /// [0.0, 1.0, 0.00000000000000012246467991473532, /// -1.0, -0.00000000000000024492935982947064, 1.0, /// 0.00000000000000036739403974420594, -1.0, /// -0.0000000000000004898587196589413, 1.0]); /// ``` pub fn default(sampling_rate: f64, frequency: f64, amplitude: f64) -> InfiniteSinusoidal { Self::new(sampling_rate, frequency, amplitude, 0.0, 0.0, 0) } } impl std::fmt::Display for InfiniteSinusoidal { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:#?}", &self) } } impl Iterator for InfiniteSinusoidal { type Item = f64; fn next(&mut self) -> Option { let x = self.mean + self.amplitude * (self.phase + self.i as f64 * self.step).sin(); self.i += 1; if self.i == 1000 { self.i = 0; self.phase = (self.phase + 1000.0 * self.step) % (consts::PI * 2.0); } Some(x) } } /// Infinite iterator returning floats forming a square wave starting /// with the high phase #[derive(Debug, Clone, Copy, PartialEq)] pub struct InfiniteSquare { periodic: InfinitePeriodic, high_duration: f64, high_value: f64, low_value: f64, } impl InfiniteSquare { /// Constructs a new infinite square wave generator /// /// # Examples /// /// ``` /// use statrs::generate::InfiniteSquare; /// /// let x = InfiniteSquare::new(3, 7, 1.0, -1.0, /// 1).take(12).collect::>(); /// assert_eq!(x, [-1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, /// -1.0, 1.0]) /// ``` pub fn new( high_duration: i64, low_duration: i64, high_value: f64, low_value: f64, delay: i64, ) -> InfiniteSquare { let duration = (high_duration + low_duration) as f64; InfiniteSquare { periodic: InfinitePeriodic::new(1.0, 1.0 / duration, duration, 0.0, delay), high_duration: high_duration as f64, high_value, low_value, } } } impl std::fmt::Display for InfiniteSquare { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:#?}", &self) } } impl Iterator for InfiniteSquare { type Item = f64; fn next(&mut self) -> Option { self.periodic.next().map(|x| { if x < self.high_duration { self.high_value } else { self.low_value } }) } } /// Infinite iterator returning floats forming a triangle wave starting with /// the raise phase from the lowest sample #[derive(Debug, Clone, Copy, PartialEq)] pub struct InfiniteTriangle { periodic: InfinitePeriodic, raise_duration: f64, raise: f64, fall: f64, high_value: f64, low_value: f64, } impl InfiniteTriangle { /// Constructs a new infinite triangle wave generator /// /// # Examples /// /// ``` /// #[macro_use] /// extern crate statrs; /// /// use statrs::generate::InfiniteTriangle; /// /// # fn main() { /// let x = InfiniteTriangle::new(4, 7, 1.0, -1.0, /// 1).take(12).collect::>(); /// let expected: [f64; 12] = [-0.714, -1.0, -0.5, 0.0, 0.5, 1.0, 0.714, /// 0.429, 0.143, -0.143, -0.429, -0.714]; /// for (&left, &right) in x.iter().zip(expected.iter()) { /// assert_almost_eq!(left, right, 1e-3); /// } /// # } /// ``` pub fn new( raise_duration: i64, fall_duration: i64, high_value: f64, low_value: f64, delay: i64, ) -> InfiniteTriangle { let duration = (raise_duration + fall_duration) as f64; let height = high_value - low_value; InfiniteTriangle { periodic: InfinitePeriodic::new(1.0, 1.0 / duration, duration, 0.0, delay), raise_duration: raise_duration as f64, raise: height / raise_duration as f64, fall: height / fall_duration as f64, high_value, low_value, } } } impl std::fmt::Display for InfiniteTriangle { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:#?}", &self) } } impl Iterator for InfiniteTriangle { type Item = f64; fn next(&mut self) -> Option { self.periodic.next().map(|x| { if x < self.raise_duration { self.low_value + x * self.raise } else { self.high_value - (x - self.raise_duration) * self.fall } }) } } /// Infinite iterator returning floats forming a sawtooth wave /// starting with the lowest sample #[derive(Debug, Clone, Copy, PartialEq)] pub struct InfiniteSawtooth { periodic: InfinitePeriodic, low_value: f64, } impl InfiniteSawtooth { /// Constructs a new infinite sawtooth wave generator /// /// # Examples /// /// ``` /// use statrs::generate::InfiniteSawtooth; /// /// let x = InfiniteSawtooth::new(5, 1.0, -1.0, /// 1).take(12).collect::>(); /// assert_eq!(x, [1.0, -1.0, -0.5, 0.0, 0.5, 1.0, -1.0, -0.5, 0.0, 0.5, /// 1.0, -1.0]); /// ``` pub fn new(period: i64, high_value: f64, low_value: f64, delay: i64) -> InfiniteSawtooth { let height = high_value - low_value; let period = period as f64; InfiniteSawtooth { periodic: InfinitePeriodic::new( 1.0, 1.0 / period, height * period / (period - 1.0), 0.0, delay, ), low_value, } } } impl std::fmt::Display for InfiniteSawtooth { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:#?}", &self) } } impl Iterator for InfiniteSawtooth { type Item = f64; fn next(&mut self) -> Option { self.periodic.next().map(|x| x + self.low_value) } } statrs-0.18.0/src/lib.rs000064400000000000000000000054031046102023000131370ustar 00000000000000//! This crate aims to be a functional port of the Math.NET Numerics //! Distribution package and in doing so providing the Rust numerical computing //! community with a robust, well-tested statistical distribution package. This //! crate also ports over some of the special statistical functions from //! Math.NET in so far as they are used in the computation of distribution //! values. This crate depends on the `rand` crate to provide RNG. //! //! # Sampling //! The common use case is to set up the distributions and sample from them which depends on the `Rand` crate for random number generation. #![cfg_attr(feature = "rand", doc = "```")] #![cfg_attr(not(feature = "rand"), doc = "```ignore")] //! use statrs::distribution::Exp; //! use rand::distributions::Distribution; //! let mut r = rand::rngs::OsRng; //! let n = Exp::new(0.5).unwrap(); //! print!("{}", n.sample(&mut r)); //! ``` //! //! # Introspecting distributions //! Statrs also comes with a number of useful utility traits for more detailed introspection of distributions. //! ``` //! use statrs::distribution::{Exp, Continuous, ContinuousCDF}; // `cdf` and `pdf` //! use statrs::statistics::Distribution; // statistical moments and entropy //! //! let n = Exp::new(1.0).unwrap(); //! assert_eq!(n.mean(), Some(1.0)); //! assert_eq!(n.variance(), Some(1.0)); //! assert_eq!(n.entropy(), Some(1.0)); //! assert_eq!(n.skewness(), Some(2.0)); //! assert_eq!(n.cdf(1.0), 0.6321205588285576784045); //! assert_eq!(n.pdf(1.0), 0.3678794411714423215955); //! ``` //! //! # Utility functions //! as well as utility functions including `erf`, `gamma`, `ln_gamma`, `beta`, etc. //! //! ``` //! use statrs::distribution::FisherSnedecor; //! use statrs::statistics::Distribution; //! //! let n = FisherSnedecor::new(1.0, 1.0).unwrap(); //! assert!(n.variance().is_none()); //! ``` //! ## Distributions implemented //! Statrs comes with a number of commonly used distributions including Normal, Gamma, Student's T, Exponential, Weibull, etc. view all implemented in `distributions` module. #![crate_type = "lib"] #![crate_name = "statrs"] #![allow(clippy::excessive_precision)] #![allow(clippy::many_single_char_names)] #![forbid(unsafe_code)] #![cfg_attr(coverage_nightly, feature(coverage_attribute))] #![cfg_attr(docsrs, feature(doc_cfg))] #[macro_use] extern crate approx; #[macro_export] macro_rules! assert_almost_eq { ($a:expr, $b:expr, $prec:expr $(,)?) => { if !$crate::prec::almost_eq($a, $b, $prec) { panic!( "assertion failed: `abs(left - right) < {:e}`, (left: `{}`, right: `{}`)", $prec, $a, $b ); } }; } pub mod consts; #[macro_use] pub mod distribution; pub mod euclid; pub mod function; pub mod generate; pub mod prec; pub mod statistics; pub mod stats_tests; statrs-0.18.0/src/prec.rs000064400000000000000000000020051046102023000133150ustar 00000000000000//! Provides utility functions for working with floating point precision use approx::AbsDiffEq; /// Standard epsilon, maximum relative precision of IEEE 754 double-precision /// floating point numbers (64 bit) e.g. `2^-53` pub const F64_PREC: f64 = 0.00000000000000011102230246251565; /// Default accuracy for `f64`, equivalent to `0.0 * F64_PREC` pub const DEFAULT_F64_ACC: f64 = 0.0000000000000011102230246251565; /// Compares if two floats are close via `approx::abs_diff_eq` /// using a maximum absolute difference (epsilon) of `acc`. pub fn almost_eq(a: f64, b: f64, acc: f64) -> bool { if a.is_infinite() && b.is_infinite() { return a == b; } a.abs_diff_eq(&b, acc) } /// Compares if two floats are close via `approx::relative_eq!` /// and `crate::consts::ACC` relative precision. /// Updates first argument to value of second argument pub fn convergence(x: &mut f64, x_new: f64) -> bool { let res = approx::relative_eq!(*x, x_new, max_relative = crate::consts::ACC); *x = x_new; res } statrs-0.18.0/src/statistics/iter_statistics.rs000064400000000000000000000163341046102023000200050ustar 00000000000000use crate::statistics::*; use std::borrow::Borrow; use std::f64; impl Statistics for T where T: IntoIterator, T::Item: Borrow, { fn min(self) -> f64 { let mut iter = self.into_iter(); match iter.next() { None => f64::NAN, Some(x) => iter.map(|x| *x.borrow()).fold(*x.borrow(), |acc, x| { if x < acc || x.is_nan() { x } else { acc } }), } } fn max(self) -> f64 { let mut iter = self.into_iter(); match iter.next() { None => f64::NAN, Some(x) => iter.map(|x| *x.borrow()).fold(*x.borrow(), |acc, x| { if x > acc || x.is_nan() { x } else { acc } }), } } fn abs_min(self) -> f64 { let mut iter = self.into_iter(); match iter.next() { None => f64::NAN, Some(init) => iter .map(|x| x.borrow().abs()) .fold(init.borrow().abs(), |acc, x| { if x < acc || x.is_nan() { x } else { acc } }), } } fn abs_max(self) -> f64 { let mut iter = self.into_iter(); match iter.next() { None => f64::NAN, Some(init) => iter .map(|x| x.borrow().abs()) .fold(init.borrow().abs(), |acc, x| { if x > acc || x.is_nan() { x } else { acc } }), } } fn mean(self) -> f64 { let mut i = 0.0; let mut mean = 0.0; for x in self { i += 1.0; mean += (x.borrow() - mean) / i; } if i > 0.0 { mean } else { f64::NAN } } fn geometric_mean(self) -> f64 { let mut i = 0.0; let mut sum = 0.0; for x in self { i += 1.0; sum += x.borrow().ln(); } if i > 0.0 { (sum / i).exp() } else { f64::NAN } } fn harmonic_mean(self) -> f64 { let mut i = 0.0; let mut sum = 0.0; for x in self { i += 1.0; let borrow = *x.borrow(); if borrow < 0f64 { return f64::NAN; } sum += 1.0 / borrow; } if i > 0.0 { i / sum } else { f64::NAN } } fn variance(self) -> f64 { let mut iter = self.into_iter(); let mut sum = match iter.next() { None => f64::NAN, Some(x) => *x.borrow(), }; let mut i = 1.0; let mut variance = 0.0; for x in iter { i += 1.0; let borrow = *x.borrow(); sum += borrow; let diff = i * borrow - sum; variance += diff * diff / (i * (i - 1.0)) } if i > 1.0 { variance / (i - 1.0) } else { f64::NAN } } fn std_dev(self) -> f64 { self.variance().sqrt() } fn population_variance(self) -> f64 { let mut iter = self.into_iter(); let mut sum = match iter.next() { None => return f64::NAN, Some(x) => *x.borrow(), }; let mut i = 1.0; let mut variance = 0.0; for x in iter { i += 1.0; let borrow = *x.borrow(); sum += borrow; let diff = i * borrow - sum; variance += diff * diff / (i * (i - 1.0)); } variance / i } fn population_std_dev(self) -> f64 { self.population_variance().sqrt() } fn covariance(self, other: Self) -> f64 { let mut n = 0.0; let mut mean1 = 0.0; let mut mean2 = 0.0; let mut comoment = 0.0; let mut iter = other.into_iter(); for x in self { let borrow = *x.borrow(); let borrow2 = match iter.next() { None => panic!("Iterators must have the same length"), Some(x) => *x.borrow(), }; let old_mean2 = mean2; n += 1.0; mean1 += (borrow - mean1) / n; mean2 += (borrow2 - mean2) / n; comoment += (borrow - mean1) * (borrow2 - old_mean2); } if iter.next().is_some() { panic!("Iterators must have the same length"); } if n > 1.0 { comoment / (n - 1.0) } else { f64::NAN } } fn population_covariance(self, other: Self) -> f64 { let mut n = 0.0; let mut mean1 = 0.0; let mut mean2 = 0.0; let mut comoment = 0.0; let mut iter = other.into_iter(); for x in self { let borrow = *x.borrow(); let borrow2 = match iter.next() { None => panic!("Iterators must have the same length"), Some(x) => *x.borrow(), }; let old_mean2 = mean2; n += 1.0; mean1 += (borrow - mean1) / n; mean2 += (borrow2 - mean2) / n; comoment += (borrow - mean1) * (borrow2 - old_mean2); } if iter.next().is_some() { panic!("Iterators must have the same length") } if n > 0.0 { comoment / n } else { f64::NAN } } fn quadratic_mean(self) -> f64 { let mut i = 0.0; let mut mean = 0.0; for x in self { let borrow = *x.borrow(); i += 1.0; mean += (borrow * borrow - mean) / i; } if i > 0.0 { mean.sqrt() } else { f64::NAN } } } #[rustfmt::skip] #[cfg(test)] mod tests { use std::f64::consts; use crate::statistics::Statistics; use crate::generate::{InfinitePeriodic, InfiniteSinusoidal}; #[test] fn test_empty_data_returns_nan() { let data = [0.0; 0]; assert!(data.min().is_nan()); assert!(data.max().is_nan()); assert!(data.mean().is_nan()); assert!(data.quadratic_mean().is_nan()); assert!(data.variance().is_nan()); assert!(data.population_variance().is_nan()); } // TODO: test github issue 137 (Math.NET) #[test] fn test_large_samples() { let shorter = InfinitePeriodic::default(4.0, 1.0).take(4*4096).collect::>(); let longer = InfinitePeriodic::default(4.0, 1.0).take(4*32768).collect::>(); assert_almost_eq!((&shorter).mean(), 0.375, 1e-14); assert_almost_eq!((&longer).mean(), 0.375, 1e-14); assert_almost_eq!((&shorter).quadratic_mean(), (0.21875f64).sqrt(), 1e-14); assert_almost_eq!((&longer).quadratic_mean(), (0.21875f64).sqrt(), 1e-14); } #[test] fn test_quadratic_mean_of_sinusoidal() { let data = InfiniteSinusoidal::default(64.0, 16.0, 2.0).take(128).collect::>(); assert_almost_eq!((&data).quadratic_mean(), 2.0 / consts::SQRT_2, 1e-15); } } statrs-0.18.0/src/statistics/mod.rs000064400000000000000000000005011046102023000153340ustar 00000000000000//! Provides traits for statistical computation pub use self::order_statistics::*; pub use self::slice_statistics::*; pub use self::statistics::*; pub use self::traits::*; mod iter_statistics; mod order_statistics; // TODO: fix later mod slice_statistics; #[allow(clippy::module_inception)] mod statistics; mod traits; statrs-0.18.0/src/statistics/order_statistics.rs000064400000000000000000000144031046102023000201500ustar 00000000000000use super::RankTieBreaker; /// The `OrderStatistics` trait provides statistical utilities /// having to do with ordering. All the algorithms are in-place thus requiring /// a mutable borrow. pub trait OrderStatistics { /// Returns the order statistic `(order 1..N)` from the data /// /// # Remarks /// /// No sorting is assumed. Order must be one-based (between `1` and `N` /// inclusive) /// Returns `f64::NAN` if order is outside the viable range or data is /// empty. /// /// # Examples /// /// ``` /// use statrs::statistics::OrderStatistics; /// use statrs::statistics::Data; /// /// let x = []; /// let mut x = Data::new(x); /// assert!(x.order_statistic(1).is_nan()); /// /// let y = [0.0, 3.0, -2.0]; /// let mut y = Data::new(y); /// assert!(y.order_statistic(0).is_nan()); /// assert!(y.order_statistic(4).is_nan()); /// assert_eq!(y.order_statistic(2), 0.0); /// assert!(y != Data::new([0.0, 3.0, -2.0])); /// ``` fn order_statistic(&mut self, order: usize) -> T; /// Returns the median value from the data /// /// # Remarks /// /// Returns `f64::NAN` if data is empty /// /// # Examples /// /// ``` /// use statrs::statistics::OrderStatistics; /// use statrs::statistics::Data; /// /// let x = []; /// let mut x = Data::new(x); /// assert!(x.median().is_nan()); /// /// let y = [0.0, 3.0, -2.0]; /// let mut y = Data::new(y); /// assert_eq!(y.median(), 0.0); /// assert!(y != Data::new([0.0, 3.0, -2.0])); fn median(&mut self) -> T; /// Estimates the tau-th quantile from the data. The tau-th quantile /// is the data value where the cumulative distribution function crosses /// tau. /// /// # Remarks /// /// No sorting is assumed. Tau must be between `0` and `1` inclusive. /// Returns `f64::NAN` if data is empty or tau is outside the inclusive /// range. /// /// # Examples /// /// ``` /// use statrs::statistics::OrderStatistics; /// use statrs::statistics::Data; /// /// let x = []; /// let mut x = Data::new(x); /// assert!(x.quantile(0.5).is_nan()); /// /// let y = [0.0, 3.0, -2.0]; /// let mut y = Data::new(y); /// assert!(y.quantile(-1.0).is_nan()); /// assert!(y.quantile(2.0).is_nan()); /// assert_eq!(y.quantile(0.5), 0.0); /// assert!(y != Data::new([0.0, 3.0, -2.0])); /// ``` fn quantile(&mut self, tau: f64) -> T; /// Estimates the p-Percentile value from the data. /// /// # Remarks /// /// Use quantile for non-integer percentiles. `p` must be between `0` and /// `100` inclusive. /// Returns `f64::NAN` if data is empty or `p` is outside the inclusive /// range. /// /// # Examples /// /// ``` /// use statrs::statistics::OrderStatistics; /// use statrs::statistics::Data; /// /// let x = []; /// let mut x = Data::new(x); /// assert!(x.percentile(0).is_nan()); /// /// let y = [1.0, 5.0, 3.0, 4.0, 10.0, 9.0, 6.0, 7.0, 8.0, 2.0]; /// let mut y = Data::new(y); /// assert_eq!(y.percentile(0), 1.0); /// assert_eq!(y.percentile(50), 5.5); /// assert_eq!(y.percentile(100), 10.0); /// assert!(y.percentile(105).is_nan()); /// assert!(y != Data::new([1.0, 5.0, 3.0, 4.0, 10.0, 9.0, 6.0, 7.0, 8.0, 2.0])); /// ``` fn percentile(&mut self, p: usize) -> T; /// Estimates the first quartile value from the data. /// /// # Remarks /// /// Returns `f64::NAN` if data is empty /// /// # Examples /// /// ``` /// #[macro_use] /// extern crate statrs; /// /// use statrs::statistics::OrderStatistics; /// use statrs::statistics::Data; /// /// # fn main() { /// let x = []; /// let mut x = Data::new(x); /// assert!(x.lower_quartile().is_nan()); /// /// let y = [2.0, 1.0, 3.0, 4.0]; /// let mut y = Data::new(y); /// assert_almost_eq!(y.lower_quartile(), 1.416666666666666, 1e-15); /// assert!(y != Data::new([2.0, 1.0, 3.0, 4.0])); /// # } /// ``` fn lower_quartile(&mut self) -> T; /// Estimates the third quartile value from the data. /// /// # Remarks /// /// Returns `f64::NAN` if data is empty /// /// # Examples /// /// ``` /// #[macro_use] /// extern crate statrs; /// /// use statrs::statistics::OrderStatistics; /// use statrs::statistics::Data; /// /// # fn main() { /// let x = []; /// let mut x = Data::new(x); /// assert!(x.upper_quartile().is_nan()); /// /// let y = [2.0, 1.0, 3.0, 4.0]; /// let mut y = Data::new(y); /// assert_almost_eq!(y.upper_quartile(), 3.5833333333333333, 1e-15); /// assert!(y != Data::new([2.0, 1.0, 3.0, 4.0])); /// # } /// ``` fn upper_quartile(&mut self) -> T; /// Estimates the inter-quartile range from the data. /// /// # Remarks /// /// Returns `f64::NAN` if data is empty /// /// # Examples /// /// ``` /// #[macro_use] /// extern crate statrs; /// /// use statrs::statistics::Data; /// use statrs::statistics::OrderStatistics; /// /// # fn main() { /// let x = []; /// let mut x = Data::new(x); /// assert!(x.interquartile_range().is_nan()); /// /// let y = [2.0, 1.0, 3.0, 4.0]; /// let mut y = Data::new(y); /// assert_almost_eq!(y.interquartile_range(), 2.166666666666667, 1e-15); /// assert!(y != Data::new([2.0, 1.0, 3.0, 4.0])); /// # } /// ``` fn interquartile_range(&mut self) -> T; /// Evaluates the rank of each entry of the data. /// /// # Examples /// /// ``` /// use statrs::statistics::{OrderStatistics, RankTieBreaker}; /// use statrs::statistics::Data; /// /// let x = []; /// let mut x = Data::new(x); /// assert_eq!(x.ranks(RankTieBreaker::Average).len(), 0); /// /// let y = [1.0, 3.0, 2.0, 2.0]; /// let mut y = Data::new([1.0, 3.0, 2.0, 2.0]); /// assert_eq!(y.clone().ranks(RankTieBreaker::Average), [1.0, 4.0, /// 2.5, 2.5]); /// assert_eq!(y.clone().ranks(RankTieBreaker::Min), [1.0, 4.0, 2.0, /// 2.0]); /// ``` fn ranks(&mut self, tie_breaker: RankTieBreaker) -> Vec; } statrs-0.18.0/src/statistics/slice_statistics.rs000064400000000000000000000406561046102023000201450ustar 00000000000000use crate::statistics::*; use core::ops::{Index, IndexMut}; #[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] pub struct Data(D); impl std::fmt::Display for Data where D: Clone + IntoIterator, I: Clone + std::fmt::Display, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut tee = self.0.clone().into_iter(); write!(f, "Data([")?; if let Some(v) = tee.next() { write!(f, "{v}")?; } for _ in 1..5 { if let Some(v) = tee.next() { write!(f, ", {v}")?; } } if tee.next().is_some() { write!(f, "...")?; } write!(f, "])") } } impl> Index for Data { type Output = f64; fn index(&self, i: usize) -> &f64 { &self.0.as_ref()[i] } } impl + AsRef<[f64]>> IndexMut for Data { fn index_mut(&mut self, i: usize) -> &mut f64 { &mut self.0.as_mut()[i] } } impl + AsRef<[f64]>> Data { pub fn new(data: D) -> Self { Data(data) } pub fn swap(&mut self, i: usize, j: usize) { self.0.as_mut().swap(i, j) } pub fn len(&self) -> usize { self.0.as_ref().len() } pub fn is_empty(&self) -> bool { self.0.as_ref().len() == 0 } pub fn iter(&self) -> core::slice::Iter<'_, f64> { self.0.as_ref().iter() } // Selection algorithm from Numerical Recipes // See: https://en.wikipedia.org/wiki/Selection_algorithm fn select_inplace(&mut self, rank: usize) -> f64 { if rank == 0 { return self.min(); } if rank > self.len() - 1 { return self.max(); } let mut low = 0; let mut high = self.len() - 1; loop { if high <= low + 1 { if high == low + 1 && self[high] < self[low] { self.swap(low, high) } return self[rank]; } let middle = (low + high) / 2; self.swap(middle, low + 1); if self[low] > self[high] { self.swap(low, high); } if self[low + 1] > self[high] { self.swap(low + 1, high); } if self[low] > self[low + 1] { self.swap(low, low + 1); } let mut begin = low + 1; let mut end = high; let pivot = self[begin]; loop { loop { begin += 1; if self[begin] >= pivot { break; } } loop { end -= 1; if self[end] <= pivot { break; } } if end < begin { break; } self.swap(begin, end); } self[low + 1] = self[end]; self[end] = pivot; if end >= rank { high = end - 1; } if end <= rank { low = begin; } } } } #[cfg(feature = "rand")] impl> ::rand::distributions::Distribution for Data { fn sample(&self, rng: &mut R) -> f64 { use rand::prelude::SliceRandom; *self.0.as_ref().choose(rng).unwrap() } } impl + AsRef<[f64]>> OrderStatistics for Data { fn order_statistic(&mut self, order: usize) -> f64 { let n = self.len(); match order { 1 => self.min(), _ if order == n => self.max(), _ if order < 1 || order > n => f64::NAN, _ => self.select_inplace(order - 1), } } fn median(&mut self) -> f64 { let k = self.len() / 2; if self.len() % 2 != 0 { self.select_inplace(k) } else { (self.select_inplace(k.saturating_sub(1)) + self.select_inplace(k)) / 2.0 } } fn quantile(&mut self, tau: f64) -> f64 { if !(0.0..=1.0).contains(&tau) || self.is_empty() { return f64::NAN; } let h = (self.len() as f64 + 1.0 / 3.0) * tau + 1.0 / 3.0; let hf = h as i64; if hf <= 0 || tau == 0.0 { return self.min(); } if hf >= self.len() as i64 || ulps_eq!(tau, 1.0) { return self.max(); } let a = self.select_inplace((hf as usize).saturating_sub(1)); let b = self.select_inplace(hf as usize); a + (h - hf as f64) * (b - a) } fn percentile(&mut self, p: usize) -> f64 { self.quantile(p as f64 / 100.0) } fn lower_quartile(&mut self) -> f64 { self.quantile(0.25) } fn upper_quartile(&mut self) -> f64 { self.quantile(0.75) } fn interquartile_range(&mut self) -> f64 { self.upper_quartile() - self.lower_quartile() } fn ranks(&mut self, tie_breaker: RankTieBreaker) -> Vec { let n = self.len(); let mut ranks: Vec = vec![0.0; n]; let mut enumerated: Vec<_> = self.iter().enumerate().collect(); enumerated.sort_by(|(_, el_a), (_, el_b)| el_a.partial_cmp(el_b).unwrap()); match tie_breaker { RankTieBreaker::First => { for (i, idx) in enumerated.into_iter().map(|(idx, _)| idx).enumerate() { ranks[idx] = (i + 1) as f64 } ranks } _ => { let mut prev = 0; let mut prev_idx = 0; let mut prev_elt = 0.0; for (i, (idx, elt)) in enumerated.iter().cloned().enumerate() { if i == 0 { prev_idx = idx; prev_elt = *elt; } if (*elt - prev_elt).abs() <= 0.0 { continue; } if i == prev + 1 { ranks[prev_idx] = i as f64; } else { handle_rank_ties(&mut ranks, &enumerated, prev, i, tie_breaker); } prev = i; prev_idx = idx; prev_elt = *elt; } handle_rank_ties(&mut ranks, &enumerated, prev, n, tie_breaker); ranks } } } } impl + AsRef<[f64]>> Min for Data { /// Returns the minimum value in the data /// /// # Remarks /// /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` /// /// # Examples /// /// ``` /// use statrs::statistics::Min; /// use statrs::statistics::Data; /// /// let x = []; /// let x = Data::new(x); /// assert!(x.min().is_nan()); /// /// let y = [0.0, f64::NAN, 3.0, -2.0]; /// let y = Data::new(y); /// assert!(y.min().is_nan()); /// /// let z = [0.0, 3.0, -2.0]; /// let z = Data::new(z); /// assert_eq!(z.min(), -2.0); /// ``` fn min(&self) -> f64 { Statistics::min(self.iter()) } } impl + AsRef<[f64]>> Max for Data { /// Returns the maximum value in the data /// /// # Remarks /// /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` /// /// # Examples /// /// ``` /// use statrs::statistics::Max; /// use statrs::statistics::Data; /// /// let x = []; /// let x = Data::new(x); /// assert!(x.max().is_nan()); /// /// let y = [0.0, f64::NAN, 3.0, -2.0]; /// let y = Data::new(y); /// assert!(y.max().is_nan()); /// /// let z = [0.0, 3.0, -2.0]; /// let z = Data::new(z); /// assert_eq!(z.max(), 3.0); /// ``` fn max(&self) -> f64 { Statistics::max(self.iter()) } } impl + AsRef<[f64]>> Distribution for Data { /// Evaluates the sample mean, an estimate of the population /// mean. /// /// # Remarks /// /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` /// /// # Examples /// /// ``` /// #[macro_use] /// extern crate statrs; /// /// use statrs::statistics::Distribution; /// use statrs::statistics::Data; /// /// # fn main() { /// let x = []; /// let x = Data::new(x); /// assert!(x.mean().unwrap().is_nan()); /// /// let y = [0.0, f64::NAN, 3.0, -2.0]; /// let y = Data::new(y); /// assert!(y.mean().unwrap().is_nan()); /// /// let z = [0.0, 3.0, -2.0]; /// let z = Data::new(z); /// assert_almost_eq!(z.mean().unwrap(), 1.0 / 3.0, 1e-15); /// # } /// ``` fn mean(&self) -> Option { Some(Statistics::mean(self.iter())) } /// Estimates the unbiased population variance from the provided samples /// /// # Remarks /// /// On a dataset of size `N`, `N-1` is used as a normalizer (Bessel's /// correction). /// /// Returns `f64::NAN` if data has less than two entries or if any entry is /// `f64::NAN` /// /// # Examples /// /// ``` /// use statrs::statistics::Distribution; /// use statrs::statistics::Data; /// /// let x = []; /// let x = Data::new(x); /// assert!(x.variance().unwrap().is_nan()); /// /// let y = [0.0, f64::NAN, 3.0, -2.0]; /// let y = Data::new(y); /// assert!(y.variance().unwrap().is_nan()); /// /// let z = [0.0, 3.0, -2.0]; /// let z = Data::new(z); /// assert_eq!(z.variance().unwrap(), 19.0 / 3.0); /// ``` fn variance(&self) -> Option { Some(Statistics::variance(self.iter())) } } impl + AsRef<[f64]> + Clone> Median for Data { /// Returns the median value from the data /// /// # Remarks /// /// Returns `f64::NAN` if data is empty /// /// # Examples /// /// ``` /// use statrs::statistics::Median; /// use statrs::statistics::Data; /// /// let x = []; /// let x = Data::new(x); /// assert!(x.median().is_nan()); /// /// let y = [0.0, 3.0, -2.0]; /// let y = Data::new(y); /// assert_eq!(y.median(), 0.0); fn median(&self) -> f64 { let mut v = self.clone(); OrderStatistics::median(&mut v) } } fn handle_rank_ties( ranks: &mut [f64], index: &[(usize, &f64)], a: usize, b: usize, tie_breaker: RankTieBreaker, ) { let rank = match tie_breaker { // equivalent to (b + a - 1) as f64 / 2.0 + 1.0 but less overflow issues RankTieBreaker::Average => b as f64 / 2.0 + a as f64 / 2.0 + 0.5, RankTieBreaker::Min => (a + 1) as f64, RankTieBreaker::Max => b as f64, RankTieBreaker::First => unreachable!(), }; for i in &index[a..b] { ranks[i.0] = rank } } #[cfg(test)] mod tests { use super::*; use crate::statistics::*; #[test] fn test_order_statistic_short() { let data = [-1.0, 5.0, 0.0, -3.0, 10.0, -0.5, 4.0, 1.0, 6.0]; let mut data = Data::new(data); assert!(data.order_statistic(0).is_nan()); assert_eq!(data.order_statistic(1), -3.0); assert_eq!(data.order_statistic(2), -1.0); assert_eq!(data.order_statistic(3), -0.5); assert_eq!(data.order_statistic(7), 5.0); assert_eq!(data.order_statistic(8), 6.0); assert_eq!(data.order_statistic(9), 10.0); assert!(data.order_statistic(10).is_nan()); } #[test] fn test_quantile_short() { let data = [-1.0, 5.0, 0.0, -3.0, 10.0, -0.5, 4.0, 0.2, 1.0, 6.0]; let mut data = Data::new(data); assert_eq!(data.quantile(0.0), -3.0); assert_eq!(data.quantile(1.0), 10.0); assert_almost_eq!(data.quantile(0.5), 3.0 / 5.0, 1e-15); assert_almost_eq!(data.quantile(0.2), -4.0 / 5.0, 1e-15); assert_eq!(data.quantile(0.7), 137.0 / 30.0); assert_eq!(data.quantile(0.01), -3.0); assert_eq!(data.quantile(0.99), 10.0); assert_almost_eq!(data.quantile(0.52), 287.0 / 375.0, 1e-15); assert_almost_eq!(data.quantile(0.325), -37.0 / 240.0, 1e-15); } #[test] fn test_ranks() { let sorted_distinct = [1.0, 2.0, 4.0, 7.0, 8.0, 9.0, 10.0, 12.0]; let mut sorted_distinct = Data::new(sorted_distinct); let sorted_ties = [1.0, 2.0, 2.0, 7.0, 9.0, 9.0, 10.0, 12.0]; let mut sorted_ties = Data::new(sorted_ties); assert_eq!( sorted_distinct.ranks(RankTieBreaker::Average), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] ); assert_eq!( sorted_ties.ranks(RankTieBreaker::Average), [1.0, 2.5, 2.5, 4.0, 5.5, 5.5, 7.0, 8.0] ); assert_eq!( sorted_distinct.ranks(RankTieBreaker::Min), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] ); assert_eq!( sorted_ties.ranks(RankTieBreaker::Min), [1.0, 2.0, 2.0, 4.0, 5.0, 5.0, 7.0, 8.0] ); assert_eq!( sorted_distinct.ranks(RankTieBreaker::Max), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] ); assert_eq!( sorted_ties.ranks(RankTieBreaker::Max), [1.0, 3.0, 3.0, 4.0, 6.0, 6.0, 7.0, 8.0] ); assert_eq!( sorted_distinct.ranks(RankTieBreaker::First), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] ); assert_eq!( sorted_ties.ranks(RankTieBreaker::First), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] ); let distinct = [1.0, 8.0, 12.0, 7.0, 2.0, 9.0, 10.0, 4.0]; let distinct = Data::new(distinct); let ties = [1.0, 9.0, 12.0, 7.0, 2.0, 9.0, 10.0, 2.0]; let ties = Data::new(ties); assert_eq!( distinct.clone().ranks(RankTieBreaker::Average), [1.0, 5.0, 8.0, 4.0, 2.0, 6.0, 7.0, 3.0] ); assert_eq!( ties.clone().ranks(RankTieBreaker::Average), [1.0, 5.5, 8.0, 4.0, 2.5, 5.5, 7.0, 2.5] ); assert_eq!( distinct.clone().ranks(RankTieBreaker::Min), [1.0, 5.0, 8.0, 4.0, 2.0, 6.0, 7.0, 3.0] ); assert_eq!( ties.clone().ranks(RankTieBreaker::Min), [1.0, 5.0, 8.0, 4.0, 2.0, 5.0, 7.0, 2.0] ); assert_eq!( distinct.clone().ranks(RankTieBreaker::Max), [1.0, 5.0, 8.0, 4.0, 2.0, 6.0, 7.0, 3.0] ); assert_eq!( ties.clone().ranks(RankTieBreaker::Max), [1.0, 6.0, 8.0, 4.0, 3.0, 6.0, 7.0, 3.0] ); assert_eq!( distinct.clone().ranks(RankTieBreaker::First), [1.0, 5.0, 8.0, 4.0, 2.0, 6.0, 7.0, 3.0] ); assert_eq!( ties.clone().ranks(RankTieBreaker::First), [1.0, 5.0, 8.0, 4.0, 2.0, 6.0, 7.0, 3.0] ); } #[test] fn test_median_short() { let even = [-1.0, 5.0, 0.0, -3.0, 10.0, -0.5, 4.0, 0.2, 1.0, 6.0]; let even = Data::new(even); assert_eq!(even.median(), 0.6); let odd = [-1.0, 5.0, 0.0, -3.0, 10.0, -0.5, 4.0, 0.2, 1.0]; let odd = Data::new(odd); assert_eq!(odd.median(), 0.2); } #[test] fn test_median_long_constant_seq() { let even = vec![2.0; 100000]; let even = Data::new(even); assert_eq!(2.0, even.median()); let odd = vec![2.0; 100001]; let odd = Data::new(odd); assert_eq!(2.0, odd.median()); } // TODO: test codeplex issue 5667 (Math.NET) #[test] fn test_median_robust_on_infinities() { let data3 = [2.0, f64::NEG_INFINITY, f64::INFINITY]; let data3 = Data::new(data3); assert_eq!(data3.median(), 2.0); assert_eq!(data3.median(), 2.0); let data3 = [f64::NEG_INFINITY, 2.0, f64::INFINITY]; let data3 = Data::new(data3); assert_eq!(data3.median(), 2.0); assert_eq!(data3.median(), 2.0); let data3 = [f64::NEG_INFINITY, f64::INFINITY, 2.0]; let data3 = Data::new(data3); assert_eq!(data3.median(), 2.0); assert_eq!(data3.median(), 2.0); let data4 = [f64::NEG_INFINITY, 2.0, 3.0, f64::INFINITY]; let data4 = Data::new(data4); assert_eq!(data4.median(), 2.5); assert_eq!(data4.median(), 2.5); } #[test] fn test_foo() { let arr = [0.0, 1.0, 2.0, 3.0]; let mut arr = Data::new(arr); arr.order_statistic(2); } } statrs-0.18.0/src/statistics/statistics.rs000064400000000000000000000262661046102023000167670ustar 00000000000000/// Enumeration of possible tie-breaking strategies /// when computing ranks #[derive(Copy, Clone, Debug)] pub enum RankTieBreaker { /// Replaces ties with their mean Average, /// Replace ties with their minimum Min, /// Replace ties with their maximum Max, /// Permutation with increasing values at each index of ties First, } /// The `Statistics` trait provides a host of statistical utilities for /// analyzing /// data sets pub trait Statistics { /// Returns the minimum value in the data /// /// # Remarks /// /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` /// /// # Examples /// /// ``` /// use std::f64; /// use statrs::statistics::Statistics; /// /// let x = &[]; /// assert!(Statistics::min(x).is_nan()); /// /// let y = &[0.0, f64::NAN, 3.0, -2.0]; /// assert!(Statistics::min(y).is_nan()); /// /// let z = &[0.0, 3.0, -2.0]; /// assert_eq!(Statistics::min(z), -2.0); /// ``` fn min(self) -> T; /// Returns the maximum value in the data /// /// # Remarks /// /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` /// /// # Examples /// /// ``` /// use std::f64; /// use statrs::statistics::Statistics; /// /// let x = &[]; /// assert!(Statistics::max(x).is_nan()); /// /// let y = &[0.0, f64::NAN, 3.0, -2.0]; /// assert!(Statistics::max(y).is_nan()); /// /// let z = &[0.0, 3.0, -2.0]; /// assert_eq!(Statistics::max(z), 3.0); /// ``` fn max(self) -> T; /// Returns the minimum absolute value in the data /// /// # Remarks /// /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` /// /// # Examples /// /// ``` /// use std::f64; /// use statrs::statistics::Statistics; /// /// let x = &[]; /// assert!(x.abs_min().is_nan()); /// /// let y = &[0.0, f64::NAN, 3.0, -2.0]; /// assert!(y.abs_min().is_nan()); /// /// let z = &[0.0, 3.0, -2.0]; /// assert_eq!(z.abs_min(), 0.0); /// ``` fn abs_min(self) -> T; /// Returns the maximum absolute value in the data /// /// # Remarks /// /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` /// /// # Examples /// /// ``` /// use std::f64; /// use statrs::statistics::Statistics; /// /// let x = &[]; /// assert!(x.abs_max().is_nan()); /// /// let y = &[0.0, f64::NAN, 3.0, -2.0]; /// assert!(y.abs_max().is_nan()); /// /// let z = &[0.0, 3.0, -2.0, -8.0]; /// assert_eq!(z.abs_max(), 8.0); /// ``` fn abs_max(self) -> T; /// Evaluates the sample mean, an estimate of the population /// mean. /// /// # Remarks /// /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` /// /// # Examples /// /// ``` /// #[macro_use] /// extern crate statrs; /// /// use std::f64; /// use statrs::statistics::Statistics; /// /// # fn main() { /// let x = &[]; /// assert!(x.mean().is_nan()); /// /// let y = &[0.0, f64::NAN, 3.0, -2.0]; /// assert!(y.mean().is_nan()); /// /// let z = &[0.0, 3.0, -2.0]; /// assert_almost_eq!(z.mean(), 1.0 / 3.0, 1e-15); /// # } /// ``` fn mean(self) -> T; /// Evaluates the geometric mean of the data /// /// # Remarks /// /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN`. /// Returns `f64::NAN` if an entry is less than `0`. Returns `0` /// if no entry is less than `0` but there are entries equal to `0`. /// /// # Examples /// /// ``` /// #[macro_use] /// extern crate statrs; /// /// use std::f64; /// use statrs::statistics::Statistics; /// /// # fn main() { /// let x = &[]; /// assert!(x.geometric_mean().is_nan()); /// /// let y = &[0.0, f64::NAN, 3.0, -2.0]; /// assert!(y.geometric_mean().is_nan()); /// /// let mut z = &[0.0, 3.0, -2.0]; /// assert!(z.geometric_mean().is_nan()); /// /// z = &[0.0, 3.0, 2.0]; /// assert_eq!(z.geometric_mean(), 0.0); /// /// z = &[1.0, 2.0, 3.0]; /// // test value from online calculator, could be more accurate /// assert_almost_eq!(z.geometric_mean(), 1.81712, 1e-5); /// # } /// ``` fn geometric_mean(self) -> T; /// Evaluates the harmonic mean of the data /// /// # Remarks /// /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN`, or if /// any value /// in data is less than `0`. Returns `0` if there are no values less than /// `0` but /// there exists values equal to `0`. /// /// # Examples /// /// ``` /// #[macro_use] /// extern crate statrs; /// /// use std::f64; /// use statrs::statistics::Statistics; /// /// # fn main() { /// let x = &[]; /// assert!(x.harmonic_mean().is_nan()); /// /// let y = &[0.0, f64::NAN, 3.0, -2.0]; /// assert!(y.harmonic_mean().is_nan()); /// /// let mut z = &[0.0, 3.0, -2.0]; /// assert!(z.harmonic_mean().is_nan()); /// /// z = &[0.0, 3.0, 2.0]; /// assert_eq!(z.harmonic_mean(), 0.0); /// /// z = &[1.0, 2.0, 3.0]; /// // test value from online calculator, could be more accurate /// assert_almost_eq!(z.harmonic_mean(), 1.63636, 1e-5); /// # } /// ``` fn harmonic_mean(self) -> T; /// Estimates the unbiased population variance from the provided samples /// /// # Remarks /// /// On a dataset of size `N`, `N-1` is used as a normalizer (Bessel's /// correction). /// /// Returns `f64::NAN` if data has less than two entries or if any entry is /// `f64::NAN` /// /// # Examples /// /// ``` /// use std::f64; /// use statrs::statistics::Statistics; /// /// let x = &[]; /// assert!(x.variance().is_nan()); /// /// let y = &[0.0, f64::NAN, 3.0, -2.0]; /// assert!(y.variance().is_nan()); /// /// let z = &[0.0, 3.0, -2.0]; /// assert_eq!(z.variance(), 19.0 / 3.0); /// ``` fn variance(self) -> T; /// Estimates the unbiased population standard deviation from the provided /// samples /// /// # Remarks /// /// On a dataset of size `N`, `N-1` is used as a normalizer (Bessel's /// correction). /// /// Returns `f64::NAN` if data has less than two entries or if any entry is /// `f64::NAN` /// /// # Examples /// /// ``` /// use std::f64; /// use statrs::statistics::Statistics; /// /// let x = &[]; /// assert!(x.std_dev().is_nan()); /// /// let y = &[0.0, f64::NAN, 3.0, -2.0]; /// assert!(y.std_dev().is_nan()); /// /// let z = &[0.0, 3.0, -2.0]; /// assert_eq!(z.std_dev(), (19f64 / 3.0).sqrt()); /// ``` fn std_dev(self) -> T; /// Evaluates the population variance from a full population. /// /// # Remarks /// /// On a dataset of size `N`, `N` is used as a normalizer and would thus /// be biased if applied to a subset /// /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` /// /// # Examples /// /// ``` /// use std::f64; /// use statrs::statistics::Statistics; /// /// let x = &[]; /// assert!(x.population_variance().is_nan()); /// /// let y = &[0.0, f64::NAN, 3.0, -2.0]; /// assert!(y.population_variance().is_nan()); /// /// let z = &[0.0, 3.0, -2.0]; /// assert_eq!(z.population_variance(), 38.0 / 9.0); /// ``` fn population_variance(self) -> T; /// Evaluates the population standard deviation from a full population. /// /// # Remarks /// /// On a dataset of size `N`, `N` is used as a normalizer and would thus /// be biased if applied to a subset /// /// Returns `f64::NAN` if data is empty or an entry is `f64::NAN` /// /// # Examples /// /// ``` /// use std::f64; /// use statrs::statistics::Statistics; /// /// let x = &[]; /// assert!(x.population_std_dev().is_nan()); /// /// let y = &[0.0, f64::NAN, 3.0, -2.0]; /// assert!(y.population_std_dev().is_nan()); /// /// let z = &[0.0, 3.0, -2.0]; /// assert_eq!(z.population_std_dev(), (38f64 / 9.0).sqrt()); /// ``` fn population_std_dev(self) -> T; /// Estimates the unbiased population covariance between the two provided /// samples /// /// # Remarks /// /// On a dataset of size `N`, `N-1` is used as a normalizer (Bessel's /// correction). /// /// Returns `f64::NAN` if data has less than two entries or if any entry is /// `f64::NAN` /// /// # Panics /// /// If the two sample containers do not contain the same number of elements /// /// # Examples /// /// ``` /// #[macro_use] /// extern crate statrs; /// /// use std::f64; /// use statrs::statistics::Statistics; /// /// # fn main() { /// let x = &[]; /// assert!(x.covariance(&[]).is_nan()); /// /// let y1 = &[0.0, f64::NAN, 3.0, -2.0]; /// let y2 = &[-5.0, 4.0, 10.0, f64::NAN]; /// assert!(y1.covariance(y2).is_nan()); /// /// let z1 = &[0.0, 3.0, -2.0]; /// let z2 = &[-5.0, 4.0, 10.0]; /// assert_almost_eq!(z1.covariance(z2), -5.5, 1e-14); /// # } /// ``` fn covariance(self, other: Self) -> T; /// Evaluates the population covariance between the two provider populations /// /// # Remarks /// /// On a dataset of size `N`, `N` is used as a normalizer and would thus be /// biased if applied to a subset /// /// Returns `f64::NAN` if data is empty or any entry is `f64::NAN` /// /// # Panics /// /// If the two sample containers do not contain the same number of elements /// /// # Examples /// /// ``` /// #[macro_use] /// extern crate statrs; /// /// use std::f64; /// use statrs::statistics::Statistics; /// /// # fn main() { /// let x = &[]; /// assert!(x.population_covariance(&[]).is_nan()); /// /// let y1 = &[0.0, f64::NAN, 3.0, -2.0]; /// let y2 = &[-5.0, 4.0, 10.0, f64::NAN]; /// assert!(y1.population_covariance(y2).is_nan()); /// /// let z1 = &[0.0, 3.0, -2.0]; /// let z2 = &[-5.0, 4.0, 10.0]; /// assert_almost_eq!(z1.population_covariance(z2), -11.0 / 3.0, 1e-14); /// # } /// ``` fn population_covariance(self, other: Self) -> T; /// Estimates the quadratic mean (Root Mean Square) of the data /// /// # Remarks /// /// Returns `f64::NAN` if data is empty or any entry is `f64::NAN` /// /// # Examples /// /// ``` /// #[macro_use] /// extern crate statrs; /// /// use std::f64; /// use statrs::statistics::Statistics; /// /// # fn main() { /// let x = &[]; /// assert!(x.quadratic_mean().is_nan()); /// /// let y = &[0.0, f64::NAN, 3.0, -2.0]; /// assert!(y.quadratic_mean().is_nan()); /// /// let z = &[0.0, 3.0, -2.0]; /// // test value from online calculator, could be more accurate /// assert_almost_eq!(z.quadratic_mean(), 2.08167, 1e-5); /// # } /// ``` fn quadratic_mean(self) -> T; } statrs-0.18.0/src/statistics/traits.rs000064400000000000000000000106031046102023000160670ustar 00000000000000use ::num_traits::float::Float; /// The `Min` trait specifies than an object has a minimum value pub trait Min { /// Returns the minimum value in the domain of a given distribution /// if it exists, otherwise `None`. /// /// # Examples /// /// ``` /// use statrs::statistics::Min; /// use statrs::distribution::Uniform; /// /// let n = Uniform::new(0.0, 1.0).unwrap(); /// assert_eq!(0.0, n.min()); /// ``` fn min(&self) -> T; } /// The `Max` trait specifies that an object has a maximum value pub trait Max { /// Returns the maximum value in the domain of a given distribution /// if it exists, otherwise `None`. /// /// # Examples /// /// ``` /// use statrs::statistics::Max; /// use statrs::distribution::Uniform; /// /// let n = Uniform::new(0.0, 1.0).unwrap(); /// assert_eq!(1.0, n.max()); /// ``` fn max(&self) -> T; } pub trait DiscreteDistribution { /// Returns the mean, if it exists. fn mean(&self) -> Option { None } /// Returns the variance, if it exists. fn variance(&self) -> Option { None } /// Returns the standard deviation, if it exists. fn std_dev(&self) -> Option { self.variance().map(|var| var.sqrt()) } /// Returns the entropy, if it exists. fn entropy(&self) -> Option { None } /// Returns the skewness, if it exists. fn skewness(&self) -> Option { None } } pub trait Distribution { /// Returns the mean, if it exists. /// /// # Examples /// /// ``` /// use statrs::statistics::Distribution; /// use statrs::distribution::Uniform; /// /// let n = Uniform::new(0.0, 1.0).unwrap(); /// assert_eq!(0.5, n.mean().unwrap()); /// ``` fn mean(&self) -> Option { None } /// Returns the variance, if it exists. /// /// # Examples /// /// ``` /// use statrs::statistics::Distribution; /// use statrs::distribution::Uniform; /// /// let n = Uniform::new(0.0, 1.0).unwrap(); /// assert_eq!(1.0 / 12.0, n.variance().unwrap()); /// ``` fn variance(&self) -> Option { None } /// Returns the standard deviation, if it exists. /// /// # Examples /// /// ``` /// use statrs::statistics::Distribution; /// use statrs::distribution::Uniform; /// /// let n = Uniform::new(0.0, 1.0).unwrap(); /// assert_eq!((1f64 / 12f64).sqrt(), n.std_dev().unwrap()); /// ``` fn std_dev(&self) -> Option { self.variance().map(|var| var.sqrt()) } /// Returns the entropy, if it exists. /// /// # Examples /// /// ``` /// use statrs::statistics::Distribution; /// use statrs::distribution::Uniform; /// /// let n = Uniform::new(0.0, 1.0).unwrap(); /// assert_eq!(0.0, n.entropy().unwrap()); /// ``` fn entropy(&self) -> Option { None } /// Returns the skewness, if it exists. /// /// # Examples /// /// ``` /// use statrs::statistics::Distribution; /// use statrs::distribution::Uniform; /// /// let n = Uniform::new(0.0, 1.0).unwrap(); /// assert_eq!(0.0, n.skewness().unwrap()); /// ``` fn skewness(&self) -> Option { None } } /// The `Mean` trait implements the calculation of a mean. // TODO: Clarify the traits of multidimensional distributions pub trait MeanN { fn mean(&self) -> Option; } // TODO: Clarify the traits of multidimensional distributions pub trait VarianceN { fn variance(&self) -> Option; } /// The `Median` trait returns the median of the distribution. pub trait Median { /// Returns the median. /// /// # Examples /// /// ``` /// use statrs::statistics::Median; /// use statrs::distribution::Uniform; /// /// let n = Uniform::new(0.0, 1.0).unwrap(); /// assert_eq!(0.5, n.median()); /// ``` fn median(&self) -> T; } /// The `Mode` trait specifies that an object has a closed form solution /// for its mode(s) pub trait Mode { /// Returns the mode, if one exists. /// /// # Examples /// /// ``` /// use statrs::statistics::Mode; /// use statrs::distribution::Uniform; /// /// let n = Uniform::new(0.0, 1.0).unwrap(); /// assert_eq!(Some(0.5), n.mode()); /// ``` fn mode(&self) -> T; } statrs-0.18.0/src/stats_tests/fisher.rs000064400000000000000000000273571046102023000162450ustar 00000000000000use super::Alternative; use crate::distribution::{Discrete, DiscreteCDF, Hypergeometric, HypergeometricError}; const EPSILON: f64 = 1.0 - 1e-4; /// Binary search in two-sided test with starting bound as argument fn binary_search( n: u64, n1: u64, n2: u64, mode: u64, p_exact: f64, epsilon: f64, upper: bool, ) -> u64 { let (mut min_val, mut max_val) = { if upper { (mode, n) } else { (0, mode) } }; let population = n1 + n2; let successes = n1; let draws = n; let dist = Hypergeometric::new(population, successes, draws).unwrap(); let mut guess = 0; loop { if max_val - min_val <= 1 { break; } guess = { if max_val == min_val + 1 && guess == min_val { max_val } else { (max_val + min_val) / 2 } }; let ng = { if upper { guess - 1 } else { guess + 1 } }; let pmf_comp = dist.pmf(ng); let p_guess = dist.pmf(guess); if p_guess <= p_exact && p_exact < pmf_comp { break; } if p_guess < p_exact { max_val = guess } else { min_val = guess } } if guess == 0 { guess = min_val } if upper { loop { if guess > 0 && dist.pmf(guess) < p_exact * epsilon { guess -= 1; } else { break; } } loop { if dist.pmf(guess) > p_exact / epsilon { guess += 1; } else { break; } } } else { loop { if dist.pmf(guess) < p_exact * epsilon { guess += 1; } else { break; } } loop { if guess > 0 && dist.pmf(guess) > p_exact / epsilon { guess -= 1; } else { break; } } } guess } #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[non_exhaustive] pub enum FishersExactTestError { /// The table does not describe a valid [`Hypergeometric`] distribution. /// Make sure that the contingency table stores the data in row-major order. TableInvalidForHypergeometric(HypergeometricError), } impl std::fmt::Display for FishersExactTestError { #[cfg_attr(coverage_nightly, coverage(off))] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { FishersExactTestError::TableInvalidForHypergeometric(hg_err) => { writeln!(f, "Cannot create a Hypergeometric distribution from the data in the contingency table.")?; writeln!(f, "Is it in row-major order?")?; write!(f, "Inner error: '{hg_err}'") } } } } impl std::error::Error for FishersExactTestError {} impl From for FishersExactTestError { fn from(value: HypergeometricError) -> Self { Self::TableInvalidForHypergeometric(value) } } /// Perform a Fisher exact test on a 2x2 contingency table. /// Based on scipy's fisher test: /// Expects a table in row-major order /// Returns the [odds ratio](https://en.wikipedia.org/wiki/Odds_ratio) and p_value /// # Examples /// /// ``` /// use statrs::stats_tests::fishers_exact_with_odds_ratio; /// use statrs::stats_tests::Alternative; /// let table = [3, 5, 4, 50]; /// let (odds_ratio, p_value) = fishers_exact_with_odds_ratio(&table, Alternative::Less).unwrap(); /// ``` pub fn fishers_exact_with_odds_ratio( table: &[u64; 4], alternative: Alternative, ) -> Result<(f64, f64), FishersExactTestError> { // If both values in a row or column are zero, p-value is 1 and odds ratio is NaN. match table { [0, _, 0, _] | [_, 0, _, 0] => return Ok((f64::NAN, 1.0)), // both 0 in a row [0, 0, _, _] | [_, _, 0, 0] => return Ok((f64::NAN, 1.0)), // both 0 in a column _ => (), // continue } let odds_ratio = { if table[1] > 0 && table[2] > 0 { (table[0] * table[3]) as f64 / (table[1] * table[2]) as f64 } else { f64::INFINITY } }; let p_value = fishers_exact(table, alternative)?; Ok((odds_ratio, p_value)) } /// Perform a Fisher exact test on a 2x2 contingency table. /// Based on scipy's fisher test: /// Expects a table in row-major order /// Returns only the p_value /// # Examples /// /// ``` /// use statrs::stats_tests::fishers_exact; /// use statrs::stats_tests::Alternative; /// let table = [3, 5, 4, 50]; /// let p_value = fishers_exact(&table, Alternative::Less).unwrap(); /// ``` pub fn fishers_exact( table: &[u64; 4], alternative: Alternative, ) -> Result { // If both values in a row or column are zero, the p-value is 1 and the odds ratio is NaN. match table { [0, _, 0, _] | [_, 0, _, 0] => return Ok(1.0), // both 0 in a row [0, 0, _, _] | [_, _, 0, 0] => return Ok(1.0), // both 0 in a column _ => (), // continue } let n1 = table[0] + table[1]; let n2 = table[2] + table[3]; let n = table[0] + table[2]; let p_value = { let population = n1 + n2; let successes = n1; match alternative { Alternative::Less => { let draws = n; let dist = Hypergeometric::new(population, successes, draws)?; dist.cdf(table[0]) } Alternative::Greater => { let draws = table[1] + table[3]; let dist = Hypergeometric::new(population, successes, draws)?; dist.cdf(table[1]) } Alternative::TwoSided => { let draws = n; let dist = Hypergeometric::new(population, successes, draws)?; let p_exact = dist.pmf(table[0]); let mode = ((n + 1) * (n1 + 1)) / (n1 + n2 + 2); let p_mode = dist.pmf(mode); if (p_exact - p_mode).abs() / p_exact.max(p_mode) <= 1.0 - EPSILON { return Ok(1.0); } if table[0] < mode { let p_lower = dist.cdf(table[0]); if dist.pmf(n) > p_exact / EPSILON { return Ok(p_lower); } let guess = binary_search(n, n1, n2, mode, p_exact, EPSILON, true); return Ok(p_lower + 1.0 - dist.cdf(guess - 1)); } let p_upper = 1.0 - dist.cdf(table[0] - 1); if dist.pmf(0) > p_exact / EPSILON { return Ok(p_upper); } let guess = binary_search(n, n1, n2, mode, p_exact, EPSILON, false); p_upper + dist.cdf(guess) } } }; Ok(p_value.min(1.0)) } #[cfg(test)] mod tests { use super::*; use crate::prec; /// Test fishers_exact by comparing against values from scipy. #[test] fn test_fishers_exact() { let cases = [ ( [3, 5, 4, 50], 0.9963034765672599, 0.03970749246529277, 0.03970749246529276, ), ( [61, 118, 2, 1], 0.27535061623455315, 0.9598172545684959, 0.27535061623455315, ), ( [172, 46, 90, 127], 1.0, 6.662405187351769e-16, 9.041009036528785e-16, ), ( [127, 38, 112, 43], 0.8637599357870167, 0.20040942958644145, 0.3687862842650179, ), ( [186, 177, 111, 154], 0.9918518696328176, 0.012550663906725129, 0.023439141644624434, ), ( [137, 49, 135, 183], 0.999999999998533, 5.6517533666400615e-12, 8.870999836202932e-12, ), ( [37, 115, 37, 152], 0.8834621182590621, 0.17638403366123565, 0.29400927608021704, ), ( [124, 117, 119, 175], 0.9956704915461392, 0.007134712391455461, 0.011588218284387445, ), ( [70, 114, 41, 118], 0.9945558498544903, 0.010384865876586255, 0.020438291037108678, ), ( [173, 21, 89, 7], 0.2303739114068352, 0.8808002774812677, 0.4027047267306024, ), ( [18, 147, 123, 58], 4.077820702304103e-29, 0.9999999999999817, 0.0, ), ( [116, 20, 92, 186], 0.9999999999998267, 6.598118571034892e-25, 8.164831402188242e-25, ), ( [9, 22, 44, 38], 0.01584272038710196, 0.9951463496539362, 0.021581786662999272, ), ( [9, 101, 135, 7], 3.3336213533847776e-50, 1.0, 3.3336213533847776e-50, ), ( [153, 27, 191, 144], 0.9999999999950817, 2.473736787266208e-11, 3.185816623300107e-11, ), ( [111, 195, 189, 69], 6.665245982898848e-19, 0.9999999999994574, 1.0735744915712542e-18, ), ( [125, 21, 31, 131], 0.99999999999974, 9.720661317939016e-34, 1.0352129312860277e-33, ), ( [201, 192, 69, 179], 0.9999999988714893, 3.1477232259550017e-09, 4.761075937088169e-09, ), ( [124, 138, 159, 160], 0.30153826772785475, 0.7538974235759873, 0.5601766196310243, ), ]; for (table, less_expected, greater_expected, two_sided_expected) in cases.iter() { for (alternative, expected) in [ Alternative::Less, Alternative::Greater, Alternative::TwoSided, ] .iter() .zip(vec![less_expected, greater_expected, two_sided_expected]) { let p_value = fishers_exact(table, *alternative).unwrap(); assert!(prec::almost_eq(p_value, *expected, 1e-12)); } } } #[test] fn test_fishers_exact_for_trivial() { let cases = [[0, 0, 1, 2], [1, 2, 0, 0], [1, 0, 2, 0], [0, 1, 0, 2]]; for table in cases.iter() { assert_eq!(fishers_exact(table, Alternative::Less).unwrap(), 1.0) } } #[test] fn test_fishers_exact_with_odds() { let table = [3, 5, 4, 50]; let (odds_ratio, p_value) = fishers_exact_with_odds_ratio(&table, Alternative::Less).unwrap(); assert!(prec::almost_eq(p_value, 0.9963034765672599, 1e-12)); assert!(prec::almost_eq(odds_ratio, 7.5, 1e-1)); } } statrs-0.18.0/src/stats_tests/mod.rs000064400000000000000000000007241046102023000155310ustar 00000000000000pub mod fisher; /// Specifies an [alternative hypothesis](https://en.wikipedia.org/wiki/Alternative_hypothesis) #[derive(Debug, Copy, Clone)] pub enum Alternative { #[doc(alias = "two-tailed")] #[doc(alias = "two tailed")] TwoSided, #[doc(alias = "one-tailed")] #[doc(alias = "one tailed")] Less, #[doc(alias = "one-tailed")] #[doc(alias = "one tailed")] Greater, } pub use fisher::{fishers_exact, fishers_exact_with_odds_ratio}; statrs-0.18.0/tests/gather_nist_data.sh000075500000000000000000000022011046102023000162260ustar 00000000000000#! /bin/bash # this script is to download and preprocess datafiles for the nist_tests.rs # integration test for statrs downloads data to directory specified by env # var STATRS_NIST_DATA_DIR process_file() { # Define input and output file names SOURCE=$1 FILENAME=$2 TARGET=${STATRS_NIST_DATA_DIR-tests}/${FILENAME} echo -e ${FILENAME} '\n\tDownloading...' curl -fsSL ${SOURCE}/$FILENAME > ${TARGET} # Extract line numbers for Certified Values and Data from the header INFO=$(grep "Certified Values:" $TARGET) CERTIFIED_VALUES_START=$(echo $INFO | awk '{print $4}') CERTIFIED_VALUES_END=$(echo $INFO | awk '{print $6}') INFO=$(grep "Data :" $TARGET) DATA_START=$(echo $INFO | awk '{print $4}') DATA_END=$(echo $INFO | awk '{print $6}') echo -e '\tFormatting...' # Extract and reformat sections sed -n -i \ -e "${CERTIFIED_VALUES_START},${CERTIFIED_VALUES_END}p" \ -e "${DATA_START},${DATA_END}p" \ $TARGET } URL='https://www.itl.nist.gov/div898/strd/univ/data' for file in Lottery.dat Lew.dat Mavro.dat Michelso.dat NumAcc1.dat NumAcc2.dat NumAcc3.dat do process_file $URL $file done statrs-0.18.0/tests/nist_tests.rs000064400000000000000000000100331046102023000151360ustar 00000000000000//! This test relies on data that is reusable but not distributable by statrs as //! such, the data will need to be downloaded from the relevant NIST StRD dataset //! the parsing for testing assumes data to be of form, //! ```text //! sample mean : //! sample std_dev : //! sample correlation: //! [zero or more blank lines] //! data0 //! data1 //! data2 //! ... //! ``` //! This test can be run on it's own from the shell from this folder as //! ```sh //! ./gather_nist_data.sh && cargo test -- --ignored nist_ //! ``` use anyhow::Result; use approx::assert_relative_eq; use statrs::statistics::Statistics; use std::io::{BufRead, BufReader}; use std::path::PathBuf; use std::{env, fs}; struct TestCase { certified: CertifiedValues, values: Vec, } impl std::fmt::Debug for TestCase { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "TestCase({:?}, [...]", self.certified) } } #[derive(Debug)] struct CertifiedValues { mean: f64, std_dev: f64, corr: f64, } impl std::fmt::Display for CertifiedValues { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "μ={:.3e}, σ={:.3e}, r={:.3e}", self.mean, self.std_dev, self.corr ) } } const NIST_DATA_DIR_ENV: &str = "STATRS_NIST_DATA_DIR"; const FILENAMES: [&str; 7] = [ "Lottery.dat", "Lew.dat", "Mavro.dat", "Michelso.dat", "NumAcc1.dat", "NumAcc2.dat", "NumAcc3.dat", ]; fn get_path(fname: &str, prefix: Option<&str>) -> PathBuf { if let Some(prefix) = prefix { [prefix, fname].iter().collect() } else { ["tests", fname].iter().collect() } } #[test] #[ignore = "NIST tests should not run from typical `cargo test` calls"] fn nist_strd_univariate_mean() { for fname in FILENAMES { let filepath = get_path(fname, env::var(NIST_DATA_DIR_ENV).ok().as_deref()); let case = parse_file(filepath) .unwrap_or_else(|e| panic!("failed parsing file {fname} with `{e:?}`")); assert_relative_eq!(case.values.mean(), case.certified.mean, epsilon = 1e-12); } } #[test] #[ignore] fn nist_strd_univariate_std_dev() { for fname in FILENAMES { let filepath = get_path(fname, env::var(NIST_DATA_DIR_ENV).ok().as_deref()); let case = parse_file(filepath) .unwrap_or_else(|e| panic!("failed parsing file {fname} with `{e:?}`")); assert_relative_eq!( case.values.std_dev(), case.certified.std_dev, epsilon = 1e-10 ); } } fn parse_certified_value(line: String) -> Result { line.chars() .skip_while(|&c| c != ':') .skip(1) // skip through ':' delimiter .skip_while(|&c| c.is_whitespace()) // effectively `String` trim .take_while(|&c| matches!(c, '0'..='9' | '-' | '.')) .collect::() .parse::() .map_err(|e| e.into()) } fn parse_file(path: impl AsRef) -> anyhow::Result { let f = fs::File::open(path)?; let reader = BufReader::new(f); let mut lines = reader.lines(); let mean = parse_certified_value(lines.next().expect("file should not be exhausted")?)?; let std_dev = parse_certified_value(lines.next().expect("file should not be exhausted")?)?; let corr = parse_certified_value(lines.next().expect("file should not be exhausted")?)?; Ok(TestCase { certified: CertifiedValues { mean, std_dev, corr, }, values: lines .map_while(|line| line.ok()?.trim().parse().ok()) .collect(), }) } #[test] #[ignore = "NIST tests should not run from typical `cargo test` calls"] fn nist_test_covariance_consistent_with_variance() {} #[test] #[ignore = "NIST tests should not run from typical `cargo test` calls"] fn nist_test_covariance_is_symmetric() {}