Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disallow duplicated enum variant indices #628

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,20 @@ a.using_encoded(|ref slice| {
});

b.using_encoded(|ref slice| {
assert_eq!(slice, &b"\x01\x01\0\0\0\x02\0\0\0\0\0\0\0");
assert_eq!(slice, &b"\x00\x01\0\0\0\x02\0\0\0\0\0\0\0");
});

c.using_encoded(|ref slice| {
assert_eq!(slice, &b"\x02\x01\0\0\0\x02\0\0\0\0\0\0\0");
assert_eq!(slice, &b"\x01\x01\0\0\0\x02\0\0\0\0\0\0\0");
});

let mut da: &[u8] = b"\x0f";
assert_eq!(EnumType::decode(&mut da).ok(), Some(a));

let mut db: &[u8] = b"\x01\x01\0\0\0\x02\0\0\0\0\0\0\0";
let mut db: &[u8] = b"\x00\x01\0\0\0\x02\0\0\0\0\0\0\0";
assert_eq!(EnumType::decode(&mut db).ok(), Some(b));

let mut dc: &[u8] = b"\x02\x01\0\0\0\x02\0\0\0\0\0\0\0";
let mut dc: &[u8] = b"\x01\x01\0\0\0\x02\0\0\0\0\0\0\0";
assert_eq!(EnumType::decode(&mut dc).ok(), Some(c));

let mut dz: &[u8] = &[0];
Expand Down Expand Up @@ -216,7 +216,7 @@ The derive implementation supports the following attributes:
- `codec(encoded_as = "OtherType")`: Needs to be placed above a field and makes the field being
encoded by using `OtherType`.
- `codec(index = 0)`: Needs to be placed above an enum variant to make the variant use the given
index when encoded. By default the index is determined by counting from `0` beginning wth the
index when encoded. By default the index is determined by counting from `0` beginning with the
first variant.
- `codec(encode_bound)`, `codec(decode_bound)` and `codec(mel_bound)`: All 3 attributes take
in a `where` clause for the `Encode`, `Decode` and `MaxEncodedLen` trait implementation for
Expand Down
26 changes: 18 additions & 8 deletions derive/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use proc_macro2::{Ident, Span, TokenStream};
use syn::{spanned::Spanned, Data, Error, Field, Fields};

use crate::utils;
use crate::utils::{self, UsedIndexes};

/// Generate function block for function `Decode::decode`.
///
Expand Down Expand Up @@ -56,10 +56,19 @@ pub fn quote(
)
.to_compile_error();
}

let recurse = data_variants().enumerate().map(|(i, v)| {
let mut used_indexes =
match UsedIndexes::from(data_variants()).map_err(|e| e.to_compile_error()) {
Ok(index) => index,
Err(e) => return e,
};
let mut items = vec![];
for v in data_variants() {
let name = &v.ident;
let index = utils::variant_index(v, i);
let index = match used_indexes.variant_index(v).map_err(|e| e.into_compile_error())
{
Ok(i) => i,
Err(e) => return e,
};

let create = create_instance(
quote! { #type_name #type_generics :: #name },
Expand All @@ -69,7 +78,7 @@ pub fn quote(
crate_path,
);

quote_spanned! { v.span() =>
let item = quote_spanned! { v.span() =>
#[allow(clippy::unnecessary_cast)]
__codec_x_edqy if __codec_x_edqy == #index as ::core::primitive::u8 => {
// NOTE: This lambda is necessary to work around an upstream bug
Expand All @@ -80,8 +89,9 @@ pub fn quote(
#create
})();
},
}
});
};
items.push(item);
}

