Skip to content

Commit

Permalink
Only output type bounds when a field actually refers to a generic
Browse files Browse the repository at this point in the history
  • Loading branch information
alex committed Nov 26, 2024
1 parent 44ef78d commit edfbe15
Showing 1 changed file with 72 additions and 7 deletions.
79 changes: 72 additions & 7 deletions asn1_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub fn derive_asn1_read(input: proc_macro::TokenStream) -> proc_macro::TokenStre
let lifetime_name = add_lifetime_if_none(&mut generics);
add_bounds(
&mut generics,
all_field_types(&input.data),
all_field_types(&input.data, &input.generics),
syn::parse_quote!(asn1::Asn1Readable<#lifetime_name>),
syn::parse_quote!(asn1::Asn1DefinedByReadable<#lifetime_name, asn1::ObjectIdentifier>),
false,
Expand Down Expand Up @@ -61,9 +61,10 @@ pub fn derive_asn1_write(input: proc_macro::TokenStream) -> proc_macro::TokenStr
let mut input = syn::parse_macro_input!(input as syn::DeriveInput);

let name = input.ident;
let fields = all_field_types(&input.data, &input.generics);
add_bounds(
&mut input.generics,
all_field_types(&input.data),
fields,
syn::parse_quote!(asn1::Asn1Writable),
syn::parse_quote!(asn1::Asn1DefinedByWritable<asn1::ObjectIdentifier>),
true,
Expand Down Expand Up @@ -275,16 +276,33 @@ fn add_lifetime_if_none(generics: &mut syn::Generics) -> syn::Lifetime {
generics.lifetimes().next().unwrap().lifetime.clone()
}

fn all_field_types(data: &syn::Data) -> Vec<(syn::Type, OpType, bool)> {
fn all_field_types(data: &syn::Data, generics: &syn::Generics) -> Vec<(syn::Type, OpType, bool)> {
let generic_params = generics
.params
.iter()
.filter_map(|p| {
if let syn::GenericParam::Type(tp) = p {
Some(tp.ident.clone())
} else {
None
}
})
.collect::<Vec<_>>();

let mut field_types = vec![];
match data {
syn::Data::Struct(v) => {
add_field_types(&mut field_types, &v.fields, None);
add_field_types(&mut field_types, &v.fields, None, &generic_params);
}
syn::Data::Enum(v) => {
for variant in &v.variants {
let (op_type, _) = extract_field_properties(&variant.attrs);
add_field_types(&mut field_types, &variant.fields, Some(op_type));
add_field_types(
&mut field_types,
&variant.fields,
Some(op_type),
&generic_params,
);
}
}
syn::Data::Union(_) => panic!("Unions not supported"),
Expand All @@ -296,27 +314,74 @@ fn add_field_types(
field_types: &mut Vec<(syn::Type, OpType, bool)>,
fields: &syn::Fields,
op_type: Option<OpType>,
generic_params: &[syn::Ident],
) {
match fields {
syn::Fields::Named(v) => {
for f in &v.named {
add_field_type(field_types, f, op_type.clone());
add_field_type(field_types, f, op_type.clone(), generic_params);
}
}
syn::Fields::Unnamed(v) => {
for f in &v.unnamed {
add_field_type(field_types, f, op_type.clone());
add_field_type(field_types, f, op_type.clone(), generic_params);
}
}
syn::Fields::Unit => {}
}
}

fn type_contains_generic_param(t: &syn::Type, generic_params: &[syn::Ident]) -> bool {
match t {
syn::Type::Array(v) => type_contains_generic_param(&v.elem, generic_params),
syn::Type::BareFn(_) => todo!("BareFn"),
syn::Type::Group(v) => type_contains_generic_param(&v.elem, generic_params),
syn::Type::ImplTrait(_) => todo!("ImplTrait"),
syn::Type::Infer(_) => false,
syn::Type::Macro(_) => false,
syn::Type::Never(_) => false,
syn::Type::Paren(v) => type_contains_generic_param(&v.elem, generic_params),
syn::Type::Path(v) => {
if let Some(q) = &v.qself {
if type_contains_generic_param(&q.ty, generic_params) {
return true;
}
} else if generic_params.contains(&v.path.segments[0].ident) {
return true;
}
v.path.segments.iter().any(|s| match &s.arguments {
syn::PathArguments::AngleBracketed(a) => a.args.iter().any(|ga| match ga {
syn::GenericArgument::Type(t) => type_contains_generic_param(t, generic_params),
_ => false,
}),
syn::PathArguments::Parenthesized(_) => todo!("ParenthesizedGenericArguments"),
syn::PathArguments::None => false,
})
}
syn::Type::Ptr(v) => type_contains_generic_param(&v.elem, generic_params),
syn::Type::Reference(v) => type_contains_generic_param(&v.elem, generic_params),
syn::Type::Slice(v) => type_contains_generic_param(&v.elem, generic_params),
syn::Type::TraitObject(_) => todo!("TraitObject"),
syn::Type::Tuple(v) => v
.elems
.iter()
.any(|t| type_contains_generic_param(t, generic_params)),
syn::Type::Verbatim(_) => false,

_ => false,
}
}

fn add_field_type(
field_types: &mut Vec<(syn::Type, OpType, bool)>,
f: &syn::Field,
op_type: Option<OpType>,
generic_params: &[syn::Ident],
) {
if !type_contains_generic_param(&f.ty, generic_params) {
return;
}

// If we have an op_type here, it means it came from an enum variant. In
// that case, even though it wasn't marked "required", it is for the
// purposes of how we're using it.
Expand Down

0 comments on commit edfbe15

Please sign in to comment.