use std::{borrow::Cow, ops::Deref};
use proc_macro2::Span;
use syn::{
parse::{discouraged::Speculative, Parse, ParseStream},
punctuated::Punctuated,
spanned::Spanned,
Attribute, Data, Ident, Meta, Path, PredicateType, Result, Token, TraitBound,
TraitBoundModifier, Type, TypeParamBound, TypePath, WhereClause, WherePredicate,
};
use crate::{
util::{self, MetaListExt},
Error, Incomparable, Item, Skip, SkipGroup, Trait, TraitImpl, DERIVE_WHERE,
};
#[derive(Default)]
pub struct ItemAttr {
pub skip_inner: Skip,
pub incomparable: Incomparable,
pub derive_wheres: Vec<DeriveWhere>,
}
impl ItemAttr {
pub fn from_attrs(span: Span, data: &Data, attrs: &[Attribute]) -> Result<Self> {
let mut self_ = ItemAttr::default();
let mut skip_inners = Vec::new();
let mut incomparables = Vec::new();
for attr in attrs {
if attr.path().is_ident(DERIVE_WHERE) {
if let Meta::List(list) = &attr.meta {
if let Ok(nested) =
list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
{
match nested.len() {
0 => return Err(Error::empty(list.span())),
1 => {
let meta =
nested.into_iter().next().expect("unexpected empty list");
if meta.path().is_ident(Skip::SKIP_INNER) {
if let Data::Enum(_) = data {
return Err(Error::option_enum_skip_inner(meta.span()));
}
skip_inners.push(meta);
} else if meta.path().is_ident(Incomparable::INCOMPARABLE) {
incomparables.push(meta)
} else if meta.path().is_ident("crate") {
}
else {
self_
.derive_wheres
.push(DeriveWhere::from_attr(span, data, attr)?);
}
}
_ => self_
.derive_wheres
.push(DeriveWhere::from_attr(span, data, attr)?),
}
}
else {
self_
.derive_wheres
.push(DeriveWhere::from_attr(span, data, attr)?)
}
} else {
return Err(Error::option_syntax(attr.meta.span()));
}
}
}
if self_.derive_wheres.is_empty() {
return Err(Error::none(span));
}
self_
.derive_wheres
.dedup_by(|derive_where_1, derive_where_2| {
if derive_where_1.generics == derive_where_2.generics {
derive_where_2.spans.append(&mut derive_where_1.spans);
derive_where_2.traits.append(&mut derive_where_1.traits);
true
} else {
false
}
});
for derive_where in &self_.derive_wheres {
for (skip, trait_) in (1..).zip(&derive_where.traits) {
if let Some((span, _)) = derive_where
.spans
.iter()
.zip(&derive_where.traits)
.skip(skip)
.find(|(_, other_trait)| *other_trait == trait_)
{
return Err(Error::trait_duplicate(*span));
}
}
}
for meta in skip_inners {
self_
.skip_inner
.add_attribute(&self_.derive_wheres, None, &meta)?;
}
for meta in incomparables {
self_
.incomparable
.add_attribute(&meta, &self_.derive_wheres)?;
}
Ok(self_)
}
}
pub struct DeriveWhere {
pub spans: Vec<Span>,
pub traits: Vec<DeriveTrait>,
pub generics: Vec<Generic>,
}
impl DeriveWhere {
fn from_attr(span: Span, data: &Data, attr: &Attribute) -> Result<Self> {
attr.parse_args_with(|input: ParseStream| {
let mut spans = Vec::new();
let mut traits = Vec::new();
let mut generics = Vec::new();
assert!(!input.is_empty());
while !input.is_empty() {
let (span, trait_) = DeriveTrait::from_stream(span, data, input)?;
spans.push(span);
traits.push(trait_);
if !input.is_empty() {
let mut fork = input.fork();
let no_delimiter_found = match <Token![,]>::parse(&fork) {
Ok(_) => {
input.advance_to(&fork);
None
}
Err(error) => {
fork = input.fork();
Some(error.span())
}
};
if <Token![;]>::parse(&fork).is_ok() {
input.advance_to(&fork);
if !input.is_empty() {
generics = Punctuated::<Generic, Token![,]>::parse_terminated(input)?
.into_iter()
.collect();
}
}
else if let Some(span) = no_delimiter_found {
return Err(Error::derive_where_delimiter(span));
}
}
}
Ok(Self {
generics,
spans,
traits,
})
})
}
pub fn contains(&self, trait_: Trait) -> bool {
self.traits
.iter()
.any(|derive_trait| derive_trait == trait_)
}
pub fn any_custom_bound(&self) -> bool {
self.generics.iter().any(|generic| match generic {
Generic::CustomBound(_) => true,
Generic::NoBound(_) => false,
})
}
pub fn has_type_param(&self, type_param: &Ident) -> bool {
self.generics.iter().any(|generic| match generic {
Generic::NoBound(Type::Path(TypePath { qself: None, path })) => {
if let Some(ident) = path.get_ident() {
ident == type_param
} else {
false
}
}
_ => false,
})
}
pub fn any_skip(&self) -> bool {
self.traits
.iter()
.any(|trait_| SkipGroup::trait_supported(**trait_))
}
pub fn where_clause(
&self,
where_clause: &mut Option<Cow<WhereClause>>,
trait_: &DeriveTrait,
item: &Item,
) {
if !self.generics.is_empty() {
let where_clause = where_clause.get_or_insert(Cow::Owned(WhereClause {
where_token: <Token![where]>::default(),
predicates: Punctuated::default(),
}));
for generic in &self.generics {
where_clause
.to_mut()
.predicates
.push(WherePredicate::Type(match generic {
Generic::CustomBound(type_bound) => type_bound.clone(),
Generic::NoBound(path) => PredicateType {
lifetimes: None,
bounded_ty: path.clone(),
colon_token: <Token![:]>::default(),
bounds: trait_.where_bounds(item),
},
}));
}
}
}
}
#[derive(Eq, PartialEq)]
pub enum Generic {
CustomBound(PredicateType),
NoBound(Type),
}
impl Parse for Generic {
fn parse(input: ParseStream) -> Result<Self> {
let fork = input.fork();
if let Ok(where_predicate) = WherePredicate::parse(&fork) {
input.advance_to(&fork);
if let WherePredicate::Type(path) = where_predicate {
Ok(Generic::CustomBound(path))
} else {
Err(Error::generic(where_predicate.span()))
}
} else {
match Type::parse(input) {
Ok(type_) => Ok(Generic::NoBound(type_)),
Err(error) => Err(Error::generic_syntax(error.span(), error)),
}
}
}
}
#[derive(Eq, PartialEq)]
pub enum DeriveTrait {
Clone,
Copy,
Debug,
Default,
Eq,
Hash,
Ord,
PartialEq,
PartialOrd,
#[cfg(feature = "zeroize")]
Zeroize {
crate_: Option<Path>,
},
#[cfg(feature = "zeroize")]
ZeroizeOnDrop {
crate_: Option<Path>,
},
}
impl Deref for DeriveTrait {
type Target = Trait;
fn deref(&self) -> &Self::Target {
use DeriveTrait::*;
match self {
Clone => &Trait::Clone,
Copy => &Trait::Copy,
Debug => &Trait::Debug,
Default => &Trait::Default,
Eq => &Trait::Eq,
Hash => &Trait::Hash,
Ord => &Trait::Ord,
PartialEq => &Trait::PartialEq,
PartialOrd => &Trait::PartialOrd,
#[cfg(feature = "zeroize")]
Zeroize { .. } => &Trait::Zeroize,
#[cfg(feature = "zeroize")]
ZeroizeOnDrop { .. } => &Trait::ZeroizeOnDrop,
}
}
}
impl PartialEq<Trait> for &DeriveTrait {
fn eq(&self, other: &Trait) -> bool {
let trait_: &Trait = self;
trait_ == other
}
}
impl DeriveTrait {
pub fn path(&self) -> Path {
use DeriveTrait::*;
match self {
Clone => util::path_from_root_and_strs(self.crate_(), &["clone", "Clone"]),
Copy => util::path_from_root_and_strs(self.crate_(), &["marker", "Copy"]),
Debug => util::path_from_root_and_strs(self.crate_(), &["fmt", "Debug"]),
Default => util::path_from_root_and_strs(self.crate_(), &["default", "Default"]),
Eq => util::path_from_root_and_strs(self.crate_(), &["cmp", "Eq"]),
Hash => util::path_from_root_and_strs(self.crate_(), &["hash", "Hash"]),
Ord => util::path_from_root_and_strs(self.crate_(), &["cmp", "Ord"]),
PartialEq => util::path_from_root_and_strs(self.crate_(), &["cmp", "PartialEq"]),
PartialOrd => util::path_from_root_and_strs(self.crate_(), &["cmp", "PartialOrd"]),
#[cfg(feature = "zeroize")]
Zeroize { .. } => util::path_from_root_and_strs(self.crate_(), &["Zeroize"]),
#[cfg(feature = "zeroize")]
ZeroizeOnDrop { .. } => util::path_from_root_and_strs(self.crate_(), &["ZeroizeOnDrop"]),
}
}
pub fn crate_(&self) -> Path {
use DeriveTrait::*;
match self {
Clone => util::path_from_strs(&["core"]),
Copy => util::path_from_strs(&["core"]),
Debug => util::path_from_strs(&["core"]),
Default => util::path_from_strs(&["core"]),
Eq => util::path_from_strs(&["core"]),
Hash => util::path_from_strs(&["core"]),
Ord => util::path_from_strs(&["core"]),
PartialEq => util::path_from_strs(&["core"]),
PartialOrd => util::path_from_strs(&["core"]),
#[cfg(feature = "zeroize")]
Zeroize { crate_, .. } => {
if let Some(crate_) = crate_ {
crate_.clone()
} else {
util::path_from_strs(&["zeroize"])
}
}
#[cfg(feature = "zeroize")]
ZeroizeOnDrop { crate_, .. } => {
if let Some(crate_) = crate_ {
crate_.clone()
} else {
util::path_from_strs(&["zeroize"])
}
}
}
}
fn where_bounds(&self, data: &Item) -> Punctuated<TypeParamBound, Token![+]> {
let mut list = Punctuated::new();
list.push(TypeParamBound::Trait(TraitBound {
paren_token: None,
modifier: TraitBoundModifier::None,
lifetimes: None,
path: self.path(),
}));
if let Some(bound) = self.additional_where_bounds(data) {
list.push(bound)
}
list
}
fn from_stream(span: Span, data: &Data, input: ParseStream) -> Result<(Span, Self)> {
match Meta::parse(input) {
Ok(meta) => {
let trait_ = Trait::from_path(meta.path())?;
if let Data::Union(_) = data {
if !trait_.supports_union() {
return Err(Error::union(span));
}
}
match &meta {
Meta::Path(path) => Ok((path.span(), trait_.default_derive_trait())),
Meta::List(list) => {
let nested = list.parse_non_empty_nested_metas()?;
Ok((list.span(), trait_.parse_derive_trait(meta.span(), nested)?))
}
Meta::NameValue(name_value) => Err(Error::option_syntax(name_value.span())),
}
}
Err(error) => Err(Error::trait_syntax(error.span())),
}
}
}