use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens};
use syn::{
    parse_quote, punctuated::Punctuated, Block, FnArg, Lifetime, ReturnType, Signature, Type,
    WhereClause,
};
use crate::parse::{AsyncItem, RecursionArgs};
impl ToTokens for AsyncItem {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        self.0.to_tokens(tokens);
    }
}
pub fn expand(item: &mut AsyncItem, args: &RecursionArgs) {
    item.0.attrs.push(parse_quote!(#[must_use]));
    transform_sig(&mut item.0.sig, args);
    transform_block(&mut item.0.block);
}
fn transform_block(block: &mut Block) {
    let brace = block.brace_token;
    *block = parse_quote!({
        Box::pin(async move #block)
    });
    block.brace_token = brace;
}
enum ArgLifetime {
    New(Lifetime),
    Existing(Lifetime),
}
impl ArgLifetime {
    pub fn lifetime(self) -> Lifetime {
        match self {
            ArgLifetime::New(lt) | ArgLifetime::Existing(lt) => lt,
        }
    }
}
fn transform_sig(sig: &mut Signature, args: &RecursionArgs) {
    let ret = match &sig.output {
        ReturnType::Default => quote!(()),
        ReturnType::Type(_, ret) => quote!(#ret),
    };
    sig.asyncness = None;
    let mut ref_arguments = Vec::new();
    let mut self_lifetime = None;
    for arg in &mut sig.inputs {
        if let FnArg::Typed(pt) = arg {
            if let Type::Reference(tr) = pt.ty.as_mut() {
                ref_arguments.push(tr);
            }
        } else if let FnArg::Receiver(recv) = arg {
            if let Some((_, slt)) = &mut recv.reference {
                self_lifetime = Some(slt);
            }
        }
    }
    let mut counter = 0;
    let mut lifetimes = Vec::new();
    if !ref_arguments.is_empty() {
        for ra in &mut ref_arguments {
            if ra.lifetime.is_none() {
                let lt = Lifetime::new(&format!("'life{counter}"), Span::call_site());
                lifetimes.push(ArgLifetime::New(parse_quote!(#lt)));
                ra.lifetime = Some(lt);
                counter += 1;
            } else {
                let lt = ra.lifetime.as_ref().cloned().unwrap();
                let ident_matches = |x: &ArgLifetime| {
                    if let ArgLifetime::Existing(elt) = x {
                        elt.ident == lt.ident
                    } else {
                        false
                    }
                };
                if !lifetimes.iter().any(ident_matches) {
                    lifetimes.push(ArgLifetime::Existing(
                        ra.lifetime.as_ref().cloned().unwrap(),
                    ));
                }
            }
        }
    }
    let mut requires_lifetime = false;
    let mut where_clause_lifetimes = vec![];
    let mut where_clause_generics = vec![];
    let asr: Lifetime = parse_quote!('async_recursion);
    for param in sig.generics.type_params() {
        let ident = param.ident.clone();
        where_clause_generics.push(ident);
        requires_lifetime = true;
    }
    if !lifetimes.is_empty() {
        for alt in lifetimes {
            if let ArgLifetime::New(lt) = &alt {
                sig.generics.params.push(parse_quote!(#lt));
            }
            let lt = alt.lifetime();
            where_clause_lifetimes.push(lt);
        }
        requires_lifetime = true;
    }
    if let Some(slt) = self_lifetime {
        let lt = {
            if let Some(lt) = slt.as_mut() {
                lt.clone()
            } else {
                let lt: Lifetime = parse_quote!('life_self);
                sig.generics.params.push(parse_quote!(#lt));
                *slt = Some(lt.clone());
                lt
            }
        };
        where_clause_lifetimes.push(lt);
        requires_lifetime = true;
    }
    let box_lifetime: TokenStream = if requires_lifetime {
        sig.generics.params.push(parse_quote!('async_recursion));
        quote!(+ #asr)
    } else {
        quote!()
    };
    let send_bound: TokenStream = if args.send_bound {
        quote!(+ ::core::marker::Send)
    } else {
        quote!()
    };
    let where_clause = sig
        .generics
        .where_clause
        .get_or_insert_with(|| WhereClause {
            where_token: Default::default(),
            predicates: Punctuated::new(),
        });
    for generic_ident in where_clause_generics {
        where_clause
            .predicates
            .push(parse_quote!(#generic_ident : #asr));
    }
    for lifetime in where_clause_lifetimes {
        where_clause.predicates.push(parse_quote!(#lifetime : #asr));
    }
    sig.output = parse_quote! {
        -> ::core::pin::Pin<Box<
            dyn ::core::future::Future<Output = #ret> #box_lifetime #send_bound >>
    };
}