Skip to content

Commit

Permalink
Fix #229
Browse files Browse the repository at this point in the history
Possibly needs more tests covering combinations with non-basic groups
mixed with basic, tagged basic groups and optional fields (should work
though)
  • Loading branch information
rooooooooob committed Apr 11, 2024
1 parent bc9c3d7 commit f964d76
Show file tree
Hide file tree
Showing 6 changed files with 324 additions and 52 deletions.
110 changes: 72 additions & 38 deletions src/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6552,22 +6552,30 @@ fn generate_enum(
// We avoid checking ALL variants if we can figure it out by instead checking the type.
// This only works when the variants don't have first types in common.
let mut non_overlapping_types_match = {
let mut first_types = BTreeSet::new();
let mut duplicates = false;
let mut all_first_types = BTreeSet::new();
let mut duplicates_or_unknown = false;
for variant in variants.iter() {
for first_type in variant.cbor_types_inner(types, rep) {
// to_byte(0) is used since cbor_event::Type doesn't implement
// Ord or Hash so we can't put it in a set. Since we fix the lenth
// to always 0 this still remains a 1-to-1 mapping to Type.
if !first_types.insert(first_type.to_byte(0)) {
duplicates = true;
match variant.cbor_types_inner(types, rep) {
Some(first_types) => {
for first_type in first_types.iter() {
// to_byte(0) is used since cbor_event::Type doesn't implement
// Ord or Hash so we can't put it in a set. Since we fix the lenth
// to always 0 this still remains a 1-to-1 mapping to Type.
if !all_first_types.insert(first_type.to_byte(0)) {
duplicates_or_unknown = true;
}
}
}
None => {
duplicates_or_unknown = true;
break;
}
}
}
if duplicates {
if duplicates_or_unknown {
None
} else {
let deser_covers_all_types = first_types.len() == 8;
let deser_covers_all_types = all_first_types.len() == 8;
Some((Block::new("match raw.cbor_type()?"), deser_covers_all_types))
}
};
Expand Down Expand Up @@ -6869,12 +6877,22 @@ fn generate_enum(
} else {
variant.name_as_var()
};
let (before, after) =
if cli.preserve_encodings || !variant.rust_type().is_fixed_value() {
(Cow::from(format!("let {var_names_str} = ")), ";")
} else {
(Cow::from(""), "")
// also used to check if it's a basic rust group
let basic_rust_group_def_len =
match ty.conceptual_type.resolve_alias_shallow() {
ConceptualRustType::Rust(ident) if types.is_plain_group(ident) => {
Some(types.rust_struct(ident).unwrap().cbor_len_info(types))
}
_ => None,
};
let (before, after) = if cli.preserve_encodings
|| !variant.rust_type().is_fixed_value()
|| basic_rust_group_def_len.is_some()
{
(Cow::from(format!("let {var_names_str} = ")), ";")
} else {
(Cow::from(""), "")
};
let mut variant_deser_code = gen_scope.generate_deserialize(
types,
(variant.rust_type()).into(),
Expand All @@ -6883,34 +6901,49 @@ fn generate_enum(
cli,
);
let names_without_outer = enum_gen_info.names_without_outer();
// we can avoid this ugly block and directly do it as a line possibly
if variant_deser_code.content.as_single_line().is_some()
&& names_without_outer.len() == 1
{
variant_deser_code = gen_scope.generate_deserialize(
types,
(variant.rust_type()).into(),
DeserializeBeforeAfter::new(
&format!("Ok({}::{}(", name, variant.name),
"))",
false,
),
DeserializeConfig::new(&variant.name_as_var()),
if let Some(len_info) = basic_rust_group_def_len {
// this will never be 1 line to don't bother with the below cases
variant_deser_code = surround_in_len_checks(
variant_deser_code,
len_info,
rep.unwrap(),
cli,
);
} else if names_without_outer.is_empty() {
variant_deser_code.content.line(&format!(
"Ok({}::{}({}))",
name, variant.name, var_names_str
));
variant_deser_code
.content
.line(&format!("Ok({}::{})", name, variant.name));
} else {
enum_gen_info.generate_constructor(
&mut variant_deser_code.content,
"Ok(",
")",
None,
);
// we can avoid this ugly block and directly do it as a line possibly
if variant_deser_code.content.as_single_line().is_some()
&& names_without_outer.len() == 1
{
variant_deser_code = gen_scope.generate_deserialize(
types,
(variant.rust_type()).into(),
DeserializeBeforeAfter::new(
&format!("Ok({}::{}(", name, variant.name),
"))",
false,
),
DeserializeConfig::new(&variant.name_as_var()),
cli,
);
} else if names_without_outer.is_empty() {
variant_deser_code
.content
.line(&format!("Ok({}::{})", name, variant.name));
} else {
enum_gen_info.generate_constructor(
&mut variant_deser_code.content,
"Ok(",
")",
None,
);
}
variant_deser_code
}
variant_deser_code
}
EnumVariantData::Inlined(record) => make_inline_deser_code(
gen_scope,
Expand All @@ -6924,6 +6957,7 @@ fn generate_enum(
};
let cbor_types_str = variant
.cbor_types_inner(types, rep)
.expect("Already checked above")
.into_iter()
.map(cbor_type_code_str)
.collect::<Vec<_>>()
Expand Down
44 changes: 37 additions & 7 deletions src/intermediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2047,28 +2047,58 @@ impl EnumVariant {
}
}

/// Gets the next CBOR type after the passed in rep (array/map) tag
/// Returns None if this is not possible and brute-force deserialization
/// trying every variant should be used instead
pub fn cbor_types_inner(
&self,
types: &IntermediateTypes,
rep: Option<Representation>,
) -> Vec<CBORType> {
outer_rep: Option<Representation>,
) -> Option<Vec<CBORType>> {
match &self.data {
EnumVariantData::RustType(ty) => ty.cbor_types(types),
EnumVariantData::RustType(ty) => {
if ty.encodings.is_empty() && outer_rep.is_some() {
if let ConceptualRustType::Rust(ident) =
ty.conceptual_type.resolve_alias_shallow()
{
match types.rust_struct(ident).unwrap().variant() {
// we can't know this unless there's a way to provide this info
RustStructType::Extern => None,
RustStructType::Record(record) => {
let mut ret = vec![];
for field in record.fields.iter() {
ret.extend(field.rust_type.cbor_types(types));
if !field.optional {
break;
}
}
Some(ret)
}
RustStructType::GroupChoice { .. } => None,
_ => Some(ty.cbor_types(types)),
}
} else {
Some(ty.cbor_types(types))
}
} else {
Some(ty.cbor_types(types))
}
}
EnumVariantData::Inlined(record) => {
if rep.is_some() {
if outer_rep.is_some() {
let mut ret = vec![];
for field in record.fields.iter() {
ret.extend(field.rust_type.cbor_types(types));
if !field.optional {
break;
}
}
ret
Some(ret)
} else {
match record.rep {
Some(match record.rep {
Representation::Array => vec![CBORType::Array],
Representation::Map => vec![CBORType::Map],
}
})
}
}
}
Expand Down
34 changes: 33 additions & 1 deletion tests/core/input.cddl
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,45 @@ non_overlapping_type_choice_all = uint / nint / text / bytes / #6.30("hello worl

non_overlapping_type_choice_some = uint / nint / text

non_overlap_basic_embed = [
overlap_basic_embed = [
; @name identity
tag: 0 //
; @name x
tag: 1, hash: bytes .size 32
]

non_overlap_basic_embed = [
; @name first
x: uint, tag: 0 //
; @name second
y: text, tag: 1
]

non_overlap_basic_embed_multi_fields = [
; @name first
x: uint, z: uint //
; @name second
y: text, z: uint
]

non_overlap_basic_embed_mixed = [
; @name first
x: uint, tag: 0 //
; @name second
y: text, z: uint
]

third = (bytes, uint)

non_overlap_basic_embed_mixed_explicit = [
; @name first
x: uint, tag: 0 //
; @name second
y: text, z: uint //
; we don't use name dsl due to: https://github.com/dcSpark/cddl-codegen/issues/230
third
]

enums = [
c_enum,
type_choice,
Expand Down
29 changes: 27 additions & 2 deletions tests/core/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,35 @@ mod tests {
deser_test(&NonOverlappingTypeChoiceSome::Text("Hello, World!".into()));
}

#[test]
fn overlap_basic_embed() {
deser_test(&OverlapBasicEmbed::new_identity());
deser_test(&OverlapBasicEmbed::new_x(vec![85; 32]).unwrap());
}

#[test]
fn non_overlap_basic_embed() {
deser_test(&NonOverlapBasicEmbed::new_identity());
deser_test(&NonOverlapBasicEmbed::new_x(vec![85; 32]).unwrap());
deser_test(&NonOverlapBasicEmbed::new_first(100));
deser_test(&NonOverlapBasicEmbed::new_second("cddl".to_owned()));
}

#[test]
fn non_overlap_basic_embed_multi_fields() {
deser_test(&NonOverlapBasicEmbedMultiFields::new_first(100, 1_000_000));
deser_test(&NonOverlapBasicEmbedMultiFields::new_second("cddl".to_owned(), 0));
}

#[test]
fn non_overlap_basic_embed_mixed() {
deser_test(&NonOverlapBasicEmbedMixed::new_first(100));
deser_test(&NonOverlapBasicEmbedMixed::new_second("cddl".to_owned(), 0));
}

#[test]
fn non_overlap_basic_embed_mixed_explicit() {
deser_test(&NonOverlapBasicEmbedMixedExplicit::new_first(100));
deser_test(&NonOverlapBasicEmbedMixedExplicit::new_second("cddl".to_owned(), 0));
deser_test(&NonOverlapBasicEmbedMixedExplicit::new_third(vec![0xBA, 0xAD, 0xF0, 0x0D], 4));
}

#[test]
Expand Down
34 changes: 33 additions & 1 deletion tests/preserve-encodings/input.cddl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,45 @@ non_overlapping_type_choice_all = uint / nint / text / bytes / #6.13("hello worl

non_overlapping_type_choice_some = uint / nint / text ; @used_as_key

non_overlap_basic_embed = [
overlap_basic_embed = [
; @name identity
tag: 0 //
; @name x
tag: 1, hash: bytes .size 32
]

non_overlap_basic_embed = [
; @name first
x: uint, tag: 0 //
; @name second
y: text, tag: 1
]

non_overlap_basic_embed_multi_fields = [
; @name first
x: uint, z: uint //
; @name second
y: text, z: uint
]

non_overlap_basic_embed_mixed = [
; @name first
x: uint, tag: 0 //
; @name second
y: text, z: uint
]

third = (bytes, uint)

non_overlap_basic_embed_mixed_explicit = [
; @name first
x: uint, tag: 0 //
; @name second
y: text, z: uint //
; we don't use name dsl due to: https://github.com/dcSpark/cddl-codegen/issues/230
third
]

c_enum = 3 / 1 / 4

enums = [
Expand Down
Loading

0 comments on commit f964d76

Please sign in to comment.