yew_macro/hook/
signature.rs1use std::iter::once;
2use std::mem::take;
3
4use proc_macro2::{Span, TokenStream};
5use proc_macro_error::emit_error;
6use quote::{quote, ToTokens};
7use syn::punctuated::{Pair, Punctuated};
8use syn::spanned::Spanned;
9use syn::visit_mut::VisitMut;
10use syn::{
11 parse_quote, parse_quote_spanned, visit_mut, FnArg, GenericParam, Ident, Lifetime,
12 LifetimeParam, Pat, Receiver, ReturnType, Signature, Type, TypeImplTrait, TypeParam,
13 TypeParamBound, TypeReference, WherePredicate,
14};
15
16use super::lifetime;
17
18fn type_is_generic(ty: &Type, param: &TypeParam) -> bool {
19 match ty {
20 Type::Path(path) => path.path.is_ident(¶m.ident),
21 _ => false,
22 }
23}
24
25#[derive(Default)]
26pub struct CollectArgs {
27 needs_boxing: bool,
28}
29
30impl CollectArgs {
31 pub fn new() -> Self {
32 Self::default()
33 }
34}
35
36impl VisitMut for CollectArgs {
37 fn visit_type_impl_trait_mut(&mut self, impl_trait: &mut TypeImplTrait) {
38 self.needs_boxing = true;
39
40 visit_mut::visit_type_impl_trait_mut(self, impl_trait);
41 }
42
43 fn visit_receiver_mut(&mut self, recv: &mut Receiver) {
44 emit_error!(recv, "methods cannot be hooks");
45
46 visit_mut::visit_receiver_mut(self, recv);
47 }
48}
49
50pub struct HookSignature {
51 pub hook_lifetime: Lifetime,
52 pub sig: Signature,
53 pub output_type: Type,
54 pub needs_boxing: bool,
55}
56
57impl HookSignature {
58 fn rewrite_return_type(hook_lifetime: &Lifetime, rt_type: &ReturnType) -> (ReturnType, Type) {
59 let bound = quote! { #hook_lifetime + };
60
61 match rt_type {
62 ReturnType::Default => (
63 parse_quote! { -> impl #bound ::yew::functional::Hook<Output = ()> },
64 parse_quote! { () },
65 ),
66 ReturnType::Type(arrow, ref return_type) => {
67 if let Type::Reference(ref m) = &**return_type {
68 if m.lifetime.is_none() {
69 let mut return_type_ref = m.clone();
70 return_type_ref.lifetime = parse_quote!('hook);
71
72 let return_type_ref = Type::Reference(return_type_ref);
73
74 return (
75 parse_quote_spanned! {
76 return_type.span() => #arrow impl #bound ::yew::functional::Hook<Output = #return_type_ref>
77 },
78 return_type_ref,
79 );
80 }
81 }
82
83 (
84 parse_quote_spanned! {
85 return_type.span() => #arrow impl #bound ::yew::functional::Hook<Output = #return_type>
86 },
87 *return_type.clone(),
88 )
89 }
90 }
91 }
92
93 pub fn rewrite(sig: &Signature) -> Self {
95 let mut sig = sig.clone();
96
97 let mut arg_info = CollectArgs::new();
98 arg_info.visit_signature_mut(&mut sig);
99
100 let mut lifetimes = lifetime::CollectLifetimes::new("'arg", sig.ident.span());
101 for arg in sig.inputs.iter_mut() {
102 match arg {
103 FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
104 FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
105 }
106 }
107
108 let Signature {
109 ref mut generics,
110 output: ref return_type,
111 ..
112 } = sig;
113
114 let hook_lifetime = Lifetime::new("'hook", Span::mixed_site());
115 let mut params: Punctuated<_, _> = once(hook_lifetime.clone())
116 .chain(lifetimes.elided)
117 .map(|lifetime| {
118 GenericParam::Lifetime(LifetimeParam {
119 attrs: vec![],
120 lifetime,
121 colon_token: None,
122 bounds: Default::default(),
123 })
124 })
125 .map(|param| Pair::new(param, Some(Default::default())))
126 .chain(take(&mut generics.params).into_pairs())
127 .collect();
128
129 for type_param in params.iter_mut().skip(1) {
130 match type_param {
131 GenericParam::Lifetime(param) => {
132 if let Some(predicate) = generics
133 .where_clause
134 .iter_mut()
135 .flat_map(|c| &mut c.predicates)
136 .find_map(|predicate| match predicate {
137 WherePredicate::Lifetime(p) if p.lifetime == param.lifetime => Some(p),
138 _ => None,
139 })
140 {
141 predicate.bounds.push(hook_lifetime.clone());
142 } else {
143 param.colon_token = Some(param.colon_token.unwrap_or_default());
144 param.bounds.push(hook_lifetime.clone());
145 }
146 }
147
148 GenericParam::Type(param) => {
149 if let Some(predicate) = generics
150 .where_clause
151 .iter_mut()
152 .flat_map(|c| &mut c.predicates)
153 .find_map(|predicate| match predicate {
154 WherePredicate::Type(p) if type_is_generic(&p.bounded_ty, param) => {
155 Some(p)
156 }
157 _ => None,
158 })
159 {
160 predicate
161 .bounds
162 .push(TypeParamBound::Lifetime(hook_lifetime.clone()));
163 } else {
164 param.colon_token = Some(param.colon_token.unwrap_or_default());
165 param
166 .bounds
167 .push(TypeParamBound::Lifetime(hook_lifetime.clone()));
168 }
169 }
170
171 GenericParam::Const(_) => {}
172 }
173 }
174
175 generics.params = params;
176
177 let (output, output_type) = Self::rewrite_return_type(&hook_lifetime, return_type);
178 sig.output = output;
179
180 Self {
181 hook_lifetime,
182 sig,
183 output_type,
184 needs_boxing: arg_info.needs_boxing,
185 }
186 }
187
188 pub fn phantom_types(&self) -> Vec<Ident> {
189 self.sig
190 .generics
191 .type_params()
192 .map(|ty_param| ty_param.ident.clone())
193 .collect()
194 }
195
196 pub fn phantom_lifetimes(&self) -> Vec<TypeReference> {
197 self.sig
198 .generics
199 .lifetimes()
200 .map(|life| TypeReference {
201 and_token: Default::default(),
202 lifetime: Some(life.lifetime.clone()),
203 mutability: None,
204 elem: Box::new(Type::Tuple(syn::TypeTuple {
205 paren_token: Default::default(),
206 elems: Default::default(),
207 })),
208 })
209 .collect()
210 }
211
212 pub fn input_args(&self) -> Vec<Ident> {
213 self.sig
214 .inputs
215 .iter()
216 .filter_map(|m| {
217 if let FnArg::Typed(m) = m {
218 if let Pat::Ident(ref m) = *m.pat {
219 return Some(m.ident.clone());
220 }
221 }
222
223 None
224 })
225 .collect()
226 }
227
228 pub fn input_types(&self) -> Vec<Type> {
229 self.sig
230 .inputs
231 .iter()
232 .filter_map(|m| {
233 if let FnArg::Typed(m) = m {
234 return Some(*m.ty.clone());
235 }
236
237 None
238 })
239 .collect()
240 }
241
242 pub fn call_generics(&self) -> TokenStream {
243 let mut generics = self.sig.generics.clone();
244
245 generics.params = generics
247 .params
248 .into_iter()
249 .filter(|m| !matches!(m, GenericParam::Lifetime(_)))
250 .collect();
251
252 let (_impl_generics, ty_generics, _where_clause) = generics.split_for_impl();
253 ty_generics.as_turbofish().to_token_stream()
254 }
255}