Skip to content

Commit

Permalink
Fixed perfect derives in conjunction with `#[derive(Asn1DefinedByRead…
Browse files Browse the repository at this point in the history
…)]` and `#[derive(Asn1DefinedByWrite)]` (#506)
  • Loading branch information
alex authored Nov 26, 2024
1 parent f8ca030 commit e23fbf1
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 16 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ asn1 = { version = "0.20", default-features = false }

## Changelog

### [0.20.1]

#### Fixes

- Fixed ["perfect derives"](https://smallcultfollowing.com/babysteps/blog/2022/04/12/implied-bounds-and-perfect-derive/)
in conjunction with `#[derive(Asn1DefinedByRead)]` and
`#[derive(Asn1DefinedByWrite)]`.
([#506](https://github.com/alex/rust-asn1/pull/506))

### [0.20.0]

#### :rotating_light: Breaking changes
Expand Down
96 changes: 80 additions & 16 deletions asn1_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub fn derive_asn1_read(input: proc_macro::TokenStream) -> proc_macro::TokenStre
let lifetime_name = add_lifetime_if_none(&mut generics);
add_bounds(
&mut generics,
all_field_types(&input.data, &input.generics),
all_field_types(&input.data, false, &input.generics),
syn::parse_quote!(asn1::Asn1Readable<#lifetime_name>),
syn::parse_quote!(asn1::Asn1DefinedByReadable<#lifetime_name, asn1::ObjectIdentifier>),
false,
Expand Down Expand Up @@ -61,7 +61,7 @@ pub fn derive_asn1_write(input: proc_macro::TokenStream) -> proc_macro::TokenStr
let mut input = syn::parse_macro_input!(input as syn::DeriveInput);

let name = input.ident;
let fields = all_field_types(&input.data, &input.generics);
let fields = all_field_types(&input.data, false, &input.generics);
add_bounds(
&mut input.generics,
fields,
Expand Down Expand Up @@ -146,10 +146,17 @@ pub fn derive_asn1_defined_by_read(input: proc_macro::TokenStream) -> proc_macro
let input = syn::parse_macro_input!(input as syn::DeriveInput);

let name = input.ident;
let (_, ty_generics, where_clause) = input.generics.split_for_impl();
let (_, ty_generics, _) = input.generics.split_for_impl();
let mut generics = input.generics.clone();
let lifetime_name = add_lifetime_if_none(&mut generics);
let (impl_generics, _, _) = generics.split_for_impl();
add_bounds(
&mut generics,
all_field_types(&input.data, true, &input.generics),
syn::parse_quote!(asn1::Asn1Readable<#lifetime_name>),
syn::parse_quote!(asn1::Asn1DefinedByReadable<#lifetime_name, asn1::ObjectIdentifier>),
false,
);
let (impl_generics, _, where_clause) = generics.split_for_impl();

let mut read_block = vec![];
let mut default_ident = None;
Expand Down Expand Up @@ -204,9 +211,17 @@ pub fn derive_asn1_defined_by_read(input: proc_macro::TokenStream) -> proc_macro

#[proc_macro_derive(Asn1DefinedByWrite, attributes(default, defined_by))]
pub fn derive_asn1_defined_by_write(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = syn::parse_macro_input!(input as syn::DeriveInput);
let mut input = syn::parse_macro_input!(input as syn::DeriveInput);

let name = input.ident;
let fields = all_field_types(&input.data, true, &input.generics);
add_bounds(
&mut input.generics,
fields,
syn::parse_quote!(asn1::Asn1Writable),
syn::parse_quote!(asn1::Asn1DefinedByWritable<asn1::ObjectIdentifier>),
true,
);
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

let mut write_blocks = vec![];
Expand Down Expand Up @@ -273,10 +288,19 @@ fn add_lifetime_if_none(generics: &mut syn::Generics) -> syn::Lifetime {
)));
};

generics.lifetimes().next().unwrap().lifetime.clone()
generics
.lifetimes()
.next()
.expect("No lifetime found")
.lifetime
.clone()
}

fn all_field_types(data: &syn::Data, generics: &syn::Generics) -> Vec<(syn::Type, OpType, bool)> {
fn all_field_types(
data: &syn::Data,
ignore_properties: bool,
generics: &syn::Generics,
) -> Vec<(syn::Type, OpType, bool)> {
let generic_params = generics
.params
.iter()
Expand All @@ -292,15 +316,27 @@ fn all_field_types(data: &syn::Data, generics: &syn::Generics) -> Vec<(syn::Type
let mut field_types = vec![];
match data {
syn::Data::Struct(v) => {
add_field_types(&mut field_types, &v.fields, None, &generic_params);
add_field_types(
&mut field_types,
&v.fields,
None,
ignore_properties,
&generic_params,
);
}
syn::Data::Enum(v) => {
for variant in &v.variants {
let (op_type, _) = extract_field_properties(&variant.attrs);
let op_type = if ignore_properties {
None
} else {
let (op_type, _) = extract_field_properties(&variant.attrs);
Some(op_type)
};
add_field_types(
&mut field_types,
&variant.fields,
Some(op_type),
op_type,
ignore_properties,
&generic_params,
);
}
Expand All @@ -314,17 +350,30 @@ fn add_field_types(
field_types: &mut Vec<(syn::Type, OpType, bool)>,
fields: &syn::Fields,
op_type: Option<OpType>,
ignore_properties: bool,
generic_params: &[syn::Ident],
) {
match fields {
syn::Fields::Named(v) => {
for f in &v.named {
add_field_type(field_types, f, op_type.clone(), generic_params);
add_field_type(
field_types,
f,
op_type.clone(),
ignore_properties,
generic_params,
);
}
}
syn::Fields::Unnamed(v) => {
for f in &v.unnamed {
add_field_type(field_types, f, op_type.clone(), generic_params);
add_field_type(
field_types,
f,
op_type.clone(),
ignore_properties,
generic_params,
);
}
}
syn::Fields::Unit => {}
Expand Down Expand Up @@ -376,6 +425,7 @@ fn add_field_type(
field_types: &mut Vec<(syn::Type, OpType, bool)>,
f: &syn::Field,
op_type: Option<OpType>,
ignore_properties: bool,
generic_params: &[syn::Ident],
) {
if !type_contains_generic_param(&f.ty, generic_params) {
Expand All @@ -391,6 +441,8 @@ fn add_field_type(
} else if let Some(OpType::Implicit(mut args)) = op_type {
args.required = true;
(OpType::Implicit(args), None)
} else if ignore_properties {
(OpType::Regular, None)
} else {
extract_field_properties(&f.attrs)
};
Expand Down Expand Up @@ -508,21 +560,33 @@ fn extract_field_properties(attrs: &[syn::Attribute]) -> (OpType, Option<syn::Ex
for attr in attrs {
if attr.path().is_ident("explicit") {
if let OpType::Regular = op_type {
op_type = OpType::Explicit(attr.parse_args::<OpTypeArgs>().unwrap());
op_type = OpType::Explicit(
attr.parse_args::<OpTypeArgs>()
.expect("Error parsing #[explicit]"),
);
} else {
panic!("Can't specify #[explicit] or #[implicit] more than once")
}
} else if attr.path().is_ident("implicit") {
if let OpType::Regular = op_type {
op_type = OpType::Implicit(attr.parse_args::<OpTypeArgs>().unwrap());
op_type = OpType::Implicit(
attr.parse_args::<OpTypeArgs>()
.expect("Error parsing #[implicit]"),
);
} else {
panic!("Can't specify #[explicit] or #[implicit] more than once")
}
} else if attr.path().is_ident("default") {
assert!(default.is_none(), "Can't specify #[default] more than once");
default = Some(attr.parse_args::<syn::Expr>().unwrap());
default = Some(
attr.parse_args::<syn::Expr>()
.expect("Error parsing #[default]"),
);
} else if attr.path().is_ident("defined_by") {
op_type = OpType::DefinedBy(attr.parse_args::<syn::Ident>().unwrap());
op_type = OpType::DefinedBy(
attr.parse_args::<syn::Ident>()
.expect("Error parsing #[defined_by]"),
);
}
}

Expand Down
48 changes: 48 additions & 0 deletions tests/derive_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -828,3 +828,51 @@ fn test_perfect_derive() {
(Ok(TaggedEnum::Explicit(1)), b"\xa1\x03\x02\x01\x01"),
]);
}

#[test]
fn test_defined_by_perfect_derive() {
trait X {
type Type: PartialEq + std::fmt::Debug;
}

#[derive(PartialEq, Debug)]
struct Op;
impl X for Op {
type Type = u64;
}

#[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug)]
struct S<T: X> {
oid: asn1::DefinedByMarker<asn1::ObjectIdentifier>,
#[defined_by(oid)]
value: Value<T>,
}

pub const OID1: asn1::ObjectIdentifier = asn1::oid!(1, 2, 3);
pub const OID2: asn1::ObjectIdentifier = asn1::oid!(1, 2, 4);

#[derive(asn1::Asn1DefinedByRead, asn1::Asn1DefinedByWrite, PartialEq, Debug)]
enum Value<T: X> {
#[defined_by(OID1)]
A(T::Type),
#[defined_by(OID2)]
B(T::Type),
}

assert_roundtrips::<S<Op>>(&[
(
Ok(S {
oid: asn1::DefinedByMarker::marker(),
value: Value::A(5),
}),
b"\x30\x07\x06\x02\x2a\x03\x02\x01\x05",
),
(
Ok(S {
oid: asn1::DefinedByMarker::marker(),
value: Value::B(7),
}),
b"\x30\x07\x06\x02\x2a\x04\x02\x01\x07",
),
]);
}

0 comments on commit e23fbf1

Please sign in to comment.