From 909a79944d4b6ab32031acfd02c3eba76f85b379 Mon Sep 17 00:00:00 2001 From: AlexSherbinin Date: Thu, 18 Jul 2024 21:34:22 +0300 Subject: [PATCH] feat: added support for enums that consuming strings they were parsed from, changed behaviour to derive `From` trait instead of `TryFrom` when default variant is defined --- strum_macros/Cargo.toml | 2 +- strum_macros/src/helpers/case_style.rs | 34 ++--- strum_macros/src/helpers/lifetime_check.rs | 20 +++ strum_macros/src/helpers/metadata.rs | 46 +++---- strum_macros/src/helpers/mod.rs | 1 + strum_macros/src/macros/enum_is.rs | 3 +- strum_macros/src/macros/enum_try_as.rs | 4 +- strum_macros/src/macros/strings/display.rs | 26 ++-- .../src/macros/strings/from_string.rs | 123 ++++++++++-------- strum_tests/src/lib.rs | 2 +- strum_tests/tests/as_ref_str.rs | 2 +- strum_tests/tests/enum_is.rs | 33 +++-- strum_tests/tests/enum_try_as.rs | 16 +-- strum_tests/tests/enum_variant_table.rs | 4 +- strum_tests/tests/from_str.rs | 16 +++ strum_tests/tests/to_string.rs | 2 +- 16 files changed, 187 insertions(+), 147 deletions(-) create mode 100644 strum_macros/src/helpers/lifetime_check.rs diff --git a/strum_macros/Cargo.toml b/strum_macros/Cargo.toml index c053ac26..7ccf2525 100644 --- a/strum_macros/Cargo.toml +++ b/strum_macros/Cargo.toml @@ -23,7 +23,7 @@ heck = "0.5.0" proc-macro2 = "1.0" quote = "1.0" rustversion = "1.0" -syn = { version = "2.0", features = ["parsing", "extra-traits"] } +syn = { version = "2.0", features = ["parsing", "extra-traits", "visit"] } [dev-dependencies] strum = { path = "../strum", version= "0.26" } diff --git a/strum_macros/src/helpers/case_style.rs b/strum_macros/src/helpers/case_style.rs index bcea7886..dabcecd3 100644 --- a/strum_macros/src/helpers/case_style.rs +++ b/strum_macros/src/helpers/case_style.rs @@ -116,6 +116,23 @@ impl CaseStyleHelpers for Ident { } } +/// heck doesn't treat numbers as new words, but this function does. +/// E.g. for input `Hello2You`, heck would output `hello2_you`, and snakify would output `hello_2_you`. +pub fn snakify(s: &str) -> String { + let mut output: Vec = s.to_string().to_snake_case().chars().collect(); + let mut num_starts = vec![]; + for (pos, c) in output.iter().enumerate() { + if c.is_ascii_digit() && pos != 0 && !output[pos - 1].is_ascii_digit() { + num_starts.push(pos); + } + } + // need to do in reverse, because after inserting, all chars after the point of insertion are off + for i in num_starts.into_iter().rev() { + output.insert(i, '_') + } + output.into_iter().collect() +} + #[cfg(test)] mod tests { use super::*; @@ -159,20 +176,3 @@ mod tests { assert_eq!(MixedCase, f("mixed_case").unwrap()); } } - -/// heck doesn't treat numbers as new words, but this function does. -/// E.g. for input `Hello2You`, heck would output `hello2_you`, and snakify would output `hello_2_you`. -pub fn snakify(s: &str) -> String { - let mut output: Vec = s.to_string().to_snake_case().chars().collect(); - let mut num_starts = vec![]; - for (pos, c) in output.iter().enumerate() { - if c.is_digit(10) && pos != 0 && !output[pos - 1].is_digit(10) { - num_starts.push(pos); - } - } - // need to do in reverse, because after inserting, all chars after the point of insertion are off - for i in num_starts.into_iter().rev() { - output.insert(i, '_') - } - output.into_iter().collect() -} diff --git a/strum_macros/src/helpers/lifetime_check.rs b/strum_macros/src/helpers/lifetime_check.rs new file mode 100644 index 00000000..fa550d73 --- /dev/null +++ b/strum_macros/src/helpers/lifetime_check.rs @@ -0,0 +1,20 @@ +use syn::visit::Visit; + +#[derive(Default)] +struct LifetimeVisitor { + contains_lifetime: bool, +} + +impl<'ast> Visit<'ast> for LifetimeVisitor { + fn visit_lifetime(&mut self, i: &'ast syn::Lifetime) { + self.contains_lifetime = true; + syn::visit::visit_lifetime(self, i); + } +} + +pub fn contains_lifetime(ty: &syn::Type) -> bool { + let mut visitor = LifetimeVisitor::default(); + visitor.visit_type(ty); + + visitor.contains_lifetime +} diff --git a/strum_macros/src/helpers/metadata.rs b/strum_macros/src/helpers/metadata.rs index 94100a7f..0084bc0c 100644 --- a/strum_macros/src/helpers/metadata.rs +++ b/strum_macros/src/helpers/metadata.rs @@ -87,7 +87,7 @@ impl Parse for EnumMeta { } pub enum EnumDiscriminantsMeta { - Derive { kw: kw::derive, paths: Vec }, + Derive { _kw: kw::derive, paths: Vec }, Name { kw: kw::name, name: Ident }, Vis { kw: kw::vis, vis: Visibility }, Other { path: Path, nested: TokenStream }, @@ -96,12 +96,12 @@ pub enum EnumDiscriminantsMeta { impl Parse for EnumDiscriminantsMeta { fn parse(input: ParseStream) -> syn::Result { if input.peek(kw::derive) { - let kw = input.parse()?; + let _kw = input.parse()?; let content; parenthesized!(content in input); let paths = content.parse_terminated(Path::parse, Token![,])?; Ok(EnumDiscriminantsMeta::Derive { - kw, + _kw, paths: paths.into_iter().collect(), }) } else if input.peek(kw::name) { @@ -154,7 +154,7 @@ pub enum VariantMeta { value: LitStr, }, Serialize { - kw: kw::serialize, + _kw: kw::serialize, value: LitStr, }, Documentation { @@ -175,45 +175,38 @@ pub enum VariantMeta { value: bool, }, Props { - kw: kw::props, + _kw: kw::props, props: Vec<(LitStr, LitStr)>, }, } impl Parse for VariantMeta { fn parse(input: ParseStream) -> syn::Result { - let lookahead = input.lookahead1(); - if lookahead.peek(kw::message) { - let kw = input.parse()?; + if let Ok(kw) = input.parse() { let _: Token![=] = input.parse()?; let value = input.parse()?; Ok(VariantMeta::Message { kw, value }) - } else if lookahead.peek(kw::detailed_message) { - let kw = input.parse()?; + } else if let Ok(kw) = input.parse() { let _: Token![=] = input.parse()?; let value = input.parse()?; Ok(VariantMeta::DetailedMessage { kw, value }) - } else if lookahead.peek(kw::serialize) { - let kw = input.parse()?; + } else if let Ok(_kw) = input.parse() { let _: Token![=] = input.parse()?; let value = input.parse()?; - Ok(VariantMeta::Serialize { kw, value }) - } else if lookahead.peek(kw::to_string) { - let kw = input.parse()?; + Ok(VariantMeta::Serialize { _kw, value }) + } else if let Ok(kw) = input.parse() { let _: Token![=] = input.parse()?; let value = input.parse()?; Ok(VariantMeta::ToString { kw, value }) - } else if lookahead.peek(kw::disabled) { - Ok(VariantMeta::Disabled(input.parse()?)) - } else if lookahead.peek(kw::default) { - Ok(VariantMeta::Default(input.parse()?)) - } else if lookahead.peek(kw::default_with) { - let kw = input.parse()?; + } else if let Ok(kw) = input.parse() { + Ok(VariantMeta::Disabled(kw)) + } else if let Ok(kw) = input.parse() { + Ok(VariantMeta::Default(kw)) + } else if let Ok(kw) = input.parse() { let _: Token![=] = input.parse()?; let value = input.parse()?; Ok(VariantMeta::DefaultWith { kw, value }) - } else if lookahead.peek(kw::ascii_case_insensitive) { - let kw = input.parse()?; + } else if let Ok(kw) = input.parse() { let value = if input.peek(Token![=]) { let _: Token![=] = input.parse()?; input.parse::()?.value @@ -221,20 +214,19 @@ impl Parse for VariantMeta { true }; Ok(VariantMeta::AsciiCaseInsensitive { kw, value }) - } else if lookahead.peek(kw::props) { - let kw = input.parse()?; + } else if let Ok(_kw) = input.parse() { let content; parenthesized!(content in input); let props = content.parse_terminated(Prop::parse, Token![,])?; Ok(VariantMeta::Props { - kw, + _kw, props: props .into_iter() .map(|Prop(k, v)| (LitStr::new(&k.to_string(), k.span()), v)) .collect(), }) } else { - Err(lookahead.error()) + Err(input.lookahead1().error()) } } } diff --git a/strum_macros/src/helpers/mod.rs b/strum_macros/src/helpers/mod.rs index 23d60b53..e421dff6 100644 --- a/strum_macros/src/helpers/mod.rs +++ b/strum_macros/src/helpers/mod.rs @@ -5,6 +5,7 @@ pub use self::variant_props::HasStrumVariantProperties; pub mod case_style; pub mod inner_variant_props; +pub mod lifetime_check; mod metadata; pub mod type_props; pub mod variant_props; diff --git a/strum_macros/src/macros/enum_is.rs b/strum_macros/src/macros/enum_is.rs index c239628d..3237475f 100644 --- a/strum_macros/src/macros/enum_is.rs +++ b/strum_macros/src/macros/enum_is.rs @@ -42,6 +42,5 @@ pub fn enum_is_inner(ast: &DeriveInput) -> syn::Result { impl #impl_generics #enum_name #ty_generics #where_clause { #(#variants)* } - } - .into()) + }) } diff --git a/strum_macros/src/macros/enum_try_as.rs b/strum_macros/src/macros/enum_try_as.rs index c6d0127c..203a4e92 100644 --- a/strum_macros/src/macros/enum_try_as.rs +++ b/strum_macros/src/macros/enum_try_as.rs @@ -64,9 +64,7 @@ pub fn enum_try_as_inner(ast: &DeriveInput) -> syn::Result { } }) }, - _ => { - return None; - } + _ => None } }) diff --git a/strum_macros/src/macros/strings/display.rs b/strum_macros/src/macros/strings/display.rs index 5e7a4229..9c9ae353 100644 --- a/strum_macros/src/macros/strings/display.rs +++ b/strum_macros/src/macros/strings/display.rs @@ -37,7 +37,8 @@ pub fn display_inner(ast: &DeriveInput) -> syn::Result { .enumerate() .map(|(index, field)| { assert!(field.ident.is_none()); - let ident = syn::parse_str::(format!("field{}", index).as_str()).unwrap(); + let ident = + syn::parse_str::(format!("field{}", index).as_str()).unwrap(); quote! { ref #ident } }) .collect(); @@ -97,14 +98,14 @@ pub fn display_inner(ast: &DeriveInput) -> syn::Result { #name::#ident #params => ::core::fmt::Display::fmt(&format!(#output, #args), f) } } - }, + } Fields::Unnamed(ref unnamed_fields) => { let used_vars = capture_format_strings(&output)?; if used_vars.iter().any(String::is_empty) { return Err(syn::Error::new_spanned( &output, "Empty {} is not allowed; Use manual numbering ({0})", - )) + )); } if used_vars.is_empty() { quote! { #name::#ident #params => ::core::fmt::Display::fmt(#output, f) } @@ -157,14 +158,17 @@ pub fn display_inner(ast: &DeriveInput) -> syn::Result { } fn capture_format_string_idents(string_literal: &LitStr) -> syn::Result> { - capture_format_strings(string_literal)?.into_iter().map(|ident| { - syn::parse_str::(ident.as_str()).map_err(|_| { - syn::Error::new_spanned( - string_literal, - "Invalid identifier inside format string bracket", - ) + capture_format_strings(string_literal)? + .into_iter() + .map(|ident| { + syn::parse_str::(ident.as_str()).map_err(|_| { + syn::Error::new_spanned( + string_literal, + "Invalid identifier inside format string bracket", + ) + }) }) - }).collect() + .collect() } fn capture_format_strings(string_literal: &LitStr) -> syn::Result> { @@ -193,7 +197,7 @@ fn capture_format_strings(string_literal: &LitStr) -> syn::Result> { ))?; let inside_brackets = &format_str[start_index + 1..i]; - let ident_str = inside_brackets.split(":").next().unwrap().trim_end(); + let ident_str = inside_brackets.split(':').next().unwrap().trim_end(); var_used.push(ident_str.to_owned()); } } diff --git a/strum_macros/src/macros/strings/from_string.rs b/strum_macros/src/macros/strings/from_string.rs index de32cbbe..c2941ea3 100644 --- a/strum_macros/src/macros/strings/from_string.rs +++ b/strum_macros/src/macros/strings/from_string.rs @@ -3,8 +3,8 @@ use quote::quote; use syn::{Data, DeriveInput, Fields}; use crate::helpers::{ - non_enum_error, occurrence_error, HasInnerVariantProperties, HasStrumVariantProperties, - HasTypeProperties, + lifetime_check::contains_lifetime, non_enum_error, occurrence_error, HasInnerVariantProperties, + HasStrumVariantProperties, HasTypeProperties, }; pub fn from_string_inner(ast: &DeriveInput) -> syn::Result { @@ -21,6 +21,7 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result { let mut default_kw = None; let mut default = quote! { ::core::result::Result::Err(#strum_module_path::ParseError::VariantNotFound) }; + let mut is_default_generic_over_lifetime = false; let mut phf_exact_match_arms = Vec::new(); let mut standard_match_arms = Vec::new(); @@ -37,19 +38,27 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result { return Err(occurrence_error(fst_kw, kw, "default")); } - match &variant.fields { - Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {} - _ => { - return Err(syn::Error::new_spanned( - variant, - "Default only works on newtype structs with a single String field", - )) + let fields_error = syn::Error::new_spanned( + variant, + "Default only works on newtype structs with a single String field", + ); + let field = match &variant.fields { + Fields::Unnamed(fields) => { + if let Some(field) = fields.unnamed.iter().next() { + field + } else { + return Err(fields_error); + } } - } + _ => return Err(fields_error), + }; + default_kw = Some(kw); default = quote! { ::core::result::Result::Ok(#name::#ident(s.into())) }; + is_default_generic_over_lifetime = contains_lifetime(&field.ty); + continue; } @@ -143,56 +152,60 @@ pub fn from_string_inner(ast: &DeriveInput) -> syn::Result { } }; - let from_str = quote! { - #[allow(clippy::use_self)] - impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause { - type Err = #strum_module_path::ParseError; - fn from_str(s: &str) -> ::core::result::Result< #name #ty_generics , ::Err> { - #phf_body - #standard_match_body + let error_ty = if default_kw.is_some() { + quote! { ::core::convert::Infallible } + } else { + quote! { #strum_module_path::ParseError } + }; + + let from_str_owned = if !is_default_generic_over_lifetime { + quote! { + #[allow(clippy::use_self)] + impl #impl_generics ::core::str::FromStr for #name #ty_generics #where_clause { + type Err = #error_ty; + fn from_str(s: &str) -> ::core::result::Result< #name #ty_generics , ::Err> { + ::core::convert::TryFrom::try_from(s) + } } } + } else { + TokenStream::default() }; - let try_from_str = try_from_str( - name, - &impl_generics, - &ty_generics, - where_clause, - &strum_module_path, - ); - - Ok(quote! { - #from_str - #try_from_str - }) -} -#[rustversion::before(1.34)] -fn try_from_str( - _name: &proc_macro2::Ident, - _impl_generics: &syn::ImplGenerics, - _ty_generics: &syn::TypeGenerics, - _where_clause: Option<&syn::WhereClause>, - _strum_module_path: &syn::Path, -) -> TokenStream { - Default::default() -} + let str_lifetime = if is_default_generic_over_lifetime { + ast.generics.lifetimes().next().map(|param| ¶m.lifetime) + } else { + None + }; -#[rustversion::since(1.34)] -fn try_from_str( - name: &proc_macro2::Ident, - impl_generics: &syn::ImplGenerics, - ty_generics: &syn::TypeGenerics, - where_clause: Option<&syn::WhereClause>, - strum_module_path: &syn::Path, -) -> TokenStream { - quote! { - #[allow(clippy::use_self)] - impl #impl_generics ::core::convert::TryFrom<&str> for #name #ty_generics #where_clause { - type Error = #strum_module_path::ParseError; - fn try_from(s: &str) -> ::core::result::Result< #name #ty_generics , >::Error> { - ::core::str::FromStr::from_str(s) + let from_str = if default_kw.is_some() { + quote! { + impl #impl_generics From<& #str_lifetime str> for #name #ty_generics #where_clause { + fn from(s: & #str_lifetime str) -> #name #ty_generics { + let result: Result<_, ::core::convert::Infallible> = (|| { + #phf_body + #standard_match_body + })(); + + result.unwrap() + } } } - } + } else { + quote! { + #[allow(clippy::use_self)] + impl #impl_generics ::core::convert::TryFrom<& #str_lifetime str> for #name #ty_generics #where_clause { + type Error = #error_ty; + fn try_from(s: & #str_lifetime str) -> ::core::result::Result< #name #ty_generics, >::Error> { + #phf_body + #standard_match_body + } + } + } + }; + + Ok(quote! { + #from_str_owned + #from_str + }) } diff --git a/strum_tests/src/lib.rs b/strum_tests/src/lib.rs index eb3f65dd..2d6f45e3 100644 --- a/strum_tests/src/lib.rs +++ b/strum_tests/src/lib.rs @@ -10,6 +10,6 @@ pub enum Color { Blue { hue: usize }, #[strum(serialize = "y", serialize = "yellow")] Yellow, - #[strum(disabled)] + #[strum(default)] Green(String), } diff --git a/strum_tests/tests/as_ref_str.rs b/strum_tests/tests/as_ref_str.rs index d492f174..f151677d 100644 --- a/strum_tests/tests/as_ref_str.rs +++ b/strum_tests/tests/as_ref_str.rs @@ -1,4 +1,4 @@ -#![allow(deprecated)] +#![allow(deprecated, dead_code)] use std::str::FromStr; use strum::{AsRefStr, AsStaticRef, AsStaticStr, EnumString, IntoStaticStr}; diff --git a/strum_tests/tests/enum_is.rs b/strum_tests/tests/enum_is.rs index d45ae4cb..ff9bd320 100644 --- a/strum_tests/tests/enum_is.rs +++ b/strum_tests/tests/enum_is.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use std::borrow::Cow; use strum::EnumIs; @@ -28,12 +30,9 @@ enum Foo { } #[test] fn generics_test() { - let foo = LifeTimeTest::One(Cow::Borrowed("Hello")); - assert!(foo.is_one()); - let foo = LifeTimeTest::Two("Hello"); - assert!(foo.is_two()); - let foo = LifeTimeTest::One(Cow::Owned("Hello".to_string())); - assert!(foo.is_one()); + assert!(LifeTimeTest::One(Cow::Borrowed("Hello")).is_one()); + assert!(LifeTimeTest::Two("Hello").is_two()); + assert!(LifeTimeTest::One(Cow::Owned("Hello".to_string())).is_one()); } #[test] fn simple_test() { @@ -47,19 +46,19 @@ fn named_0() { #[test] fn named_1() { - let foo = Foo::Named1 { - _a: Default::default(), - }; - assert!(foo.is_named_1()); + assert!(Foo::Named1 { + _a: Default::default() + } + .is_named_1()); } #[test] fn named_2() { - let foo = Foo::Named2 { + assert!(Foo::Named2 { _a: Default::default(), - _b: Default::default(), - }; - assert!(foo.is_named_2()); + _b: Default::default() + } + .is_named_2()); } #[test] @@ -69,14 +68,12 @@ fn unnamed_0() { #[test] fn unnamed_1() { - let foo = Foo::Unnamed1(Default::default()); - assert!(foo.is_unnamed_1()); + assert!(Foo::Unnamed1(Default::default()).is_unnamed_1()); } #[test] fn unnamed_2() { - let foo = Foo::Unnamed2(Default::default(), Default::default()); - assert!(foo.is_unnamed_2()); + assert!(Foo::Unnamed2(Default::default(), Default::default()).is_unnamed_2()); } #[test] diff --git a/strum_tests/tests/enum_try_as.rs b/strum_tests/tests/enum_try_as.rs index 6cbf81e5..7e973a87 100644 --- a/strum_tests/tests/enum_try_as.rs +++ b/strum_tests/tests/enum_try_as.rs @@ -19,23 +19,24 @@ enum Foo { #[test] fn unnamed_0() { - let foo = Foo::Unnamed0(); - assert_eq!(Some(()), foo.try_as_unnamed_0()); + assert_eq!(Some(()), Foo::Unnamed0().try_as_unnamed_0()); } #[test] fn unnamed_1() { - let foo = Foo::Unnamed1(128); - assert_eq!(Some(&128), foo.try_as_unnamed_1_ref()); + assert_eq!(Some(&128), Foo::Unnamed1(128).try_as_unnamed_1_ref()); } #[test] fn unnamed_2() { - let foo = Foo::Unnamed2(true, String::from("Hay")); - assert_eq!(Some((true, String::from("Hay"))), foo.try_as_unnamed_2()); + assert_eq!( + Some((true, String::from("Hay"))), + Foo::Unnamed2(true, String::from("Hay")).try_as_unnamed_2() + ); } #[test] +#[allow(clippy::disallowed_names)] fn can_mutate() { let mut foo = Foo::Unnamed1(128); if let Some(value) = foo.try_as_unnamed_1_mut() { @@ -46,6 +47,5 @@ fn can_mutate() { #[test] fn doesnt_match_other_variations() { - let foo = Foo::Unnamed1(66); - assert_eq!(None, foo.try_as_unnamed_0()); + assert_eq!(None, Foo::Unnamed1(66).try_as_unnamed_0()); } diff --git a/strum_tests/tests/enum_variant_table.rs b/strum_tests/tests/enum_variant_table.rs index 25e854fd..38a1be57 100644 --- a/strum_tests/tests/enum_variant_table.rs +++ b/strum_tests/tests/enum_variant_table.rs @@ -6,7 +6,7 @@ enum Color { Yellow, Green, #[strum(disabled)] - Teal, + _Teal, Blue, #[strum(disabled)] Indigo, @@ -16,7 +16,7 @@ enum Color { // because if it doesn't compile, enum variants that conflict with keywords won't work #[derive(EnumTable)] enum Keyword { - Const, + _Const, } #[test] diff --git a/strum_tests/tests/from_str.rs b/strum_tests/tests/from_str.rs index 734282bf..b6485323 100644 --- a/strum_tests/tests/from_str.rs +++ b/strum_tests/tests/from_str.rs @@ -28,6 +28,15 @@ enum Color { White(String), } +#[derive(Debug, Eq, PartialEq, EnumString)] +#[strum(serialize_all = "UPPERCASE")] +enum HttpMethod<'a> { + Get, + Post, + #[strum(default)] + Unrecognized(&'a str), +} + #[rustversion::since(1.34)] fn assert_from_str<'a, T>(a: T, from: &'a str) where @@ -229,3 +238,10 @@ fn color_default_with_white() { } } } + +#[test] +fn http_method_from() { + assert_eq!(HttpMethod::Get, HttpMethod::from("GET")); + assert_eq!(HttpMethod::Post, HttpMethod::from("POST")); + assert_eq!(HttpMethod::Unrecognized("HEAD"), HttpMethod::from("HEAD")); +} diff --git a/strum_tests/tests/to_string.rs b/strum_tests/tests/to_string.rs index c3c29730..ab6017e4 100644 --- a/strum_tests/tests/to_string.rs +++ b/strum_tests/tests/to_string.rs @@ -1,4 +1,4 @@ -#![allow(deprecated)] +#![allow(deprecated, clippy::to_string_trait_impl)] use std::str::FromStr; use std::string::ToString;