diff --git a/README.md b/README.md index 8b079051..69d32d40 100644 --- a/README.md +++ b/README.md @@ -216,7 +216,7 @@ The derive implementation supports the following attributes: - `codec(encoded_as = "OtherType")`: Needs to be placed above a field and makes the field being encoded by using `OtherType`. - `codec(index = 0)`: Needs to be placed above an enum variant to make the variant use the given - index when encoded. By default the index is determined by counting from `0` beginning wth the + index when encoded. By default the index is determined by counting from `0` beginning with the first variant. - `codec(encode_bound)`, `codec(decode_bound)` and `codec(mel_bound)`: All 3 attributes take in a `where` clause for the `Encode`, `Decode` and `MaxEncodedLen` trait implementation for diff --git a/derive/src/decode.rs b/derive/src/decode.rs index 1e228776..593305c2 100644 --- a/derive/src/decode.rs +++ b/derive/src/decode.rs @@ -44,10 +44,18 @@ pub fn quote( Ok(variants) => variants, Err(e) => return e.to_compile_error(), }; - - let recurse = variants.iter().enumerate().map(|(i, v)| { + match utils::check_indexes(variants.iter()).map_err(|e| e.to_compile_error()) { + Ok(()) => (), + Err(e) => return e, + }; + let mut items = vec![]; + for (index, v) in variants.iter().enumerate() { let name = &v.ident; - let index = utils::variant_index(v, i); + let index = match utils::variant_index(v, index).map_err(|e| e.into_compile_error()) + { + Ok(i) => i, + Err(e) => return e, + }; let create = create_instance( quote! { #type_name #type_generics :: #name }, @@ -57,7 +65,7 @@ pub fn quote( crate_path, ); - quote_spanned! { v.span() => + let item = quote_spanned! { v.span() => #[allow(clippy::unnecessary_cast)] __codec_x_edqy if __codec_x_edqy == #index as ::core::primitive::u8 => { // NOTE: This lambda is necessary to work around an upstream bug @@ -68,8 +76,9 @@ pub fn quote( #create })(); }, - } - }); + }; + items.push(item); + } let read_byte_err_msg = format!("Could not decode `{type_name}`, failed to read variant byte"); @@ -79,7 +88,7 @@ pub fn quote( match #input.read_byte() .map_err(|e| e.chain(#read_byte_err_msg))? { - #( #recurse )* + #( #items )* _ => { #[allow(clippy::redundant_closure_call)] return (move || { diff --git a/derive/src/encode.rs b/derive/src/encode.rs index df7af38a..4df20a7c 100644 --- a/derive/src/encode.rs +++ b/derive/src/encode.rs @@ -306,12 +306,19 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS if variants.is_empty() { return quote!(); } - - let recurse = variants.iter().enumerate().map(|(i, f)| { + match utils::check_indexes(variants.iter()).map_err(|e| e.to_compile_error()) { + Ok(()) => (), + Err(e) => return e, + }; + let mut items = vec![]; + for (index, f) in variants.iter().enumerate() { let name = &f.ident; - let index = utils::variant_index(f, i); - - match f.fields { + let index = match utils::variant_index(f, index).map_err(|e| e.into_compile_error()) + { + Ok(i) => i, + Err(e) => return e, + }; + let item = match f.fields { Fields::Named(ref fields) => { let fields = &fields.named; let field_name = |_, ident: &Option| quote!(#ident); @@ -389,12 +396,12 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS [hinting, encoding] }, - } - }); - - let recurse_hinting = recurse.clone().map(|[hinting, _]| hinting); - let recurse_encoding = recurse.clone().map(|[_, encoding]| encoding); + }; + items.push(item) + } + let recurse_hinting = items.iter().map(|[hinting, _]| hinting); + let recurse_encoding = items.iter().map(|[_, encoding]| encoding); let hinting = quote! { // The variant index uses 1 byte. 1_usize + match *#self_ { diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 2f9ea7d9..8ba6d3de 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -94,8 +94,8 @@ fn wrap_with_dummy_const( /// * if variant has attribute: `#[codec(index = "$n")]` then n /// * else if variant has discriminant (like 3 in `enum T { A = 3 }`) then the discriminant. /// * else its position in the variant set, excluding skipped variants, but including variant with -/// discriminant or attribute. Warning this position does collision with discriminant or attribute -/// index. +/// discriminant or attribute. Warning this position does collision with discriminant or attribute +/// index. /// /// variant attributes: /// * `#[codec(skip)]`: the variant is not encoded. diff --git a/derive/src/utils.rs b/derive/src/utils.rs index 735e2348..cd3d5375 100644 --- a/derive/src/utils.rs +++ b/derive/src/utils.rs @@ -17,9 +17,9 @@ //! NOTE: attributes finder must be checked using check_attribute first, //! otherwise the macro can panic. -use std::str::FromStr; +use std::{collections::HashMap, str::FromStr}; -use proc_macro2::TokenStream; +use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::{ parse::Parse, punctuated::Punctuated, spanned::Spanned, token, Attribute, Data, DataEnum, @@ -38,11 +38,29 @@ where }) } +/// check usage of variant indexes with #[scale(index = $int)] attribute or +/// explicit discriminant on the variant +pub fn check_indexes<'a, I: Iterator>(values: I) -> syn::Result<()> { + let mut map: HashMap = HashMap::new(); + for (i, v) in values.enumerate() { + let index = variant_index(v, i)?; + if let Some(span) = map.insert(index, v.span()) { + let mut error = syn::Error::new( + v.span(), + "scale codec error: Invalid variant index, the variant index is duplicated.", + ); + error.combine(syn::Error::new(span, "Variant index used here.")); + return Err(error); + } + } + Ok(()) +} + /// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute /// is found, fall back to the discriminant or just the variant index. -pub fn variant_index(v: &Variant, i: usize) -> TokenStream { +pub fn variant_index(v: &Variant, index: usize) -> syn::Result { // first look for an attribute - let index = find_meta_item(v.attrs.iter(), |meta| { + let codec_index = find_meta_item(v.attrs.iter(), |meta| { if let Meta::NameValue(ref nv) = meta { if nv.path.is_ident("index") { if let Expr::Lit(ExprLit { lit: Lit::Int(ref v), .. }) = nv.value { @@ -56,14 +74,30 @@ pub fn variant_index(v: &Variant, i: usize) -> TokenStream { None }); - - // then fallback to discriminant or just index - index.map(|i| quote! { #i }).unwrap_or_else(|| { - v.discriminant - .as_ref() - .map(|(_, expr)| quote! { #expr }) - .unwrap_or_else(|| quote! { #i }) - }) + if let Some(index) = codec_index { + Ok(index) + } else { + match v.discriminant.as_ref() { + Some((_, syn::Expr::Lit(ExprLit { lit: syn::Lit::Int(v), .. }))) => { + let byte = v.base10_parse::().expect( + "scale codec error: Invalid variant index, discriminant doesn't fit u8.", + ); + Ok(byte) + }, + Some((_, expr)) => Err(syn::Error::new( + expr.span(), + "scale codec error: Invalid discriminant, only int literal are accepted, e.g. \ + `= 32`.", + )), + None => index.try_into().map_err(|_| { + syn::Error::new( + v.span(), + "scale codec error: Variant index is too large, only 256 variants are \ + supported.", + ) + }), + } + } } /// Look for a `#[codec(encoded_as = "SomeType")]` outer attribute on the given diff --git a/tests/scale_codec_ui/codec_duplicate_index.rs b/tests/scale_codec_ui/codec_duplicate_index.rs new file mode 100644 index 00000000..dc3e666e --- /dev/null +++ b/tests/scale_codec_ui/codec_duplicate_index.rs @@ -0,0 +1,17 @@ +#[derive(::parity_scale_codec::Decode, ::parity_scale_codec::Encode)] +#[codec(crate = ::parity_scale_codec)] +enum T { + A = 3, + #[codec(index = 3)] + B, +} + +#[derive(::parity_scale_codec::Decode, ::parity_scale_codec::Encode)] +#[codec(crate = ::parity_scale_codec)] +enum T1 { + A, + #[codec(index = 0)] + B, +} + +fn main() {} diff --git a/tests/scale_codec_ui/codec_duplicate_index.stderr b/tests/scale_codec_ui/codec_duplicate_index.stderr new file mode 100644 index 00000000..afa68f85 --- /dev/null +++ b/tests/scale_codec_ui/codec_duplicate_index.stderr @@ -0,0 +1,23 @@ +error: scale codec error: Invalid variant index, the variant index is duplicated. + --> tests/scale_codec_ui/codec_duplicate_index.rs:5:2 + | +5 | #[codec(index = 3)] + | ^ + +error: Variant index used here. + --> tests/scale_codec_ui/codec_duplicate_index.rs:4:2 + | +4 | A = 3, + | ^ + +error: scale codec error: Invalid variant index, the variant index is duplicated. + --> tests/scale_codec_ui/codec_duplicate_index.rs:13:2 + | +13 | #[codec(index = 0)] + | ^ + +error: Variant index used here. + --> tests/scale_codec_ui/codec_duplicate_index.rs:12:2 + | +12 | A, + | ^ diff --git a/tests/scale_codec_ui/discriminant_variant_counted_in_default_index.rs b/tests/scale_codec_ui/discriminant_variant_counted_in_default_index.rs new file mode 100644 index 00000000..03836635 --- /dev/null +++ b/tests/scale_codec_ui/discriminant_variant_counted_in_default_index.rs @@ -0,0 +1,16 @@ +#[derive(::parity_scale_codec::Decode, ::parity_scale_codec::Encode)] +#[codec(crate = ::parity_scale_codec)] +enum T { + A = 1, + B, +} + +#[derive(::parity_scale_codec::Decode, ::parity_scale_codec::Encode)] +#[codec(crate = ::parity_scale_codec)] +enum T2 { + #[codec(index = 1)] + A, + B, +} + +fn main() {} diff --git a/tests/scale_codec_ui/discriminant_variant_counted_in_default_index.stderr b/tests/scale_codec_ui/discriminant_variant_counted_in_default_index.stderr new file mode 100644 index 00000000..8ca151b3 --- /dev/null +++ b/tests/scale_codec_ui/discriminant_variant_counted_in_default_index.stderr @@ -0,0 +1,23 @@ +error: scale codec error: Invalid variant index, the variant index is duplicated. + --> tests/scale_codec_ui/discriminant_variant_counted_in_default_index.rs:5:2 + | +5 | B, + | ^ + +error: Variant index used here. + --> tests/scale_codec_ui/discriminant_variant_counted_in_default_index.rs:4:2 + | +4 | A = 1, + | ^ + +error: scale codec error: Invalid variant index, the variant index is duplicated. + --> tests/scale_codec_ui/discriminant_variant_counted_in_default_index.rs:13:2 + | +13 | B, + | ^ + +error: Variant index used here. + --> tests/scale_codec_ui/discriminant_variant_counted_in_default_index.rs:11:2 + | +11 | #[codec(index = 1)] + | ^ diff --git a/tests/variant_number.rs b/tests/variant_number.rs index 9bdaba0a..0c71689c 100644 --- a/tests/variant_number.rs +++ b/tests/variant_number.rs @@ -1,18 +1,6 @@ use parity_scale_codec::Encode; use parity_scale_codec_derive::Encode as DeriveEncode; -#[test] -fn discriminant_variant_counted_in_default_index() { - #[derive(DeriveEncode)] - enum T { - A = 1, - B, - } - - assert_eq!(T::A.encode(), vec![1]); - assert_eq!(T::B.encode(), vec![1]); -} - #[test] fn skipped_variant_not_counted_in_default_index() { #[derive(DeriveEncode)] @@ -27,14 +15,16 @@ fn skipped_variant_not_counted_in_default_index() { } #[test] -fn index_attr_variant_counted_and_reused_in_default_index() { +fn index_attr_variant_duplicates_indices() { + // Tests codec index overriding and that variant indexes are without duplicates #[derive(DeriveEncode)] enum T { + #[codec(index = 0)] + A = 1, #[codec(index = 1)] - A, - B, + B = 0, } - assert_eq!(T::A.encode(), vec![1]); + assert_eq!(T::A.encode(), vec![0]); assert_eq!(T::B.encode(), vec![1]); }