ruma-macros-0.10.5/.cargo_vcs_info.json0000644000000001600000000000100133420ustar { "git": { "sha1": "67b2ec7d34eb35e47c7bf1d0da0e6326049179ac" }, "path_in_vcs": "crates/ruma-macros" }ruma-macros-0.10.5/Cargo.toml0000644000000025450000000000100113510ustar # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO # # When uploading crates to the registry Cargo will automatically # "normalize" Cargo.toml files for maximal compatibility # with all versions of Cargo and also rewrite `path` dependencies # to registry (e.g., crates.io) dependencies. # # If you are reading this file be aware that the original Cargo.toml # will likely look very different (and much more reasonable). # See Cargo.toml.orig for the original contents. [package] edition = "2021" rust-version = "1.60" name = "ruma-macros" version = "0.10.5" description = "Procedural macros used by the Ruma crates." homepage = "https://www.ruma.io/" readme = "README.md" keywords = [ "matrix", "chat", "messaging", "ruma", ] categories = [ "api-bindings", "web-programming", ] license = "MIT" repository = "https://github.com/ruma/ruma" [lib] proc-macro = true [dependencies.once_cell] version = "1.13.0" [dependencies.proc-macro-crate] version = "1.0.0" [dependencies.proc-macro2] version = "1.0.24" [dependencies.quote] version = "1.0.8" [dependencies.ruma-identifiers-validation] version = "0.9.0" default-features = false [dependencies.serde] version = "1.0.139" features = ["derive"] [dependencies.syn] version = "1.0.57" features = [ "extra-traits", "full", "visit", ] [dependencies.toml] version = "0.5.9" [features] compat = [] ruma-macros-0.10.5/Cargo.toml.orig000064400000000000000000000014131046102023000150230ustar 00000000000000[package] categories = ["api-bindings", "web-programming"] description = "Procedural macros used by the Ruma crates." homepage = "https://www.ruma.io/" keywords = ["matrix", "chat", "messaging", "ruma"] license = "MIT" name = "ruma-macros" readme = "README.md" repository = "https://github.com/ruma/ruma" version = "0.10.5" edition = "2021" rust-version = "1.60" [lib] proc-macro = true [features] compat = [] [dependencies] once_cell = "1.13.0" proc-macro-crate = "1.0.0" proc-macro2 = "1.0.24" quote = "1.0.8" ruma-identifiers-validation = { version = "0.9.0", path = "../ruma-identifiers-validation", default-features = false } serde = { version = "1.0.139", features = ["derive"] } syn = { version = "1.0.57", features = ["extra-traits", "full", "visit"] } toml = "0.5.9" ruma-macros-0.10.5/README.md000064400000000000000000000002201046102023000134060ustar 00000000000000# ruma-macros **ruma-macros** provides procedural macros for easily generating types for [Ruma] crates. [Ruma]: https://github.com/ruma/ruma/ ruma-macros-0.10.5/src/api/api_metadata.rs000064400000000000000000000254031046102023000164600ustar 00000000000000//! Details of the `metadata` section of the procedural macro. use quote::ToTokens; use syn::{ braced, parse::{Parse, ParseStream}, Ident, LitBool, LitStr, Token, }; use super::{auth_scheme::AuthScheme, util, version::MatrixVersionLiteral}; mod kw { syn::custom_keyword!(metadata); syn::custom_keyword!(description); syn::custom_keyword!(method); syn::custom_keyword!(name); syn::custom_keyword!(unstable_path); syn::custom_keyword!(r0_path); syn::custom_keyword!(stable_path); syn::custom_keyword!(rate_limited); syn::custom_keyword!(authentication); syn::custom_keyword!(added); syn::custom_keyword!(deprecated); syn::custom_keyword!(removed); } /// The result of processing the `metadata` section of the macro. pub struct Metadata { /// The description field. pub description: LitStr, /// The method field. pub method: Ident, /// The name field. pub name: LitStr, /// The unstable path field. pub unstable_path: Option, /// The pre-v1.1 path field. pub r0_path: Option, /// The stable path field. pub stable_path: Option, /// The rate_limited field. pub rate_limited: LitBool, /// The authentication field. pub authentication: AuthScheme, /// The added field. pub added: Option, /// The deprecated field. pub deprecated: Option, /// The removed field. pub removed: Option, } fn set_field(field: &mut Option, value: T) -> syn::Result<()> { match field { Some(existing_value) => { let mut error = syn::Error::new_spanned(value, "duplicate field assignment"); error.combine(syn::Error::new_spanned(existing_value, "first one here")); Err(error) } None => { *field = Some(value); Ok(()) } } } impl Parse for Metadata { fn parse(input: ParseStream<'_>) -> syn::Result { let metadata_kw: kw::metadata = input.parse()?; let _: Token![:] = input.parse()?; let field_values; braced!(field_values in input); let field_values = field_values.parse_terminated::(FieldValue::parse)?; let mut description = None; let mut method = None; let mut name = None; let mut unstable_path = None; let mut r0_path = None; let mut stable_path = None; let mut rate_limited = None; let mut authentication = None; let mut added = None; let mut deprecated = None; let mut removed = None; for field_value in field_values { match field_value { FieldValue::Description(d) => set_field(&mut description, d)?, FieldValue::Method(m) => set_field(&mut method, m)?, FieldValue::Name(n) => set_field(&mut name, n)?, FieldValue::UnstablePath(p) => set_field(&mut unstable_path, p)?, FieldValue::R0Path(p) => set_field(&mut r0_path, p)?, FieldValue::StablePath(p) => set_field(&mut stable_path, p)?, FieldValue::RateLimited(rl) => set_field(&mut rate_limited, rl)?, FieldValue::Authentication(a) => set_field(&mut authentication, a)?, FieldValue::Added(v) => set_field(&mut added, v)?, FieldValue::Deprecated(v) => set_field(&mut deprecated, v)?, FieldValue::Removed(v) => set_field(&mut removed, v)?, } } let missing_field = |name| syn::Error::new_spanned(metadata_kw, format!("missing field `{}`", name)); let stable_or_r0 = stable_path.as_ref().or(r0_path.as_ref()); if let Some(path) = stable_or_r0 { if added.is_none() { return Err(syn::Error::new_spanned( path, "stable path was defined, while `added` version was not defined", )); } } if let Some(deprecated) = &deprecated { if added.is_none() { return Err(syn::Error::new_spanned( deprecated, "deprecated version is defined while added version is not defined", )); } } // note: It is possible that matrix will remove endpoints in a single version, while not // having a deprecation version inbetween, but that would not be allowed by their own // deprecation policy, so lets just assume there's always a deprecation version before a // removal one. // // If matrix does so anyways, we can just alter this. if let Some(removed) = &removed { if deprecated.is_none() { return Err(syn::Error::new_spanned( removed, "removed version is defined while deprecated version is not defined", )); } } if let Some(added) = &added { if stable_or_r0.is_none() { return Err(syn::Error::new_spanned( added, "added version is defined, but no stable or r0 path exists", )); } } if let Some(r0) = &r0_path { let added = added.as_ref().expect("we error if r0 or stable is defined without added"); if added.major.get() == 1 && added.minor > 0 { return Err(syn::Error::new_spanned( r0, "r0 defined while added version is newer than v1.0", )); } if stable_path.is_none() { return Err(syn::Error::new_spanned(r0, "r0 defined without stable path")); } if !r0.value().contains("/r0/") { return Err(syn::Error::new_spanned(r0, "r0 endpoint does not contain /r0/")); } } if let Some(stable) = &stable_path { if stable.value().contains("/r0/") { return Err(syn::Error::new_spanned( stable, "stable endpoint contains /r0/ (did you make a copy-paste error?)", )); } } if unstable_path.is_none() && r0_path.is_none() && stable_path.is_none() { return Err(syn::Error::new_spanned( metadata_kw, "need to define one of [r0_path, stable_path, unstable_path]", )); } Ok(Self { description: description.ok_or_else(|| missing_field("description"))?, method: method.ok_or_else(|| missing_field("method"))?, name: name.ok_or_else(|| missing_field("name"))?, unstable_path, r0_path, stable_path, rate_limited: rate_limited.ok_or_else(|| missing_field("rate_limited"))?, authentication: authentication.ok_or_else(|| missing_field("authentication"))?, added, deprecated, removed, }) } } enum Field { Description, Method, Name, UnstablePath, R0Path, StablePath, RateLimited, Authentication, Added, Deprecated, Removed, } impl Parse for Field { fn parse(input: ParseStream<'_>) -> syn::Result { let lookahead = input.lookahead1(); if lookahead.peek(kw::description) { let _: kw::description = input.parse()?; Ok(Self::Description) } else if lookahead.peek(kw::method) { let _: kw::method = input.parse()?; Ok(Self::Method) } else if lookahead.peek(kw::name) { let _: kw::name = input.parse()?; Ok(Self::Name) } else if lookahead.peek(kw::unstable_path) { let _: kw::unstable_path = input.parse()?; Ok(Self::UnstablePath) } else if lookahead.peek(kw::r0_path) { let _: kw::r0_path = input.parse()?; Ok(Self::R0Path) } else if lookahead.peek(kw::stable_path) { let _: kw::stable_path = input.parse()?; Ok(Self::StablePath) } else if lookahead.peek(kw::rate_limited) { let _: kw::rate_limited = input.parse()?; Ok(Self::RateLimited) } else if lookahead.peek(kw::authentication) { let _: kw::authentication = input.parse()?; Ok(Self::Authentication) } else if lookahead.peek(kw::added) { let _: kw::added = input.parse()?; Ok(Self::Added) } else if lookahead.peek(kw::deprecated) { let _: kw::deprecated = input.parse()?; Ok(Self::Deprecated) } else if lookahead.peek(kw::removed) { let _: kw::removed = input.parse()?; Ok(Self::Removed) } else { Err(lookahead.error()) } } } enum FieldValue { Description(LitStr), Method(Ident), Name(LitStr), UnstablePath(EndpointPath), R0Path(EndpointPath), StablePath(EndpointPath), RateLimited(LitBool), Authentication(AuthScheme), Added(MatrixVersionLiteral), Deprecated(MatrixVersionLiteral), Removed(MatrixVersionLiteral), } impl Parse for FieldValue { fn parse(input: ParseStream<'_>) -> syn::Result { let field: Field = input.parse()?; let _: Token![:] = input.parse()?; Ok(match field { Field::Description => Self::Description(input.parse()?), Field::Method => Self::Method(input.parse()?), Field::Name => Self::Name(input.parse()?), Field::UnstablePath => Self::UnstablePath(input.parse()?), Field::R0Path => Self::R0Path(input.parse()?), Field::StablePath => Self::StablePath(input.parse()?), Field::RateLimited => Self::RateLimited(input.parse()?), Field::Authentication => Self::Authentication(input.parse()?), Field::Added => Self::Added(input.parse()?), Field::Deprecated => Self::Deprecated(input.parse()?), Field::Removed => Self::Removed(input.parse()?), }) } } #[derive(Clone)] pub struct EndpointPath(LitStr); impl EndpointPath { pub fn value(&self) -> String { self.0.value() } } impl Parse for EndpointPath { fn parse(input: ParseStream<'_>) -> syn::Result { let path: LitStr = input.parse()?; if util::is_valid_endpoint_path(&path.value()) { Ok(Self(path)) } else { Err(syn::Error::new_spanned( &path, "path may only contain printable ASCII characters with no spaces", )) } } } impl ToTokens for EndpointPath { fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { self.0.to_tokens(tokens); } } ruma-macros-0.10.5/src/api/api_request.rs000064400000000000000000000102611046102023000163640ustar 00000000000000//! Details of the `request` section of the procedural macro. use std::collections::btree_map::{BTreeMap, Entry}; use proc_macro2::TokenStream; use quote::quote; use syn::{ parse_quote, punctuated::Punctuated, spanned::Spanned, visit::Visit, Attribute, Field, Ident, Lifetime, Token, }; use super::{ api_metadata::Metadata, kw, util::{all_cfgs, extract_cfg}, }; /// The result of processing the `request` section of the macro. pub(crate) struct Request { /// The `request` keyword pub(super) request_kw: kw::request, /// The attributes that will be applied to the struct definition. pub(super) attributes: Vec, /// The fields of the request. pub(super) fields: Punctuated, } impl Request { /// The combination of every fields unique lifetime annotation. fn all_lifetimes(&self) -> BTreeMap> { let mut lifetimes = BTreeMap::new(); struct Visitor<'lt> { field_cfg: Option, lifetimes: &'lt mut BTreeMap>, } impl<'ast> Visit<'ast> for Visitor<'_> { fn visit_lifetime(&mut self, lt: &'ast Lifetime) { match self.lifetimes.entry(lt.clone()) { Entry::Vacant(v) => { v.insert(self.field_cfg.clone()); } Entry::Occupied(mut o) => { let lifetime_cfg = o.get_mut(); // If at least one field uses this lifetime and has no cfg attribute, we // don't need a cfg attribute for the lifetime either. *lifetime_cfg = Option::zip(lifetime_cfg.as_ref(), self.field_cfg.as_ref()) .map(|(a, b)| { let expr_a = extract_cfg(a); let expr_b = extract_cfg(b); parse_quote! { #[cfg( any( #expr_a, #expr_b ) )] } }); } } } } for field in &self.fields { let field_cfg = if field.attrs.is_empty() { None } else { all_cfgs(&field.attrs) }; Visitor { lifetimes: &mut lifetimes, field_cfg }.visit_type(&field.ty); } lifetimes } pub(super) fn expand( &self, metadata: &Metadata, error_ty: &TokenStream, ruma_common: &TokenStream, ) -> TokenStream { let ruma_macros = quote! { #ruma_common::exports::ruma_macros }; let docs = format!( "Data for a request to the `{}` API endpoint.\n\n{}", metadata.name.value(), metadata.description.value(), ); let struct_attributes = &self.attributes; let method = &metadata.method; let authentication = &metadata.authentication; let unstable_attr = metadata.unstable_path.as_ref().map(|p| quote! { unstable = #p, }); let r0_attr = metadata.r0_path.as_ref().map(|p| quote! { r0 = #p, }); let stable_attr = metadata.stable_path.as_ref().map(|p| quote! { stable = #p, }); let request_ident = Ident::new("Request", self.request_kw.span()); let lifetimes = self.all_lifetimes(); let lifetimes = lifetimes.iter().map(|(lt, attr)| quote! { #attr #lt }); let fields = &self.fields; quote! { #[doc = #docs] #[derive( Clone, Debug, #ruma_macros::Request, #ruma_common::serde::Incoming, #ruma_common::serde::_FakeDeriveSerde, )] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] #[incoming_derive(!Deserialize, #ruma_macros::_FakeDeriveRumaApi)] #[ruma_api( method = #method, authentication = #authentication, #unstable_attr #r0_attr #stable_attr error_ty = #error_ty, )] #( #struct_attributes )* pub struct #request_ident < #(#lifetimes),* > { #fields } } } } ruma-macros-0.10.5/src/api/api_response.rs000064400000000000000000000033161046102023000165350ustar 00000000000000//! Details of the `response` section of the procedural macro. use proc_macro2::TokenStream; use quote::quote; use syn::{punctuated::Punctuated, spanned::Spanned, Attribute, Field, Ident, Token}; use super::{api_metadata::Metadata, kw}; /// The result of processing the `response` section of the macro. pub(crate) struct Response { /// The `response` keyword pub(super) response_kw: kw::response, /// The attributes that will be applied to the struct definition. pub attributes: Vec, /// The fields of the response. pub fields: Punctuated, } impl Response { pub(super) fn expand( &self, metadata: &Metadata, error_ty: &TokenStream, ruma_common: &TokenStream, ) -> TokenStream { let ruma_macros = quote! { #ruma_common::exports::ruma_macros }; let docs = format!("Data in the response from the `{}` API endpoint.", metadata.name.value()); let struct_attributes = &self.attributes; let response_ident = Ident::new("Response", self.response_kw.span()); let fields = &self.fields; quote! { #[doc = #docs] #[derive( Clone, Debug, #ruma_macros::Response, #ruma_common::serde::Incoming, #ruma_common::serde::_FakeDeriveSerde, )] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] #[incoming_derive(!Deserialize, #ruma_macros::_FakeDeriveRumaApi)] #[ruma_api(error_ty = #error_ty)] #( #struct_attributes )* pub struct #response_ident { #fields } } } } ruma-macros-0.10.5/src/api/attribute.rs000064400000000000000000000110061046102023000160440ustar 00000000000000//! Details of the `#[ruma_api(...)]` attributes. use syn::{ parse::{Parse, ParseStream}, Ident, LitStr, Token, Type, }; mod kw { syn::custom_keyword!(body); syn::custom_keyword!(raw_body); syn::custom_keyword!(path); syn::custom_keyword!(query); syn::custom_keyword!(query_map); syn::custom_keyword!(header); syn::custom_keyword!(authentication); syn::custom_keyword!(method); syn::custom_keyword!(error_ty); syn::custom_keyword!(unstable); syn::custom_keyword!(r0); syn::custom_keyword!(stable); syn::custom_keyword!(manual_body_serde); } pub enum RequestMeta { NewtypeBody, RawBody, Path, Query, QueryMap, Header(Ident), } impl Parse for RequestMeta { fn parse(input: ParseStream<'_>) -> syn::Result { let lookahead = input.lookahead1(); if lookahead.peek(kw::body) { let _: kw::body = input.parse()?; Ok(Self::NewtypeBody) } else if lookahead.peek(kw::raw_body) { let _: kw::raw_body = input.parse()?; Ok(Self::RawBody) } else if lookahead.peek(kw::path) { let _: kw::path = input.parse()?; Ok(Self::Path) } else if lookahead.peek(kw::query) { let _: kw::query = input.parse()?; Ok(Self::Query) } else if lookahead.peek(kw::query_map) { let _: kw::query_map = input.parse()?; Ok(Self::QueryMap) } else if lookahead.peek(kw::header) { let _: kw::header = input.parse()?; let _: Token![=] = input.parse()?; input.parse().map(Self::Header) } else { Err(lookahead.error()) } } } pub enum DeriveRequestMeta { Authentication(Type), Method(Type), ErrorTy(Type), UnstablePath(LitStr), R0Path(LitStr), StablePath(LitStr), } impl Parse for DeriveRequestMeta { fn parse(input: ParseStream<'_>) -> syn::Result { let lookahead = input.lookahead1(); if lookahead.peek(kw::authentication) { let _: kw::authentication = input.parse()?; let _: Token![=] = input.parse()?; input.parse().map(Self::Authentication) } else if lookahead.peek(kw::method) { let _: kw::method = input.parse()?; let _: Token![=] = input.parse()?; input.parse().map(Self::Method) } else if lookahead.peek(kw::error_ty) { let _: kw::error_ty = input.parse()?; let _: Token![=] = input.parse()?; input.parse().map(Self::ErrorTy) } else if lookahead.peek(kw::unstable) { let _: kw::unstable = input.parse()?; let _: Token![=] = input.parse()?; input.parse().map(Self::UnstablePath) } else if lookahead.peek(kw::r0) { let _: kw::r0 = input.parse()?; let _: Token![=] = input.parse()?; input.parse().map(Self::R0Path) } else if lookahead.peek(kw::stable) { let _: kw::stable = input.parse()?; let _: Token![=] = input.parse()?; input.parse().map(Self::StablePath) } else { Err(lookahead.error()) } } } pub enum ResponseMeta { NewtypeBody, RawBody, Header(Ident), } impl Parse for ResponseMeta { fn parse(input: ParseStream<'_>) -> syn::Result { let lookahead = input.lookahead1(); if lookahead.peek(kw::body) { let _: kw::body = input.parse()?; Ok(Self::NewtypeBody) } else if lookahead.peek(kw::raw_body) { let _: kw::raw_body = input.parse()?; Ok(Self::RawBody) } else if lookahead.peek(kw::header) { let _: kw::header = input.parse()?; let _: Token![=] = input.parse()?; input.parse().map(Self::Header) } else { Err(lookahead.error()) } } } #[allow(clippy::large_enum_variant)] pub enum DeriveResponseMeta { ManualBodySerde, ErrorTy(Type), } impl Parse for DeriveResponseMeta { fn parse(input: ParseStream<'_>) -> syn::Result { let lookahead = input.lookahead1(); if lookahead.peek(kw::manual_body_serde) { let _: kw::manual_body_serde = input.parse()?; Ok(Self::ManualBodySerde) } else if lookahead.peek(kw::error_ty) { let _: kw::error_ty = input.parse()?; let _: Token![=] = input.parse()?; input.parse().map(Self::ErrorTy) } else { Err(lookahead.error()) } } } ruma-macros-0.10.5/src/api/auth_scheme.rs000064400000000000000000000026441046102023000163360ustar 00000000000000use proc_macro2::TokenStream; use quote::ToTokens; use syn::parse::{Parse, ParseStream}; mod kw { syn::custom_keyword!(None); syn::custom_keyword!(AccessToken); syn::custom_keyword!(ServerSignatures); syn::custom_keyword!(QueryOnlyAccessToken); } pub enum AuthScheme { None(kw::None), AccessToken(kw::AccessToken), ServerSignatures(kw::ServerSignatures), QueryOnlyAccessToken(kw::QueryOnlyAccessToken), } impl Parse for AuthScheme { fn parse(input: ParseStream<'_>) -> syn::Result { let lookahead = input.lookahead1(); if lookahead.peek(kw::None) { input.parse().map(Self::None) } else if lookahead.peek(kw::AccessToken) { input.parse().map(Self::AccessToken) } else if lookahead.peek(kw::ServerSignatures) { input.parse().map(Self::ServerSignatures) } else if lookahead.peek(kw::QueryOnlyAccessToken) { input.parse().map(Self::QueryOnlyAccessToken) } else { Err(lookahead.error()) } } } impl ToTokens for AuthScheme { fn to_tokens(&self, tokens: &mut TokenStream) { match self { AuthScheme::None(kw) => kw.to_tokens(tokens), AuthScheme::AccessToken(kw) => kw.to_tokens(tokens), AuthScheme::ServerSignatures(kw) => kw.to_tokens(tokens), AuthScheme::QueryOnlyAccessToken(kw) => kw.to_tokens(tokens), } } } ruma-macros-0.10.5/src/api/request/incoming.rs000064400000000000000000000226741046102023000173510ustar 00000000000000use proc_macro2::TokenStream; use quote::quote; use syn::Field; use super::{Request, RequestField}; use crate::api::auth_scheme::AuthScheme; impl Request { pub fn expand_incoming(&self, ruma_common: &TokenStream) -> TokenStream { let http = quote! { #ruma_common::exports::http }; let serde = quote! { #ruma_common::exports::serde }; let serde_json = quote! { #ruma_common::exports::serde_json }; let method = &self.method; let error_ty = &self.error_ty; let incoming_request_type = if self.has_lifetimes() { quote! { IncomingRequest } } else { quote! { Request } }; // FIXME: the rest of the field initializer expansions are gated `cfg(...)` // except this one. If we get errors about missing fields in IncomingRequest for // a path field look here. let (parse_request_path, path_vars) = if self.has_path_fields() { let path_vars: Vec<_> = self.path_fields_ordered().filter_map(|f| f.ident.as_ref()).collect(); let parse_request_path = quote! { let (#(#path_vars,)*) = #serde::Deserialize::deserialize( #serde::de::value::SeqDeserializer::<_, #serde::de::value::Error>::new( path_args.iter().map(::std::convert::AsRef::as_ref) ) )?; }; (parse_request_path, quote! { #(#path_vars,)* }) } else { (TokenStream::new(), TokenStream::new()) }; let (parse_query, query_vars) = if let Some(field) = self.query_map_field() { let cfg_attrs = field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::>(); let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let parse = quote! { #( #cfg_attrs )* let #field_name = #ruma_common::serde::urlencoded::from_str( &request.uri().query().unwrap_or(""), )?; }; ( parse, quote! { #( #cfg_attrs )* #field_name, }, ) } else if self.has_query_fields() { let (decls, names) = vars( self.fields.iter().filter_map(RequestField::as_query_field), quote! { request_query }, ); let request_query_ty = if self.lifetimes.query.is_empty() { quote! { RequestQuery } } else { quote! { IncomingRequestQuery } }; let parse = quote! { let request_query: #request_query_ty = #ruma_common::serde::urlencoded::from_str( &request.uri().query().unwrap_or("") )?; #decls }; (parse, names) } else { (TokenStream::new(), TokenStream::new()) }; let (parse_headers, header_vars) = if self.has_header_fields() { let (decls, names): (TokenStream, Vec<_>) = self .header_fields() .map(|request_field| { let (field, header_name) = match request_field { RequestField::Header(field, header_name) => (field, header_name), _ => panic!("expected request field to be header variant"), }; let cfg_attrs = field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::>(); let field_name = &field.ident; let header_name_string = header_name.to_string(); let (some_case, none_case) = match &field.ty { syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) if segments.last().unwrap().ident == "Option" => { (quote! { Some(str_value.to_owned()) }, quote! { None }) } _ => ( quote! { str_value.to_owned() }, quote! { return Err( #ruma_common::api::error::HeaderDeserializationError::MissingHeader( #header_name_string.into() ).into(), ) }, ), }; let decl = quote! { #( #cfg_attrs )* let #field_name = match headers.get(#http::header::#header_name) { Some(header_value) => { let str_value = header_value.to_str()?; #some_case } None => #none_case, }; }; ( decl, quote! { #( #cfg_attrs )* #field_name }, ) }) .unzip(); let parse = quote! { let headers = request.headers(); #decls }; (parse, quote! { #(#names,)* }) } else { (TokenStream::new(), TokenStream::new()) }; let extract_body = self.has_body_fields().then(|| { let request_body_ty = if self.lifetimes.body.is_empty() { quote! { RequestBody } } else { quote! { IncomingRequestBody } }; quote! { let request_body: #request_body_ty = { let body = ::std::convert::AsRef::<[::std::primitive::u8]>::as_ref( request.body(), ); #serde_json::from_slice(match body { // If the request body is completely empty, pretend it is an empty JSON // object instead. This allows requests with only optional body parameters // to be deserialized in that case. [] => b"{}", b => b, })? }; } }); let (parse_body, body_vars) = if let Some(field) = self.raw_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let parse = quote! { let #field_name = ::std::convert::AsRef::<[u8]>::as_ref(request.body()).to_vec(); }; (parse, quote! { #field_name, }) } else { vars(self.body_fields(), quote! { request_body }) }; let non_auth_impl = matches!(self.authentication, AuthScheme::None(_)).then(|| { quote! { #[automatically_derived] #[cfg(feature = "server")] impl #ruma_common::api::IncomingNonAuthRequest for #incoming_request_type {} } }); quote! { #[automatically_derived] #[cfg(feature = "server")] impl #ruma_common::api::IncomingRequest for #incoming_request_type { type EndpointError = #error_ty; type OutgoingResponse = Response; const METADATA: #ruma_common::api::Metadata = self::METADATA; fn try_from_http_request( request: #http::Request, path_args: &[S], ) -> ::std::result::Result where B: ::std::convert::AsRef<[::std::primitive::u8]>, S: ::std::convert::AsRef<::std::primitive::str>, { if request.method() != #http::Method::#method { return Err(#ruma_common::api::error::FromHttpRequestError::MethodMismatch { expected: #http::Method::#method, received: request.method().clone(), }); } #parse_request_path #parse_query #parse_headers #extract_body #parse_body ::std::result::Result::Ok(Self { #path_vars #query_vars #header_vars #body_vars }) } } #non_auth_impl } } } fn vars<'a>( fields: impl IntoIterator, src: TokenStream, ) -> (TokenStream, TokenStream) { fields .into_iter() .map(|field| { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let cfg_attrs = field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::>(); let decl = quote! { #( #cfg_attrs )* let #field_name = #src.#field_name; }; ( decl, quote! { #( #cfg_attrs )* #field_name, }, ) }) .unzip() } ruma-macros-0.10.5/src/api/request/outgoing.rs000064400000000000000000000236371046102023000174010ustar 00000000000000use proc_macro2::TokenStream; use quote::quote; use syn::{Field, LitStr}; use super::{Request, RequestField}; use crate::api::{auth_scheme::AuthScheme, util}; impl Request { pub fn expand_outgoing(&self, ruma_common: &TokenStream) -> TokenStream { let bytes = quote! { #ruma_common::exports::bytes }; let http = quote! { #ruma_common::exports::http }; let percent_encoding = quote! { #ruma_common::exports::percent_encoding }; let method = &self.method; let error_ty = &self.error_ty; let (unstable_path, r0_path, stable_path) = if self.has_path_fields() { let path_format_args_call_with_percent_encoding = |s: &LitStr| -> TokenStream { util::path_format_args_call(s.value(), &percent_encoding) }; ( self.unstable_path.as_ref().map(path_format_args_call_with_percent_encoding), self.r0_path.as_ref().map(path_format_args_call_with_percent_encoding), self.stable_path.as_ref().map(path_format_args_call_with_percent_encoding), ) } else { ( self.unstable_path.as_ref().map(|path| quote! { format_args!(#path) }), self.r0_path.as_ref().map(|path| quote! { format_args!(#path) }), self.stable_path.as_ref().map(|path| quote! { format_args!(#path) }), ) }; let unstable_path = util::map_option_literal(&unstable_path); let r0_path = util::map_option_literal(&r0_path); let stable_path = util::map_option_literal(&stable_path); let request_query_string = if let Some(field) = self.query_map_field() { let field_name = field.ident.as_ref().expect("expected field to have identifier"); quote! {{ // This function exists so that the compiler will throw an error when the type of // the field with the query_map attribute doesn't implement // `IntoIterator`. // // This is necessary because the `ruma_common::serde::urlencoded::to_string` call will // result in a runtime error when the type cannot be encoded as a list key-value // pairs (?key1=value1&key2=value2). // // By asserting that it implements the iterator trait, we can ensure that it won't // fail. fn assert_trait_impl(_: &T) where T: ::std::iter::IntoIterator< Item = (::std::string::String, ::std::string::String), >, {} let request_query = RequestQuery(self.#field_name); assert_trait_impl(&request_query.0); format_args!( "?{}", #ruma_common::serde::urlencoded::to_string(request_query)? ) }} } else if self.has_query_fields() { let request_query_init_fields = struct_init_fields( self.fields.iter().filter_map(RequestField::as_query_field), quote! { self }, ); quote! {{ let request_query = RequestQuery { #request_query_init_fields }; format_args!( "?{}", #ruma_common::serde::urlencoded::to_string(request_query)? ) }} } else { quote! { "" } }; // If there are no body fields, the request body will be empty (not `{}`), so the // `application/json` content-type would be wrong. It may also cause problems with CORS // policies that don't allow the `Content-Type` header (for things such as `.well-known` // that are commonly handled by something else than a homeserver). let mut header_kvs = if self.raw_body_field().is_some() || self.has_body_fields() { quote! { req_headers.insert( #http::header::CONTENT_TYPE, #http::header::HeaderValue::from_static("application/json"), ); } } else { TokenStream::new() }; header_kvs.extend(self.header_fields().map(|request_field| { let (field, header_name) = match request_field { RequestField::Header(field, header_name) => (field, header_name), _ => unreachable!("expected request field to be header variant"), }; let field_name = &field.ident; match &field.ty { syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) if segments.last().unwrap().ident == "Option" => { quote! { if let Some(header_val) = self.#field_name.as_ref() { req_headers.insert( #http::header::#header_name, #http::header::HeaderValue::from_str(header_val)?, ); } } } _ => quote! { req_headers.insert( #http::header::#header_name, #http::header::HeaderValue::from_str(self.#field_name.as_ref())?, ); }, } })); header_kvs.extend(match self.authentication { AuthScheme::AccessToken(_) => quote! { req_headers.insert( #http::header::AUTHORIZATION, ::std::convert::TryFrom::<_>::try_from(::std::format!( "Bearer {}", access_token .get_required_for_endpoint() .ok_or(#ruma_common::api::error::IntoHttpError::NeedsAuthentication)?, ))?, ); }, AuthScheme::None(_) => quote! { if let Some(access_token) = access_token.get_not_required_for_endpoint() { req_headers.insert( #http::header::AUTHORIZATION, ::std::convert::TryFrom::<_>::try_from( ::std::format!("Bearer {}", access_token), )? ); } }, AuthScheme::QueryOnlyAccessToken(_) | AuthScheme::ServerSignatures(_) => quote! {}, }); let request_body = if let Some(field) = self.raw_body_field() { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); quote! { #ruma_common::serde::slice_to_buf(&self.#field_name) } } else if self.has_body_fields() { let initializers = struct_init_fields(self.body_fields(), quote! { self }); quote! { #ruma_common::serde::json_to_buf(&RequestBody { #initializers })? } } else if method == "GET" { quote! { ::default() } } else { quote! { #ruma_common::serde::slice_to_buf(b"{}") } }; let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); let non_auth_impl = matches!(self.authentication, AuthScheme::None(_)).then(|| { quote! { #[automatically_derived] #[cfg(feature = "client")] impl #impl_generics #ruma_common::api::OutgoingNonAuthRequest for Request #ty_generics #where_clause {} } }); quote! { #[automatically_derived] #[cfg(feature = "client")] impl #impl_generics #ruma_common::api::OutgoingRequest for Request #ty_generics #where_clause { type EndpointError = #error_ty; type IncomingResponse = Response; const METADATA: #ruma_common::api::Metadata = self::METADATA; fn try_into_http_request( self, base_url: &::std::primitive::str, access_token: #ruma_common::api::SendAccessToken<'_>, considering_versions: &'_ [#ruma_common::api::MatrixVersion], ) -> ::std::result::Result<#http::Request, #ruma_common::api::error::IntoHttpError> { let metadata = self::METADATA; let mut req_builder = #http::Request::builder() .method(#http::Method::#method) .uri(::std::format!( "{}{}{}", base_url.strip_suffix('/').unwrap_or(base_url), #ruma_common::api::select_path(considering_versions, &metadata, #unstable_path, #r0_path, #stable_path)?, #request_query_string, )); if let Some(mut req_headers) = req_builder.headers_mut() { #header_kvs } let http_request = req_builder.body(#request_body)?; Ok(http_request) } } #non_auth_impl } } } /// Produces code for a struct initializer for the given field kind to be accessed through the /// given variable name. fn struct_init_fields<'a>( fields: impl IntoIterator, src: TokenStream, ) -> TokenStream { fields .into_iter() .map(|field| { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let cfg_attrs = field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::>(); quote! { #( #cfg_attrs )* #field_name: #src.#field_name, } }) .collect() } ruma-macros-0.10.5/src/api/request.rs000064400000000000000000000407031046102023000155370ustar 00000000000000use std::collections::{BTreeMap, BTreeSet}; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use syn::{ parse::{Parse, ParseStream}, parse_quote, punctuated::Punctuated, DeriveInput, Field, Generics, Ident, Lifetime, LitStr, Token, Type, }; use super::{ attribute::{DeriveRequestMeta, RequestMeta}, auth_scheme::AuthScheme, util::collect_lifetime_idents, }; use crate::util::import_ruma_common; mod incoming; mod outgoing; pub fn expand_derive_request(input: DeriveInput) -> syn::Result { let fields = match input.data { syn::Data::Struct(s) => s.fields, _ => panic!("This derive macro only works on structs"), }; let mut lifetimes = RequestLifetimes::default(); let fields = fields .into_iter() .map(|f| { let f = RequestField::try_from(f)?; let ty = &f.field().ty; match &f { RequestField::Header(..) => collect_lifetime_idents(&mut lifetimes.header, ty), RequestField::Body(_) => collect_lifetime_idents(&mut lifetimes.body, ty), RequestField::NewtypeBody(_) => collect_lifetime_idents(&mut lifetimes.body, ty), RequestField::RawBody(_) => collect_lifetime_idents(&mut lifetimes.body, ty), RequestField::Path(_) => collect_lifetime_idents(&mut lifetimes.path, ty), RequestField::Query(_) => collect_lifetime_idents(&mut lifetimes.query, ty), RequestField::QueryMap(_) => collect_lifetime_idents(&mut lifetimes.query, ty), } Ok(f) }) .collect::>()?; let mut authentication = None; let mut error_ty = None; let mut method = None; let mut unstable_path = None; let mut r0_path = None; let mut stable_path = None; for attr in input.attrs { if !attr.path.is_ident("ruma_api") { continue; } let metas = attr.parse_args_with(Punctuated::::parse_terminated)?; for meta in metas { match meta { DeriveRequestMeta::Authentication(t) => authentication = Some(parse_quote!(#t)), DeriveRequestMeta::Method(t) => method = Some(parse_quote!(#t)), DeriveRequestMeta::ErrorTy(t) => error_ty = Some(t), DeriveRequestMeta::UnstablePath(s) => unstable_path = Some(s), DeriveRequestMeta::R0Path(s) => r0_path = Some(s), DeriveRequestMeta::StablePath(s) => stable_path = Some(s), } } } let request = Request { ident: input.ident, generics: input.generics, fields, lifetimes, authentication: authentication.expect("missing authentication attribute"), method: method.expect("missing method attribute"), unstable_path, r0_path, stable_path, error_ty: error_ty.expect("missing error_ty attribute"), }; request.check()?; Ok(request.expand_all()) } #[derive(Default)] struct RequestLifetimes { pub body: BTreeSet, pub path: BTreeSet, pub query: BTreeSet, pub header: BTreeSet, } struct Request { ident: Ident, generics: Generics, lifetimes: RequestLifetimes, fields: Vec, authentication: AuthScheme, method: Ident, unstable_path: Option, r0_path: Option, stable_path: Option, error_ty: Type, } impl Request { fn body_fields(&self) -> impl Iterator { self.fields.iter().filter_map(RequestField::as_body_field) } fn has_body_fields(&self) -> bool { self.fields .iter() .any(|f| matches!(f, RequestField::Body(_) | RequestField::NewtypeBody(_))) } fn has_newtype_body(&self) -> bool { self.fields.iter().any(|f| matches!(f, RequestField::NewtypeBody(_))) } fn has_header_fields(&self) -> bool { self.fields.iter().any(|f| matches!(f, RequestField::Header(..))) } fn has_path_fields(&self) -> bool { self.fields.iter().any(|f| matches!(f, RequestField::Path(_))) } fn has_query_fields(&self) -> bool { self.fields.iter().any(|f| matches!(f, RequestField::Query(_))) } fn has_lifetimes(&self) -> bool { !(self.lifetimes.body.is_empty() && self.lifetimes.path.is_empty() && self.lifetimes.query.is_empty() && self.lifetimes.header.is_empty()) } fn header_fields(&self) -> impl Iterator { self.fields.iter().filter(|f| matches!(f, RequestField::Header(..))) } fn path_fields_ordered(&self) -> impl Iterator { let map: BTreeMap = self .fields .iter() .filter_map(RequestField::as_path_field) .map(|f| (f.ident.as_ref().unwrap().to_string(), f)) .collect(); self.stable_path .as_ref() .or(self.r0_path.as_ref()) .or(self.unstable_path.as_ref()) .expect("one of the paths to be defined") .value() .split('/') .filter_map(|s| { s.strip_prefix(':') .map(|s| *map.get(s).expect("path args have already been checked")) }) .collect::>() .into_iter() } fn raw_body_field(&self) -> Option<&Field> { self.fields.iter().find_map(RequestField::as_raw_body_field) } fn query_map_field(&self) -> Option<&Field> { self.fields.iter().find_map(RequestField::as_query_map_field) } fn expand_all(&self) -> TokenStream { let ruma_common = import_ruma_common(); let ruma_macros = quote! { #ruma_common::exports::ruma_macros }; let serde = quote! { #ruma_common::exports::serde }; let request_body_struct = self.has_body_fields().then(|| { let serde_attr = self.has_newtype_body().then(|| quote! { #[serde(transparent)] }); let fields = self.fields.iter().filter_map(RequestField::as_body_field); // Though we don't track the difference between newtype body and body // for lifetimes, the outer check and the macro failing if it encounters // an illegal combination of field attributes, is enough to guarantee // `body_lifetimes` correctness. let lifetimes = &self.lifetimes.body; let derive_deserialize = lifetimes.is_empty().then(|| quote! { #serde::Deserialize }); quote! { /// Data in the request body. #[cfg(any(feature = "client", feature = "server"))] #[derive(Debug, #ruma_macros::_FakeDeriveRumaApi, #ruma_macros::_FakeDeriveSerde)] #[cfg_attr(feature = "client", derive(#serde::Serialize))] #[cfg_attr( feature = "server", derive(#ruma_common::serde::Incoming, #derive_deserialize) )] #serde_attr struct RequestBody< #(#lifetimes),* > { #(#fields),* } } }); let request_query_def = if let Some(f) = self.query_map_field() { let field = Field { ident: None, colon_token: None, ..f.clone() }; Some(quote! { (#field); }) } else if self.has_query_fields() { let fields = self.fields.iter().filter_map(RequestField::as_query_field); Some(quote! { { #(#fields),* } }) } else { None }; let request_query_struct = request_query_def.map(|def| { let lifetimes = &self.lifetimes.query; let derive_deserialize = lifetimes.is_empty().then(|| quote! { #serde::Deserialize }); quote! { /// Data in the request's query string. #[cfg(any(feature = "client", feature = "server"))] #[derive(Debug, #ruma_macros::_FakeDeriveRumaApi, #ruma_macros::_FakeDeriveSerde)] #[cfg_attr(feature = "client", derive(#serde::Serialize))] #[cfg_attr( feature = "server", derive(#ruma_common::serde::Incoming, #derive_deserialize) )] struct RequestQuery< #(#lifetimes),* > #def } }); let outgoing_request_impl = self.expand_outgoing(&ruma_common); let incoming_request_impl = self.expand_incoming(&ruma_common); quote! { #request_body_struct #request_query_struct #outgoing_request_impl #incoming_request_impl } } pub(super) fn check(&self) -> syn::Result<()> { // TODO: highlight problematic fields let path_fields: Vec<_> = self.fields.iter().filter_map(RequestField::as_path_field).collect(); self.check_path(&path_fields, self.unstable_path.as_ref())?; self.check_path(&path_fields, self.r0_path.as_ref())?; self.check_path(&path_fields, self.stable_path.as_ref())?; let newtype_body_fields = self.fields.iter().filter(|field| { matches!(field, RequestField::NewtypeBody(_) | RequestField::RawBody(_)) }); let has_newtype_body_field = match newtype_body_fields.count() { 0 => false, 1 => true, _ => { return Err(syn::Error::new_spanned( &self.ident, "Can't have more than one newtype body field", )) } }; let query_map_fields = self.fields.iter().filter(|f| matches!(f, RequestField::QueryMap(_))); let has_query_map_field = match query_map_fields.count() { 0 => false, 1 => true, _ => { return Err(syn::Error::new_spanned( &self.ident, "Can't have more than one query_map field", )) } }; let has_body_fields = self.fields.iter().any(|f| matches!(f, RequestField::Body(_))); let has_query_fields = self.fields.iter().any(|f| matches!(f, RequestField::Query(_))); if has_newtype_body_field && has_body_fields { return Err(syn::Error::new_spanned( &self.ident, "Can't have both a newtype body field and regular body fields", )); } if has_query_map_field && has_query_fields { return Err(syn::Error::new_spanned( &self.ident, "Can't have both a query map field and regular query fields", )); } // TODO when/if `&[(&str, &str)]` is supported remove this if has_query_map_field && !self.lifetimes.query.is_empty() { return Err(syn::Error::new_spanned( &self.ident, "Lifetimes are not allowed for query_map fields", )); } if self.method == "GET" && (has_body_fields || has_newtype_body_field) { return Err(syn::Error::new_spanned( &self.ident, "GET endpoints can't have body fields", )); } Ok(()) } fn check_path(&self, fields: &[&Field], path: Option<&LitStr>) -> syn::Result<()> { let path = if let Some(lit) = path { lit } else { return Ok(()) }; let path_args: Vec<_> = path .value() .split('/') .filter_map(|s| s.strip_prefix(':').map(str::to_string)) .collect(); let field_map: BTreeMap<_, _> = fields.iter().map(|&f| (f.ident.as_ref().unwrap().to_string(), f)).collect(); // test if all macro fields exist in the path for (name, field) in field_map.iter() { if !path_args.contains(name) { return Err({ let mut err = syn::Error::new_spanned( field, "this path argument field is not defined in...", ); err.combine(syn::Error::new_spanned(path, "...this path.")); err }); } } // test if all path fields exists in macro fields for arg in &path_args { if !field_map.contains_key(arg) { return Err(syn::Error::new_spanned( path, format!( "a corresponding request path argument field for \"{}\" does not exist", arg ), )); } } Ok(()) } } /// The types of fields that a request can have. enum RequestField { /// JSON data in the body of the request. Body(Field), /// Data in an HTTP header. Header(Field, Ident), /// A specific data type in the body of the request. NewtypeBody(Field), /// Arbitrary bytes in the body of the request. RawBody(Field), /// Data that appears in the URL path. Path(Field), /// Data that appears in the query string. Query(Field), /// Data that appears in the query string as dynamic key-value pairs. QueryMap(Field), } impl RequestField { /// Creates a new `RequestField`. fn new(field: Field, kind_attr: Option) -> Self { if let Some(attr) = kind_attr { match attr { RequestMeta::NewtypeBody => RequestField::NewtypeBody(field), RequestMeta::RawBody => RequestField::RawBody(field), RequestMeta::Path => RequestField::Path(field), RequestMeta::Query => RequestField::Query(field), RequestMeta::QueryMap => RequestField::QueryMap(field), RequestMeta::Header(header) => RequestField::Header(field, header), } } else { RequestField::Body(field) } } /// Return the contained field if this request field is a body kind. pub fn as_body_field(&self) -> Option<&Field> { match self { RequestField::Body(field) | RequestField::NewtypeBody(field) => Some(field), _ => None, } } /// Return the contained field if this request field is a raw body kind. pub fn as_raw_body_field(&self) -> Option<&Field> { match self { RequestField::RawBody(field) => Some(field), _ => None, } } /// Return the contained field if this request field is a path kind. pub fn as_path_field(&self) -> Option<&Field> { match self { RequestField::Path(field) => Some(field), _ => None, } } /// Return the contained field if this request field is a query kind. pub fn as_query_field(&self) -> Option<&Field> { match self { RequestField::Query(field) => Some(field), _ => None, } } /// Return the contained field if this request field is a query map kind. pub fn as_query_map_field(&self) -> Option<&Field> { match self { RequestField::QueryMap(field) => Some(field), _ => None, } } /// Gets the inner `Field` value. pub fn field(&self) -> &Field { match self { RequestField::Body(field) | RequestField::Header(field, _) | RequestField::NewtypeBody(field) | RequestField::RawBody(field) | RequestField::Path(field) | RequestField::Query(field) | RequestField::QueryMap(field) => field, } } } impl TryFrom for RequestField { type Error = syn::Error; fn try_from(mut field: Field) -> syn::Result { let (mut api_attrs, attrs) = field.attrs.into_iter().partition::, _>(|attr| attr.path.is_ident("ruma_api")); field.attrs = attrs; let kind_attr = match api_attrs.as_slice() { [] => None, [_] => Some(api_attrs.pop().unwrap().parse_args::()?), _ => { return Err(syn::Error::new_spanned( &api_attrs[1], "multiple field kind attribute found, there can only be one", )); } }; Ok(RequestField::new(field, kind_attr)) } } impl Parse for RequestField { fn parse(input: ParseStream<'_>) -> syn::Result { input.call(Field::parse_named)?.try_into() } } impl ToTokens for RequestField { fn to_tokens(&self, tokens: &mut TokenStream) { self.field().to_tokens(tokens); } } ruma-macros-0.10.5/src/api/response/incoming.rs000064400000000000000000000134571046102023000175160ustar 00000000000000use proc_macro2::TokenStream; use quote::quote; use syn::Type; use super::{Response, ResponseField}; impl Response { pub fn expand_incoming(&self, error_ty: &Type, ruma_common: &TokenStream) -> TokenStream { let http = quote! { #ruma_common::exports::http }; let serde_json = quote! { #ruma_common::exports::serde_json }; let extract_response_headers = self.has_header_fields().then(|| { quote! { let mut headers = response.headers().clone(); } }); let typed_response_body_decl = self.has_body_fields().then(|| { quote! { let response_body: ResponseBody = { let body = ::std::convert::AsRef::<[::std::primitive::u8]>::as_ref( response.body(), ); #serde_json::from_slice(match body { // If the response body is completely empty, pretend it is an empty // JSON object instead. This allows responses with only optional body // parameters to be deserialized in that case. [] => b"{}", b => b, })? }; } }); let response_init_fields = { let mut fields = vec![]; let mut raw_body = None; for response_field in &self.fields { let field = response_field.field(); let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let cfg_attrs = field.attrs.iter().filter(|a| a.path.is_ident("cfg")).collect::>(); fields.push(match response_field { ResponseField::Body(_) | ResponseField::NewtypeBody(_) => { quote! { #( #cfg_attrs )* #field_name: response_body.#field_name } } ResponseField::Header(_, header_name) => { let optional_header = match &field.ty { syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) if segments.last().unwrap().ident == "Option" => { quote! { #( #cfg_attrs )* #field_name: { headers.remove(#http::header::#header_name) .map(|h| h.to_str().map(|s| s.to_owned())) .transpose()? } } } _ => quote! { #( #cfg_attrs )* #field_name: { headers.remove(#http::header::#header_name) .expect("response missing expected header") .to_str()? .to_owned() } }, }; quote! { #optional_header } } // This field must be instantiated last to avoid `use of move value` error. // We are guaranteed only one new body field because of a check in // `parse_response`. ResponseField::RawBody(_) => { raw_body = Some(quote! { #( #cfg_attrs )* #field_name: { ::std::convert::AsRef::<[::std::primitive::u8]>::as_ref( response.body(), ) .to_vec() } }); // skip adding to the vec continue; } }); } fields.extend(raw_body); quote! { #(#fields,)* } }; quote! { #[automatically_derived] #[cfg(feature = "client")] impl #ruma_common::api::IncomingResponse for Response { type EndpointError = #error_ty; fn try_from_http_response>( response: #http::Response, ) -> ::std::result::Result< Self, #ruma_common::api::error::FromHttpResponseError<#error_ty>, > { if response.status().as_u16() < 400 { #extract_response_headers #typed_response_body_decl ::std::result::Result::Ok(Self { #response_init_fields }) } else { match <#error_ty as #ruma_common::api::EndpointError>::try_from_http_response( response ) { ::std::result::Result::Ok(err) => { Err(#ruma_common::api::error::ServerError::Known(err).into()) } ::std::result::Result::Err(response_err) => { Err(#ruma_common::api::error::ServerError::Unknown(response_err).into()) } } } } } } } } ruma-macros-0.10.5/src/api/response/outgoing.rs000064400000000000000000000063051046102023000175400ustar 00000000000000use proc_macro2::TokenStream; use quote::quote; use super::{Response, ResponseField}; impl Response { pub fn expand_outgoing(&self, ruma_common: &TokenStream) -> TokenStream { let bytes = quote! { #ruma_common::exports::bytes }; let http = quote! { #ruma_common::exports::http }; let serialize_response_headers = self.fields.iter().filter_map(|response_field| { response_field.as_header_field().map(|(field, header_name)| { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); match &field.ty { syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) if segments.last().unwrap().ident == "Option" => { quote! { if let Some(header) = self.#field_name { headers.insert( #http::header::#header_name, header.parse()?, ); } } } _ => quote! { headers.insert( #http::header::#header_name, self.#field_name.parse()?, ); }, } }) }); let body = if let Some(field) = self.fields.iter().find_map(ResponseField::as_raw_body_field) { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); quote! { #ruma_common::serde::slice_to_buf(&self.#field_name) } } else { let fields = self.fields.iter().filter_map(|response_field| { response_field.as_body_field().map(|field| { let field_name = field.ident.as_ref().expect("expected field to have an identifier"); let cfg_attrs = field.attrs.iter().filter(|a| a.path.is_ident("cfg")); quote! { #( #cfg_attrs )* #field_name: self.#field_name, } }) }); quote! { #ruma_common::serde::json_to_buf(&ResponseBody { #(#fields)* })? } }; quote! { #[automatically_derived] #[cfg(feature = "server")] impl #ruma_common::api::OutgoingResponse for Response { fn try_into_http_response( self, ) -> ::std::result::Result<#http::Response, #ruma_common::api::error::IntoHttpError> { let mut resp_builder = #http::Response::builder() .header(#http::header::CONTENT_TYPE, "application/json"); if let Some(mut headers) = resp_builder.headers_mut() { #(#serialize_response_headers)* } ::std::result::Result::Ok(resp_builder.body(#body)?) } } } } } ruma-macros-0.10.5/src/api/response.rs000064400000000000000000000206341046102023000157060ustar 00000000000000use std::ops::Not; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use syn::{ parse::{Parse, ParseStream}, punctuated::Punctuated, visit::Visit, DeriveInput, Field, Generics, Ident, Lifetime, Token, Type, }; use super::attribute::{DeriveResponseMeta, ResponseMeta}; use crate::util::import_ruma_common; mod incoming; mod outgoing; pub fn expand_derive_response(input: DeriveInput) -> syn::Result { let fields = match input.data { syn::Data::Struct(s) => s.fields, _ => panic!("This derive macro only works on structs"), }; let fields = fields.into_iter().map(ResponseField::try_from).collect::>()?; let mut manual_body_serde = false; let mut error_ty = None; for attr in input.attrs { if !attr.path.is_ident("ruma_api") { continue; } let metas = attr.parse_args_with(Punctuated::::parse_terminated)?; for meta in metas { match meta { DeriveResponseMeta::ManualBodySerde => manual_body_serde = true, DeriveResponseMeta::ErrorTy(t) => error_ty = Some(t), } } } let response = Response { ident: input.ident, generics: input.generics, fields, manual_body_serde, error_ty: error_ty.unwrap(), }; response.check()?; Ok(response.expand_all()) } struct Response { ident: Ident, generics: Generics, fields: Vec, manual_body_serde: bool, error_ty: Type, } impl Response { /// Whether or not this request has any data in the HTTP body. fn has_body_fields(&self) -> bool { self.fields .iter() .any(|f| matches!(f, ResponseField::Body(_) | &ResponseField::NewtypeBody(_))) } /// Whether or not this request has a single newtype body field. fn has_newtype_body(&self) -> bool { self.fields.iter().any(|f| matches!(f, ResponseField::NewtypeBody(_))) } /// Whether or not this request has a single raw body field. fn has_raw_body(&self) -> bool { self.fields.iter().any(|f| matches!(f, ResponseField::RawBody(_))) } /// Whether or not this request has any data in the URL path. fn has_header_fields(&self) -> bool { self.fields.iter().any(|f| matches!(f, &ResponseField::Header(..))) } fn expand_all(&self) -> TokenStream { let ruma_common = import_ruma_common(); let ruma_macros = quote! { #ruma_common::exports::ruma_macros }; let serde = quote! { #ruma_common::exports::serde }; let response_body_struct = (!self.has_raw_body()).then(|| { let serde_derives = self.manual_body_serde.not().then(|| { quote! { #[cfg_attr(feature = "client", derive(#serde::Deserialize))] #[cfg_attr(feature = "server", derive(#serde::Serialize))] } }); let serde_attr = self.has_newtype_body().then(|| quote! { #[serde(transparent)] }); let fields = self.fields.iter().filter_map(ResponseField::as_body_field); quote! { /// Data in the response body. #[cfg(any(feature = "client", feature = "server"))] #[derive(Debug, #ruma_macros::_FakeDeriveRumaApi, #ruma_macros::_FakeDeriveSerde)] #serde_derives #serde_attr struct ResponseBody { #(#fields),* } } }); let outgoing_response_impl = self.expand_outgoing(&ruma_common); let incoming_response_impl = self.expand_incoming(&self.error_ty, &ruma_common); quote! { #response_body_struct #outgoing_response_impl #incoming_response_impl } } pub fn check(&self) -> syn::Result<()> { // TODO: highlight problematic fields assert!( self.generics.params.is_empty() && self.generics.where_clause.is_none(), "This macro doesn't support generic types" ); let newtype_body_fields = self .fields .iter() .filter(|f| matches!(f, ResponseField::NewtypeBody(_) | ResponseField::RawBody(_))); let has_newtype_body_field = match newtype_body_fields.count() { 0 => false, 1 => true, _ => { return Err(syn::Error::new_spanned( &self.ident, "Can't have more than one newtype body field", )) } }; let has_body_fields = self.fields.iter().any(|f| matches!(f, ResponseField::Body(_))); if has_newtype_body_field && has_body_fields { return Err(syn::Error::new_spanned( &self.ident, "Can't have both a newtype body field and regular body fields", )); } Ok(()) } } /// The types of fields that a response can have. enum ResponseField { /// JSON data in the body of the response. Body(Field), /// Data in an HTTP header. Header(Field, Ident), /// A specific data type in the body of the response. NewtypeBody(Field), /// Arbitrary bytes in the body of the response. RawBody(Field), } impl ResponseField { /// Creates a new `ResponseField`. fn new(field: Field, kind_attr: Option) -> Self { if let Some(attr) = kind_attr { match attr { ResponseMeta::NewtypeBody => ResponseField::NewtypeBody(field), ResponseMeta::RawBody => ResponseField::RawBody(field), ResponseMeta::Header(header) => ResponseField::Header(field, header), } } else { ResponseField::Body(field) } } /// Gets the inner `Field` value. fn field(&self) -> &Field { match self { ResponseField::Body(field) | ResponseField::Header(field, _) | ResponseField::NewtypeBody(field) | ResponseField::RawBody(field) => field, } } /// Return the contained field if this response field is a body kind. fn as_body_field(&self) -> Option<&Field> { match self { ResponseField::Body(field) | ResponseField::NewtypeBody(field) => Some(field), _ => None, } } /// Return the contained field if this response field is a raw body kind. fn as_raw_body_field(&self) -> Option<&Field> { match self { ResponseField::RawBody(field) => Some(field), _ => None, } } /// Return the contained field and HTTP header ident if this response field is a header kind. fn as_header_field(&self) -> Option<(&Field, &Ident)> { match self { ResponseField::Header(field, ident) => Some((field, ident)), _ => None, } } } impl TryFrom for ResponseField { type Error = syn::Error; fn try_from(mut field: Field) -> syn::Result { if has_lifetime(&field.ty) { return Err(syn::Error::new_spanned( field.ident, "Lifetimes on Response fields cannot be supported until GAT are stable", )); } let (mut api_attrs, attrs) = field.attrs.into_iter().partition::, _>(|attr| attr.path.is_ident("ruma_api")); field.attrs = attrs; let kind_attr = match api_attrs.as_slice() { [] => None, [_] => Some(api_attrs.pop().unwrap().parse_args::()?), _ => { return Err(syn::Error::new_spanned( &api_attrs[1], "multiple field kind attribute found, there can only be one", )); } }; Ok(ResponseField::new(field, kind_attr)) } } impl Parse for ResponseField { fn parse(input: ParseStream<'_>) -> syn::Result { input.call(Field::parse_named)?.try_into() } } impl ToTokens for ResponseField { fn to_tokens(&self, tokens: &mut TokenStream) { self.field().to_tokens(tokens); } } fn has_lifetime(ty: &Type) -> bool { struct Visitor { found_lifetime: bool, } impl<'ast> Visit<'ast> for Visitor { fn visit_lifetime(&mut self, _lt: &'ast Lifetime) { self.found_lifetime = true; } } let mut vis = Visitor { found_lifetime: false }; vis.visit_type(ty); vis.found_lifetime } ruma-macros-0.10.5/src/api/util.rs000064400000000000000000000054721046102023000150300ustar 00000000000000//! Functions to aid the `Api::to_tokens` method. use std::collections::BTreeSet; use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens}; use syn::{parse_quote, visit::Visit, Attribute, Lifetime, NestedMeta, Type}; pub fn map_option_literal(ver: &Option) -> TokenStream { match ver { Some(v) => quote! { ::std::option::Option::Some(#v) }, None => quote! { ::std::option::Option::None }, } } pub fn is_valid_endpoint_path(string: &str) -> bool { string.as_bytes().iter().all(|b| (0x21..=0x7E).contains(b)) } pub fn collect_lifetime_idents(lifetimes: &mut BTreeSet, ty: &Type) { struct Visitor<'lt>(&'lt mut BTreeSet); impl<'ast> Visit<'ast> for Visitor<'_> { fn visit_lifetime(&mut self, lt: &'ast Lifetime) { self.0.insert(lt.clone()); } } Visitor(lifetimes).visit_type(ty); } pub fn all_cfgs_expr(cfgs: &[Attribute]) -> Option { let sub_cfgs: Vec<_> = cfgs.iter().filter_map(extract_cfg).collect(); (!sub_cfgs.is_empty()).then(|| quote! { all( #(#sub_cfgs),* ) }) } pub fn all_cfgs(cfgs: &[Attribute]) -> Option { let cfg_expr = all_cfgs_expr(cfgs)?; Some(parse_quote! { #[cfg( #cfg_expr )] }) } pub fn extract_cfg(attr: &Attribute) -> Option { if !attr.path.is_ident("cfg") { return None; } let meta = attr.parse_meta().expect("cfg attribute can be parsed to syn::Meta"); let mut list = match meta { syn::Meta::List(l) => l, _ => panic!("unexpected cfg syntax"), }; assert!(list.path.is_ident("cfg"), "expected cfg attributes only"); assert_eq!(list.nested.len(), 1, "expected one item inside cfg()"); Some(list.nested.pop().unwrap().into_value()) } pub fn path_format_args_call( mut format_string: String, percent_encoding: &TokenStream, ) -> TokenStream { let mut format_args = Vec::new(); while let Some(start_of_segment) = format_string.find(':') { // ':' should only ever appear at the start of a segment assert_eq!(&format_string[start_of_segment - 1..start_of_segment], "/"); let end_of_segment = match format_string[start_of_segment..].find('/') { Some(rel_pos) => start_of_segment + rel_pos, None => format_string.len(), }; let path_var = Ident::new(&format_string[start_of_segment + 1..end_of_segment], Span::call_site()); format_args.push(quote! { #percent_encoding::utf8_percent_encode( &::std::string::ToString::to_string(&self.#path_var), #percent_encoding::NON_ALPHANUMERIC, ) }); format_string.replace_range(start_of_segment..end_of_segment, "{}"); } quote! { format_args!(#format_string, #(#format_args),*) } } ruma-macros-0.10.5/src/api/version.rs000064400000000000000000000027471046102023000155420ustar 00000000000000use std::num::NonZeroU8; use proc_macro2::TokenStream; use quote::{format_ident, quote, ToTokens}; use syn::{parse::Parse, Error, LitFloat}; #[derive(Clone)] pub struct MatrixVersionLiteral { pub(crate) major: NonZeroU8, pub(crate) minor: u8, } impl Parse for MatrixVersionLiteral { fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { let fl: LitFloat = input.parse()?; if !fl.suffix().is_empty() { return Err(Error::new_spanned( fl, "matrix version has to be only two positive numbers separated by a `.`", )); } let ver_vec: Vec = fl.to_string().split('.').map(&str::to_owned).collect(); let ver: [String; 2] = ver_vec.try_into().map_err(|_| { Error::new_spanned(&fl, "did not contain only both an X and Y value like X.Y") })?; let major: NonZeroU8 = ver[0].parse().map_err(|e| { Error::new_spanned(&fl, format!("major number failed to parse as >0 number: {e}")) })?; let minor: u8 = ver[1] .parse() .map_err(|e| Error::new_spanned(&fl, format!("minor number failed to parse: {e}")))?; Ok(Self { major, minor }) } } impl ToTokens for MatrixVersionLiteral { fn to_tokens(&self, tokens: &mut TokenStream) { let variant = format_ident!("V{}_{}", u8::from(self.major), self.minor); tokens.extend(quote! { ::ruma_common::api::MatrixVersion::#variant }); } } ruma-macros-0.10.5/src/api.rs000064400000000000000000000160731046102023000140520ustar 00000000000000//! Methods and types for generating API endpoints. use std::{env, fs, path::Path}; use once_cell::sync::Lazy; use proc_macro2::{Span, TokenStream}; use quote::quote; use serde::{de::IgnoredAny, Deserialize}; use syn::{ braced, parse::{Parse, ParseStream}, Attribute, Field, Token, Type, }; use self::{api_metadata::Metadata, api_request::Request, api_response::Response}; use crate::util::import_ruma_common; mod api_metadata; mod api_request; mod api_response; mod attribute; mod auth_scheme; pub mod request; pub mod response; mod util; mod version; mod kw { use syn::custom_keyword; custom_keyword!(error); custom_keyword!(request); custom_keyword!(response); } /// The result of processing the `ruma_api` macro, ready for output back to source code. pub struct Api { /// The `metadata` section of the macro. metadata: Metadata, /// The `request` section of the macro. request: Option, /// The `response` section of the macro. response: Option, /// The `error` section of the macro. error_ty: Option, } impl Api { pub fn expand_all(self) -> TokenStream { let maybe_error = ensure_feature_presence().map(syn::Error::to_compile_error); let ruma_common = import_ruma_common(); let http = quote! { #ruma_common::exports::http }; let metadata = &self.metadata; let description = &metadata.description; let method = &metadata.method; let name = &metadata.name; let unstable_path = util::map_option_literal(&metadata.unstable_path); let r0_path = util::map_option_literal(&metadata.r0_path); let stable_path = util::map_option_literal(&metadata.stable_path); let rate_limited = &self.metadata.rate_limited; let authentication = &self.metadata.authentication; let added = util::map_option_literal(&metadata.added); let deprecated = util::map_option_literal(&metadata.deprecated); let removed = util::map_option_literal(&metadata.removed); let error_ty = self.error_ty.map_or_else( || quote! { #ruma_common::api::error::MatrixError }, |err_ty| quote! { #err_ty }, ); let request = self.request.map(|req| req.expand(metadata, &error_ty, &ruma_common)); let response = self.response.map(|res| res.expand(metadata, &error_ty, &ruma_common)); let metadata_doc = format!("Metadata for the `{}` API endpoint.", name.value()); quote! { #maybe_error #[doc = #metadata_doc] pub const METADATA: #ruma_common::api::Metadata = #ruma_common::api::Metadata { description: #description, method: #http::Method::#method, name: #name, unstable_path: #unstable_path, r0_path: #r0_path, stable_path: #stable_path, added: #added, deprecated: #deprecated, removed: #removed, rate_limited: #rate_limited, authentication: #ruma_common::api::AuthScheme::#authentication, }; #request #response #[cfg(not(any(feature = "client", feature = "server")))] type _SilenceUnusedError = #error_ty; } } } impl Parse for Api { fn parse(input: ParseStream<'_>) -> syn::Result { let metadata: Metadata = input.parse()?; let req_attrs = input.call(Attribute::parse_outer)?; let (request, attributes) = if input.peek(kw::request) { let request = parse_request(input, req_attrs)?; let after_req_attrs = input.call(Attribute::parse_outer)?; (Some(request), after_req_attrs) } else { // There was no `request` field so the attributes are for `response` (None, req_attrs) }; let response = if input.peek(kw::response) { Some(parse_response(input, attributes)?) } else if !attributes.is_empty() { return Err(syn::Error::new_spanned( &attributes[0], "attributes are not supported on the error type", )); } else { None }; let error_ty = input .peek(kw::error) .then(|| { let _: kw::error = input.parse()?; let _: Token![:] = input.parse()?; input.parse() }) .transpose()?; Ok(Self { metadata, request, response, error_ty }) } } fn parse_request(input: ParseStream<'_>, attributes: Vec) -> syn::Result { let request_kw: kw::request = input.parse()?; let _: Token![:] = input.parse()?; let fields; braced!(fields in input); let fields = fields.parse_terminated::<_, Token![,]>(Field::parse_named)?; Ok(Request { request_kw, attributes, fields }) } fn parse_response(input: ParseStream<'_>, attributes: Vec) -> syn::Result { let response_kw: kw::response = input.parse()?; let _: Token![:] = input.parse()?; let fields; braced!(fields in input); let fields = fields.parse_terminated::<_, Token![,]>(Field::parse_named)?; Ok(Response { attributes, fields, response_kw }) } // Returns an error with a helpful error if the crate `ruma_api!` is used from doesn't declare both // a `client` and a `server` feature. fn ensure_feature_presence() -> Option<&'static syn::Error> { #[derive(Deserialize)] struct CargoToml { features: Features, } #[derive(Deserialize)] struct Features { client: Option, server: Option, } static RESULT: Lazy> = Lazy::new(|| { let manifest_dir = env::var("CARGO_MANIFEST_DIR") .map_err(|_| syn::Error::new(Span::call_site(), "Failed to read CARGO_MANIFEST_DIR"))?; let manifest_file = Path::new(&manifest_dir).join("Cargo.toml"); let manifest_bytes = fs::read(manifest_file) .map_err(|_| syn::Error::new(Span::call_site(), "Failed to read Cargo.toml"))?; let manifest_parsed: CargoToml = toml::from_slice(&manifest_bytes) .map_err(|_| syn::Error::new(Span::call_site(), "Failed to parse Cargo.toml"))?; if manifest_parsed.features.client.is_none() { return Err(syn::Error::new( Span::call_site(), "This crate doesn't define a `client` feature in its `Cargo.toml`.\n\ Please add a `client` feature such that generated `OutgoingRequest` and \ `IncomingResponse` implementations can be enabled.", )); } if manifest_parsed.features.server.is_none() { return Err(syn::Error::new( Span::call_site(), "This crate doesn't define a `server` feature in its `Cargo.toml`.\n\ Please add a `server` feature such that generated `IncomingRequest` and \ `OutgoingResponse` implementations can be enabled.", )); } Ok(()) }); RESULT.as_ref().err() } ruma-macros-0.10.5/src/events/event.rs000064400000000000000000000450411046102023000157230ustar 00000000000000//! Implementation of the top level `*Event` derive macro. use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use syn::{parse_quote, Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, GenericParam}; use super::{ event_parse::{to_kind_variation, EventKind, EventKindVariation}, util::{has_prev_content, is_non_stripped_room_event}, }; use crate::{import_ruma_common, util::to_camel_case}; /// Derive `Event` macro code generation. pub fn expand_event(input: DeriveInput) -> syn::Result { let ruma_common = import_ruma_common(); let ident = &input.ident; let (kind, var) = to_kind_variation(ident).ok_or_else(|| { syn::Error::new_spanned(ident, "not a valid ruma event struct identifier") })?; let fields: Vec<_> = if let Data::Struct(DataStruct { fields: Fields::Named(FieldsNamed { named, .. }), .. }) = &input.data { if !named.iter().any(|f| f.ident.as_ref().unwrap() == "content") { return Err(syn::Error::new( Span::call_site(), "struct must contain a `content` field", )); } named.iter().cloned().collect() } else { return Err(syn::Error::new_spanned( input.ident, "the `Event` derive only supports structs with named fields", )); }; let mut res = TokenStream::new(); res.extend(expand_serialize_event(&input, var, &fields, &ruma_common)); res.extend( expand_deserialize_event(&input, kind, var, &fields, &ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); if var.is_sync() { res.extend(expand_sync_from_into_full(&input, kind, var, &fields, &ruma_common)); } if matches!(kind, EventKind::MessageLike | EventKind::State) && matches!(var, EventKindVariation::Original | EventKindVariation::OriginalSync) { res.extend(expand_redact_event(&input, kind, var, &fields, &ruma_common)); } if is_non_stripped_room_event(kind, var) { res.extend(expand_eq_ord_event(&input)); } Ok(res) } fn expand_serialize_event( input: &DeriveInput, var: EventKindVariation, fields: &[Field], ruma_common: &TokenStream, ) -> TokenStream { let serde = quote! { #ruma_common::exports::serde }; let ident = &input.ident; let (impl_gen, ty_gen, where_clause) = input.generics.split_for_impl(); let serialize_fields: Vec<_> = fields .iter() .map(|field| { let name = field.ident.as_ref().unwrap(); if name == "content" && var.is_redacted() { quote! { if #ruma_common::events::RedactedEventContent::has_serialize_fields(&self.content) { state.serialize_field("content", &self.content)?; } } } else if name == "unsigned" { quote! { if !#ruma_common::serde::is_empty(&self.unsigned) { state.serialize_field("unsigned", &self.unsigned)?; } } } else { let name_s = name.to_string(); match &field.ty { syn::Type::Path(syn::TypePath { path: syn::Path { segments, .. }, .. }) if segments.last().unwrap().ident == "Option" => { quote! { if let Some(content) = self.#name.as_ref() { state.serialize_field(#name_s, content)?; } } } _ => quote! { state.serialize_field(#name_s, &self.#name)?; }, } } }) .collect(); quote! { #[automatically_derived] impl #impl_gen #serde::ser::Serialize for #ident #ty_gen #where_clause { fn serialize(&self, serializer: S) -> Result where S: #serde::ser::Serializer, { use #serde::ser::{SerializeStruct as _, Error as _}; let mut state = serializer.serialize_struct(stringify!(#ident), 7)?; let event_type = #ruma_common::events::EventContent::event_type(&self.content); state.serialize_field("type", &event_type)?; #( #serialize_fields )* state.end() } } } } fn expand_deserialize_event( input: &DeriveInput, kind: EventKind, var: EventKindVariation, fields: &[Field], ruma_common: &TokenStream, ) -> syn::Result { let serde = quote! { #ruma_common::exports::serde }; let serde_json = quote! { #ruma_common::exports::serde_json }; let ident = &input.ident; // we know there is a content field already let content_type = &fields .iter() // we also know that the fields are named and have an ident .find(|f| f.ident.as_ref().unwrap() == "content") .unwrap() .ty; let (impl_generics, ty_gen, where_clause) = input.generics.split_for_impl(); let is_generic = !input.generics.params.is_empty(); let enum_variants: Vec<_> = fields .iter() .map(|field| { let name = field.ident.as_ref().unwrap(); to_camel_case(name) }) .collect(); let deserialize_var_types: Vec<_> = fields .iter() .map(|field| { let name = field.ident.as_ref().unwrap(); if name == "content" || (name == "unsigned" && has_prev_content(kind, var)) { if is_generic { quote! { ::std::boxed::Box<#serde_json::value::RawValue> } } else { quote! { #content_type } } } else if name == "state_key" && var == EventKindVariation::Initial { quote! { ::std::string::String } } else { let ty = &field.ty; quote! { #ty } } }) .collect(); let ok_or_else_fields: Vec<_> = fields .iter() .map(|field| { let name = field.ident.as_ref().unwrap(); Ok(if name == "content" { if is_generic && var.is_redacted() { quote! { let content = match C::has_deserialize_fields() { #ruma_common::events::HasDeserializeFields::False => { C::empty(&event_type).map_err(#serde::de::Error::custom)? }, #ruma_common::events::HasDeserializeFields::True => { let json = content.ok_or_else( || #serde::de::Error::missing_field("content"), )?; C::from_parts(&event_type, &json) .map_err(#serde::de::Error::custom)? }, #ruma_common::events::HasDeserializeFields::Optional => { let json = content.unwrap_or( #serde_json::value::RawValue::from_string("{}".to_owned()) .unwrap() ); C::from_parts(&event_type, &json) .map_err(#serde::de::Error::custom)? }, }; } } else if is_generic { quote! { let content = { let json = content .ok_or_else(|| #serde::de::Error::missing_field("content"))?; C::from_parts(&event_type, &json).map_err(#serde::de::Error::custom)? }; } } else { quote! { let content = content.ok_or_else( || #serde::de::Error::missing_field("content"), )?; } } } else if name == "unsigned" { if has_prev_content(kind, var) { quote! { let unsigned = unsigned.map(|json| { #ruma_common::events::StateUnsignedFromParts::_from_parts( &event_type, &json, ).map_err(#serde::de::Error::custom) }).transpose()?.unwrap_or_default(); } } else { quote! { let unsigned = unsigned.unwrap_or_default(); } } } else if name == "state_key" && var == EventKindVariation::Initial { let ty = &field.ty; quote! { let state_key: ::std::string::String = state_key.unwrap_or_default(); let state_key: #ty = <#ty as #serde::de::Deserialize>::deserialize( #serde::de::IntoDeserializer::::into_deserializer(state_key), )?; } } else { quote! { let #name = #name.ok_or_else(|| { #serde::de::Error::missing_field(stringify!(#name)) })?; } }) }) .collect::>()?; let field_names: Vec<_> = fields.iter().flat_map(|f| &f.ident).collect(); let deserialize_impl_gen = if is_generic { let gen = &input.generics.params; quote! { <'de, #gen> } } else { quote! { <'de> } }; let deserialize_phantom_type = if is_generic { quote! { ::std::marker::PhantomData } } else { quote! {} }; Ok(quote! { #[automatically_derived] impl #deserialize_impl_gen #serde::de::Deserialize<'de> for #ident #ty_gen #where_clause { fn deserialize(deserializer: D) -> Result where D: #serde::de::Deserializer<'de>, { #[derive(#serde::Deserialize)] #[serde(field_identifier, rename_all = "snake_case")] enum Field { // since this is represented as an enum we have to add it so the JSON picks it // up Type, #( #enum_variants, )* #[serde(other)] Unknown, } /// Visits the fields of an event struct to handle deserialization of /// the `content` and `prev_content` fields. struct EventVisitor #impl_generics (#deserialize_phantom_type #ty_gen); #[automatically_derived] impl #deserialize_impl_gen #serde::de::Visitor<'de> for EventVisitor #ty_gen #where_clause { type Value = #ident #ty_gen; fn expecting( &self, formatter: &mut ::std::fmt::Formatter<'_>, ) -> ::std::fmt::Result { write!(formatter, "struct implementing {}", stringify!(#content_type)) } fn visit_map(self, mut map: A) -> Result where A: #serde::de::MapAccess<'de>, { let mut event_type: Option = None; #( let mut #field_names: Option<#deserialize_var_types> = None; )* while let Some(key) = map.next_key()? { match key { Field::Unknown => { let _: #serde::de::IgnoredAny = map.next_value()?; }, Field::Type => { if event_type.is_some() { return Err(#serde::de::Error::duplicate_field("type")); } event_type = Some(map.next_value()?); } #( Field::#enum_variants => { if #field_names.is_some() { return Err(#serde::de::Error::duplicate_field( stringify!(#field_names), )); } #field_names = Some(map.next_value()?); } )* } } let event_type = event_type.ok_or_else(|| #serde::de::Error::missing_field("type"))?; #( #ok_or_else_fields )* Ok(#ident { #( #field_names ),* }) } } deserializer.deserialize_map(EventVisitor(#deserialize_phantom_type)) } } }) } fn expand_redact_event( input: &DeriveInput, kind: EventKind, var: EventKindVariation, fields: &[Field], ruma_common: &TokenStream, ) -> syn::Result { let redacted_type = kind.to_event_ident(var.to_redacted())?; let ident = &input.ident; let mut generics = input.generics.clone(); if generics.params.is_empty() { return Ok(TokenStream::new()); } assert_eq!(generics.params.len(), 1, "expected one generic parameter"); let ty_param = match &generics.params[0] { GenericParam::Type(ty) => ty.ident.clone(), _ => panic!("expected a type parameter"), }; let where_clause = generics.make_where_clause(); where_clause.predicates.push(parse_quote! { #ty_param: #ruma_common::events::RedactContent }); let assoc_type_bounds = (kind == EventKind::State).then(|| quote! { StateKey = #ty_param::StateKey }); let trait_name = format_ident!("Redacted{kind}Content"); let redacted_event_content_bound = quote! { #ruma_common::events::#trait_name<#assoc_type_bounds> }; where_clause.predicates.push(parse_quote! { <#ty_param as #ruma_common::events::RedactContent>::Redacted: #redacted_event_content_bound }); let (impl_generics, ty_gen, where_clause) = generics.split_for_impl(); let fields = fields.iter().filter_map(|field| { let ident = field.ident.as_ref().unwrap(); if ident == "content" || ident == "prev_content" { None } else if ident == "unsigned" { Some(quote! { unsigned: #ruma_common::events::RedactedUnsigned::new_because( ::std::boxed::Box::new(redaction), ) }) } else { Some(quote! { #ident: self.#ident }) } }); Ok(quote! { #[automatically_derived] impl #impl_generics #ruma_common::events::Redact for #ident #ty_gen #where_clause { type Redacted = #ruma_common::events::#redacted_type< <#ty_param as #ruma_common::events::RedactContent>::Redacted, >; fn redact( self, redaction: #ruma_common::events::room::redaction::SyncRoomRedactionEvent, version: &#ruma_common::RoomVersionId, ) -> Self::Redacted { let content = #ruma_common::events::RedactContent::redact(self.content, version); #ruma_common::events::#redacted_type { content, #(#fields),* } } } }) } fn expand_sync_from_into_full( input: &DeriveInput, kind: EventKind, var: EventKindVariation, fields: &[Field], ruma_common: &TokenStream, ) -> syn::Result { let ident = &input.ident; let full_struct = kind.to_event_ident(var.to_full())?; let (impl_generics, ty_gen, where_clause) = input.generics.split_for_impl(); let fields: Vec<_> = fields.iter().flat_map(|f| &f.ident).collect(); Ok(quote! { #[automatically_derived] impl #impl_generics ::std::convert::From<#full_struct #ty_gen> for #ident #ty_gen #where_clause { fn from(event: #full_struct #ty_gen) -> Self { let #full_struct { #( #fields, )* .. } = event; Self { #( #fields, )* } } } #[automatically_derived] impl #impl_generics #ident #ty_gen #where_clause { /// Convert this sync event into a full event, one with a room_id field. pub fn into_full_event( self, room_id: #ruma_common::OwnedRoomId, ) -> #full_struct #ty_gen { let Self { #( #fields, )* } = self; #full_struct { #( #fields, )* room_id, } } } }) } fn expand_eq_ord_event(input: &DeriveInput) -> TokenStream { let ident = &input.ident; let (impl_gen, ty_gen, where_clause) = input.generics.split_for_impl(); quote! { #[automatically_derived] impl #impl_gen ::std::cmp::PartialEq for #ident #ty_gen #where_clause { /// Checks if two `EventId`s are equal. fn eq(&self, other: &Self) -> ::std::primitive::bool { self.event_id == other.event_id } } #[automatically_derived] impl #impl_gen ::std::cmp::Eq for #ident #ty_gen #where_clause {} #[automatically_derived] impl #impl_gen ::std::cmp::PartialOrd for #ident #ty_gen #where_clause { /// Compares `EventId`s and orders them lexicographically. fn partial_cmp(&self, other: &Self) -> ::std::option::Option<::std::cmp::Ordering> { self.event_id.partial_cmp(&other.event_id) } } #[automatically_derived] impl #impl_gen ::std::cmp::Ord for #ident #ty_gen #where_clause { /// Compares `EventId`s and orders them lexicographically. fn cmp(&self, other: &Self) -> ::std::cmp::Ordering { self.event_id.cmp(&other.event_id) } } } } ruma-macros-0.10.5/src/events/event_content.rs000064400000000000000000000630171046102023000174600ustar 00000000000000//! Implementations of the EventContent derive macro. #![allow(clippy::too_many_arguments)] // FIXME use std::borrow::Cow; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use syn::{ parse::{Parse, ParseStream}, DeriveInput, Field, Ident, LitStr, Token, Type, }; use crate::util::m_prefix_name_to_type_name; use super::event_parse::{EventKind, EventKindVariation}; mod kw { // This `content` field is kept when the event is redacted. syn::custom_keyword!(skip_redaction); // Do not emit any redacted event code. syn::custom_keyword!(custom_redacted); // The kind of event content this is. syn::custom_keyword!(kind); syn::custom_keyword!(type_fragment); // The type to use for a state events' `state_key` field. syn::custom_keyword!(state_key_type); // The type to use for a state events' `unsigned` field. syn::custom_keyword!(unsigned_type); // Another type string accepted for deserialization. syn::custom_keyword!(alias); } /// Parses struct attributes for `*EventContent` derives. /// /// `#[ruma_event(type = "m.room.alias")]` enum EventStructMeta { /// Variant holds the "m.whatever" event type. Type(LitStr), Kind(EventKind), /// This attribute signals that the events redacted form is manually implemented and should not /// be generated. CustomRedacted, StateKeyType(Box), UnsignedType(Box), /// Variant that holds alternate event type accepted for deserialization. Alias(LitStr), } impl EventStructMeta { fn get_event_type(&self) -> Option<&LitStr> { match self { Self::Type(t) => Some(t), _ => None, } } fn get_event_kind(&self) -> Option { match self { Self::Kind(k) => Some(*k), _ => None, } } fn get_state_key_type(&self) -> Option<&Type> { match self { Self::StateKeyType(ty) => Some(ty), _ => None, } } fn get_unsigned_type(&self) -> Option<&Type> { match self { Self::UnsignedType(ty) => Some(ty), _ => None, } } fn get_alias(&self) -> Option<&LitStr> { match self { Self::Alias(t) => Some(t), _ => None, } } } impl Parse for EventStructMeta { fn parse(input: ParseStream<'_>) -> syn::Result { let lookahead = input.lookahead1(); if lookahead.peek(Token![type]) { let _: Token![type] = input.parse()?; let _: Token![=] = input.parse()?; input.parse().map(EventStructMeta::Type) } else if lookahead.peek(kw::kind) { let _: kw::kind = input.parse()?; let _: Token![=] = input.parse()?; input.parse().map(EventStructMeta::Kind) } else if lookahead.peek(kw::custom_redacted) { let _: kw::custom_redacted = input.parse()?; Ok(EventStructMeta::CustomRedacted) } else if lookahead.peek(kw::state_key_type) { let _: kw::state_key_type = input.parse()?; let _: Token![=] = input.parse()?; input.parse().map(EventStructMeta::StateKeyType) } else if lookahead.peek(kw::unsigned_type) { let _: kw::unsigned_type = input.parse()?; let _: Token![=] = input.parse()?; input.parse().map(EventStructMeta::UnsignedType) } else if lookahead.peek(kw::alias) { let _: kw::alias = input.parse()?; let _: Token![=] = input.parse()?; input.parse().map(EventStructMeta::Alias) } else { Err(lookahead.error()) } } } /// Parses field attributes for `*EventContent` derives. /// /// `#[ruma_event(skip_redaction)]` enum EventFieldMeta { /// Fields marked with `#[ruma_event(skip_redaction)]` are kept when the event is /// redacted. SkipRedaction, /// The given field holds a part of the event type (replaces the `*` in a `m.foo.*` event /// type). TypeFragment, } impl Parse for EventFieldMeta { fn parse(input: ParseStream<'_>) -> syn::Result { let lookahead = input.lookahead1(); if lookahead.peek(kw::skip_redaction) { let _: kw::skip_redaction = input.parse()?; Ok(EventFieldMeta::SkipRedaction) } else if lookahead.peek(kw::type_fragment) { let _: kw::type_fragment = input.parse()?; Ok(EventFieldMeta::TypeFragment) } else { Err(lookahead.error()) } } } struct MetaAttrs(Vec); impl MetaAttrs { fn is_custom(&self) -> bool { self.0.iter().any(|a| matches!(a, &EventStructMeta::CustomRedacted)) } fn get_event_type(&self) -> Option<&LitStr> { self.0.iter().find_map(|a| a.get_event_type()) } fn get_event_kind(&self) -> Option { self.0.iter().find_map(|a| a.get_event_kind()) } fn get_state_key_type(&self) -> Option<&Type> { self.0.iter().find_map(|a| a.get_state_key_type()) } fn get_unsigned_type(&self) -> Option<&Type> { self.0.iter().find_map(|a| a.get_unsigned_type()) } fn get_aliases(&self) -> impl Iterator { self.0.iter().filter_map(|a| a.get_alias()) } } impl Parse for MetaAttrs { fn parse(input: ParseStream<'_>) -> syn::Result { let attrs = syn::punctuated::Punctuated::::parse_terminated(input)?; Ok(Self(attrs.into_iter().collect())) } } /// Create an `EventContent` implementation for a struct. pub fn expand_event_content( input: &DeriveInput, ruma_common: &TokenStream, ) -> syn::Result { let content_attr = input .attrs .iter() .filter(|attr| attr.path.is_ident("ruma_event")) .map(|attr| attr.parse_args::()) .collect::>>()?; let mut event_types: Vec<_> = content_attr.iter().filter_map(|attrs| attrs.get_event_type()).collect(); let event_type = match event_types.as_slice() { [] => { return Err(syn::Error::new( Span::call_site(), "no event type attribute found, \ add `#[ruma_event(type = \"any.room.event\", kind = Kind)]` \ below the event content derive", )); } [_] => event_types.pop().unwrap(), _ => { return Err(syn::Error::new( Span::call_site(), "multiple event type attributes found, there can only be one", )); } }; let mut event_kinds: Vec<_> = content_attr.iter().filter_map(|attrs| attrs.get_event_kind()).collect(); let event_kind = match event_kinds.as_slice() { [] => None, [_] => Some(event_kinds.pop().unwrap()), _ => { return Err(syn::Error::new( Span::call_site(), "multiple event kind attributes found, there can only be one", )); } }; let state_key_types: Vec<_> = content_attr.iter().filter_map(|attrs| attrs.get_state_key_type()).collect(); let state_key_type = match (event_kind, state_key_types.as_slice()) { (Some(EventKind::State), []) => { return Err(syn::Error::new( Span::call_site(), "no state_key_type attribute found, please specify one", )); } (Some(EventKind::State), [ty]) => Some(quote! { #ty }), (Some(EventKind::State), _) => { return Err(syn::Error::new( Span::call_site(), "multiple state_key_type attribute found, there can only be one", )); } (_, []) => None, (_, [ty, ..]) => { return Err(syn::Error::new_spanned( ty, "state_key_type attribute is not valid for non-state event kinds", )); } }; let unsigned_types: Vec<_> = content_attr.iter().filter_map(|attrs| attrs.get_unsigned_type()).collect(); let unsigned_type = match unsigned_types.as_slice() { [] => None, [ty] => Some(quote! { #ty }), _ => { return Err(syn::Error::new( Span::call_site(), "multiple unsigned attributes found, there can only be one", )); } }; let ident = &input.ident; let fields = match &input.data { syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.iter(), _ => { return Err(syn::Error::new( Span::call_site(), "event content types need to be structs", )); } }; let event_type_s = event_type.value(); let prefix = event_type_s.strip_suffix(".*"); if prefix.unwrap_or(&event_type_s).contains('*') { return Err(syn::Error::new_spanned( event_type, "event type may only contain `*` as part of a `.*` suffix", )); } if prefix.is_some() && !event_kind.map_or(false, |k| k.is_account_data()) { return Err(syn::Error::new_spanned( event_type, "only account data events may contain a `.*` suffix", )); } let aliases: Vec<_> = content_attr.iter().flat_map(|attrs| attrs.get_aliases()).collect(); for alias in &aliases { if alias.value().ends_with(".*") != prefix.is_some() { return Err(syn::Error::new_spanned( event_type, "aliases should have the same `.*` suffix, or lack thereof, as the main event type", )); } } // We only generate redacted content structs for state and message-like events let redacted_event_content = needs_redacted(&content_attr, event_kind).then(|| { generate_redacted_event_content( ident, fields.clone(), event_type, event_kind, state_key_type.as_ref(), unsigned_type.clone(), &aliases, ruma_common, ) .unwrap_or_else(syn::Error::into_compile_error) }); let event_content_impl = generate_event_content_impl( ident, fields, event_type, event_kind, state_key_type.as_ref(), unsigned_type, &aliases, ruma_common, ) .unwrap_or_else(syn::Error::into_compile_error); let static_event_content_impl = event_kind .map(|k| generate_static_event_content_impl(ident, k, false, event_type, ruma_common)); let type_aliases = event_kind.map(|k| { generate_event_type_aliases(k, ident, &event_type.value(), ruma_common) .unwrap_or_else(syn::Error::into_compile_error) }); Ok(quote! { #redacted_event_content #event_content_impl #static_event_content_impl #type_aliases }) } fn generate_redacted_event_content<'a>( ident: &Ident, fields: impl Iterator, event_type: &LitStr, event_kind: Option, state_key_type: Option<&TokenStream>, unsigned_type: Option, aliases: &[&LitStr], ruma_common: &TokenStream, ) -> syn::Result { assert!( !event_type.value().contains('*'), "Event type shouldn't contain a `*`, this should have been checked previously" ); let serde = quote! { #ruma_common::exports::serde }; let serde_json = quote! { #ruma_common::exports::serde_json }; let doc = format!("Redacted form of [`{}`]", ident); let redacted_ident = format_ident!("Redacted{ident}"); let kept_redacted_fields: Vec<_> = fields .map(|f| { let mut keep_field = false; let attrs = f .attrs .iter() .map(|a| -> syn::Result<_> { if a.path.is_ident("ruma_event") { if let EventFieldMeta::SkipRedaction = a.parse_args()? { keep_field = true; } // don't re-emit our `ruma_event` attributes Ok(None) } else { Ok(Some(a.clone())) } }) .filter_map(Result::transpose) .collect::>()?; if keep_field { Ok(Some(Field { attrs, ..f.clone() })) } else { Ok(None) } }) .filter_map(Result::transpose) .collect::>()?; let redaction_struct_fields = kept_redacted_fields.iter().flat_map(|f| &f.ident); let (redacted_fields, redacted_return) = if kept_redacted_fields.is_empty() { (quote! { ; }, quote! { Ok(#redacted_ident {}) }) } else { ( quote! { { #( #kept_redacted_fields, )* } }, quote! { Err(#serde::de::Error::custom( format!("this redacted event has fields that cannot be constructed") )) }, ) }; let (has_deserialize_fields, has_serialize_fields) = if kept_redacted_fields.is_empty() { (quote! { #ruma_common::events::HasDeserializeFields::False }, quote! { false }) } else { (quote! { #ruma_common::events::HasDeserializeFields::True }, quote! { true }) }; let constructor = kept_redacted_fields.is_empty().then(|| { let doc = format!("Creates an empty {}.", redacted_ident); quote! { impl #redacted_ident { #[doc = #doc] pub fn new() -> Self { Self } } } }); let redacted_event_content = generate_event_content_impl( &redacted_ident, kept_redacted_fields.iter(), event_type, event_kind, state_key_type, unsigned_type, aliases, ruma_common, ) .unwrap_or_else(syn::Error::into_compile_error); let sub_trait_name = event_kind.map(|kind| format_ident!("Redacted{kind}Content")); let static_event_content_impl = event_kind.map(|kind| { generate_static_event_content_impl(&redacted_ident, kind, true, event_type, ruma_common) }); let mut event_types = aliases.to_owned(); event_types.push(event_type); let event_types = quote! { [#(#event_types,)*] }; Ok(quote! { // this is the non redacted event content's impl #[automatically_derived] impl #ruma_common::events::RedactContent for #ident { type Redacted = #redacted_ident; fn redact(self, version: &#ruma_common::RoomVersionId) -> #redacted_ident { #redacted_ident { #( #redaction_struct_fields: self.#redaction_struct_fields, )* } } } #[doc = #doc] #[derive(Clone, Debug, #serde::Deserialize, #serde::Serialize)] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] pub struct #redacted_ident #redacted_fields #constructor #redacted_event_content #[automatically_derived] impl #ruma_common::events::RedactedEventContent for #redacted_ident { fn empty(ev_type: &str) -> #serde_json::Result { if !#event_types.contains(&ev_type) { return Err(#serde::de::Error::custom( format!("expected event type as one of `{:?}`, found `{}`", #event_types, ev_type) )); } #redacted_return } fn has_serialize_fields(&self) -> bool { #has_serialize_fields } fn has_deserialize_fields() -> #ruma_common::events::HasDeserializeFields { #has_deserialize_fields } } #[automatically_derived] impl #ruma_common::events::#sub_trait_name for #redacted_ident {} #static_event_content_impl }) } fn generate_event_type_aliases( event_kind: EventKind, ident: &Ident, event_type: &str, ruma_common: &TokenStream, ) -> syn::Result { // The redaction module has its own event types. if ident == "RoomRedactionEventContent" { return Ok(quote! {}); } let ident_s = ident.to_string(); let ev_type_s = ident_s.strip_suffix("Content").ok_or_else(|| { syn::Error::new_spanned(ident, "Expected content struct name ending in `Content`") })?; let type_aliases = [ EventKindVariation::None, EventKindVariation::Sync, EventKindVariation::Original, EventKindVariation::OriginalSync, EventKindVariation::Stripped, EventKindVariation::Initial, EventKindVariation::Redacted, EventKindVariation::RedactedSync, ] .iter() .filter_map(|&var| Some((var, event_kind.to_event_ident(var).ok()?))) .map(|(var, ev_struct)| { let ev_type = format_ident!("{var}{ev_type_s}"); let doc_text = match var { EventKindVariation::None | EventKindVariation::Original => "", EventKindVariation::Sync | EventKindVariation::OriginalSync => { " from a `sync_events` response" } EventKindVariation::Stripped => " from an invited room preview", EventKindVariation::Redacted => " that has been redacted", EventKindVariation::RedactedSync => { " from a `sync_events` response that has been redacted" } EventKindVariation::Initial => " for creating a room", }; let ev_type_doc = format!("An `{}` event{}.", event_type, doc_text); let content_struct = if var.is_redacted() { Cow::Owned(format_ident!("Redacted{ident}")) } else { Cow::Borrowed(ident) }; quote! { #[doc = #ev_type_doc] pub type #ev_type = #ruma_common::events::#ev_struct<#content_struct>; } }) .collect(); Ok(type_aliases) } fn generate_event_content_impl<'a>( ident: &Ident, mut fields: impl Iterator, event_type: &LitStr, event_kind: Option, state_key_type: Option<&TokenStream>, unsigned_type: Option, aliases: &[&'a LitStr], ruma_common: &TokenStream, ) -> syn::Result { let serde = quote! { #ruma_common::exports::serde }; let serde_json = quote! { #ruma_common::exports::serde_json }; let (event_type_ty_decl, event_type_ty, event_type_fn_impl); let type_suffix_data = event_type .value() .strip_suffix('*') .map(|type_prefix| { let type_fragment_field = fields .find_map(|f| { f.attrs.iter().filter(|a| a.path.is_ident("ruma_event")).find_map(|a| { match a.parse_args() { Ok(EventFieldMeta::TypeFragment) => Some(Ok(f)), Ok(_) => None, Err(e) => Some(Err(e)), } }) }) .transpose()? .ok_or_else(|| { syn::Error::new_spanned( event_type, "event type with a `.*` suffix requires there to be a \ `#[ruma_event(type_fragment)]` field", ) })? .ident .as_ref() .expect("type fragment field needs to have a name"); >::Ok((type_prefix.to_owned(), type_fragment_field)) }) .transpose()?; match event_kind { Some(kind) => { let i = kind.to_event_type_enum(); event_type_ty_decl = None; event_type_ty = quote! { #ruma_common::events::#i }; event_type_fn_impl = match &type_suffix_data { Some((type_prefix, type_fragment_field)) => { let format = type_prefix.to_owned() + "{}"; quote! { ::std::convert::From::from(::std::format!(#format, self.#type_fragment_field)) } } None => quote! { ::std::convert::From::from(#event_type) }, }; } None => { let camel_case_type_name = m_prefix_name_to_type_name(event_type)?; let i = format_ident!("{}EventType", camel_case_type_name); event_type_ty_decl = Some(quote! { /// Implementation detail, you don't need to care about this. #[doc(hidden)] pub struct #i { // Set to None for intended type, Some for a different one ty: ::std::option::Option, } impl #serde::Serialize for #i { fn serialize(&self, serializer: S) -> ::std::result::Result where S: #serde::Serializer, { let s = self.ty.as_ref().map(|t| &t.0[..]).unwrap_or(#event_type); serializer.serialize_str(s) } } }); event_type_ty = quote! { #i }; event_type_fn_impl = quote! { #event_type_ty { ty: ::std::option::Option::None } }; } } let sub_trait_impl = event_kind.map(|kind| { let trait_name = format_ident!("{kind}Content"); let state_event_content_impl = (event_kind == Some(EventKind::State)).then(|| { assert!(state_key_type.is_some()); let unsigned_type = unsigned_type .unwrap_or_else(|| quote! { #ruma_common::events::StateUnsigned }); quote! { type StateKey = #state_key_type; type Unsigned = #unsigned_type; } }); quote! { #[automatically_derived] impl #ruma_common::events::#trait_name for #ident { #state_event_content_impl } } }); let event_types = aliases.iter().chain([&event_type]); let from_parts_fn_impl = if let Some((_, type_fragment_field)) = &type_suffix_data { let type_prefixes = event_types.map(|ev_type| { ev_type .value() .strip_suffix('*') .expect("aliases have already been checked to have the same suffix") .to_owned() }); let type_prefixes = quote! { [#(#type_prefixes,)*] }; quote! { if let Some(type_fragment) = #type_prefixes.iter().find_map(|prefix| ev_type.strip_prefix(prefix)) { let mut content: Self = #serde_json::from_str(content.get())?; content.#type_fragment_field = type_fragment.to_owned(); ::std::result::Result::Ok(content) } else { ::std::result::Result::Err(#serde::de::Error::custom( ::std::format!("expected event type starting with one of `{:?}`, found `{}`", #type_prefixes, ev_type) )) } } } else { let event_types = quote! { [#(#event_types,)*] }; quote! { if !#event_types.contains(&ev_type) { return ::std::result::Result::Err(#serde::de::Error::custom( ::std::format!("expected event type as one of `{:?}`, found `{}`", #event_types, ev_type) )); } #serde_json::from_str(content.get()) } }; Ok(quote! { #event_type_ty_decl #[automatically_derived] impl #ruma_common::events::EventContent for #ident { type EventType = #event_type_ty; fn event_type(&self) -> Self::EventType { #event_type_fn_impl } fn from_parts( ev_type: &::std::primitive::str, content: &#serde_json::value::RawValue, ) -> #serde_json::Result { #from_parts_fn_impl } } #sub_trait_impl }) } fn generate_static_event_content_impl( ident: &Ident, event_kind: EventKind, redacted: bool, event_type: &LitStr, ruma_common: &TokenStream, ) -> TokenStream { let event_kind = match event_kind { EventKind::GlobalAccountData => quote! { GlobalAccountData }, EventKind::RoomAccountData => quote! { RoomAccountData }, EventKind::Ephemeral => quote! { EphemeralRoomData }, EventKind::MessageLike => quote! { MessageLike { redacted: #redacted } }, EventKind::State => quote! { State { redacted: #redacted } }, EventKind::ToDevice => quote! { ToDevice }, EventKind::RoomRedaction | EventKind::Presence | EventKind::Decrypted | EventKind::HierarchySpaceChild => { unreachable!("not a valid event content kind") } }; quote! { impl #ruma_common::events::StaticEventContent for #ident { const KIND: #ruma_common::events::EventKind = #ruma_common::events::EventKind::#event_kind; const TYPE: &'static ::std::primitive::str = #event_type; } } } fn needs_redacted(input: &[MetaAttrs], event_kind: Option) -> bool { // `is_custom` means that the content struct does not need a generated // redacted struct also. If no `custom_redacted` attrs are found the content // needs a redacted struct generated. !input.iter().any(|a| a.is_custom()) && matches!(event_kind, Some(EventKind::MessageLike) | Some(EventKind::State)) } ruma-macros-0.10.5/src/events/event_enum.rs000064400000000000000000000717631046102023000167610ustar 00000000000000//! Implementation of event enum and event content enum macros. use std::{fmt, iter::zip}; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, IdentFragment, ToTokens}; use syn::{Attribute, Data, DataEnum, DeriveInput, Ident, LitStr, Path}; use super::event_parse::{EventEnumDecl, EventEnumEntry, EventKind}; use crate::util::m_prefix_name_to_type_name; /// Custom keywords for the `event_enum!` macro mod kw { syn::custom_keyword!(kind); syn::custom_keyword!(events); } pub(crate) fn is_non_stripped_room_event(kind: EventKind, var: EventEnumVariation) -> bool { matches!(kind, EventKind::MessageLike | EventKind::State) && matches!(var, EventEnumVariation::None | EventEnumVariation::Sync) } type EventKindFn = fn(EventKind, EventEnumVariation) -> bool; /// This const is used to generate the accessor methods for the `Any*Event` enums. /// /// DO NOT alter the field names unless the structs in `ruma_common::events::event_kinds` have /// changed. const EVENT_FIELDS: &[(&str, EventKindFn)] = &[ ("origin_server_ts", is_non_stripped_room_event), ("room_id", |kind, var| { matches!(kind, EventKind::MessageLike | EventKind::State | EventKind::Ephemeral) && matches!(var, EventEnumVariation::None) }), ("event_id", is_non_stripped_room_event), ("sender", |kind, var| { matches!(kind, EventKind::MessageLike | EventKind::State | EventKind::ToDevice) && var != EventEnumVariation::Initial }), ]; /// Create a content enum from `EventEnumInput`. pub fn expand_event_enums(input: &EventEnumDecl) -> syn::Result { use EventEnumVariation as V; let ruma_common = crate::import_ruma_common(); let mut res = TokenStream::new(); let kind = input.kind; let attrs = &input.attrs; let docs: Vec<_> = input.events.iter().map(EventEnumEntry::docs).collect::>()?; let variants: Vec<_> = input.events.iter().map(EventEnumEntry::to_variant).collect::>()?; let events = &input.events; let docs = &docs; let variants = &variants; let ruma_common = &ruma_common; res.extend(expand_content_enum(kind, events, docs, attrs, variants, ruma_common)); res.extend( expand_event_enum(kind, V::None, events, docs, attrs, variants, ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); if matches!(kind, EventKind::MessageLike | EventKind::State) { res.extend( expand_event_enum(kind, V::Sync, events, docs, attrs, variants, ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); res.extend( expand_redact(kind, V::None, variants, ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); res.extend( expand_redact(kind, V::Sync, variants, ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); res.extend( expand_from_full_event(kind, V::None, variants) .unwrap_or_else(syn::Error::into_compile_error), ); res.extend( expand_into_full_event(kind, V::Sync, variants, ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); } if matches!(kind, EventKind::Ephemeral) { res.extend( expand_event_enum(kind, V::Sync, events, docs, attrs, variants, ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); } if matches!(kind, EventKind::State) { res.extend( expand_event_enum(kind, V::Stripped, events, docs, attrs, variants, ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); res.extend( expand_event_enum(kind, V::Initial, events, docs, attrs, variants, ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); } Ok(res) } fn expand_event_enum( kind: EventKind, var: EventEnumVariation, events: &[EventEnumEntry], docs: &[TokenStream], attrs: &[Attribute], variants: &[EventEnumVariant], ruma_common: &TokenStream, ) -> syn::Result { let event_struct = kind.to_event_ident(var.into())?; let ident = kind.to_event_enum_ident(var.into())?; let variant_decls = variants.iter().map(|v| v.decl()); let content: Vec<_> = events .iter() .map(|event| { event .stable_name() .map(|stable_name| to_event_path(stable_name, &event.ev_path, kind, var)) }) .collect::>()?; let custom_ty = format_ident!("Custom{}Content", kind); let deserialize_impl = expand_deserialize_impl(kind, var, events, ruma_common)?; let field_accessor_impl = expand_accessor_methods(kind, var, variants, ruma_common)?; let from_impl = expand_from_impl(&ident, &content, variants); Ok(quote! { #( #attrs )* #[derive(Clone, Debug)] #[allow(clippy::large_enum_variant, unused_qualifications)] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] pub enum #ident { #( #docs #variant_decls(#content), )* /// An event not defined by the Matrix specification #[doc(hidden)] _Custom( #ruma_common::events::#event_struct<#ruma_common::events::_custom::#custom_ty>, ), } #deserialize_impl #field_accessor_impl #from_impl }) } fn expand_deserialize_impl( kind: EventKind, var: EventEnumVariation, events: &[EventEnumEntry], ruma_common: &TokenStream, ) -> syn::Result { let serde = quote! { #ruma_common::exports::serde }; let serde_json = quote! { #ruma_common::exports::serde_json }; let ident = kind.to_event_enum_ident(var.into())?; let match_arms: TokenStream = events .iter() .map(|event| { let variant = event.to_variant()?; let variant_attrs = { let attrs = &variant.attrs; quote! { #(#attrs)* } }; let self_variant = variant.ctor(quote! { Self }); let content = to_event_path(event.stable_name()?, &event.ev_path, kind, var); let ev_types = event.aliases.iter().chain([&event.ev_type]); Ok(quote! { #variant_attrs #(#ev_types)|* => { let event = #serde_json::from_str::<#content>(json.get()) .map_err(D::Error::custom)?; Ok(#self_variant(event)) }, }) }) .collect::>()?; Ok(quote! { #[allow(unused_qualifications)] impl<'de> #serde::de::Deserialize<'de> for #ident { fn deserialize(deserializer: D) -> Result where D: #serde::de::Deserializer<'de>, { use #serde::de::Error as _; let json = Box::<#serde_json::value::RawValue>::deserialize(deserializer)?; let #ruma_common::events::EventTypeDeHelper { ev_type, .. } = #ruma_common::serde::from_raw_json_value(&json)?; match &*ev_type { #match_arms _ => { let event = #serde_json::from_str(json.get()).map_err(D::Error::custom)?; Ok(Self::_Custom(event)) }, } } } }) } fn expand_from_impl( ty: &Ident, content: &[TokenStream], variants: &[EventEnumVariant], ) -> TokenStream { let from_impls = content.iter().zip(variants).map(|(content, variant)| { let ident = &variant.ident; let attrs = &variant.attrs; quote! { #[allow(unused_qualifications)] #[automatically_derived] #(#attrs)* impl ::std::convert::From<#content> for #ty { fn from(c: #content) -> Self { Self::#ident(c) } } } }); quote! { #( #from_impls )* } } fn expand_from_full_event( kind: EventKind, var: EventEnumVariation, variants: &[EventEnumVariant], ) -> syn::Result { let ident = kind.to_event_enum_ident(var.into())?; let sync = kind.to_event_enum_ident(var.to_sync().into())?; let ident_variants = variants.iter().map(|v| v.match_arm(&ident)); let self_variants = variants.iter().map(|v| v.ctor(quote! { Self })); Ok(quote! { #[automatically_derived] impl ::std::convert::From<#ident> for #sync { fn from(event: #ident) -> Self { match event { #( #ident_variants(event) => { #self_variants(::std::convert::From::from(event)) }, )* #ident::_Custom(event) => { Self::_Custom(::std::convert::From::from(event)) }, } } } }) } fn expand_into_full_event( kind: EventKind, var: EventEnumVariation, variants: &[EventEnumVariant], ruma_common: &TokenStream, ) -> syn::Result { let ident = kind.to_event_enum_ident(var.into())?; let full = kind.to_event_enum_ident(var.to_full().into())?; let self_variants = variants.iter().map(|v| v.match_arm(quote! { Self })); let full_variants = variants.iter().map(|v| v.ctor(&full)); Ok(quote! { #[automatically_derived] impl #ident { /// Convert this sync event into a full event (one with a `room_id` field). pub fn into_full_event(self, room_id: #ruma_common::OwnedRoomId) -> #full { match self { #( #self_variants(event) => { #full_variants(event.into_full_event(room_id)) }, )* Self::_Custom(event) => { #full::_Custom(event.into_full_event(room_id)) }, } } } }) } /// Create a content enum from `EventEnumInput`. fn expand_content_enum( kind: EventKind, events: &[EventEnumEntry], docs: &[TokenStream], attrs: &[Attribute], variants: &[EventEnumVariant], ruma_common: &TokenStream, ) -> syn::Result { let serde = quote! { #ruma_common::exports::serde }; let serde_json = quote! { #ruma_common::exports::serde_json }; let ident = kind.to_content_enum(); let event_type_enum = kind.to_event_type_enum(); let content: Vec<_> = events .iter() .map(|event| { let stable_name = event.stable_name()?; Ok(to_event_content_path(kind, stable_name, &event.ev_path, None)) }) .collect::>()?; let event_type_match_arms: TokenStream = zip(zip(events, variants), &content) .map(|((event, variant), ev_content)| { let variant_attrs = { let attrs = &variant.attrs; quote! { #(#attrs)* } }; let variant_ctor = variant.ctor(quote! { Self }); let ev_types = event.aliases.iter().chain([&event.ev_type]); let ev_types = if event.ev_type.value().ends_with(".*") { let ev_types = ev_types.map(|ev_type| { ev_type .value() .strip_suffix(".*") .expect("aliases have already been checked to have the same suffix") .to_owned() }); quote! { _s if #(_s.starts_with(#ev_types))||* } } else { quote! { #(#ev_types)|* } }; Ok(quote! { #variant_attrs #ev_types => { let content = #ev_content::from_parts(event_type, input)?; ::std::result::Result::Ok(#variant_ctor(content)) }, }) }) .collect::>()?; let variant_decls = variants.iter().map(|v| v.decl()).collect::>(); let variant_arms = variants.iter().map(|v| v.match_arm(quote! { Self })).collect::>(); let sub_trait_name = format_ident!("{kind}Content"); let state_event_content_impl = (kind == EventKind::State).then(|| { quote! { type StateKey = String; // FIXME: Not actually used type Unsigned = #ruma_common::events::StateUnsigned; } }); let from_impl = expand_from_impl(&ident, &content, variants); let serialize_custom_event_error_path = quote! { #ruma_common::events::serialize_custom_event_error }.to_string(); Ok(quote! { #( #attrs )* #[derive(Clone, Debug, #serde::Serialize)] #[serde(untagged)] #[allow(clippy::large_enum_variant)] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] pub enum #ident { #( #docs #variant_decls(#content), )* #[doc(hidden)] #[serde(serialize_with = #serialize_custom_event_error_path)] _Custom { event_type: crate::PrivOwnedStr, }, } #[automatically_derived] impl #ruma_common::events::EventContent for #ident { type EventType = #ruma_common::events::#event_type_enum; fn event_type(&self) -> Self::EventType { match self { #( #variant_arms(content) => content.event_type(), )* Self::_Custom { event_type } => ::std::convert::From::from(&event_type.0[..]), } } fn from_parts( event_type: &::std::primitive::str, input: &#serde_json::value::RawValue, ) -> #serde_json::Result { match event_type { #event_type_match_arms ty => { ::std::result::Result::Ok(Self::_Custom { event_type: crate::PrivOwnedStr(ty.into()), }) } } } } #[automatically_derived] impl #ruma_common::events::#sub_trait_name for #ident { #state_event_content_impl } #from_impl }) } fn expand_redact( kind: EventKind, var: EventEnumVariation, variants: &[EventEnumVariant], ruma_common: &TokenStream, ) -> syn::Result { let ident = kind.to_event_enum_ident(var.into())?; let self_variants = variants.iter().map(|v| v.match_arm(quote! { Self })); let redacted_variants = variants.iter().map(|v| v.ctor(&ident)); Ok(quote! { #[automatically_derived] impl #ruma_common::events::Redact for #ident { type Redacted = Self; fn redact( self, redaction: #ruma_common::events::room::redaction::SyncRoomRedactionEvent, version: &#ruma_common::RoomVersionId, ) -> Self { match self { #( #self_variants(event) => #redacted_variants( #ruma_common::events::Redact::redact(event, redaction, version), ), )* Self::_Custom(event) => Self::_Custom( #ruma_common::events::Redact::redact(event, redaction, version), ) } } } }) } fn expand_accessor_methods( kind: EventKind, var: EventEnumVariation, variants: &[EventEnumVariant], ruma_common: &TokenStream, ) -> syn::Result { let ident = kind.to_event_enum_ident(var.into())?; let event_type_enum = format_ident!("{}Type", kind); let self_variants: Vec<_> = variants.iter().map(|v| v.match_arm(quote! { Self })).collect(); let maybe_redacted = kind.is_timeline() && matches!(var, EventEnumVariation::None | EventEnumVariation::Sync); let event_type_match_arms = if maybe_redacted { quote! { #( #self_variants(event) => event.event_type(), )* Self::_Custom(event) => event.event_type(), } } else { quote! { #( #self_variants(event) => #ruma_common::events::EventContent::event_type(&event.content), )* Self::_Custom(event) => ::std::convert::From::from( #ruma_common::events::EventContent::event_type(&event.content), ), } }; let content_enum = kind.to_content_enum(); let content_variants: Vec<_> = variants.iter().map(|v| v.ctor(&content_enum)).collect(); let content_accessor = if maybe_redacted { quote! { /// Returns the content for this event if it is not redacted, or `None` if it is. pub fn original_content(&self) -> Option<#content_enum> { match self { #( #self_variants(event) => { event.as_original().map(|ev| #content_variants(ev.content.clone())) } )* Self::_Custom(event) => event.as_original().map(|ev| { #content_enum::_Custom { event_type: crate::PrivOwnedStr( ::std::convert::From::from( ::std::string::ToString::to_string( &#ruma_common::events::EventContent::event_type( &ev.content, ), ), ), ), } }), } } } } else { quote! { /// Returns the content for this event. pub fn content(&self) -> #content_enum { match self { #( #self_variants(event) => #content_variants(event.content.clone()), )* Self::_Custom(event) => #content_enum::_Custom { event_type: crate::PrivOwnedStr( ::std::convert::From::from( ::std::string::ToString::to_string( &#ruma_common::events::EventContent::event_type(&event.content) ) ), ), }, } } } }; let methods = EVENT_FIELDS.iter().map(|(name, has_field)| { has_field(kind, var).then(|| { let docs = format!("Returns this event's `{}` field.", name); let ident = Ident::new(name, Span::call_site()); let field_type = field_return_type(name, ruma_common); let variants = variants.iter().map(|v| v.match_arm(quote! { Self })); let call_parens = maybe_redacted.then(|| quote! { () }); let ampersand = (*name != "origin_server_ts").then(|| quote! { & }); quote! { #[doc = #docs] pub fn #ident(&self) -> #field_type { match self { #( #variants(event) => #ampersand event.#ident #call_parens, )* Self::_Custom(event) => #ampersand event.#ident #call_parens, } } } }) }); let state_key_accessor = (kind == EventKind::State).then(|| { let variants = variants.iter().map(|v| v.match_arm(quote! { Self })); let call_parens = maybe_redacted.then(|| quote! { () }); quote! { /// Returns this event's `state_key` field. pub fn state_key(&self) -> &::std::primitive::str { match self { #( #variants(event) => &event.state_key #call_parens .as_ref(), )* Self::_Custom(event) => &event.state_key #call_parens .as_ref(), } } } }); let maybe_redacted_accessors = maybe_redacted.then(|| { let variants = variants.iter().map(|v| v.match_arm(quote! { Self })); let variants2 = variants.clone(); quote! { /// Returns this event's `transaction_id` from inside `unsigned`, if there is one. pub fn transaction_id(&self) -> Option<&#ruma_common::TransactionId> { match self { #( #variants(event) => { event.as_original().and_then(|ev| ev.unsigned.transaction_id.as_deref()) } )* Self::_Custom(event) => { event.as_original().and_then(|ev| ev.unsigned.transaction_id.as_deref()) } } } /// Returns this event's `relations` from inside `unsigned`, if that field exists. pub fn relations(&self) -> Option<&#ruma_common::events::Relations> { match self { #( #variants2(event) => { event.as_original().and_then(|ev| ev.unsigned.relations.as_ref()) } )* Self::_Custom(event) => { event.as_original().and_then(|ev| ev.unsigned.relations.as_ref()) } } } } }); Ok(quote! { #[automatically_derived] impl #ident { /// Returns the `type` of this event. pub fn event_type(&self) -> #ruma_common::events::#event_type_enum { match self { #event_type_match_arms } } #content_accessor #( #methods )* #state_key_accessor #maybe_redacted_accessors } }) } fn to_event_path( name: &LitStr, path: &Path, kind: EventKind, var: EventEnumVariation, ) -> TokenStream { let event = m_prefix_name_to_type_name(name).unwrap(); let event_name = if kind == EventKind::ToDevice { assert_eq!(var, EventEnumVariation::None); format_ident!("ToDevice{}Event", event) } else { format_ident!("{}{}Event", var, event) }; quote! { #path::#event_name } } fn to_event_content_path( kind: EventKind, name: &LitStr, path: &Path, prefix: Option<&str>, ) -> TokenStream { let event = m_prefix_name_to_type_name(name).unwrap(); let content_str = match kind { EventKind::ToDevice => { format_ident!("ToDevice{}{}EventContent", prefix.unwrap_or(""), event) } _ => format_ident!("{}{}EventContent", prefix.unwrap_or(""), event), }; quote! { #path::#content_str } } fn field_return_type(name: &str, ruma_common: &TokenStream) -> TokenStream { match name { "origin_server_ts" => quote! { #ruma_common::MilliSecondsSinceUnixEpoch }, "room_id" => quote! { &#ruma_common::RoomId }, "event_id" => quote! { &#ruma_common::EventId }, "sender" => quote! { &#ruma_common::UserId }, _ => panic!("the `ruma_macros::event_enum::EVENT_FIELD` const was changed"), } } pub(crate) struct EventEnumVariant { pub attrs: Vec, pub ident: Ident, } impl EventEnumVariant { pub(crate) fn to_tokens(&self, prefix: Option, with_attrs: bool) -> TokenStream where T: ToTokens, { let mut tokens = TokenStream::new(); if with_attrs { for attr in &self.attrs { attr.to_tokens(&mut tokens); } } if let Some(p) = prefix { tokens.extend(quote! { #p :: }); } self.ident.to_tokens(&mut tokens); tokens } pub(crate) fn decl(&self) -> TokenStream { self.to_tokens::(None, true) } pub(crate) fn match_arm(&self, prefix: impl ToTokens) -> TokenStream { self.to_tokens(Some(prefix), true) } pub(crate) fn ctor(&self, prefix: impl ToTokens) -> TokenStream { self.to_tokens(Some(prefix), false) } } impl EventEnumEntry { pub(crate) fn has_type_fragment(&self) -> bool { self.ev_type.value().ends_with(".*") } pub(crate) fn to_variant(&self) -> syn::Result { let attrs = self.attrs.clone(); let ident = m_prefix_name_to_type_name(self.stable_name()?)?; Ok(EventEnumVariant { attrs, ident }) } pub(crate) fn stable_name(&self) -> syn::Result<&LitStr> { if self.ev_type.value().starts_with("m.") { Ok(&self.ev_type) } else { self.aliases.iter().find(|alias| alias.value().starts_with("m.")).ok_or_else(|| { syn::Error::new( Span::call_site(), format!( "A matrix event must declare a well-known type that starts with `m.` \ either as the main type or as an alias, found `{}`", self.ev_type.value() ), ) }) } } pub(crate) fn docs(&self) -> syn::Result { let stable_name = self.stable_name()?; let mut doc = quote! { #[doc = #stable_name] }; if self.ev_type != *stable_name { let unstable_name = format!("This variant uses the unstable type `{}`.", self.ev_type.value()); doc.extend(quote! { #[doc = ""] #[doc = #unstable_name] }); } match self.aliases.len() { 0 => {} 1 => { let alias = format!( "This variant can also be deserialized from the `{}` type.", self.aliases[0].value() ); doc.extend(quote! { #[doc = ""] #[doc = #alias] }); } _ => { let aliases = format!( "This variant can also be deserialized from the following types: {}.", self.aliases .iter() .map(|alias| format!("`{}`", alias.value())) .collect::>() .join(", ") ); doc.extend(quote! { #[doc = ""] #[doc = #aliases] }); } } Ok(doc) } } pub(crate) fn expand_from_impls_derived(input: DeriveInput) -> TokenStream { let variants = match &input.data { Data::Enum(DataEnum { variants, .. }) => variants, _ => panic!("this derive macro only works with enums"), }; let from_impls = variants.iter().map(|variant| match &variant.fields { syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => { let inner_struct = &fields.unnamed.first().unwrap().ty; let var_ident = &variant.ident; let id = &input.ident; quote! { #[automatically_derived] impl ::std::convert::From<#inner_struct> for #id { fn from(c: #inner_struct) -> Self { Self::#var_ident(c) } } } } _ => { panic!("this derive macro only works with enum variants with a single unnamed field") } }); quote! { #( #from_impls )* } } // If the variants of this enum change `to_event_path` needs to be updated as well. #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum EventEnumVariation { None, Sync, Stripped, Initial, } impl From for crate::events::event_parse::EventKindVariation { fn from(v: EventEnumVariation) -> Self { match v { EventEnumVariation::None => Self::None, EventEnumVariation::Sync => Self::Sync, EventEnumVariation::Stripped => Self::Stripped, EventEnumVariation::Initial => Self::Initial, } } } // FIXME: Duplicated with the other EventKindVariation type impl EventEnumVariation { pub fn to_sync(self) -> Self { match self { EventEnumVariation::None => EventEnumVariation::Sync, _ => panic!("No sync form of {self:?}"), } } pub fn to_full(self) -> Self { match self { EventEnumVariation::Sync => EventEnumVariation::None, _ => panic!("No full form of {self:?}"), } } } impl IdentFragment for EventEnumVariation { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { EventEnumVariation::None => write!(f, ""), EventEnumVariation::Sync => write!(f, "Sync"), EventEnumVariation::Stripped => write!(f, "Stripped"), EventEnumVariation::Initial => write!(f, "Initial"), } } } ruma-macros-0.10.5/src/events/event_parse.rs000064400000000000000000000272671046102023000171270ustar 00000000000000//! Implementation of event enum and event content enum macros. use std::fmt; use proc_macro2::Span; use quote::{format_ident, IdentFragment}; use syn::{ braced, parse::{self, Parse, ParseStream}, punctuated::Punctuated, Attribute, Ident, LitStr, Path, Token, }; /// Custom keywords for the `event_enum!` macro mod kw { syn::custom_keyword!(kind); syn::custom_keyword!(events); syn::custom_keyword!(alias); } // If the variants of this enum change `to_event_path` needs to be updated as well. #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum EventKindVariation { None, Sync, Original, OriginalSync, Stripped, Initial, Redacted, RedactedSync, } impl fmt::Display for EventKindVariation { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { EventKindVariation::None => write!(f, ""), EventKindVariation::Sync => write!(f, "Sync"), EventKindVariation::Original => write!(f, "Original"), EventKindVariation::OriginalSync => write!(f, "OriginalSync"), EventKindVariation::Stripped => write!(f, "Stripped"), EventKindVariation::Initial => write!(f, "Initial"), EventKindVariation::Redacted => write!(f, "Redacted"), EventKindVariation::RedactedSync => write!(f, "RedactedSync"), } } } impl EventKindVariation { pub fn is_redacted(self) -> bool { matches!(self, Self::Redacted | Self::RedactedSync) } pub fn is_sync(self) -> bool { matches!(self, Self::OriginalSync | Self::RedactedSync) } pub fn to_redacted(self) -> Self { match self { EventKindVariation::Original => EventKindVariation::Redacted, EventKindVariation::OriginalSync => EventKindVariation::RedactedSync, _ => panic!("No redacted form of {self:?}"), } } pub fn to_full(self) -> Self { match self { EventKindVariation::OriginalSync => EventKindVariation::Original, EventKindVariation::RedactedSync => EventKindVariation::Redacted, _ => panic!("No original (unredacted) form of {self:?}"), } } } // If the variants of this enum change `to_event_path` needs to be updated as well. #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum EventKind { GlobalAccountData, RoomAccountData, Ephemeral, MessageLike, State, ToDevice, RoomRedaction, Presence, HierarchySpaceChild, Decrypted, } impl fmt::Display for EventKind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { EventKind::GlobalAccountData => write!(f, "GlobalAccountDataEvent"), EventKind::RoomAccountData => write!(f, "RoomAccountDataEvent"), EventKind::Ephemeral => write!(f, "EphemeralRoomEvent"), EventKind::MessageLike => write!(f, "MessageLikeEvent"), EventKind::State => write!(f, "StateEvent"), EventKind::ToDevice => write!(f, "ToDeviceEvent"), EventKind::RoomRedaction => write!(f, "RoomRedactionEvent"), EventKind::Presence => write!(f, "PresenceEvent"), EventKind::HierarchySpaceChild => write!(f, "HierarchySpaceChildEvent"), EventKind::Decrypted => unreachable!(), } } } impl IdentFragment for EventKind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(self, f) } } impl IdentFragment for EventKindVariation { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(self, f) } } impl EventKind { pub fn is_account_data(self) -> bool { matches!(self, Self::GlobalAccountData | Self::RoomAccountData) } pub fn is_timeline(self) -> bool { matches!(self, Self::MessageLike | Self::RoomRedaction | Self::State) } pub fn to_event_ident(self, var: EventKindVariation) -> syn::Result { use EventKindVariation as V; match (self, var) { (_, V::None) | (Self::Ephemeral | Self::MessageLike | Self::State, V::Sync) | ( Self::MessageLike | Self::RoomRedaction | Self::State, V::Original | V::OriginalSync | V::Redacted | V::RedactedSync, ) | (Self::State, V::Stripped | V::Initial) => Ok(format_ident!("{var}{self}")), _ => Err(syn::Error::new( Span::call_site(), format!( "({:?}, {:?}) is not a valid event kind / variation combination", self, var ), )), } } pub fn to_event_enum_ident(self, var: EventKindVariation) -> syn::Result { Ok(format_ident!("Any{}", self.to_event_ident(var)?)) } pub fn to_event_type_enum(self) -> Ident { format_ident!("{}Type", self) } /// `Any[kind]EventContent` pub fn to_content_enum(self) -> Ident { format_ident!("Any{}Content", self) } } impl Parse for EventKind { fn parse(input: ParseStream<'_>) -> syn::Result { let ident: Ident = input.parse()?; Ok(match ident.to_string().as_str() { "GlobalAccountData" => EventKind::GlobalAccountData, "RoomAccountData" => EventKind::RoomAccountData, "EphemeralRoom" => EventKind::Ephemeral, "MessageLike" => EventKind::MessageLike, "State" => EventKind::State, "ToDevice" => EventKind::ToDevice, id => { return Err(syn::Error::new_spanned( ident, format!( "valid event kinds are GlobalAccountData, RoomAccountData, EphemeralRoom, \ MessageLike, State, ToDevice found `{}`", id ), )); } }) } } // This function is only used in the `Event` derive macro expansion code. /// Validates the given `ident` is a valid event struct name and returns a tuple of enums /// representing the name. pub fn to_kind_variation(ident: &Ident) -> Option<(EventKind, EventKindVariation)> { let ident_str = ident.to_string(); match ident_str.as_str() { "GlobalAccountDataEvent" => Some((EventKind::GlobalAccountData, EventKindVariation::None)), "RoomAccountDataEvent" => Some((EventKind::RoomAccountData, EventKindVariation::None)), "EphemeralRoomEvent" => Some((EventKind::Ephemeral, EventKindVariation::None)), "SyncEphemeralRoomEvent" => Some((EventKind::Ephemeral, EventKindVariation::Sync)), "OriginalMessageLikeEvent" => Some((EventKind::MessageLike, EventKindVariation::Original)), "OriginalSyncMessageLikeEvent" => { Some((EventKind::MessageLike, EventKindVariation::OriginalSync)) } "RedactedMessageLikeEvent" => Some((EventKind::MessageLike, EventKindVariation::Redacted)), "RedactedSyncMessageLikeEvent" => { Some((EventKind::MessageLike, EventKindVariation::RedactedSync)) } "OriginalStateEvent" => Some((EventKind::State, EventKindVariation::Original)), "OriginalSyncStateEvent" => Some((EventKind::State, EventKindVariation::OriginalSync)), "StrippedStateEvent" => Some((EventKind::State, EventKindVariation::Stripped)), "InitialStateEvent" => Some((EventKind::State, EventKindVariation::Initial)), "RedactedStateEvent" => Some((EventKind::State, EventKindVariation::Redacted)), "RedactedSyncStateEvent" => Some((EventKind::State, EventKindVariation::RedactedSync)), "ToDeviceEvent" => Some((EventKind::ToDevice, EventKindVariation::None)), "PresenceEvent" => Some((EventKind::Presence, EventKindVariation::None)), "HierarchySpaceChildEvent" => { Some((EventKind::HierarchySpaceChild, EventKindVariation::Stripped)) } "OriginalRoomRedactionEvent" => Some((EventKind::RoomRedaction, EventKindVariation::None)), "OriginalSyncRoomRedactionEvent" => { Some((EventKind::RoomRedaction, EventKindVariation::OriginalSync)) } "RedactedRoomRedactionEvent" => { Some((EventKind::RoomRedaction, EventKindVariation::Redacted)) } "RedactedSyncRoomRedactionEvent" => { Some((EventKind::RoomRedaction, EventKindVariation::RedactedSync)) } "DecryptedOlmV1Event" | "DecryptedMegolmV1Event" => { Some((EventKind::Decrypted, EventKindVariation::None)) } _ => None, } } pub struct EventEnumEntry { pub attrs: Vec, pub aliases: Vec, pub ev_type: LitStr, pub ev_path: Path, } impl Parse for EventEnumEntry { fn parse(input: ParseStream<'_>) -> syn::Result { let (ruma_enum_attrs, attrs) = input .call(Attribute::parse_outer)? .into_iter() .partition::, _>(|attr| attr.path.is_ident("ruma_enum")); let ev_type: LitStr = input.parse()?; let _: Token![=>] = input.parse()?; let ev_path = input.call(Path::parse_mod_style)?; let has_suffix = ev_type.value().ends_with(".*"); let mut aliases = Vec::with_capacity(ruma_enum_attrs.len()); for attr_list in ruma_enum_attrs { for alias_attr in attr_list .parse_args_with(Punctuated::::parse_terminated)? { let alias = alias_attr.into_inner(); if alias.value().ends_with(".*") == has_suffix { aliases.push(alias); } else { return Err(syn::Error::new_spanned( &attr_list, "aliases should have the same `.*` suffix, or lack thereof, as the main event type", )); } } } Ok(Self { attrs, aliases, ev_type, ev_path }) } } /// The entire `event_enum!` macro structure directly as it appears in the source code. pub struct EventEnumDecl { /// Outer attributes on the field, such as a docstring. pub attrs: Vec, /// The event kind. pub kind: EventKind, /// An array of valid matrix event types. /// /// This will generate the variants of the event type "kind". There needs to be a corresponding /// variant in `ruma_common::events::EventType` for this event (converted to a valid Rust-style /// type name by stripping `m.`, replacing the remaining dots by underscores and then /// converting from snake_case to CamelCase). pub events: Vec, } /// The entire `event_enum!` macro structure directly as it appears in the source code. pub struct EventEnumInput { pub(crate) enums: Vec, } impl Parse for EventEnumInput { fn parse(input: ParseStream<'_>) -> parse::Result { let mut enums = vec![]; while !input.is_empty() { let attrs = input.call(Attribute::parse_outer)?; let _: Token![enum] = input.parse()?; let kind: EventKind = input.parse()?; let content; braced!(content in input); let events = content.parse_terminated::<_, Token![,]>(EventEnumEntry::parse)?; let events = events.into_iter().collect(); enums.push(EventEnumDecl { attrs, kind, events }); } Ok(EventEnumInput { enums }) } } pub struct EventEnumAliasAttr(LitStr); impl EventEnumAliasAttr { pub fn into_inner(self) -> LitStr { self.0 } } impl Parse for EventEnumAliasAttr { fn parse(input: ParseStream<'_>) -> syn::Result { let _: kw::alias = input.parse()?; let _: Token![=] = input.parse()?; let s: LitStr = input.parse()?; Ok(Self(s)) } } ruma-macros-0.10.5/src/events/event_type.rs000064400000000000000000000230531046102023000167630ustar 00000000000000use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::{parse_quote, Ident, LitStr}; use super::event_parse::{EventEnumEntry, EventEnumInput, EventKind}; pub fn expand_event_type_enum( input: EventEnumInput, ruma_common: TokenStream, ) -> syn::Result { let mut room: Vec<&Vec> = vec![]; let mut state: Vec<&Vec> = vec![]; let mut message: Vec<&Vec> = vec![]; let mut ephemeral: Vec<&Vec> = vec![]; let mut room_account: Vec<&Vec> = vec![]; let mut global_account: Vec<&Vec> = vec![]; let mut to_device: Vec<&Vec> = vec![]; for event in &input.enums { match event.kind { EventKind::GlobalAccountData => global_account.push(&event.events), EventKind::RoomAccountData => room_account.push(&event.events), EventKind::Ephemeral => ephemeral.push(&event.events), EventKind::MessageLike => { message.push(&event.events); room.push(&event.events); } EventKind::State => { state.push(&event.events); room.push(&event.events); } EventKind::ToDevice => to_device.push(&event.events), EventKind::RoomRedaction | EventKind::Presence | EventKind::Decrypted | EventKind::HierarchySpaceChild => {} } } let presence = vec![EventEnumEntry { attrs: vec![], aliases: vec![], ev_type: LitStr::new("m.presence", Span::call_site()), ev_path: parse_quote! { #ruma_common::events::presence }, }]; let mut all = input.enums.iter().map(|e| &e.events).collect::>(); all.push(&presence); let mut res = TokenStream::new(); res.extend( generate_enum("EventType", &all, &ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); res.extend( generate_enum("RoomEventType", &room, &ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); res.extend( generate_enum("StateEventType", &state, &ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); res.extend( generate_enum("MessageLikeEventType", &message, &ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); res.extend( generate_enum("EphemeralRoomEventType", &ephemeral, &ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); res.extend( generate_enum("RoomAccountDataEventType", &room_account, &ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); res.extend( generate_enum("GlobalAccountDataEventType", &global_account, &ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); res.extend( generate_enum("ToDeviceEventType", &to_device, &ruma_common) .unwrap_or_else(syn::Error::into_compile_error), ); Ok(res) } fn generate_enum( ident: &str, input: &[&Vec], ruma_common: &TokenStream, ) -> syn::Result { let serde = quote! { #ruma_common::exports::serde }; let enum_doc = format!("The type of `{}` this is.", ident.strip_suffix("Type").unwrap()); let deprecated_attr = (ident == "EventType").then(|| { quote! { #[deprecated = "use a fine-grained event type enum like RoomEventType instead"] } }); let ident = Ident::new(ident, Span::call_site()); let mut deduped: Vec<&EventEnumEntry> = vec![]; for item in input.iter().copied().flatten() { if let Some(idx) = deduped.iter().position(|e| e.ev_type == item.ev_type) { // If there is a variant without config attributes use that if deduped[idx].attrs != item.attrs && item.attrs.is_empty() { deduped[idx] = item; } } else { deduped.push(item); } } let event_types = deduped.iter().map(|e| &e.ev_type); let variants: Vec<_> = deduped .iter() .map(|e| { let start = e.to_variant()?.decl(); let data = e.has_type_fragment().then(|| quote! { (::std::string::String) }); Ok(quote! { #start #data }) }) .collect::>()?; let to_cow_str_match_arms: Vec<_> = deduped .iter() .map(|e| { let v = e.to_variant()?; let start = v.match_arm(quote! { Self }); let ev_type = &e.ev_type; Ok(if let Some(prefix) = ev_type.value().strip_suffix(".*") { let fstr = prefix.to_owned() + ".{}"; quote! { #start(_s) => ::std::borrow::Cow::Owned(::std::format!(#fstr, _s)) } } else { quote! { #start => ::std::borrow::Cow::Borrowed(#ev_type) } }) }) .collect::>()?; let mut from_str_match_arms = TokenStream::new(); for event in &deduped { let v = event.to_variant()?; let ctor = v.ctor(quote! { Self }); let ev_types = event.aliases.iter().chain([&event.ev_type]); let attrs = &event.attrs; if event.ev_type.value().ends_with(".*") { for ev_type in ev_types { let name = ev_type.value(); let prefix = name .strip_suffix('*') .expect("aliases have already been checked to have the same suffix"); from_str_match_arms.extend(quote! { #(#attrs)* // Use if-let guard once available _s if _s.starts_with(#prefix) => { #ctor(::std::convert::From::from(_s.strip_prefix(#prefix).unwrap())) } }); } } else { from_str_match_arms.extend(quote! { #(#attrs)* #(#ev_types)|* => #ctor, }); } } let from_ident_for_room = if ident == "StateEventType" || ident == "MessageLikeEventType" { let match_arms: Vec<_> = deduped .iter() .map(|e| { let v = e.to_variant()?; let ident_var = v.match_arm(quote! { #ident }); let room_var = v.ctor(quote! { Self }); Ok(if e.has_type_fragment() { quote! { #ident_var (_s) => #room_var (_s) } } else { quote! { #ident_var => #room_var } }) }) .collect::>()?; Some(quote! { #[allow(deprecated)] impl ::std::convert::From<#ident> for RoomEventType { fn from(s: #ident) -> Self { match s { #(#match_arms,)* #ident ::_Custom(_s) => Self::_Custom(_s), } } } }) } else { None }; Ok(quote! { #[doc = #enum_doc] /// /// This type can hold an arbitrary string. To build events with a custom type, convert it /// from a string with `::from() / .into()`. To check for events that are not available as a /// documented variant here, use its string representation, obtained through `.to_string()`. #deprecated_attr #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] pub enum #ident { #( #[doc = #event_types] #variants, )* #[doc(hidden)] _Custom(crate::PrivOwnedStr), } #[allow(deprecated)] impl #ident { fn to_cow_str(&self) -> ::std::borrow::Cow<'_, ::std::primitive::str> { match self { #(#to_cow_str_match_arms,)* Self::_Custom(crate::PrivOwnedStr(s)) => ::std::borrow::Cow::Borrowed(s), } } } #[allow(deprecated)] impl ::std::fmt::Display for #ident { fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { self.to_cow_str().fmt(f) } } #[allow(deprecated)] impl ::std::convert::From<&::std::primitive::str> for #ident { fn from(s: &::std::primitive::str) -> Self { match s { #from_str_match_arms _ => Self::_Custom(crate::PrivOwnedStr(::std::convert::From::from(s))), } } } #[allow(deprecated)] impl ::std::convert::From<::std::string::String> for #ident { fn from(s: ::std::string::String) -> Self { ::std::convert::From::from(s.as_str()) } } #[allow(deprecated)] impl<'de> #serde::Deserialize<'de> for #ident { fn deserialize(deserializer: D) -> ::std::result::Result where D: #serde::Deserializer<'de> { let s = #ruma_common::serde::deserialize_cow_str(deserializer)?; Ok(::std::convert::From::from(&s[..])) } } #[allow(deprecated)] impl #serde::Serialize for #ident { fn serialize(&self, serializer: S) -> ::std::result::Result where S: #serde::Serializer, { self.to_cow_str().serialize(serializer) } } #from_ident_for_room }) } ruma-macros-0.10.5/src/events/util.rs000064400000000000000000000012321046102023000155510ustar 00000000000000use super::event_parse::{EventKind, EventKindVariation}; pub(crate) fn is_non_stripped_room_event(kind: EventKind, var: EventKindVariation) -> bool { matches!(kind, EventKind::MessageLike | EventKind::State) && matches!( var, EventKindVariation::Original | EventKindVariation::OriginalSync | EventKindVariation::Redacted | EventKindVariation::RedactedSync ) } pub(crate) fn has_prev_content(kind: EventKind, var: EventKindVariation) -> bool { matches!(kind, EventKind::State) && matches!(var, EventKindVariation::Original | EventKindVariation::OriginalSync) } ruma-macros-0.10.5/src/events.rs000064400000000000000000000002331046102023000145740ustar 00000000000000//! Methods and types for generating events. pub mod event; pub mod event_content; pub mod event_enum; pub mod event_parse; pub mod event_type; mod util; ruma-macros-0.10.5/src/identifiers.rs000064400000000000000000000546271046102023000156150ustar 00000000000000//! Methods and types for generating identifiers. use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use syn::{ parse::{Parse, ParseStream}, punctuated::Punctuated, Fields, ImplGenerics, Index, ItemStruct, LitStr, Path, Token, }; pub struct IdentifierInput { pub dollar_crate: Path, pub id: LitStr, } impl Parse for IdentifierInput { fn parse(input: ParseStream<'_>) -> syn::Result { let dollar_crate = input.parse()?; let _: Token![,] = input.parse()?; let id = input.parse()?; Ok(Self { dollar_crate, id }) } } pub fn expand_id_zst(input: ItemStruct) -> syn::Result { let id = &input.ident; let owned = format_ident!("Owned{id}"); let owned_decl = expand_owned_id(&input); let meta = input.attrs.iter().filter(|attr| attr.path.is_ident("ruma_id")).try_fold( IdZstMeta::default(), |meta, attr| { let list: Punctuated = attr.parse_args_with(Punctuated::parse_terminated)?; list.into_iter().try_fold(meta, IdZstMeta::merge) }, )?; let extra_impls = if let Some(validate) = meta.validate { expand_checked_impls(&input, validate) } else { assert!( input.generics.params.is_empty(), "generic unchecked IDs are not currently supported" ); expand_unchecked_impls(&input) }; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); // So we don't have to insert #where_clause everywhere when it is always None in practice assert_eq!(where_clause, None, "where clauses on identifier types are not currently supported"); let as_str_docs = format!("Creates a string slice from this `{}`.", id); let as_bytes_docs = format!("Creates a byte slice from this `{}`.", id); let as_str_impl = match &input.fields { Fields::Named(_) | Fields::Unit => { syn::Error::new(Span::call_site(), "Only tuple structs are supported currently.") .into_compile_error() } Fields::Unnamed(u) => { let last_idx = Index::from(u.unnamed.len() - 1); quote! { &self.#last_idx } } }; let id_ty = quote! { #id #ty_generics }; let owned_ty = quote! { #owned #ty_generics }; let partial_eq_string = expand_partial_eq_string(id_ty.clone(), &impl_generics); // FIXME: Remove? let box_partial_eq_string = expand_partial_eq_string(quote! { Box<#id_ty> }, &impl_generics); Ok(quote! { #owned_decl impl #impl_generics #id_ty { pub(super) const fn from_borrowed(s: &str) -> &Self { unsafe { std::mem::transmute(s) } } pub(super) fn from_box(s: Box) -> Box { unsafe { Box::from_raw(Box::into_raw(s) as _) } } pub(super) fn from_rc(s: std::rc::Rc) -> std::rc::Rc { unsafe { std::rc::Rc::from_raw(std::rc::Rc::into_raw(s) as _) } } pub(super) fn from_arc(s: std::sync::Arc) -> std::sync::Arc { unsafe { std::sync::Arc::from_raw(std::sync::Arc::into_raw(s) as _) } } pub(super) fn into_owned(self: Box) -> Box { unsafe { Box::from_raw(Box::into_raw(self) as _) } } #[doc = #as_str_docs] #[inline] pub fn as_str(&self) -> &str { #as_str_impl } #[doc = #as_bytes_docs] #[inline] pub fn as_bytes(&self) -> &[u8] { self.as_str().as_bytes() } } impl #impl_generics Clone for Box<#id_ty> { fn clone(&self) -> Self { (**self).into() } } impl #impl_generics ToOwned for #id_ty { type Owned = #owned_ty; fn to_owned(&self) -> Self::Owned { #owned::from_ref(self) } } impl #impl_generics AsRef for #id_ty { fn as_ref(&self) -> &str { self.as_str() } } impl #impl_generics AsRef for Box<#id_ty> { fn as_ref(&self) -> &str { self.as_str() } } impl #impl_generics From<&#id_ty> for String { fn from(id: &#id_ty) -> Self { id.as_str().to_owned() } } impl #impl_generics From> for String { fn from(id: Box<#id_ty>) -> Self { id.into_owned().into() } } impl #impl_generics From<&#id_ty> for Box<#id_ty> { fn from(id: &#id_ty) -> Self { <#id_ty>::from_box(id.as_str().into()) } } impl #impl_generics From<&#id_ty> for std::rc::Rc<#id_ty> { fn from(s: &#id_ty) -> std::rc::Rc<#id_ty> { let rc = std::rc::Rc::::from(s.as_str()); <#id_ty>::from_rc(rc) } } impl #impl_generics From<&#id_ty> for std::sync::Arc<#id_ty> { fn from(s: &#id_ty) -> std::sync::Arc<#id_ty> { let arc = std::sync::Arc::::from(s.as_str()); <#id_ty>::from_arc(arc) } } impl #impl_generics PartialEq<#id_ty> for Box<#id_ty> { fn eq(&self, other: &#id_ty) -> bool { self.as_str() == other.as_str() } } impl #impl_generics PartialEq<&'_ #id_ty> for Box<#id_ty> { fn eq(&self, other: &&#id_ty) -> bool { self.as_str() == other.as_str() } } impl #impl_generics PartialEq> for #id_ty { fn eq(&self, other: &Box<#id_ty>) -> bool { self.as_str() == other.as_str() } } impl #impl_generics PartialEq> for &'_ #id_ty { fn eq(&self, other: &Box<#id_ty>) -> bool { self.as_str() == other.as_str() } } impl #impl_generics std::fmt::Debug for #id_ty { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { ::fmt(self.as_str(), f) } } impl #impl_generics std::fmt::Display for #id_ty { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.as_str()) } } impl #impl_generics serde::Serialize for #id_ty { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { serializer.serialize_str(self.as_str()) } } #partial_eq_string #box_partial_eq_string #extra_impls }) } fn expand_owned_id(input: &ItemStruct) -> TokenStream { let id = &input.ident; let owned = format_ident!("Owned{id}"); let doc_header = format!("Owned variant of {id}"); let (impl_generics, ty_generics, _where_clause) = input.generics.split_for_impl(); let id_ty = quote! { #id #ty_generics }; let owned_ty = quote! { #owned #ty_generics }; let partial_eq_string = expand_partial_eq_string(owned_ty.clone(), &impl_generics); quote! { #[doc = #doc_header] /// /// The wrapper type for this type is variable, by default it'll use [`Box`], /// but you can change that by setting "`--cfg=ruma_identifiers_storage=...`" using /// `RUSTFLAGS` or `.cargo/config.toml` (under `[build]` -> `rustflags = ["..."]`) /// to the following; /// - `ruma_identifiers_storage="Arc"` to use [`Arc`](std::sync::Arc) as a wrapper type. pub struct #owned #impl_generics { #[cfg(not(any(ruma_identifiers_storage = "Arc")))] inner: Box<#id_ty>, #[cfg(ruma_identifiers_storage = "Arc")] inner: std::sync::Arc<#id_ty>, } impl #impl_generics #owned_ty { fn from_ref(v: &#id_ty) -> Self { Self { #[cfg(not(any(ruma_identifiers_storage = "Arc")))] inner: #id::from_box(v.as_str().into()), #[cfg(ruma_identifiers_storage = "Arc")] inner: #id::from_arc(v.as_str().into()), } } } impl #impl_generics AsRef<#id_ty> for #owned_ty { fn as_ref(&self) -> &#id_ty { &*self.inner } } impl #impl_generics AsRef for #owned_ty { fn as_ref(&self) -> &str { (*self.inner).as_ref() } } impl #impl_generics std::clone::Clone for #owned_ty { fn clone(&self) -> Self { (&*self.inner).into() } } impl #impl_generics std::ops::Deref for #owned_ty { type Target = #id_ty; fn deref(&self) -> &Self::Target { &self.inner } } impl #impl_generics std::borrow::Borrow<#id_ty> for #owned_ty { fn borrow(&self) -> &#id_ty { self.as_ref() } } impl #impl_generics From<&'_ #id_ty> for #owned_ty { fn from(id: &#id_ty) -> #owned_ty { #owned { inner: id.into() } } } impl #impl_generics From> for #owned_ty { fn from(b: Box<#id_ty>) -> #owned_ty { Self { inner: b.into() } } } impl #impl_generics From> for #owned_ty { fn from(a: std::sync::Arc<#id_ty>) -> #owned_ty { Self { #[cfg(not(any(ruma_identifiers_storage = "Arc")))] inner: a.as_ref().into(), #[cfg(ruma_identifiers_storage = "Arc")] inner: a, } } } impl #impl_generics From<#owned_ty> for Box<#id_ty> { fn from(a: #owned_ty) -> Box<#id_ty> { #[cfg(not(any(ruma_identifiers_storage = "Arc")))] { a.inner } #[cfg(ruma_identifiers_storage = "Arc")] { a.inner.as_ref().into() } } } impl #impl_generics From<#owned_ty> for std::sync::Arc<#id_ty> { fn from(a: #owned_ty) -> std::sync::Arc<#id_ty> { #[cfg(not(any(ruma_identifiers_storage = "Arc")))] { a.inner.into() } #[cfg(ruma_identifiers_storage = "Arc")] { a.inner } } } impl #impl_generics std::fmt::Display for #owned_ty { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.as_str()) } } impl #impl_generics std::fmt::Debug for #owned_ty { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { ::fmt(self.as_str(), f) } } impl #impl_generics std::cmp::PartialEq for #owned_ty { fn eq(&self, other: &Self) -> bool { self.as_str() == other.as_str() } } impl #impl_generics std::cmp::Eq for #owned_ty {} impl #impl_generics std::cmp::PartialOrd for #owned_ty { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl #impl_generics std::cmp::Ord for #owned_ty { fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.as_str().cmp(other.as_str()) } } impl #impl_generics std::hash::Hash for #owned_ty { fn hash(&self, state: &mut H) where H: std::hash::Hasher, { self.as_str().hash(state) } } impl #impl_generics serde::Serialize for #owned_ty { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { serializer.serialize_str(self.as_str()) } } #partial_eq_string impl #impl_generics PartialEq<#id_ty> for #owned_ty { fn eq(&self, other: &#id_ty) -> bool { AsRef::<#id_ty>::as_ref(self) == other } } impl #impl_generics PartialEq<#owned_ty> for #id_ty { fn eq(&self, other: &#owned_ty) -> bool { self == AsRef::<#id_ty>::as_ref(other) } } impl #impl_generics PartialEq<&#id_ty> for #owned_ty { fn eq(&self, other: &&#id_ty) -> bool { AsRef::<#id_ty>::as_ref(self) == *other } } impl #impl_generics PartialEq<#owned_ty> for &#id_ty { fn eq(&self, other: &#owned_ty) -> bool { *self == AsRef::<#id_ty>::as_ref(other) } } impl #impl_generics PartialEq> for #owned_ty { fn eq(&self, other: &Box<#id_ty>) -> bool { AsRef::<#id_ty>::as_ref(self) == AsRef::<#id_ty>::as_ref(other) } } impl #impl_generics PartialEq<#owned_ty> for Box<#id_ty> { fn eq(&self, other: &#owned_ty) -> bool { AsRef::<#id_ty>::as_ref(self) == AsRef::<#id_ty>::as_ref(other) } } impl #impl_generics PartialEq> for #owned_ty { fn eq(&self, other: &std::sync::Arc<#id_ty>) -> bool { AsRef::<#id_ty>::as_ref(self) == AsRef::<#id_ty>::as_ref(other) } } impl #impl_generics PartialEq<#owned_ty> for std::sync::Arc<#id_ty> { fn eq(&self, other: &#owned_ty) -> bool { AsRef::<#id_ty>::as_ref(self) == AsRef::<#id_ty>::as_ref(other) } } } } fn expand_checked_impls(input: &ItemStruct, validate: Path) -> TokenStream { let id = &input.ident; let owned = format_ident!("Owned{id}"); let (impl_generics, ty_generics, _where_clause) = input.generics.split_for_impl(); let generic_params = &input.generics.params; let parse_doc_header = format!("Try parsing a `&str` into an `Owned{}`.", id); let parse_box_doc_header = format!("Try parsing a `&str` into a `Box<{}>`.", id); let parse_rc_docs = format!("Try parsing a `&str` into an `Rc<{}>`.", id); let parse_arc_docs = format!("Try parsing a `&str` into an `Arc<{}>`.", id); let id_ty = quote! { #id #ty_generics }; let owned_ty = quote! { #owned #ty_generics }; quote! { impl #impl_generics #id_ty { #[doc = #parse_doc_header] /// /// The same can also be done using `FromStr`, `TryFrom` or `TryInto`. /// This function is simply more constrained and thus useful in generic contexts. pub fn parse( s: impl AsRef, ) -> Result<#owned_ty, crate::IdParseError> { let s = s.as_ref(); #validate(s)?; Ok(#id::from_borrowed(s).to_owned()) } #[doc = #parse_box_doc_header] /// /// The same can also be done using `FromStr`, `TryFrom` or `TryInto`. /// This function is simply more constrained and thus useful in generic contexts. pub fn parse_box( s: impl AsRef + Into>, ) -> Result, crate::IdParseError> { #validate(s.as_ref())?; Ok(#id::from_box(s.into())) } #[doc = #parse_rc_docs] pub fn parse_rc( s: impl AsRef + Into>, ) -> Result, crate::IdParseError> { #validate(s.as_ref())?; Ok(#id::from_rc(s.into())) } #[doc = #parse_arc_docs] pub fn parse_arc( s: impl AsRef + Into>, ) -> Result, crate::IdParseError> { #validate(s.as_ref())?; Ok(#id::from_arc(s.into())) } } impl<'de, #generic_params> serde::Deserialize<'de> for Box<#id_ty> { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { use serde::de::Error; let s = String::deserialize(deserializer)?; match #id::parse_box(s) { Ok(o) => Ok(o), Err(e) => Err(D::Error::custom(e)), } } } impl<'de, #generic_params> serde::Deserialize<'de> for #owned_ty { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { use serde::de::Error; let s = String::deserialize(deserializer)?; match #id::parse(s) { Ok(o) => Ok(o), Err(e) => Err(D::Error::custom(e)), } } } impl<'a, #generic_params> std::convert::TryFrom<&'a str> for &'a #id_ty { type Error = crate::IdParseError; fn try_from(s: &'a str) -> Result { #validate(s)?; Ok(<#id_ty>::from_borrowed(s)) } } impl #impl_generics std::str::FromStr for Box<#id_ty> { type Err = crate::IdParseError; fn from_str(s: &str) -> Result { <#id_ty>::parse_box(s) } } impl #impl_generics std::convert::TryFrom<&str> for Box<#id_ty> { type Error = crate::IdParseError; fn try_from(s: &str) -> Result { <#id_ty>::parse_box(s) } } impl #impl_generics std::convert::TryFrom for Box<#id_ty> { type Error = crate::IdParseError; fn try_from(s: String) -> Result { <#id_ty>::parse_box(s) } } impl #impl_generics std::str::FromStr for #owned_ty { type Err = crate::IdParseError; fn from_str(s: &str) -> Result { <#id_ty>::parse(s) } } impl #impl_generics std::convert::TryFrom<&str> for #owned_ty { type Error = crate::IdParseError; fn try_from(s: &str) -> Result { <#id_ty>::parse(s) } } impl #impl_generics std::convert::TryFrom for #owned_ty { type Error = crate::IdParseError; fn try_from(s: String) -> Result { <#id_ty>::parse(s) } } } } fn expand_unchecked_impls(input: &ItemStruct) -> TokenStream { let id = &input.ident; let owned = format_ident!("Owned{id}"); quote! { impl<'a> From<&'a str> for &'a #id { fn from(s: &'a str) -> Self { #id::from_borrowed(s) } } impl From<&str> for #owned { fn from(s: &str) -> Self { <&#id>::from(s).into() } } impl From> for #owned { fn from(s: Box) -> Self { <&#id>::from(&*s).into() } } impl From for #owned { fn from(s: String) -> Self { <&#id>::from(s.as_str()).into() } } impl From<&str> for Box<#id> { fn from(s: &str) -> Self { #id::from_box(s.into()) } } impl From> for Box<#id> { fn from(s: Box) -> Self { #id::from_box(s) } } impl From for Box<#id> { fn from(s: String) -> Self { #id::from_box(s.into()) } } impl From> for Box { fn from(id: Box<#id>) -> Self { id.into_owned() } } impl<'de> serde::Deserialize<'de> for Box<#id> { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { Box::::deserialize(deserializer).map(#id::from_box) } } impl<'de> serde::Deserialize<'de> for #owned { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { // FIXME: Deserialize inner, convert that Box::::deserialize(deserializer).map(#id::from_box).map(Into::into) } } } } fn expand_partial_eq_string(ty: TokenStream, impl_generics: &ImplGenerics<'_>) -> TokenStream { IntoIterator::into_iter([ (ty.clone(), quote! { str }), (ty.clone(), quote! { &str }), (ty.clone(), quote! { String }), (quote! { str }, ty.clone()), (quote! { &str }, ty.clone()), (quote! { String }, ty), ]) .map(|(lhs, rhs)| { quote! { impl #impl_generics PartialEq<#rhs> for #lhs { fn eq(&self, other: &#rhs) -> bool { AsRef::::as_ref(self) == AsRef::::as_ref(other) } } } }) .collect() } mod kw { syn::custom_keyword!(validate); } #[derive(Default)] struct IdZstMeta { validate: Option, } impl IdZstMeta { fn merge(self, other: IdZstMeta) -> syn::Result { let validate = match (self.validate, other.validate) { (None, None) => None, (Some(val), None) | (None, Some(val)) => Some(val), (Some(a), Some(b)) => { let mut error = syn::Error::new_spanned(b, "duplicate attribute argument"); error.combine(syn::Error::new_spanned(a, "note: first one here")); return Err(error); } }; Ok(Self { validate }) } } impl Parse for IdZstMeta { fn parse(input: ParseStream<'_>) -> syn::Result { let _: kw::validate = input.parse()?; let _: Token![=] = input.parse()?; let validate = Some(input.parse()?); Ok(Self { validate }) } } ruma-macros-0.10.5/src/lib.rs000064400000000000000000000374751046102023000140600ustar 00000000000000#![doc(html_favicon_url = "https://www.ruma.io/favicon.ico")] #![doc(html_logo_url = "https://www.ruma.io/images/logo.png")] //! Procedural macros used by ruma crates. //! //! See the documentation for the individual macros for usage details. #![warn(missing_docs)] // https://github.com/rust-lang/rust-clippy/issues/9029 #![allow(clippy::derive_partial_eq_without_eq)] use identifiers::expand_id_zst; use proc_macro::TokenStream; use proc_macro2 as pm2; use quote::quote; use ruma_identifiers_validation::{ device_key_id, event_id, key_id, mxc_uri, room_alias_id, room_id, room_version_id, server_name, user_id, }; use syn::{parse_macro_input, DeriveInput, ItemEnum, ItemStruct}; mod api; mod events; mod identifiers; mod serde; mod util; use self::{ api::{request::expand_derive_request, response::expand_derive_response, Api}, events::{ event::expand_event, event_content::expand_event_content, event_enum::{expand_event_enums, expand_from_impls_derived}, event_parse::EventEnumInput, event_type::expand_event_type_enum, }, identifiers::IdentifierInput, serde::{ as_str_as_ref_str::expand_as_str_as_ref_str, deserialize_from_cow_str::expand_deserialize_from_cow_str, display_as_ref_str::expand_display_as_ref_str, enum_as_ref_str::expand_enum_as_ref_str, enum_from_string::expand_enum_from_string, eq_as_ref_str::expand_partial_eq_as_ref_str, incoming::expand_derive_incoming, ord_as_ref_str::{expand_ord_as_ref_str, expand_partial_ord_as_ref_str}, serialize_as_ref_str::expand_serialize_as_ref_str, }, util::import_ruma_common, }; /// Generates an enum to represent the various Matrix event types. /// /// This macro also implements the necessary traits for the type to serialize and deserialize /// itself. /// /// # Examples /// /// ```ignore /// # // HACK: This is "ignore" because of cyclical dependency drama. /// use ruma_macros::event_enum; /// /// event_enum! { /// enum ToDevice { /// "m.any.event", /// "m.other.event", /// } /// /// enum State { /// "m.more.events", /// "m.different.event", /// } /// } /// ``` /// (The enum name has to be a valid identifier for `::parse`) //// TODO: Change above (`::parse`) to [] after fully qualified syntax is //// supported: https://github.com/rust-lang/rust/issues/74563 #[proc_macro] pub fn event_enum(input: TokenStream) -> TokenStream { let event_enum_input = syn::parse_macro_input!(input as EventEnumInput); let ruma_common = import_ruma_common(); let enums = event_enum_input .enums .iter() .map(|e| expand_event_enums(e).unwrap_or_else(syn::Error::into_compile_error)) .collect::(); let event_types = expand_event_type_enum(event_enum_input, ruma_common) .unwrap_or_else(syn::Error::into_compile_error); let tokens = quote! { #enums #event_types }; tokens.into() } /// Generates an implementation of `ruma_common::events::EventContent`. /// /// Also generates type aliases depending on the kind of event, with the final `Content` of the type /// name removed and prefixed added. For instance, a message-like event content type /// `FooEventContent` will have the following aliases generated: /// /// * `type FooEvent = MessageLikeEvent` /// * `type SyncFooEvent = SyncMessageLikeEvent` /// * `type OriginalFooEvent = OriginalMessageLikeEvent` /// * `type OriginalSyncFooEvent = OriginalSyncMessageLikeEvent` /// * `type RedactedFooEvent = RedactedMessageLikeEvent` /// * `type RedactedSyncFooEvent = RedactedSyncMessageLikeEvent` /// /// You can use `cargo doc` to find out more details, its `--document-private-items` flag also lets /// you generate documentation for binaries or private parts of a library. #[proc_macro_derive(EventContent, attributes(ruma_event))] pub fn derive_event_content(input: TokenStream) -> TokenStream { let ruma_common = import_ruma_common(); let input = parse_macro_input!(input as DeriveInput); expand_event_content(&input, &ruma_common).unwrap_or_else(syn::Error::into_compile_error).into() } /// Generates implementations needed to serialize and deserialize Matrix events. #[proc_macro_derive(Event, attributes(ruma_event))] pub fn derive_event(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand_event(input).unwrap_or_else(syn::Error::into_compile_error).into() } /// Generates `From` implementations for event enums. #[proc_macro_derive(EventEnumFromEvent)] pub fn derive_from_event_to_enum(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand_from_impls_derived(input).into() } /// Generate methods and trait impl's for ZST identifier type. #[proc_macro_derive(IdZst, attributes(ruma_id))] pub fn derive_id_zst(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as ItemStruct); expand_id_zst(input).unwrap_or_else(syn::Error::into_compile_error).into() } /// Compile-time checked `DeviceKeyId` construction. #[proc_macro] pub fn device_key_id(input: TokenStream) -> TokenStream { let IdentifierInput { dollar_crate, id } = parse_macro_input!(input as IdentifierInput); assert!(device_key_id::validate(&id.value()).is_ok(), "Invalid device key id"); let output = quote! { <&#dollar_crate::DeviceKeyId as ::std::convert::TryFrom<&str>>::try_from(#id).unwrap() }; output.into() } /// Compile-time checked `EventId` construction. #[proc_macro] pub fn event_id(input: TokenStream) -> TokenStream { let IdentifierInput { dollar_crate, id } = parse_macro_input!(input as IdentifierInput); assert!(event_id::validate(&id.value()).is_ok(), "Invalid event id"); let output = quote! { <&#dollar_crate::EventId as ::std::convert::TryFrom<&str>>::try_from(#id).unwrap() }; output.into() } /// Compile-time checked `RoomAliasId` construction. #[proc_macro] pub fn room_alias_id(input: TokenStream) -> TokenStream { let IdentifierInput { dollar_crate, id } = parse_macro_input!(input as IdentifierInput); assert!(room_alias_id::validate(&id.value()).is_ok(), "Invalid room_alias_id"); let output = quote! { <&#dollar_crate::RoomAliasId as ::std::convert::TryFrom<&str>>::try_from(#id).unwrap() }; output.into() } /// Compile-time checked `RoomId` construction. #[proc_macro] pub fn room_id(input: TokenStream) -> TokenStream { let IdentifierInput { dollar_crate, id } = parse_macro_input!(input as IdentifierInput); assert!(room_id::validate(&id.value()).is_ok(), "Invalid room_id"); let output = quote! { <&#dollar_crate::RoomId as ::std::convert::TryFrom<&str>>::try_from(#id).unwrap() }; output.into() } /// Compile-time checked `RoomVersionId` construction. #[proc_macro] pub fn room_version_id(input: TokenStream) -> TokenStream { let IdentifierInput { dollar_crate, id } = parse_macro_input!(input as IdentifierInput); assert!(room_version_id::validate(&id.value()).is_ok(), "Invalid room_version_id"); let output = quote! { <#dollar_crate::RoomVersionId as ::std::convert::TryFrom<&str>>::try_from(#id).unwrap() }; output.into() } /// Compile-time checked `ServerSigningKeyId` construction. #[proc_macro] pub fn server_signing_key_id(input: TokenStream) -> TokenStream { let IdentifierInput { dollar_crate, id } = parse_macro_input!(input as IdentifierInput); assert!(key_id::validate(&id.value()).is_ok(), "Invalid server_signing_key_id"); let output = quote! { <&#dollar_crate::ServerSigningKeyId as ::std::convert::TryFrom<&str>>::try_from(#id).unwrap() }; output.into() } /// Compile-time checked `ServerName` construction. #[proc_macro] pub fn server_name(input: TokenStream) -> TokenStream { let IdentifierInput { dollar_crate, id } = parse_macro_input!(input as IdentifierInput); assert!(server_name::validate(&id.value()).is_ok(), "Invalid server_name"); let output = quote! { <&#dollar_crate::ServerName as ::std::convert::TryFrom<&str>>::try_from(#id).unwrap() }; output.into() } /// Compile-time checked `MxcUri` construction. #[proc_macro] pub fn mxc_uri(input: TokenStream) -> TokenStream { let IdentifierInput { dollar_crate, id } = parse_macro_input!(input as IdentifierInput); assert!(mxc_uri::validate(&id.value()).is_ok(), "Invalid mxc://"); let output = quote! { <&#dollar_crate::MxcUri as ::std::convert::From<&str>>::from(#id) }; output.into() } /// Compile-time checked `UserId` construction. #[proc_macro] pub fn user_id(input: TokenStream) -> TokenStream { let IdentifierInput { dollar_crate, id } = parse_macro_input!(input as IdentifierInput); assert!(user_id::validate(&id.value()).is_ok(), "Invalid user_id"); let output = quote! { <&#dollar_crate::UserId as ::std::convert::TryFrom<&str>>::try_from(#id).unwrap() }; output.into() } /// Generating an 'Incoming' version of the type this derive macro is used on. /// /// This type will be a fully-owned version of the input type, using no lifetime generics. /// /// By default, the generated type will derive `Debug` and `serde::Deserialize`. To derive /// additional traits, use `#[incoming_derive(ExtraDeriveMacro)]`. To disable the default derives, /// use `#[incoming_derive(!Debug, !Deserialize)]`. #[proc_macro_derive(Incoming, attributes(incoming_derive))] pub fn derive_incoming(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand_derive_incoming(input).unwrap_or_else(syn::Error::into_compile_error).into() } /// Derive the `AsRef` trait for an enum. #[proc_macro_derive(AsRefStr, attributes(ruma_enum))] pub fn derive_enum_as_ref_str(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as ItemEnum); expand_enum_as_ref_str(&input).unwrap_or_else(syn::Error::into_compile_error).into() } /// Derive the `From + Into>>` trait for an enum. #[proc_macro_derive(FromString, attributes(ruma_enum))] pub fn derive_enum_from_string(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as ItemEnum); expand_enum_from_string(&input).unwrap_or_else(syn::Error::into_compile_error).into() } // FIXME: The following macros aren't actually interested in type details beyond name (and possibly // generics in the future). They probably shouldn't use `DeriveInput`. /// Derive the `as_str()` method using the `AsRef` implementation of the type. #[proc_macro_derive(AsStrAsRefStr, attributes(ruma_enum))] pub fn derive_as_str_as_ref_str(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand_as_str_as_ref_str(&input.ident).unwrap_or_else(syn::Error::into_compile_error).into() } /// Derive the `fmt::Display` trait using the `AsRef` implementation of the type. #[proc_macro_derive(DisplayAsRefStr)] pub fn derive_display_as_ref_str(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand_display_as_ref_str(&input.ident).unwrap_or_else(syn::Error::into_compile_error).into() } /// Derive the `Serialize` trait using the `AsRef` implementation of the type. #[proc_macro_derive(SerializeAsRefStr)] pub fn derive_serialize_as_ref_str(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand_serialize_as_ref_str(&input.ident).unwrap_or_else(syn::Error::into_compile_error).into() } /// Derive the `Deserialize` trait using the `From>` implementation of the type. #[proc_macro_derive(DeserializeFromCowStr)] pub fn derive_deserialize_from_cow_str(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand_deserialize_from_cow_str(&input.ident) .unwrap_or_else(syn::Error::into_compile_error) .into() } /// Derive the `PartialOrd` trait using the `AsRef` implementation of the type. #[proc_macro_derive(PartialOrdAsRefStr)] pub fn derive_partial_ord_as_ref_str(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand_partial_ord_as_ref_str(&input.ident) .unwrap_or_else(syn::Error::into_compile_error) .into() } /// Derive the `Ord` trait using the `AsRef` implementation of the type. #[proc_macro_derive(OrdAsRefStr)] pub fn derive_ord_as_ref_str(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand_ord_as_ref_str(&input.ident).unwrap_or_else(syn::Error::into_compile_error).into() } /// Derive the `PartialEq` trait using the `AsRef` implementation of the type. #[proc_macro_derive(PartialEqAsRefStr)] pub fn derive_partial_eq_as_ref_str(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand_partial_eq_as_ref_str(&input.ident).unwrap_or_else(syn::Error::into_compile_error).into() } /// Shorthand for the derives `AsRefStr`, `FromString`, `DisplayAsRefStr`, `SerializeAsRefStr` and /// `DeserializeFromCowStr`. #[proc_macro_derive(StringEnum, attributes(ruma_enum))] pub fn derive_string_enum(input: TokenStream) -> TokenStream { fn expand_all(input: ItemEnum) -> syn::Result { let as_ref_str_impl = expand_enum_as_ref_str(&input)?; let from_string_impl = expand_enum_from_string(&input)?; let as_str_impl = expand_as_str_as_ref_str(&input.ident)?; let display_impl = expand_display_as_ref_str(&input.ident)?; let serialize_impl = expand_serialize_as_ref_str(&input.ident)?; let deserialize_impl = expand_deserialize_from_cow_str(&input.ident)?; Ok(quote! { #as_ref_str_impl #from_string_impl #as_str_impl #display_impl #serialize_impl #deserialize_impl }) } let input = parse_macro_input!(input as ItemEnum); expand_all(input).unwrap_or_else(syn::Error::into_compile_error).into() } /// A derive macro that generates no code, but registers the serde attribute so both `#[serde(...)]` /// and `#[cfg_attr(..., serde(...))]` are accepted on the type, its fields and (in case the input /// is an enum) variants fields. #[doc(hidden)] #[proc_macro_derive(_FakeDeriveSerde, attributes(serde))] pub fn fake_derive_serde(_input: TokenStream) -> TokenStream { TokenStream::new() } /// > ⚠ If this is the only documentation you see, please navigate to the docs for /// > `ruma_common::api::ruma_api`, where actual documentation can be found. #[proc_macro] pub fn ruma_api(input: TokenStream) -> TokenStream { let api = parse_macro_input!(input as Api); api.expand_all().into() } /// Internal helper taking care of the request-specific parts of `ruma_api!`. #[proc_macro_derive(Request, attributes(ruma_api))] pub fn derive_request(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand_derive_request(input).unwrap_or_else(syn::Error::into_compile_error).into() } /// Internal helper taking care of the response-specific parts of `ruma_api!`. #[proc_macro_derive(Response, attributes(ruma_api))] pub fn derive_response(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); expand_derive_response(input).unwrap_or_else(syn::Error::into_compile_error).into() } /// A derive macro that generates no code, but registers the ruma_api attribute so both /// `#[ruma_api(...)]` and `#[cfg_attr(..., ruma_api(...))]` are accepted on the type, its fields /// and (in case the input is an enum) variants fields. #[doc(hidden)] #[proc_macro_derive(_FakeDeriveRumaApi, attributes(ruma_api))] pub fn fake_derive_ruma_api(_input: TokenStream) -> TokenStream { TokenStream::new() } ruma-macros-0.10.5/src/serde/as_str_as_ref_str.rs000064400000000000000000000007051046102023000201000ustar 00000000000000use proc_macro2::{Ident, TokenStream}; use quote::quote; pub fn expand_as_str_as_ref_str(ident: &Ident) -> syn::Result { let as_str_doc = format!("Creates a string slice from this `{}`.", ident); Ok(quote! { #[automatically_derived] #[allow(deprecated)] impl #ident { #[doc = #as_str_doc] pub fn as_str(&self) -> &str { self.as_ref() } } }) } ruma-macros-0.10.5/src/serde/attr.rs000064400000000000000000000026701046102023000153530ustar 00000000000000use syn::{ parse::{Parse, ParseStream}, LitStr, Token, }; use super::case::RenameRule; mod kw { syn::custom_keyword!(alias); syn::custom_keyword!(rename); syn::custom_keyword!(rename_all); } #[derive(Default)] pub struct EnumAttrs { pub rename: Option, pub aliases: Vec, } pub enum Attr { Alias(LitStr), Rename(LitStr), } impl Parse for Attr { fn parse(input: ParseStream<'_>) -> syn::Result { let lookahead = input.lookahead1(); if lookahead.peek(kw::alias) { let _: kw::alias = input.parse()?; let _: Token![=] = input.parse()?; Ok(Self::Alias(input.parse()?)) } else if lookahead.peek(kw::rename) { let _: kw::rename = input.parse()?; let _: Token![=] = input.parse()?; Ok(Self::Rename(input.parse()?)) } else { Err(lookahead.error()) } } } pub struct RenameAllAttr(RenameRule); impl RenameAllAttr { pub fn into_inner(self) -> RenameRule { self.0 } } impl Parse for RenameAllAttr { fn parse(input: ParseStream<'_>) -> syn::Result { let _: kw::rename_all = input.parse()?; let _: Token![=] = input.parse()?; let s: LitStr = input.parse()?; Ok(Self( s.value() .parse() .map_err(|_| syn::Error::new_spanned(s, "invalid value for rename_all"))?, )) } } ruma-macros-0.10.5/src/serde/case.rs000064400000000000000000000211411046102023000153060ustar 00000000000000//! Code to convert the Rust-styled field/variant (e.g. `my_field`, `MyType`) to the //! case of the source (e.g. `my-field`, `MY_FIELD`). //! //! This is a minimally modified version of the same code [in serde]. //! //! [serde]: https://github.com/serde-rs/serde/blame/a9f8ea0a1e8ba1206f8c28d96b924606847b85a9/serde_derive/src/internals/case.rs use std::str::FromStr; use self::RenameRule::*; /// The different possible ways to change case of fields in a struct, or variants in an enum. #[derive(Copy, Clone, PartialEq)] pub enum RenameRule { /// Don't apply a default rename rule. None, /// Rename direct children to "lowercase" style. LowerCase, /// Rename direct children to "UPPERCASE" style. Uppercase, /// Rename direct children to "PascalCase" style, as typically used for /// enum variants. PascalCase, /// Rename direct children to "camelCase" style. CamelCase, /// Rename direct children to "snake_case" style, as commonly used for /// fields. SnakeCase, /// Rename direct children to "SCREAMING_SNAKE_CASE" style, as commonly /// used for constants. ScreamingSnakeCase, /// Rename direct children to "kebab-case" style. KebabCase, /// Rename direct children to "SCREAMING-KEBAB-CASE" style. ScreamingKebabCase, /// Rename direct children to "M_MATRIX_ERROR_CASE" style, as used for responses with error in /// Matrix spec. MatrixErrorCase, /// Rename the direct children to "m.snake_case" style. MatrixSnakeCase, /// Rename the direct children to "m.dotted.case" style. MatrixDottedCase, } impl RenameRule { /// Apply a renaming rule to an enum variant, returning the version expected in the source. pub fn apply_to_variant(&self, variant: &str) -> String { match *self { None | PascalCase => variant.to_owned(), LowerCase => variant.to_ascii_lowercase(), Uppercase => variant.to_ascii_uppercase(), CamelCase => variant[..1].to_ascii_lowercase() + &variant[1..], SnakeCase => { let mut snake = String::new(); for (i, ch) in variant.char_indices() { if i > 0 && ch.is_uppercase() { snake.push('_'); } snake.push(ch.to_ascii_lowercase()); } snake } ScreamingSnakeCase => SnakeCase.apply_to_variant(variant).to_ascii_uppercase(), KebabCase => SnakeCase.apply_to_variant(variant).replace('_', "-"), ScreamingKebabCase => ScreamingSnakeCase.apply_to_variant(variant).replace('_', "-"), MatrixErrorCase => String::from("M_") + &ScreamingSnakeCase.apply_to_variant(variant), MatrixSnakeCase => String::from("m.") + &SnakeCase.apply_to_variant(variant), MatrixDottedCase => { String::from("m.") + &SnakeCase.apply_to_variant(variant).replace('_', ".") } } } /// Apply a renaming rule to a struct field, returning the version expected in the source. #[allow(dead_code)] pub fn apply_to_field(&self, field: &str) -> String { match *self { None | LowerCase | SnakeCase => field.to_owned(), Uppercase => field.to_ascii_uppercase(), PascalCase => { let mut pascal = String::new(); let mut capitalize = true; for ch in field.chars() { if ch == '_' { capitalize = true; } else if capitalize { pascal.push(ch.to_ascii_uppercase()); capitalize = false; } else { pascal.push(ch); } } pascal } CamelCase => { let pascal = PascalCase.apply_to_field(field); pascal[..1].to_ascii_lowercase() + &pascal[1..] } ScreamingSnakeCase => field.to_ascii_uppercase(), KebabCase => field.replace('_', "-"), ScreamingKebabCase => ScreamingSnakeCase.apply_to_field(field).replace('_', "-"), MatrixErrorCase => String::from("M_") + &ScreamingSnakeCase.apply_to_field(field), MatrixSnakeCase => String::from("m.") + field, MatrixDottedCase => String::from("m.") + &field.replace('_', "."), } } } impl FromStr for RenameRule { type Err = (); fn from_str(rename_all_str: &str) -> Result { match rename_all_str { "lowercase" => Ok(LowerCase), "UPPERCASE" => Ok(Uppercase), "PascalCase" => Ok(PascalCase), "camelCase" => Ok(CamelCase), "snake_case" => Ok(SnakeCase), "SCREAMING_SNAKE_CASE" => Ok(ScreamingSnakeCase), "kebab-case" => Ok(KebabCase), "SCREAMING-KEBAB-CASE" => Ok(ScreamingKebabCase), "M_MATRIX_ERROR_CASE" => Ok(MatrixErrorCase), "m.snake_case" => Ok(MatrixSnakeCase), "m.dotted.case" => Ok(MatrixDottedCase), _ => Err(()), } } } #[test] fn rename_variants() { for &( original, lower, upper, camel, snake, screaming, kebab, screaming_kebab, matrix_error, m_snake, m_dotted, ) in &[ ( "Outcome", "outcome", "OUTCOME", "outcome", "outcome", "OUTCOME", "outcome", "OUTCOME", "M_OUTCOME", "m.outcome", "m.outcome", ), ( "VeryTasty", "verytasty", "VERYTASTY", "veryTasty", "very_tasty", "VERY_TASTY", "very-tasty", "VERY-TASTY", "M_VERY_TASTY", "m.very_tasty", "m.very.tasty", ), ("A", "a", "A", "a", "a", "A", "a", "A", "M_A", "m.a", "m.a"), ("Z42", "z42", "Z42", "z42", "z42", "Z42", "z42", "Z42", "M_Z42", "m.z42", "m.z42"), ] { assert_eq!(None.apply_to_variant(original), original); assert_eq!(LowerCase.apply_to_variant(original), lower); assert_eq!(Uppercase.apply_to_variant(original), upper); assert_eq!(PascalCase.apply_to_variant(original), original); assert_eq!(CamelCase.apply_to_variant(original), camel); assert_eq!(SnakeCase.apply_to_variant(original), snake); assert_eq!(ScreamingSnakeCase.apply_to_variant(original), screaming); assert_eq!(KebabCase.apply_to_variant(original), kebab); assert_eq!(ScreamingKebabCase.apply_to_variant(original), screaming_kebab); assert_eq!(MatrixErrorCase.apply_to_variant(original), matrix_error); assert_eq!(MatrixSnakeCase.apply_to_variant(original), m_snake); assert_eq!(MatrixDottedCase.apply_to_variant(original), m_dotted); } } #[test] fn rename_fields() { for &( original, upper, pascal, camel, screaming, kebab, screaming_kebab, matrix_error, m_snake, m_dotted, ) in &[ ( "outcome", "OUTCOME", "Outcome", "outcome", "OUTCOME", "outcome", "OUTCOME", "M_OUTCOME", "m.outcome", "m.outcome", ), ( "very_tasty", "VERY_TASTY", "VeryTasty", "veryTasty", "VERY_TASTY", "very-tasty", "VERY-TASTY", "M_VERY_TASTY", "m.very_tasty", "m.very.tasty", ), ("a", "A", "A", "a", "A", "a", "A", "M_A", "m.a", "m.a"), ("z42", "Z42", "Z42", "z42", "Z42", "z42", "Z42", "M_Z42", "m.z42", "m.z42"), ] { assert_eq!(None.apply_to_field(original), original); assert_eq!(Uppercase.apply_to_field(original), upper); assert_eq!(PascalCase.apply_to_field(original), pascal); assert_eq!(CamelCase.apply_to_field(original), camel); assert_eq!(SnakeCase.apply_to_field(original), original); assert_eq!(ScreamingSnakeCase.apply_to_field(original), screaming); assert_eq!(KebabCase.apply_to_field(original), kebab); assert_eq!(ScreamingKebabCase.apply_to_field(original), screaming_kebab); assert_eq!(MatrixErrorCase.apply_to_field(original), matrix_error); assert_eq!(MatrixSnakeCase.apply_to_field(original), m_snake); assert_eq!(MatrixDottedCase.apply_to_field(original), m_dotted); } } ruma-macros-0.10.5/src/serde/deserialize_from_cow_str.rs000064400000000000000000000015111046102023000214550ustar 00000000000000use proc_macro2::{Ident, TokenStream}; use quote::quote; use crate::util::import_ruma_common; pub fn expand_deserialize_from_cow_str(ident: &Ident) -> syn::Result { let ruma_common = import_ruma_common(); Ok(quote! { #[automatically_derived] #[allow(deprecated)] impl<'de> #ruma_common::exports::serde::de::Deserialize<'de> for #ident { fn deserialize(deserializer: D) -> ::std::result::Result where D: #ruma_common::exports::serde::de::Deserializer<'de>, { type CowStr<'a> = ::std::borrow::Cow<'a, ::std::primitive::str>; let cow = #ruma_common::serde::deserialize_cow_str(deserializer)?; Ok(::std::convert::From::>::from(cow)) } } }) } ruma-macros-0.10.5/src/serde/display_as_ref_str.rs000064400000000000000000000007241046102023000202530ustar 00000000000000use proc_macro2::{Ident, TokenStream}; use quote::quote; pub fn expand_display_as_ref_str(ident: &Ident) -> syn::Result { Ok(quote! { #[automatically_derived] #[allow(deprecated)] impl ::std::fmt::Display for #ident { fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { f.write_str(::std::convert::AsRef::<::std::primitive::str>::as_ref(self)) } } }) } ruma-macros-0.10.5/src/serde/enum_as_ref_str.rs000064400000000000000000000043751046102023000175600ustar 00000000000000use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use syn::{Fields, FieldsNamed, FieldsUnnamed, ItemEnum}; use super::{ attr::EnumAttrs, util::{get_enum_attributes, get_rename_rule}, }; pub fn expand_enum_as_ref_str(input: &ItemEnum) -> syn::Result { let enum_name = &input.ident; let rename_rule = get_rename_rule(input)?; let branches: Vec<_> = input .variants .iter() .map(|v| { let variant_name = &v.ident; let EnumAttrs { rename, .. } = get_enum_attributes(v)?; let (field_capture, variant_str) = match (rename, &v.fields) { (None, Fields::Unit) => ( None, rename_rule.apply_to_variant(&variant_name.to_string()).into_token_stream(), ), (Some(rename), Fields::Unit) => (None, rename.into_token_stream()), (None, Fields::Named(FieldsNamed { named: fields, .. })) | (None, Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. })) => { if fields.len() != 1 { return Err(syn::Error::new_spanned( v, "multiple data fields are not supported", )); } let capture = match &fields[0].ident { Some(name) => quote! { { #name: inner } }, None => quote! { (inner) }, }; (Some(capture), quote! { &inner.0 }) } (Some(_), _) => { return Err(syn::Error::new_spanned( v, "ruma_enum(rename) is only allowed on unit variants", )); } }; Ok(quote! { #enum_name :: #variant_name #field_capture => #variant_str }) }) .collect::>()?; Ok(quote! { #[automatically_derived] #[allow(deprecated)] impl ::std::convert::AsRef<::std::primitive::str> for #enum_name { fn as_ref(&self) -> &::std::primitive::str { match self { #(#branches),* } } } }) } ruma-macros-0.10.5/src/serde/enum_from_string.rs000064400000000000000000000064211046102023000177540ustar 00000000000000use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens}; use syn::{Fields, FieldsNamed, FieldsUnnamed, ItemEnum}; use super::{ attr::EnumAttrs, util::{get_enum_attributes, get_rename_rule}, }; pub fn expand_enum_from_string(input: &ItemEnum) -> syn::Result { let enum_name = &input.ident; let rename_rule = get_rename_rule(input)?; let mut fallback = None; let branches: Vec<_> = input .variants .iter() .map(|v| { let variant_name = &v.ident; let EnumAttrs { rename, aliases } = get_enum_attributes(v)?; let variant_str = match (rename, &v.fields) { (None, Fields::Unit) => Some( rename_rule.apply_to_variant(&variant_name.to_string()).into_token_stream(), ), (Some(rename), Fields::Unit) => Some(rename.into_token_stream()), (None, Fields::Named(FieldsNamed { named: fields, .. })) | (None, Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. })) => { if fields.len() != 1 { return Err(syn::Error::new_spanned( v, "multiple data fields are not supported", )); } if fallback.is_some() { return Err(syn::Error::new_spanned( v, "multiple data-carrying variants are not supported", )); } let member = match &fields[0].ident { Some(name) => name.into_token_stream(), None => quote! { 0 }, }; let ty = &fields[0].ty; fallback = Some(quote! { _ => #enum_name::#variant_name { #member: #ty(s.into()), } }); None } (Some(_), _) => { return Err(syn::Error::new_spanned( v, "ruma_enum(rename) is only allowed on unit variants", )); } }; Ok(variant_str.map(|s| { quote! { #( #aliases => #enum_name :: #variant_name, )* #s => #enum_name :: #variant_name } })) }) .collect::>()?; // Remove `None` from the iterator to avoid emitting consecutive commas in repetition let branches = branches.iter().flatten(); if fallback.is_none() { return Err(syn::Error::new(Span::call_site(), "required fallback variant not found")); } Ok(quote! { #[automatically_derived] #[allow(deprecated)] impl ::std::convert::From for #enum_name where T: ::std::convert::AsRef<::std::primitive::str> + ::std::convert::Into<::std::boxed::Box<::std::primitive::str>> { fn from(s: T) -> Self { match s.as_ref() { #( #branches, )* #fallback } } } }) } ruma-macros-0.10.5/src/serde/eq_as_ref_str.rs000064400000000000000000000010141046102023000172040ustar 00000000000000use proc_macro2::{Ident, TokenStream}; use quote::quote; pub fn expand_partial_eq_as_ref_str(ident: &Ident) -> syn::Result { Ok(quote! { #[automatically_derived] #[allow(deprecated)] impl ::std::cmp::PartialEq for #ident { fn eq(&self, other: &Self) -> bool { let other = ::std::convert::AsRef::<::std::primitive::str>::as_ref(other); ::std::convert::AsRef::<::std::primitive::str>::as_ref(self) == other } } }) } ruma-macros-0.10.5/src/serde/incoming.rs000064400000000000000000000245631046102023000162110ustar 00000000000000use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use syn::{ parse::{Parse, ParseStream}, parse_quote, punctuated::Punctuated, AngleBracketedGenericArguments, Attribute, Data, DeriveInput, GenericArgument, GenericParam, Generics, Ident, ItemType, ParenthesizedGenericArguments, Path, PathArguments, Token, Type, TypePath, TypeReference, TypeSlice, }; use crate::util::import_ruma_common; pub fn expand_derive_incoming(mut ty_def: DeriveInput) -> syn::Result { let ruma_common = import_ruma_common(); let mut found_lifetime = false; match &mut ty_def.data { Data::Union(_) => panic!("#[derive(Incoming)] does not support Union types"), Data::Enum(e) => { for var in &mut e.variants { for field in &mut var.fields { if strip_lifetimes(&mut field.ty, &ruma_common) { found_lifetime = true; } } } } Data::Struct(s) => { for field in &mut s.fields { if !matches!(field.vis, syn::Visibility::Public(_)) { return Err(syn::Error::new_spanned(field, "All fields must be marked `pub`")); } if strip_lifetimes(&mut field.ty, &ruma_common) { found_lifetime = true; } } } } let ident = format_ident!("Incoming{}", ty_def.ident, span = Span::call_site()); if !found_lifetime { let doc = format!( "Convenience type alias for [{}], for consistency with other [{}] types.", &ty_def.ident, ident ); let mut type_alias: ItemType = parse_quote! { type X = Y; }; type_alias.vis = ty_def.vis.clone(); type_alias.ident = ident; type_alias.generics = ty_def.generics.clone(); type_alias.ty = Box::new(TypePath { qself: None, path: ty_def.ident.clone().into() }.into()); return Ok(quote! { #[doc = #doc] #type_alias }); } let meta: Vec = ty_def .attrs .iter() .filter(|attr| attr.path.is_ident("incoming_derive")) .map(|attr| attr.parse_args()) .collect::>()?; let mut derive_debug = true; let mut derive_deserialize = true; let mut derives: Vec<_> = meta .into_iter() .flat_map(|m| m.derive_macs) .filter_map(|derive_mac| match derive_mac { DeriveMac::Regular(id) => Some(quote! { #id }), DeriveMac::NegativeDebug => { derive_debug = false; None } DeriveMac::NegativeDeserialize => { derive_deserialize = false; None } }) .collect(); if derive_debug { derives.push(quote! { ::std::fmt::Debug }); } derives.push(if derive_deserialize { quote! { #ruma_common::exports::serde::Deserialize } } else { quote! { #ruma_common::serde::_FakeDeriveSerde } }); ty_def.attrs.retain(filter_input_attrs); clean_generics(&mut ty_def.generics); let doc = format!("'Incoming' variant of [{}].", &ty_def.ident); ty_def.ident = ident; Ok(quote! { #[doc = #doc] #[derive( #( #derives ),* )] #ty_def }) } /// Keep any `cfg`, `cfg_attr`, `serde` or `non_exhaustive` attributes found and pass them to the /// Incoming variant. fn filter_input_attrs(attr: &Attribute) -> bool { attr.path.is_ident("cfg") || attr.path.is_ident("cfg_attr") || attr.path.is_ident("serde") || attr.path.is_ident("non_exhaustive") || attr.path.is_ident("allow") } fn clean_generics(generics: &mut Generics) { generics.params = generics .params .clone() .into_iter() .filter(|param| !matches!(param, GenericParam::Lifetime(_))) .collect(); } fn strip_lifetimes(field_type: &mut Type, ruma_common: &TokenStream) -> bool { match field_type { // T<'a> -> IncomingT // The IncomingT has to be declared by the user of this derive macro. Type::Path(TypePath { path, .. }) => { let mut has_lifetimes = false; let mut is_lifetime_generic = false; for seg in &mut path.segments { // strip generic lifetimes match &mut seg.arguments { PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) => { *args = args .clone() .into_iter() .map(|mut ty| { if let GenericArgument::Type(ty) = &mut ty { if strip_lifetimes(ty, ruma_common) { has_lifetimes = true; }; } ty }) .filter(|arg| { if let GenericArgument::Lifetime(_) = arg { is_lifetime_generic = true; false } else { true } }) .collect(); } PathArguments::Parenthesized(ParenthesizedGenericArguments { inputs, .. }) => { *inputs = inputs .clone() .into_iter() .map(|mut ty| { if strip_lifetimes(&mut ty, ruma_common) { has_lifetimes = true; }; ty }) .collect(); } _ => {} } } // If a type has a generic lifetime parameter there must be an `Incoming` variant of // that type. if is_lifetime_generic { if let Some(name) = path.segments.last_mut() { let incoming_ty_ident = format_ident!("Incoming{}", name.ident); name.ident = incoming_ty_ident; } } has_lifetimes || is_lifetime_generic } Type::Reference(TypeReference { elem, .. }) => { let special_replacement = match &mut **elem { Type::Path(ty) => { let path = &ty.path; let last_seg = path.segments.last().unwrap(); if last_seg.ident == "str" { // &str -> String Some(parse_quote! { ::std::string::String }) } else if last_seg.ident == "RawJsonValue" { Some(parse_quote! { ::std::boxed::Box<#path> }) } else if last_seg.ident == "ClientSecret" || last_seg.ident == "DeviceId" || last_seg.ident == "DeviceKeyId" || last_seg.ident == "DeviceSigningKeyId" || last_seg.ident == "EventId" || last_seg.ident == "KeyId" || last_seg.ident == "MxcUri" || last_seg.ident == "ServerName" || last_seg.ident == "SessionId" || last_seg.ident == "RoomAliasId" || last_seg.ident == "RoomId" || last_seg.ident == "RoomOrAliasId" || last_seg.ident == "RoomName" || last_seg.ident == "ServerSigningKeyId" || last_seg.ident == "SigningKeyId" || last_seg.ident == "TransactionId" || last_seg.ident == "UserId" { let ident = format_ident!("Owned{}", last_seg.ident); Some(parse_quote! { #ruma_common::#ident }) } else { None } } // &[T] -> Vec Type::Slice(TypeSlice { elem, .. }) => { // Recursively strip the lifetimes of the slice's elements. strip_lifetimes(&mut *elem, ruma_common); Some(parse_quote! { Vec<#elem> }) } _ => None, }; *field_type = match special_replacement { Some(ty) => ty, None => { // Strip lifetimes of `elem`. strip_lifetimes(elem, ruma_common); // Replace reference with `elem`. (**elem).clone() } }; true } Type::Tuple(syn::TypeTuple { elems, .. }) => { let mut has_lifetime = false; for elem in elems { if strip_lifetimes(elem, ruma_common) { has_lifetime = true; } } has_lifetime } _ => false, } } pub struct Meta { derive_macs: Vec, } impl Parse for Meta { fn parse(input: ParseStream<'_>) -> syn::Result { Ok(Self { derive_macs: Punctuated::<_, Token![,]>::parse_terminated(input)?.into_iter().collect(), }) } } pub enum DeriveMac { Regular(Path), NegativeDebug, NegativeDeserialize, } impl Parse for DeriveMac { fn parse(input: ParseStream<'_>) -> syn::Result { if input.peek(Token![!]) { let _: Token![!] = input.parse()?; let mac: Ident = input.parse()?; if mac == "Debug" { Ok(Self::NegativeDebug) } else if mac == "Deserialize" { Ok(Self::NegativeDeserialize) } else { Err(syn::Error::new_spanned( mac, "Negative incoming_derive can only be used for Debug and Deserialize", )) } } else { let mac = input.parse()?; Ok(Self::Regular(mac)) } } } ruma-macros-0.10.5/src/serde/ord_as_ref_str.rs000064400000000000000000000020411046102023000173640ustar 00000000000000use proc_macro2::{Ident, TokenStream}; use quote::quote; pub fn expand_partial_ord_as_ref_str(ident: &Ident) -> syn::Result { Ok(quote! { #[automatically_derived] #[allow(deprecated)] impl ::std::cmp::PartialOrd for #ident { fn partial_cmp(&self, other: &Self) -> ::std::option::Option<::std::cmp::Ordering> { let other = ::std::convert::AsRef::<::std::primitive::str>::as_ref(other); ::std::convert::AsRef::<::std::primitive::str>::as_ref(self).partial_cmp(other) } } }) } pub fn expand_ord_as_ref_str(ident: &Ident) -> syn::Result { Ok(quote! { #[automatically_derived] #[allow(deprecated)] impl ::std::cmp::Ord for #ident { fn cmp(&self, other: &Self) -> ::std::cmp::Ordering { let other = ::std::convert::AsRef::<::std::primitive::str>::as_ref(other); ::std::convert::AsRef::<::std::primitive::str>::as_ref(self).cmp(other) } } }) } ruma-macros-0.10.5/src/serde/serialize_as_ref_str.rs000064400000000000000000000012641046102023000205750ustar 00000000000000use proc_macro2::{Ident, TokenStream}; use quote::quote; use crate::util::import_ruma_common; pub fn expand_serialize_as_ref_str(ident: &Ident) -> syn::Result { let ruma_common = import_ruma_common(); Ok(quote! { #[automatically_derived] #[allow(deprecated)] impl #ruma_common::exports::serde::ser::Serialize for #ident { fn serialize(&self, serializer: S) -> ::std::result::Result where S: #ruma_common::exports::serde::ser::Serializer, { ::std::convert::AsRef::<::std::primitive::str>::as_ref(self).serialize(serializer) } } }) } ruma-macros-0.10.5/src/serde/util.rs000064400000000000000000000031541046102023000153540ustar 00000000000000use proc_macro2::Span; use syn::{punctuated::Punctuated, ItemEnum, Token, Variant}; use super::{ attr::{Attr, EnumAttrs, RenameAllAttr}, case::RenameRule, }; pub fn get_rename_rule(input: &ItemEnum) -> syn::Result { let rules: Vec<_> = input .attrs .iter() .filter(|attr| attr.path.is_ident("ruma_enum")) .map(|attr| attr.parse_args::().map(RenameAllAttr::into_inner)) .collect::>()?; match rules.len() { 0 => Ok(RenameRule::None), 1 => Ok(rules[0]), _ => Err(syn::Error::new( Span::call_site(), "found multiple ruma_enum(rename_all) attributes", )), } } pub fn get_enum_attributes(input: &Variant) -> syn::Result { let mut attributes = EnumAttrs::default(); for attr in &input.attrs { if !attr.path.is_ident("ruma_enum") { continue; } let enum_attrs = attr.parse_args_with(Punctuated::<_, Token![,]>::parse_terminated)?; for enum_attr in enum_attrs { match enum_attr { Attr::Rename(s) => { if attributes.rename.is_some() { return Err(syn::Error::new( Span::call_site(), "found multiple ruma_enum(rename) attributes", )); } attributes.rename = Some(s); } Attr::Alias(s) => { attributes.aliases.push(s); } } } } Ok(attributes) } ruma-macros-0.10.5/src/serde.rs000064400000000000000000000004771046102023000144040ustar 00000000000000//! Methods and types for (de)serialization. pub mod as_str_as_ref_str; pub mod attr; pub mod case; pub mod deserialize_from_cow_str; pub mod display_as_ref_str; pub mod enum_as_ref_str; pub mod enum_from_string; pub mod eq_as_ref_str; pub mod incoming; pub mod ord_as_ref_str; pub mod serialize_as_ref_str; mod util; ruma-macros-0.10.5/src/util.rs000064400000000000000000000035611046102023000142540ustar 00000000000000use proc_macro2::TokenStream; use proc_macro_crate::{crate_name, FoundCrate}; use quote::{format_ident, quote}; use syn::{Ident, LitStr}; pub(crate) fn import_ruma_common() -> TokenStream { if let Ok(FoundCrate::Name(name)) = crate_name("ruma-common") { let import = format_ident!("{name}"); quote! { ::#import } } else if let Ok(FoundCrate::Name(name)) = crate_name("ruma") { let import = format_ident!("{name}"); quote! { ::#import } } else if let Ok(FoundCrate::Name(name)) = crate_name("matrix-sdk") { let import = format_ident!("{name}"); quote! { ::#import::ruma } } else if let Ok(FoundCrate::Name(name)) = crate_name("matrix-sdk-appservice") { let import = format_ident!("{name}"); quote! { ::#import::ruma } } else { quote! { ::ruma_common } } } /// CamelCase's a field ident like "foo_bar" to "FooBar". pub(crate) fn to_camel_case(name: &Ident) -> Ident { let span = name.span(); let name = name.to_string(); let s: String = name .split('_') .map(|s| s.chars().next().unwrap().to_uppercase().to_string() + &s[1..]) .collect(); Ident::new(&s, span) } /// Splits the given string on `.` and `_` removing the `m.` then camel casing to give a Rust type /// name. pub(crate) fn m_prefix_name_to_type_name(name: &LitStr) -> syn::Result { let span = name.span(); let name = name.value(); let name = name.strip_prefix("m.").ok_or_else(|| { syn::Error::new( span, format!("well-known matrix events have to start with `m.` found `{}`", name), ) })?; let s: String = name .strip_suffix(".*") .unwrap_or(name) .split(&['.', '_'] as &[char]) .map(|s| s.chars().next().unwrap().to_uppercase().to_string() + &s[1..]) .collect(); Ok(Ident::new(&s, span)) }