Skip to content

Commit

Permalink
Implement RustCodeGenerator::add_local_derive
Browse files Browse the repository at this point in the history
  • Loading branch information
pablosichert committed Oct 2, 2024
1 parent e49434b commit 2c10bdb
Showing 1 changed file with 133 additions and 0 deletions.
133 changes: 133 additions & 0 deletions asn1rs-model/src/generate/rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub trait GeneratorSupplement<T> {
pub struct RustCodeGenerator {
models: Vec<Model<Rust>>,
global_derives: Vec<String>,
local_derives: HashMap<String, Vec<String>>,
local_attrs: HashMap<String, Vec<String>>,
direct_field_access: bool,
getter_and_setter: bool,
Expand All @@ -55,6 +56,7 @@ impl Default for RustCodeGenerator {
RustCodeGenerator {
models: Default::default(),
global_derives: Vec::default(),
local_derives: HashMap::new(),
local_attrs: HashMap::new(),
direct_field_access: true,
getter_and_setter: false,
Expand Down Expand Up @@ -93,6 +95,18 @@ impl RustCodeGenerator {
self
}

pub fn add_local_derive<N: Into<String>, I: Into<String>>(&mut self, name: N, derive: I) {
self.local_derives
.entry(name.into())
.or_default()
.push(derive.into());
}

pub fn without_additional_local_derives(mut self) -> Self {
self.local_derives.clear();
self
}

pub fn add_local_attr<N: Into<String>, I: Into<String>>(&mut self, name: N, attr: I) {
self.local_attrs
.entry(name.into())
Expand Down Expand Up @@ -913,6 +927,11 @@ impl RustCodeGenerator {
self.global_derives.iter().for_each(|derive| {
str_ct.derive(derive);
});
if let Some(local_derives) = self.local_derives.get(name) {
local_derives.iter().for_each(|derive| {
str_ct.derive(derive);
});
}
if let Some(local_attrs) = self.local_attrs.get(name) {
local_attrs.iter().for_each(|attr| {
str_ct.attr(attr);
Expand All @@ -935,6 +954,11 @@ impl RustCodeGenerator {
self.global_derives.iter().for_each(|derive| {
en_m.derive(derive);
});
if let Some(local_derives) = self.local_derives.get(name) {
local_derives.iter().for_each(|derive| {
en_m.derive(derive);
});
}
if let Some(local_attrs) = self.local_attrs.get(name) {
local_attrs.iter().for_each(|attr| {
en_m.r#macro(&format!("#[{attr}]")); // Workaround for missing `.attr` for enums in codegen
Expand Down Expand Up @@ -1036,6 +1060,115 @@ pub(crate) mod tests {
);
}

#[test]
pub fn test_struct_local_derive() {
let model = Model::try_from(Tokenizer::default().parse(
r#"Test DEFINITIONS AUTOMATIC TAGS ::=
BEGIN
MyStruct ::= SEQUENCE {
myField BOOLEAN
}
END
"#,
))
.unwrap()
.try_resolve()
.unwrap()
.to_rust();

let mut generator = RustCodeGenerator::from(model).without_additional_global_derives();
generator.add_local_derive("MyStruct", "MyDerive");
let (_file_name, file_content) = generator
.to_string_without_generators()
.into_iter()
.next()
.unwrap();

assert_starts_with_lines(
r#"
use asn1rs::prelude::*;
#[asn(sequence)]
#[derive(Default, Debug, Clone, PartialEq, Hash, MyDerive)]
pub struct MyStruct {
#[asn(boolean)] pub my_field: bool,
}
impl MyStruct {
}
"#,
&file_content,
);
}

#[test]
pub fn test_enum_local_derive() {
let model = Model::try_from(Tokenizer::default().parse(
r#"Test DEFINITIONS AUTOMATIC TAGS ::=
BEGIN
MyEnum ::= ENUMERATED {
a,
b
}
END
"#,
))
.unwrap()
.try_resolve()
.unwrap()
.to_rust();

let mut generator = RustCodeGenerator::from(model).without_additional_global_derives();
generator.add_local_derive("MyEnum", "MyDerive");
let (_file_name, file_content) = generator
.to_string_without_generators()
.into_iter()
.next()
.unwrap();

assert_starts_with_lines(
r#"
use asn1rs::prelude::*;
#[asn(enumerated)]
#[derive(Debug, Clone, PartialEq, Hash, Copy, PartialOrd, Eq, MyDerive, Default)]
pub enum MyEnum {
#[default] A,
B,
}
impl MyEnum {
pub fn variant(index: usize) -> Option<Self> {
match index {
0 => Some(MyEnum::A),
1 => Some(MyEnum::B),
_ => None,
}
}
pub const fn variants() -> [Self; 2] {
[
MyEnum::A,
MyEnum::B,
]
}
pub fn value_index(self) -> usize {
match self {
MyEnum::A => 0,
MyEnum::B => 1,
}
}
}
"#,
&file_content,
);
}

#[test]
pub fn test_struct_local_attr() {
let model = Model::try_from(Tokenizer::default().parse(
Expand Down

0 comments on commit 2c10bdb

Please sign in to comment.