let read_byte_err_msg =
format!("Could not decode `{type_name}`, failed to read variant byte");
Expand All @@ -91,7 +101,7 @@ pub fn quote(
match #input.read_byte()
.map_err(|e| e.chain(#read_byte_err_msg))?
{
#( #recurse )*
#( #items )*
_ => {
#[allow(clippy::redundant_closure_call)]
return (move || {
Expand Down
29 changes: 19 additions & 10 deletions derive/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::str::from_utf8;
use proc_macro2::{Ident, Span, TokenStream};
use syn::{punctuated::Punctuated, spanned::Spanned, token::Comma, Data, Error, Field, Fields};

use crate::utils;
use crate::{utils, utils::UsedIndexes};

type FieldsList = Punctuated<Field, Comma>;

Expand Down Expand Up @@ -313,12 +313,20 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS
if data_variants().count() == 0 {
return quote!();
}

let recurse = data_variants().enumerate().map(|(i, f)| {
let mut used_indexes =
match UsedIndexes::from(data_variants()).map_err(|e| e.to_compile_error()) {
Ok(index) => index,
Err(e) => return e,
};
let mut items = vec![];
for f in data_variants() {
let name = &f.ident;
let index = utils::variant_index(f, i);

match f.fields {
let index = match used_indexes.variant_index(f).map_err(|e| e.into_compile_error())
{
Ok(i) => i,
Err(e) => return e,
};
let item = match f.fields {
Fields::Named(ref fields) => {
let fields = &fields.named;
let field_name = |_, ident: &Option<Ident>| quote!(#ident);
Expand Down Expand Up @@ -396,11 +404,12 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS

[hinting, encoding]
},
}
});
};
items.push(item)
}

let recurse_hinting = recurse.clone().map(|[hinting, _]| hinting);
let recurse_encoding = recurse.clone().map(|[_, encoding]| encoding);
let recurse_hinting = items.iter().map(|[hinting, _]| hinting);
let recurse_encoding = items.iter().map(|[_, encoding]| encoding);

let hinting = quote! {
// The variant index uses 1 byte.
Expand Down
6 changes: 3 additions & 3 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ fn wrap_with_dummy_const(
/// * if variant has attribute: `#[codec(index = "$n")]` then n
/// * else if variant has discriminant (like 3 in `enum T { A = 3 }`) then the discriminant.
/// * else its position in the variant set, excluding skipped variants, but including variant with
/// discriminant or attribute. Warning this position does collision with discriminant or attribute
/// index.
/// discriminant or attribute. Warning this position does collision with discriminant or attribute
/// index.
///
/// variant attributes:
/// * `#[codec(skip)]`: the variant is not encoded.
Expand All @@ -119,7 +119,7 @@ fn wrap_with_dummy_const(
/// assert_eq!(EnumType::A.encode(), vec![15]);
/// assert_eq!(EnumType::B.encode(), vec![]);
/// assert_eq!(EnumType::C.encode(), vec![3]);
/// assert_eq!(EnumType::D.encode(), vec![2]);
/// assert_eq!(EnumType::D.encode(), vec![0]);
/// ```
#[proc_macro_derive(Encode, attributes(codec))]
pub fn encode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
Expand Down
118 changes: 93 additions & 25 deletions derive/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
//! NOTE: attributes finder must be checked using check_attribute first,
//! otherwise the macro can panic.

use std::str::FromStr;
use std::{collections::HashMap, str::FromStr};

use proc_macro2::TokenStream;
use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens};
use syn::{
parse::Parse, punctuated::Punctuated, spanned::Spanned, token, Attribute, Data, DeriveInput,
Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Path, Variant,
ExprLit, Field, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta, MetaNameValue, NestedMeta, Path,
Variant,
};

fn find_meta_item<'a, F, R, I, M>(mut itr: I, mut pred: F) -> Option<R>
Expand All @@ -37,32 +38,99 @@ where
})
}

/// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute
/// is found, fall back to the discriminant or just the variant index.
pub fn variant_index(v: &Variant, i: usize) -> TokenStream {
// first look for an attribute
let index = find_meta_item(v.attrs.iter(), |meta| {
if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta {
if nv.path.is_ident("index") {
if let Lit::Int(ref v) = nv.lit {
let byte = v
.base10_parse::<u8>()
.expect("Internal error, index attribute must have been checked");
return Some(byte);
/// Indexes used while defining variants.
pub struct UsedIndexes {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docs.

Also this struct should only store the used_set. The function from_iter is not required.

You only need variant_index, which should be a verbatim copy of the old variant_index function plus that you add the resulting index to used_set and check if it already exists.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added docs.
we need current to assign incremental indexes to variants in ascending order also we need to traverse the variants first and collect all the assigned indexes before we can process them to check for duplicates and have a set of indexes already used, so that when we encounter a variant without explicit discriminant/codec(index = ?) we could assign an index using UsedIndexes.current by incrementing it until the u8 is not contained inside the used set.

/// Map from index of the variant to it's attribute or definition span
used_set: HashMap<u8, Span>,
/// We need this u8 to correctly assign indexes to variants
/// that are not annotated by coded(index = ?) or explicit discriminant
current: u8,
}

impl UsedIndexes {
/// Build a Set of used indexes for use with #[scale(index = $int)] attribute or
/// explicit discriminant on the variant
pub fn from<'a, I: Iterator<Item = &'a Variant>>(values: I) -> syn::Result<Self> {
let mut map: HashMap<u8, Span> = HashMap::new();
for v in values {
if let Some((index, nv)) = find_meta_item(v.attrs.iter(), |meta| {
if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta {
if nv.path.is_ident("index") {
if let Lit::Int(ref v) = nv.lit {
let byte = v
.base10_parse::<u8>()
.expect("Internal error, index attribute must have been checked");
return Some((byte, nv.span()));
}
}
}
None
}) {
if let Some(span) = map.insert(index, nv.span()) {
let mut error = syn::Error::new(nv.span(), "Duplicate variant index. qed");
error.combine(syn::Error::new(span, "Variant index already defined here."));
return Err(error)
}
} else if let Some((
_,
expr @ syn::Expr::Lit(ExprLit { lit: syn::Lit::Int(lit_int), .. }),
Copy link
Contributor

@gui1117 gui1117 Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this doesn't capture any discriminant, user can write this:

#[derive(Encode, Decode)]
enum A {
	A = 3 + 4,
}

I think it can be ok to constraint our implementation, but then we should compile-error if the expression is not a int literal

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated with to disallow this pattern

)) = v.discriminant.as_ref()
{
let index = lit_int
.base10_parse::<u8>()
.expect("Internal error, index attribute must have been checked");
if let Some(span) = map.insert(index, expr.span()) {
let mut error = syn::Error::new(expr.span(), "Duplicate variant index. qed");
error.combine(syn::Error::new(span, "Variant index already defined here."));
return Err(error)
}
}
}
Ok(Self { current: 0, used_set: map })
}

None
});

// then fallback to discriminant or just index
index.map(|i| quote! { #i }).unwrap_or_else(|| {
v.discriminant
.as_ref()
.map(|(_, expr)| quote! { #expr })
.unwrap_or_else(|| quote! { #i })
})
/// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute
/// is found, fall back to the discriminant or just the variant index.
pub fn variant_index(&mut self, v: &Variant) -> syn::Result<TokenStream> {
// first look for an attribute
let index = find_meta_item(v.attrs.iter(), |meta| {
if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta {
if nv.path.is_ident("index") {
if let Lit::Int(ref v) = nv.lit {
let byte = v
.base10_parse::<u8>()
.expect("Internal error, index attribute must have been checked");
return Some(byte);
}
}
}

None
});

index.map_or_else(
|| match v.discriminant.as_ref() {
Some((_, expr)) => Ok(quote! { #expr }),
None => {
let idx = self.next_index();
Ok(quote! { #idx })
},
},
|i| Ok(quote! { #i }),
)
}
pkhry marked this conversation as resolved.
Show resolved Hide resolved

fn next_index(&mut self) -> u8 {
loop {
if self.used_set.contains_key(&self.current) {
self.current += 1;
} else {
let index = self.current;
self.current += 1;
return index;
}
}
}
}

/// Look for a `#[codec(encoded_as = "SomeType")]` outer attribute on the given
Expand Down
8 changes: 4 additions & 4 deletions tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,17 @@ fn should_work_for_simple_enum() {
assert_eq!(slice, &b"\x0f");
});
b.using_encoded(|ref slice| {
assert_eq!(slice, &b"\x01\x01\0\0\0\x02\0\0\0\0\0\0\0");
assert_eq!(slice, &b"\x00\x01\0\0\0\x02\0\0\0\0\0\0\0");
});
c.using_encoded(|ref slice| {
assert_eq!(slice, &b"\x02\x01\0\0\0\x02\0\0\0\0\0\0\0");
assert_eq!(slice, &b"\x01\x01\0\0\0\x02\0\0\0\0\0\0\0");
});

let mut da: &[u8] = b"\x0f";
assert_eq!(EnumType::decode(&mut da).ok(), Some(a));
let mut db: &[u8] = b"\x01\x01\0\0\0\x02\0\0\0\0\0\0\0";
let mut db: &[u8] = b"\x00\x01\0\0\0\x02\0\0\0\0\0\0\0";
assert_eq!(EnumType::decode(&mut db).ok(), Some(b));
let mut dc: &[u8] = b"\x02\x01\0\0\0\x02\0\0\0\0\0\0\0";
let mut dc: &[u8] = b"\x01\x01\0\0\0\x02\0\0\0\0\0\0\0";
assert_eq!(EnumType::decode(&mut dc).ok(), Some(c));
let mut dz: &[u8] = &[0];
assert_eq!(EnumType::decode(&mut dz).ok(), None);
Expand Down
16 changes: 15 additions & 1 deletion tests/variant_number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fn discriminant_variant_counted_in_default_index() {
}

assert_eq!(T::A.encode(), vec![1]);
assert_eq!(T::B.encode(), vec![1]);
assert_eq!(T::B.encode(), vec![0]);
}

#[test]
Expand All @@ -36,5 +36,19 @@ fn index_attr_variant_counted_and_reused_in_default_index() {
}

assert_eq!(T::A.encode(), vec![1]);
assert_eq!(T::B.encode(), vec![0]);
}
#[test]
fn index_attr_variant_duplicates_indices() {
// Tests codec index overriding and that variant indexes are without duplicates
#[derive(DeriveEncode)]
enum T {
#[codec(index = 0)]
A = 1,
#[codec(index = 1)]
B = 0,
}

assert_eq!(T::A.encode(), vec![0]);
assert_eq!(T::B.encode(), vec![1]);
}
Loading