diff --git a/benches/main.rs b/benches/main.rs index 958ba83..ef92ca6 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -1,9 +1,9 @@ use std::ops::Range; -use criterion::{criterion_group, criterion_main, Criterion, BenchmarkId, BatchSize, Throughput}; +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput}; +use fake::*; use rand::seq::IteratorRandom; use typed_graph::generic_graph::{GenericGraph, GenericResult}; -use fake::*; type TestGraph = GenericGraph; type TestResult = GenericResult; @@ -17,7 +17,7 @@ pub struct CompleteGraph { /// Number of node types (types start at 0) pub node_types: usize, /// Number of edge types (types start at 0) - pub edge_types: usize + pub edge_types: usize, } impl Dummy for TestGraph { @@ -25,21 +25,23 @@ impl Dummy for TestGraph { let mut g = TestGraph::default(); let node_count = config.width * config.height; for i in 0..node_count { - g.add_node((i, rng.gen_range(0..config.node_types))).unwrap(); + g.add_node((i, rng.gen_range(0..config.node_types))) + .unwrap(); } - + if config.edge_types != 0 { for x in 0..config.width { for y in 0..config.height { - if x == y { + if x == y { continue; } - + g.add_edge( - x, - y, - (x + y * config.width, rng.gen_range(0..config.edge_types)) - ).unwrap(); + x, + y, + (x + y * config.width, rng.gen_range(0..config.edge_types)), + ) + .unwrap(); } } } @@ -59,7 +61,7 @@ pub struct SparseGraph { /// Number of node types (types start at 0) pub node_types: usize, /// Number of edge types (types start at 0) - pub edge_types: usize + pub edge_types: usize, } impl Dummy for TestGraph { @@ -67,11 +69,16 @@ impl Dummy for TestGraph { let mut g = TestGraph::default(); let node_count = config.width * config.height; for i in 0..node_count { - g.add_node((i, rng.gen_range(0..config.node_types))).unwrap(); + g.add_node((i, rng.gen_range(0..config.node_types))) + .unwrap(); } for i in 0..node_count { - let edge_count = config.node_connections.clone().choose(rng).unwrap_or_default(); + let edge_count = config + .node_connections + .clone() + .choose(rng) + .unwrap_or_default(); for _ in 0..edge_count { // Figure out which node to connect to @@ -83,7 +90,12 @@ impl Dummy for TestGraph { continue; } - g.add_edge(i, target, (g.edge_count(), rng.gen_range(0..config.edge_types))).unwrap(); + g.add_edge( + i, + target, + (g.edge_count(), rng.gen_range(0..config.edge_types)), + ) + .unwrap(); } } @@ -93,145 +105,135 @@ impl Dummy for TestGraph { fn add_node(c: &mut Criterion) { let mut group = c.benchmark_group("Node+"); - + for side_length in (0..=100).step_by(25) { - group.throughput(Throughput::Elements((side_length*side_length) as u64)); + group.throughput(Throughput::Elements((side_length * side_length) as u64)); group.bench_with_input::<_, _, usize>( - BenchmarkId::from_parameter(side_length*side_length), - &side_length, + BenchmarkId::from_parameter(side_length * side_length), + &side_length, |b, side_length| { - - b.iter_batched_ref( - || { - TestGraph::default() - }, - |g| { - for i in 0..(side_length*side_length) { - g.add_node((i, i))?; - } + b.iter_batched_ref( + || TestGraph::default(), + |g| { + for i in 0..(side_length * side_length) { + g.add_node((i, i))?; + } - TestResult::Ok(()) - }, - BatchSize::SmallInput - ) - }); + TestResult::Ok(()) + }, + BatchSize::SmallInput, + ) + }, + ); } group.finish(); } fn remove_node(c: &mut Criterion) { let mut group = c.benchmark_group("Node-"); - + for side_length in (0..=100).step_by(25) { - group.throughput(Throughput::Elements((side_length*side_length) as u64)); + group.throughput(Throughput::Elements((side_length * side_length) as u64)); group.bench_with_input::<_, _, usize>( - BenchmarkId::from_parameter(side_length*side_length), - &side_length, + BenchmarkId::from_parameter(side_length * side_length), + &side_length, |b, side_length| { - - b.iter_batched_ref( - || { - CompleteGraph { - width: *side_length, - height: *side_length, - node_types: side_length*side_length, - edge_types: side_length*side_length - }.fake::() - }, - |g| { - let node_count = g.node_count(); - for i in 0..node_count { - g.remove_node(i)?; - } + b.iter_batched_ref( + || { + CompleteGraph { + width: *side_length, + height: *side_length, + node_types: side_length * side_length, + edge_types: side_length * side_length, + } + .fake::() + }, + |g| { + let node_count = g.node_count(); + for i in 0..node_count { + g.remove_node(i)?; + } - TestResult::Ok(()) - }, - BatchSize::SmallInput - ) - }); + TestResult::Ok(()) + }, + BatchSize::SmallInput, + ) + }, + ); } group.finish(); } fn add_edge(c: &mut Criterion) { let mut group = c.benchmark_group("Edge+"); - + for side_length in (0..=100).step_by(25) { - group.throughput(Throughput::Elements((side_length*side_length) as u64)); + group.throughput(Throughput::Elements((side_length * side_length) as u64)); group.bench_with_input::<_, _, usize>( - BenchmarkId::from_parameter(side_length*side_length), - &side_length, + BenchmarkId::from_parameter(side_length * side_length), + &side_length, |b, side_length| { - - b.iter_batched_ref( - || { - CompleteGraph { - width: *side_length, - height: *side_length, - node_types: side_length*side_length, - edge_types: 0 - }.fake::() - }, - |g| { - for x in 0..*side_length { - for y in 0..*side_length { - g.add_edge( - x, - y, - (x + y * 100, 0) - )?; + b.iter_batched_ref( + || { + CompleteGraph { + width: *side_length, + height: *side_length, + node_types: side_length * side_length, + edge_types: 0, + } + .fake::() + }, + |g| { + for x in 0..*side_length { + for y in 0..*side_length { + g.add_edge(x, y, (x + y * 100, 0))?; + } } - } - TestResult::Ok(()) - }, - BatchSize::SmallInput - ) - }); + TestResult::Ok(()) + }, + BatchSize::SmallInput, + ) + }, + ); } group.finish(); } fn remove_edge(c: &mut Criterion) { let mut group = c.benchmark_group("Edge-"); - + for side_length in (0..=100).step_by(25) { - group.throughput(Throughput::Elements((side_length*side_length) as u64)); + group.throughput(Throughput::Elements((side_length * side_length) as u64)); group.bench_with_input::<_, _, usize>( - BenchmarkId::from_parameter(side_length*side_length), - &side_length, + BenchmarkId::from_parameter(side_length * side_length), + &side_length, |b, side_length| { b.iter_batched_ref( || { CompleteGraph { width: *side_length, height: *side_length, - node_types: side_length*side_length, - edge_types: side_length*side_length - }.fake::() - }, + node_types: side_length * side_length, + edge_types: side_length * side_length, + } + .fake::() + }, |g| { let edge_count = g.edge_count(); for i in 0..edge_count { g.remove_edge(i)?; } - + TestResult::Ok(()) - }, - BatchSize::SmallInput + }, + BatchSize::SmallInput, ) - } + }, ); - } group.finish(); } -criterion_group!( - benches, - add_node, - remove_node, - add_edge, - remove_edge -); -criterion_main!(benches); \ No newline at end of file +criterion_group!(benches, add_node, remove_node, add_edge, remove_edge); +criterion_main!(benches); diff --git a/examples/json_graph.rs b/examples/json_graph.rs index ca1ef89..9af867d 100644 --- a/examples/json_graph.rs +++ b/examples/json_graph.rs @@ -1,25 +1,28 @@ +use serde::{Deserialize, Serialize}; +use serde_json::json; +use serde_json::value::Value; ///! In this example we will implement a json backed graph using serde_json -///! +///! ///! ## Implementation Details ///! The graph represents nodes and edges as json objects ///! The type and id of a node or edge are found in the "id" and "type" fields ///! The schema stores a list of allowed node types, edge types and relationships between edge types and node types ///! Migration of the schema is done by discarding any node or edge that does not have an allowed type by the new schema -///! +///! ///! ## Implementation Process ///! ### Define node and edge type ///! First we define the struct Weight which is a light wrapper around serde_json::Value ///! ```rust ///! pub struct Weight(Value); ///! ``` -///! +///! ///! ### Implementing graph traits ///! Then to use the Weight as a node and edge we need to implement NodeExt and EdgeExt ///! However before doing so we need to implement Id, Typed and PartialEq for Weight -///! +///! ///! The requirement for PartialEq comes from the fact that we set the Type for Typed as String. ///! The Type used in Typed is quick way of checking if a given weight has a specific type. -///! +///! ///! ### Implementing schema ///! Finally we define the schema ///! ```rust @@ -32,19 +35,15 @@ ///! edges: Vec<(String, String, String)> ///! } ///! ``` -///! +///! ///! Then to use the schema we implement SchemaExt and set the node and edge to Weight -///! +///! ///! ### Implementing migrations ///! Additionally we implement a migration strategy from one schema to another ///! For this example we just discard any node or edges that have types which are not allowed in the new schema - use std::collections::HashMap; use std::ops::{Deref, DerefMut}; -use serde_json::json; use typed_graph::*; -use serde_json::value::Value; -use serde::{Serialize, Deserialize}; type WeightId = u64; @@ -75,36 +74,32 @@ impl From for Weight { impl Weight { /// Get the id field from json and return it as a u64 - /// + /// /// If the id is not available this will return 0 fn get_id_from_json(&self) -> u64 { - self - .as_object() + self.as_object() .and_then(|obj| obj.get("id")) .and_then(|ty| ty.as_u64()) .unwrap_or_default() } - + /// Set the id if possible fn set_id_for_json(&mut self, id: u64) { - self - .as_object_mut() + self.as_object_mut() .and_then(|obj| obj.insert("id".to_string(), id.into())); - } - + } + /// Get the type and return it as a string refference - /// + /// /// If the type is not available this will return an empty string fn get_type_from_json(&self) -> &str { - self - .as_object() + self.as_object() .and_then(|obj| obj.get("type")) .and_then(|ty| ty.as_str()) .unwrap_or_default() } } - /// Make the id available to the graph impl Id for Weight { fn get_id(&self) -> WeightId { @@ -134,7 +129,7 @@ impl PartialEq for Weight { } /// Here we use the weight for both nodes and edges -/// Often you would want to store different values +/// Often you would want to store different values /// on the nodes and edges and therefore seperate the tow impl NodeExt for Weight {} impl EdgeExt for Weight {} @@ -149,11 +144,11 @@ struct JsonSchema { /// List of allowed node tyes nodes: Vec, /// List of allowed edge types, source node types and target node types - edges: Vec<(String, String, String)> + edges: Vec<(String, String, String)>, } /// Here we define the rules for the schema -/// +/// /// This is what is used to check if a given type is allowed impl SchemaExt for JsonSchema { type E = Weight; @@ -174,14 +169,14 @@ impl SchemaExt for JsonSchema { } /// Only let edges in the whitelist through - /// - /// This could be modified to also check for quantity by returning + /// + /// This could be modified to also check for quantity by returning /// DisAllowedEdge::ToMany if the count exceeds a specified amount fn allow_edge( - &self, + &self, _new_edge_count: usize, - edge_ty: ::Type, - source: ::Type, + edge_ty: ::Type, + source: ::Type, target: ::Type, ) -> Result<(), DisAllowedEdge> { // an edge is allowed if it is present in the schema @@ -216,22 +211,29 @@ impl MigrateSchema for JsonSchema { /// Update of edge types and node types should be consistent with that of update node and edge /// We achieve this here by just copy pasting the check we did for nodes and edges above - fn update_edge_type(&self, new_schema: &JsonSchema, edge_type: ::Type) -> Option<::Type> { + fn update_edge_type( + &self, + new_schema: &JsonSchema, + edge_type: ::Type, + ) -> Option<::Type> { let is_allowed = new_schema.edges.iter().any(|(ty, _, _)| &edge_type == ty); is_allowed.then(|| edge_type) } - fn update_node_type(&self, new_schema: &JsonSchema, node_type: ::Type) -> Option<::Type> { + fn update_node_type( + &self, + new_schema: &JsonSchema, + node_type: ::Type, + ) -> Option<::Type> { let is_allowed = new_schema.nodes.iter().any(|ty| &node_type == ty); is_allowed.then(|| node_type) } } - impl Migration for JsonSchema { // Here we can set a migration handler // A migration handler allows us to change the graph manually while the migration is being done - // In this example we are only interested in + // In this example we are only interested in type Handler = DefaultMigrationHandler; } @@ -245,17 +247,13 @@ fn main() -> SchemaResult<(), WeightId, WeightId, JsonSchema> { // -- let schemav0 = JsonSchema { version: "V0".to_string(), - nodes: vec![ - "A".to_string(), - "B".to_string(), - "C".to_string(), - ], + nodes: vec!["A".to_string(), "B".to_string(), "C".to_string()], edges: vec![ ("AB".to_string(), "A".to_string(), "B".to_string()), ("BC".to_string(), "B".to_string(), "C".to_string()), ("CA".to_string(), "C".to_string(), "A".to_string()), - ("CC".to_string(), "C".to_string(), "C".to_string()) - ] + ("CC".to_string(), "C".to_string(), "C".to_string()), + ], }; let mut gv0 = TypedGraph::new(schemav0); @@ -269,14 +267,11 @@ fn main() -> SchemaResult<(), WeightId, WeightId, JsonSchema> { let bc_id = gv0.add_edge(b_id, c_id, json!({"id": 1, "type": "BC"}))?; let ca_id = gv0.add_edge(c_id, a_id, json!({"id": 2, "type": "CA"}))?; - // trying to add a type that is not part of the schema will result in an error let new_node_id = gv0.add_node(json!({"id": 2, "type": "D"})); assert!(new_node_id.is_err()); println!("Adding node D"); println!("{:?}", new_node_id); - - // The same thing happens when trying to add an edge with a type that is not allowed let new_edge_id = gv0.add_edge(c_id, a_id, json!({"id": 2, "type": "AB"})); @@ -284,13 +279,9 @@ fn main() -> SchemaResult<(), WeightId, WeightId, JsonSchema> { println!("Adding edge AC"); println!("{:?}", new_edge_id); - - // Calling add on an id that is already used will update the type of the node or edge at that position // This only works if the replaced type is compatible with all the connected nodes and edges - - // We are also able to add multiple edges between the same nodes let dublicate_edge_id = gv0.add_edge(a_id, b_id, json!({"id": 3, "type": "AB"}))?; gv0.remove_edge(dublicate_edge_id)?; @@ -299,8 +290,6 @@ fn main() -> SchemaResult<(), WeightId, WeightId, JsonSchema> { let dublicate_edge_id = gv0.add_edge(c_id, c_id, json!({"id": 3, "type": "CC"}))?; gv0.remove_edge(dublicate_edge_id)?; - - // if we remove a node all its surrounding edges will be removed aswell let a = gv0.remove_node(a_id)?; assert_eq!(gv0.has_edge(ab_id), false); @@ -309,8 +298,6 @@ fn main() -> SchemaResult<(), WeightId, WeightId, JsonSchema> { gv0.add_edge(a_id, b_id, json!({"id": 0, "type": "AB"}))?; gv0.add_edge(c_id, a_id, json!({"id": 2, "type": "CA"}))?; - - // Traversal of the graph is done using the get_outgoing, get_incoming and get_incoming_and_outgoing functions let a_outgoing: Vec<_> = gv0.get_outgoing(a_id)?.collect(); let b_incoming: Vec<_> = gv0.get_incoming(b_id)?.collect(); @@ -327,8 +314,6 @@ fn main() -> SchemaResult<(), WeightId, WeightId, JsonSchema> { assert_eq!(b_incoming_edge.get_source(), a_id); assert_eq!(b_incoming_edge.get_target(), b_id); - - // When traversing in both directions at the same time it can be difficult to keep track of which direction the given edge is going // So to make this easer the get_inner and get_outer method can be used let b_both: Vec<_> = gv0.get_incoming_and_outgoing(b_id)?.collect(); @@ -345,11 +330,12 @@ fn main() -> SchemaResult<(), WeightId, WeightId, JsonSchema> { // But get_outer will always take you away from the starting node assert_ne!(edge0.get_outer(), edge1.get_outer()); - - // Using these short hands make traversal code work independant of direction // Here is an example of finding the longest path from a node in both directions - fn longest_distance(weight_id: WeightId, g: &TypedGraph) -> Option { + fn longest_distance( + weight_id: WeightId, + g: &TypedGraph, + ) -> Option { // Return None if the node does not exist g.get_node_safe(weight_id)?; @@ -359,20 +345,22 @@ fn main() -> SchemaResult<(), WeightId, WeightId, JsonSchema> { if visited.contains_key(&front_id) { continue; } - + visited.insert(front_id, distance); - + // here we can focus on writing the implementation instead of having to bother with directions for edge in g.get_incoming_and_outgoing(front_id).unwrap() { front.push((edge.get_outer(), distance + 1)); } } - + visited.values().max().copied() } - println!("Longest distance from {} is {:?}", b_id, longest_distance(b_id, &gv0)); - - + println!( + "Longest distance from {} is {:?}", + b_id, + longest_distance(b_id, &gv0) + ); // For deep searches of a specific depth the searches can be chained // Here we walk along the path A -> B -> C -> A @@ -390,22 +378,23 @@ fn main() -> SchemaResult<(), WeightId, WeightId, JsonSchema> { assert_eq!(outer.len(), 1); let outer_node = outer[0].get_id(); assert_eq!(outer_node, a_id); - - // An alternative way of doing the same thing is using a GraphWalker /// Create a function to tell how to move forward - /// Along with the next node it also returns the type of the edge that has been traversed - fn move_forward<'a>(n: &'a Weight, gv0: &'a TypedGraph) -> SchemaResult, WeightId, WeightId, JsonSchema> { - Ok( - gv0.get_outgoing(n.get_id())? - .map(|e| ( - e.get_weight().get_type(), - gv0.get_node(e.get_outer()).unwrap() - )) - ) - } + /// Along with the next node it also returns the type of the edge that has been traversed + fn move_forward<'a>( + n: &'a Weight, + gv0: &'a TypedGraph, + ) -> SchemaResult, WeightId, WeightId, JsonSchema> + { + Ok(gv0.get_outgoing(n.get_id())?.map(|e| { + ( + e.get_weight().get_type(), + gv0.get_node(e.get_outer()).unwrap(), + ) + })) + } // Now the walker can use the function to traverse the graph let outer: Vec<_> = gv0 @@ -415,13 +404,11 @@ fn main() -> SchemaResult<(), WeightId, WeightId, JsonSchema> { .progress(move_forward) .progress(move_forward) .many()?; - + assert_eq!(outer.len(), 1); let outer_node = outer[0].get_id(); assert_eq!(outer_node, a_id); - - // The main benefit of using the walker is that once the move_forward has been made the syntax becomes esier // and it allows for a state to be keept for each of the branches @@ -449,19 +436,14 @@ fn main() -> SchemaResult<(), WeightId, WeightId, JsonSchema> { // Finally we can see the resulting state of the branch assert_eq!(walker_target.state, vec!["AB", "BC", "CA"]); - - // Now we can try and migrate the graph to a new schema without the B node let schemav1 = JsonSchema { version: "V0".to_string(), - nodes: vec![ - "A".to_string(), - "C".to_string(), - ], + nodes: vec!["A".to_string(), "C".to_string()], edges: vec![ ("CA".to_string(), "C".to_string(), "A".to_string()), - ("CC".to_string(), "C".to_string(), "C".to_string()) - ] + ("CC".to_string(), "C".to_string(), "C".to_string()), + ], }; let gv1 = gv0.migrate(schemav1, &DefaultMigrationHandler)?; @@ -474,6 +456,5 @@ fn main() -> SchemaResult<(), WeightId, WeightId, JsonSchema> { assert_eq!(gv1.has_node(c_id), true); assert_eq!(gv1.has_edge(ca_id), true); - Ok(()) -} \ No newline at end of file +} diff --git a/examples/static_graph.rs b/examples/static_graph.rs index 628b326..dba60b7 100644 --- a/examples/static_graph.rs +++ b/examples/static_graph.rs @@ -1,10 +1,10 @@ ///! In this example we show a simple template for creating a staticly typed schemas -///! +///! ///! Everything is handled at compile time meaning the schema does not store any data ///! This makes it very safe to use as we are guaranteed to know which types of edges and nodes will be in the graph ///! ///! This example creates make_node_structand make_edge_struct which autogenerating NodeType, EdgeType, Node, Edge and a type for every kind of node and edge -///! +///! ///! The Node and Edge are then used by the schema to define a set of rules for the graph ///! ///! The schema creates the graph @@ -15,10 +15,12 @@ ///! A ---> B ---> C ///! AB BC ///! ``` - use std::fmt::Display; -use typed_graph::{Key, NodeExt, Id, Typed, EdgeExt, SchemaExt, DisAllowedEdge, TypedGraph, SchemaResult, DowncastMut, Downcast, TypedError, ToGraphWalker}; +use typed_graph::{ + DisAllowedEdge, Downcast, DowncastMut, EdgeExt, Id, Key, NodeExt, SchemaExt, SchemaResult, + ToGraphWalker, Typed, TypedError, TypedGraph, +}; /// Create NodeType, Node and a type for each NodeType macro_rules! make_node_struct { @@ -36,7 +38,7 @@ macro_rules! make_node_struct { write!(f, "{:?}", self) } } - + // Create the individual nodes $( #[derive(Debug, Clone)] @@ -73,20 +75,20 @@ macro_rules! make_node_struct { fn get_id(&self) -> K { self.id } - + fn set_id(&mut self, new_id: K) { self.id = new_id } } - + impl Typed for $name { type Type = NodeType; - + fn get_type(&self) -> Self::Type { NodeType::$name } } - + impl PartialEq for $name { fn eq(&self, other: &NodeType) -> bool { other == &NodeType::$name @@ -132,7 +134,7 @@ macro_rules! make_node_struct { } } } - + impl NodeExt for $name {} )* @@ -161,10 +163,10 @@ macro_rules! make_node_struct { } } } - + impl Typed for Node { type Type = NodeType; - + fn get_type(&self) -> Self::Type { match self { $( @@ -173,7 +175,7 @@ macro_rules! make_node_struct { } } } - + impl PartialEq for Node { fn eq(&self, other: &NodeType) -> bool { match self { @@ -204,7 +206,7 @@ macro_rules! make_edge_struct { write!(f, "{:?}", self) } } - + // Create the individual edges $( #[derive(Debug, Clone)] @@ -277,7 +279,7 @@ macro_rules! make_edge_struct { match self { Edge::$name(e) => Ok(e), e => Err(TypedError::DownCastFailed( - stringify!($name).to_string(), + stringify!($name).to_string(), e.get_type().to_string()) ) } @@ -329,10 +331,10 @@ macro_rules! make_edge_struct { } } } - + impl Typed for Edge { type Type = EdgeType; - + fn get_type(&self) -> Self::Type { match self { $( @@ -341,7 +343,7 @@ macro_rules! make_edge_struct { } } } - + impl PartialEq for Edge { fn eq(&self, other: &EdgeType) -> bool { match self { @@ -356,7 +358,7 @@ macro_rules! make_edge_struct { }; } -make_node_struct!{ +make_node_struct! { A { name: String }, @@ -368,7 +370,7 @@ make_node_struct!{ } } -make_edge_struct!{ +make_edge_struct! { AB { distance: usize }, @@ -391,25 +393,22 @@ impl SchemaExt for Schema { "Schema".to_string() } - fn allow_node( - &self, - _node_ty: NodeType - ) -> Result<(), typed_graph::DisAllowedNode> { + fn allow_node(&self, _node_ty: NodeType) -> Result<(), typed_graph::DisAllowedNode> { Ok(()) } fn allow_edge( - &self, - _new_edge_count: usize, - edge_ty: EdgeType, - source: NodeType, - target: NodeType, - ) -> Result<(), typed_graph::DisAllowedEdge> { + &self, + _new_edge_count: usize, + edge_ty: EdgeType, + source: NodeType, + target: NodeType, + ) -> Result<(), typed_graph::DisAllowedEdge> { match (source, target, edge_ty) { (NodeType::A, NodeType::B, EdgeType::AB) | (NodeType::B, NodeType::C, EdgeType::BC) | (NodeType::C, NodeType::A, EdgeType::CA) => Ok(()), - _ => Err(DisAllowedEdge::InvalidType) + _ => Err(DisAllowedEdge::InvalidType), } } } @@ -441,7 +440,7 @@ fn main() -> SchemaResult<(), usize, usize, Schema> { // If we want to retrieve data from the graph // We can get the generic node let node = g.get_node(a_id)?; - + // And then make requests on that println!("Node id = {} type = {}", node.get_id(), node.get_type()); @@ -469,22 +468,25 @@ fn main() -> SchemaResult<(), usize, usize, Schema> { println!("CA distance = {}", ca.distance); // We can now traverse the graph to calculate the total distance between all the nodes - + /// Function to retrieve the next node in the chain /// The generic allows us to specify the type of the edge we will encounter /// Since we have defined the schema we also know which type of node will be encountered - /// + /// /// For larer projects, these might be defined per node/edge - fn get_connected_node<'a, E>(node: &'a Node, g: &'a TypedGraph) -> SchemaResult)>, usize, usize, Schema> + fn get_connected_node<'a, E>( + node: &'a Node, + g: &'a TypedGraph, + ) -> SchemaResult)>, usize, usize, Schema> where - >::E: Downcast + >::E: Downcast, { - Ok(g - .get_outgoing(node.get_id())? - .map(|e| ( - e.get_weight_downcast::().unwrap(), - g.get_node(e.get_target()).unwrap() - ))) + Ok(g.get_outgoing(node.get_id())?.map(|e| { + ( + e.get_weight_downcast::().unwrap(), + g.get_node(e.get_target()).unwrap(), + ) + })) } // traverse the graph using a GraphWalker @@ -492,18 +494,21 @@ fn main() -> SchemaResult<(), usize, usize, Schema> { .to_walker(&g)? .set_state(0) .progress_with_state( - get_connected_node::>, + get_connected_node::>, // Boiler plate code for incrementing the distance - |mut acc, nc| {acc += nc.distance; acc } - ) - .progress_with_state( - get_connected_node::>, - |mut acc, nc| {acc += nc.distance; acc } - ) - .progress_with_state( - get_connected_node::>, - |mut acc, nc| {acc += nc.distance; acc } + |mut acc, nc| { + acc += nc.distance; + acc + }, ) + .progress_with_state(get_connected_node::>, |mut acc, nc| { + acc += nc.distance; + acc + }) + .progress_with_state(get_connected_node::>, |mut acc, nc| { + acc += nc.distance; + acc + }) .one_with_state()? .unwrap() .state; @@ -511,4 +516,4 @@ fn main() -> SchemaResult<(), usize, usize, Schema> { println!("ABCA distance = {}", distance); Ok(()) -} \ No newline at end of file +} diff --git a/src/generic_graph/generic_graph.rs b/src/generic_graph/generic_graph.rs index dbd0189..5913ce7 100644 --- a/src/generic_graph/generic_graph.rs +++ b/src/generic_graph/generic_graph.rs @@ -1,7 +1,10 @@ -use crate::{NodeExt, Id, Typed, EdgeExt, SchemaExt, DisAllowedNode, DisAllowedEdge, TypedGraph, SchemaResult, Key, TypeIdentifier}; use super::GenericWeight; +use crate::{ + DisAllowedEdge, DisAllowedNode, EdgeExt, Id, Key, NodeExt, SchemaExt, SchemaResult, + TypeIdentifier, Typed, TypedGraph, +}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use serde::{Serialize, Deserialize}; use std::hash::Hash; // Define a node and edge type @@ -16,37 +19,31 @@ pub type GenericGraph = TypedGraph pub type GenericResult = SchemaResult>; /// Common trait for all generic node or edge types -/// +/// /// This has further requirements than TypeIdentifier due to the GenericSchema pub trait GenericTypeIdentifier: TypeIdentifier + Eq + Hash {} -impl GenericTypeIdentifier for T -where - T: TypeIdentifier + Eq + Hash -{} +impl GenericTypeIdentifier for T where T: TypeIdentifier + Eq + Hash {} /// Schema capable of controlling all aspects of the graph -/// -/// The schema is build +/// +/// The schema is build #[derive(Default, Clone, Debug, Serialize, Deserialize)] -pub struct GenericSchema < - NT: GenericTypeIdentifier, - ET: GenericTypeIdentifier -> { +pub struct GenericSchema { node_whitelist: Option>, node_blacklist: Option>, edge_whitelist: Option>, edge_blacklist: Option>, endpoint_whitelist: Option>, endpoint_blacklist: Option>, - endpoint_max_quantity: Option> + endpoint_max_quantity: Option>, } impl GenericSchema { - pub fn new() -> Self + pub fn new() -> Self where NT: Default, - ET: Default + ET: Default, { Default::default() } @@ -76,7 +73,10 @@ impl GenericSchema } /// Edge filter: (EdgeType, NodeType, NodeType) - pub fn endpoint_whitelist(mut self, edge_endpoint_whitelist: Option>) -> Self { + pub fn endpoint_whitelist( + mut self, + edge_endpoint_whitelist: Option>, + ) -> Self { self.endpoint_whitelist = edge_endpoint_whitelist; self } @@ -87,14 +87,22 @@ impl GenericSchema self } - pub fn endpoint_max_quantity(mut self, endpoint_max_quantity: Option>) -> Self { + pub fn endpoint_max_quantity( + mut self, + endpoint_max_quantity: Option>, + ) -> Self { self.endpoint_max_quantity = endpoint_max_quantity; self } - } -impl SchemaExt for GenericSchema { +impl SchemaExt for GenericSchema +where + NK: Key, + EK: Key, + NT: GenericTypeIdentifier, + ET: GenericTypeIdentifier, +{ type N = GenericWeight; type E = GenericWeight; @@ -103,43 +111,46 @@ impl Sch } fn allow_edge( - &self, + &self, new_edge_count: usize, - edge_ty: ::Type, - source: ::Type, + edge_ty: ::Type, + source: ::Type, target: ::Type, ) -> Result<(), crate::DisAllowedEdge> { - let is_whitelist = self.edge_whitelist.as_ref().map_or(true, |l| l.contains(&edge_ty)); - let is_blacklist = self.edge_blacklist.as_ref().map_or(true, |l| !l.contains(&edge_ty)); - + let is_whitelist = self + .edge_whitelist + .as_ref() + .map_or(true, |l| l.contains(&edge_ty)); + let is_blacklist = self + .edge_blacklist + .as_ref() + .map_or(true, |l| !l.contains(&edge_ty)); + let endpoint = (source.clone(), target.clone(), edge_ty.clone()); let is_endpoint_whitelist = self .endpoint_whitelist .as_ref() .map_or(true, |l| l.contains(&endpoint)); - + let is_endpoint_blacklist = self .endpoint_blacklist .as_ref() .map_or(true, |l| !l.contains(&endpoint)); - - let is_allowed_type = is_whitelist && is_blacklist && is_endpoint_whitelist && is_endpoint_blacklist; - - + + let is_allowed_type = + is_whitelist && is_blacklist && is_endpoint_whitelist && is_endpoint_blacklist; + if !is_allowed_type { - return Err(DisAllowedEdge::InvalidType); + return Err(DisAllowedEdge::InvalidType); } - let is_endpoint_quantity = self - .endpoint_max_quantity - .as_ref() - .map_or(true, |l| l - .get(&endpoint) + let is_endpoint_quantity = self.endpoint_max_quantity.as_ref().map_or(true, |l| { + l.get(&endpoint) .map_or(true, |quantity| new_edge_count <= *quantity) - ); + }); let is_allowed_quantity = is_endpoint_quantity; - + if !is_allowed_quantity { return Err(DisAllowedEdge::ToMany); } @@ -147,37 +158,66 @@ impl Sch Ok(()) } - fn allow_node( - &self, - node_ty: ::Type - ) -> Result<(), crate::DisAllowedNode> { - let is_whitelist = self.node_whitelist.as_ref().map_or(true, |l| l.contains(&node_ty)); - let is_blacklist = self.node_blacklist.as_ref().map_or(true, |l| !l.contains(&node_ty)); + fn allow_node(&self, node_ty: ::Type) -> Result<(), crate::DisAllowedNode> { + let is_whitelist = self + .node_whitelist + .as_ref() + .map_or(true, |l| l.contains(&node_ty)); + let is_blacklist = self + .node_blacklist + .as_ref() + .map_or(true, |l| !l.contains(&node_ty)); let is_allowed = is_whitelist && is_blacklist; if !is_allowed { - return Err(DisAllowedNode::InvalidType) + return Err(DisAllowedNode::InvalidType); } Ok(()) } } -impl GenericGraph { +impl + GenericGraph +{ pub fn assert_eq(&self, other: &Self) -> GenericResult<(), NK, EK, NT, ET> { - assert_eq!(self.node_count(), other.node_count(), "Inconsistent node count"); - assert_eq!(self.edge_count(), other.edge_count(), "Inconsistent edge count"); + assert_eq!( + self.node_count(), + other.node_count(), + "Inconsistent node count" + ); + assert_eq!( + self.edge_count(), + other.edge_count(), + "Inconsistent edge count" + ); for node in self.nodes() { let other_node = other.get_node(node.get_id())?; - assert_eq!(node.get_type(), other_node.get_type(), "Inconsistent node type"); + assert_eq!( + node.get_type(), + other_node.get_type(), + "Inconsistent node type" + ); } for edge in self.edges_full() { let other_edge = other.get_edge_full(edge.get_id())?; - assert_eq!(edge.get_type(), other_edge.get_type(), "Inconsistent edge type"); - assert_eq!(edge.get_target(), other_edge.get_target(), "Inconsistent edge target"); - assert_eq!(edge.get_source(), other_edge.get_source(), "Inconsistent edge source"); + assert_eq!( + edge.get_type(), + other_edge.get_type(), + "Inconsistent edge type" + ); + assert_eq!( + edge.get_target(), + other_edge.get_target(), + "Inconsistent edge target" + ); + assert_eq!( + edge.get_source(), + other_edge.get_source(), + "Inconsistent edge source" + ); } for node in self.node_ids() { @@ -193,4 +233,4 @@ impl Gen Ok(()) } -} \ No newline at end of file +} diff --git a/src/generic_graph/generic_weight.rs b/src/generic_graph/generic_weight.rs index 24ce6b2..d0a3963 100644 --- a/src/generic_graph/generic_weight.rs +++ b/src/generic_graph/generic_weight.rs @@ -1,9 +1,11 @@ -use std::ops::{DerefMut, Deref}; -use serde::{Serialize, Deserialize}; -use crate::{Id, Typed, Key}; use super::GenericTypeIdentifier; +use crate::{Id, Key, Typed}; +use serde::{Deserialize, Serialize}; +use std::ops::{Deref, DerefMut}; -#[derive(Debug, Default, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Serialize, Deserialize)] +#[derive( + Debug, Default, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash, Serialize, Deserialize, +)] pub struct GenericWeight((K, T)); impl Deref for GenericWeight { @@ -27,17 +29,17 @@ impl From<(K, T)> for GenericWeight { impl Id for GenericWeight { fn get_id(&self) -> K { - self.0.0 + self.0 .0 } fn set_id(&mut self, new_id: K) { - self.0.0 = new_id; + self.0 .0 = new_id; } } impl PartialEq for GenericWeight { fn eq(&self, other: &T) -> bool { - &self.0.1 == other + &self.0 .1 == other } } @@ -45,6 +47,6 @@ impl Typed for GenericWeight { type Type = T; fn get_type(&self) -> Self::Type { - self.0.1.clone() + self.0 .1.clone() } -} \ No newline at end of file +} diff --git a/src/generic_graph/mod.rs b/src/generic_graph/mod.rs index f9b4f7e..35a955a 100644 --- a/src/generic_graph/mod.rs +++ b/src/generic_graph/mod.rs @@ -1,5 +1,5 @@ mod generic_graph; mod generic_weight; +pub use generic_graph::*; pub use generic_weight::GenericWeight; -pub use generic_graph::*; \ No newline at end of file diff --git a/src/graph/edge_ref.rs b/src/graph/edge_ref.rs index 231389c..ba1a468 100644 --- a/src/graph/edge_ref.rs +++ b/src/graph/edge_ref.rs @@ -1,32 +1,32 @@ -use crate::{Direction, SchemaExt, Key, Downcast, SchemaResult}; +use crate::{Direction, Downcast, Key, SchemaExt, SchemaResult}; use std::ops::Deref; #[derive(Debug, Clone)] -pub struct EdgeRef<'a, NK, EK, S> +pub struct EdgeRef<'a, NK, EK, S> where NK: Key, EK: Key, - S: SchemaExt + S: SchemaExt, { pub(crate) weight: &'a S::E, pub(crate) source: NK, pub(crate) target: NK, - pub(crate) direction: Direction + pub(crate) direction: Direction, } -impl<'a, NK, EK, S> EdgeRef<'a, NK, EK, S> +impl<'a, NK, EK, S> EdgeRef<'a, NK, EK, S> where NK: Key, EK: Key, - S: SchemaExt + S: SchemaExt, { pub fn get_weight(&self) -> &'a S::E { self.weight } - pub fn get_weight_downcast(&self) -> SchemaResult<&'a E, NK, EK, S> + pub fn get_weight_downcast(&self) -> SchemaResult<&'a E, NK, EK, S> where - S::E: Downcast + S::E: Downcast<'a, NK, EK, &'a E, S>, { self.weight.downcast() } @@ -58,14 +58,14 @@ where } } -impl<'a, NK, EK, S> Deref for EdgeRef<'a, NK, EK, S> +impl<'a, NK, EK, S> Deref for EdgeRef<'a, NK, EK, S> where NK: Key, EK: Key, - S: SchemaExt + S: SchemaExt, { type Target = S::E; fn deref(&self) -> &Self::Target { &self.weight } -} \ No newline at end of file +} diff --git a/src/graph/graph_traits.rs b/src/graph/graph_traits.rs index 9d49e93..6854ad4 100644 --- a/src/graph/graph_traits.rs +++ b/src/graph/graph_traits.rs @@ -11,7 +11,7 @@ pub trait EdgeExt: Typed + Id + Clone + Debug {} pub trait SchemaExt: Sized where NK: Key, - EK: Key + EK: Key, { /// Type of node weights used by the schema type N: NodeExt; @@ -20,49 +20,43 @@ where /// Get the name of the schema in order to provide better error messages fn name(&self) -> String; - + /// Before adding a new node, check if the new node is allowed - /// + /// /// Upon encountering an invalid edge return Err(InvalidType) - fn allow_node( - &self, - node_ty: ::Type - ) -> Result<(), DisAllowedNode>; + fn allow_node(&self, node_ty: ::Type) -> Result<(), DisAllowedNode>; /// Before adding a new edge check if the new edge is allowed - /// + /// /// Upon encountering an invalid edge return Err(InvalidType) /// If the quantity limit is reached return Err(TomMny) fn allow_edge( - &self, + &self, new_edge_count: usize, - edge_ty: ::Type, - source: ::Type, + edge_ty: ::Type, + source: ::Type, target: ::Type, ) -> Result<(), DisAllowedEdge>; } #[derive(Debug)] pub enum DisAllowedNode { - InvalidType + InvalidType, } #[derive(Debug)] pub enum DisAllowedEdge { ToMany, - InvalidType + InvalidType, } /// Trait indicating a type can be used as a key in the graph -/// -/// Mostly common key types is integers and uuid's. +/// +/// Mostly common key types is integers and uuid's. /// By implementing this trait more exotic types can be used aswell -pub trait Key: Hash + Debug + PartialEq + Eq + Copy{} +pub trait Key: Hash + Debug + PartialEq + Eq + Copy {} -impl Key for K -where - K: Hash + Debug + Eq + Copy -{} +impl Key for K where K: Hash + Debug + Eq + Copy {} /// Provide a getter and setter for the id of a node or edge pub trait Id { @@ -82,47 +76,44 @@ impl Id for T { pub trait Name { type Name; - + /// Retrieve the name of a node or edge fn get_name(&self) -> Option<&Self::Name>; } pub trait TypeIdentifier: PartialEq + Display + Debug + Clone {} -impl TypeIdentifier for T -where - T: PartialEq + Display + Debug + Clone -{} +impl TypeIdentifier for T where T: PartialEq + Display + Debug + Clone {} pub trait Typed: PartialEq { type Type: TypeIdentifier; /// Retrieve a runtime representation of the type of the node or edge - /// + /// /// PartialEq can then be used on the returned type to check if other nodes has the same type fn get_type(&self) -> Self::Type; } -pub trait Downcast +pub trait Downcast<'a, NK, EK, T, S> where -NK: Key, -EK: Key, -S: SchemaExt, + NK: Key, + EK: Key, + S: SchemaExt, { /// Cast a node or edge into a more specific type - /// + /// /// The call will fail if the requested type is not a suptype of the current one - fn downcast(&self) -> SchemaResult<&T, NK, EK, S>; + fn downcast(&'a self) -> SchemaResult; } -pub trait DowncastMut +pub trait DowncastMut<'a, NK, EK, T, S> where -NK: Key, -EK: Key, -S: SchemaExt + NK: Key, + EK: Key, + S: SchemaExt, { /// Cast a node or edge into a more specific type - /// + /// /// The call will fail if the requested type is not a suptype of the current one - fn downcast_mut(&mut self) -> SchemaResult<&mut T, NK, EK, S>; -} \ No newline at end of file + fn downcast_mut(&'a mut self) -> SchemaResult; +} diff --git a/src/graph/graph_walker.rs b/src/graph/graph_walker.rs index b1f2646..493384a 100644 --- a/src/graph/graph_walker.rs +++ b/src/graph/graph_walker.rs @@ -1,23 +1,23 @@ -use crate::{TypedGraph, SchemaExt, Key, Id, SchemaResult, Typed}; +use crate::{Id, Key, SchemaExt, SchemaResult, Typed, TypedGraph}; use std::iter::{once, Once}; #[derive(Clone)] -pub struct GraphWalker<'a, T, State, NK, EK, S, Front> +pub struct GraphWalker<'a, T, State, NK, EK, S, Front> where NK: Key, EK: Key, S: SchemaExt, State: Clone, - Front: Iterator)> + Front: Iterator)>, { g: &'a TypedGraph, - front: Front + front: Front, } /// Type storing the value and state of the walker at a given point pub struct WalkerTarget { pub val: T, - pub state: State + pub state: State, } impl<'a, T, State, NK, EK, S, Progress> GraphWalker<'a, T, State, NK, EK, S, Progress> @@ -26,41 +26,70 @@ where EK: Key, S: SchemaExt, State: Clone, - Progress: Iterator)> + 'a + Progress: Iterator)> + 'a, { - pub fn new(g: &'a TypedGraph) -> GraphWalker<'a, (), (), NK, EK, S, impl Iterator)> > { + pub fn new( + g: &'a TypedGraph, + ) -> GraphWalker<'a, (), (), NK, EK, S, impl Iterator)>> + { GraphWalker { g, - front: once(((), Ok(()))) + front: once(((), Ok(()))), } } - pub fn new_from(g: &'a TypedGraph, start: NK) -> GraphWalker<'a, &'a S::N, (), NK, EK, S, impl Iterator)>> { + pub fn new_from( + g: &'a TypedGraph, + start: NK, + ) -> GraphWalker< + 'a, + &'a S::N, + (), + NK, + EK, + S, + impl Iterator)>, + > { GraphWalker { g, - front: g.get_node_safe(start).map(|n| ((), Ok(n))).into_iter() + front: g.get_node_safe(start).map(|n| ((), Ok(n))).into_iter(), } } pub fn set_state( - self, - new_state: NewState - ) -> GraphWalker<'a, T, NewState, NK, EK, S, impl Iterator)>> + self, + new_state: NewState, + ) -> GraphWalker< + 'a, + T, + NewState, + NK, + EK, + S, + impl Iterator)>, + > where - NewState: Clone + NewState: Clone, { GraphWalker { g: self.g, - front: self.front - .map(move |(_, res)| (new_state.clone(), res)) + front: self.front.map(move |(_, res)| (new_state.clone(), res)), } } /// Moves the walker forward without changing the state of the branch pub fn progress<'b, NewT, NextStep, StateAddition, WalkerStep>( - self, - walker_step: WalkerStep - ) -> GraphWalker<'b, NewT, State, NK, EK, S, impl Iterator)> + 'b> + self, + walker_step: WalkerStep, + ) -> GraphWalker< + 'b, + NewT, + State, + NK, + EK, + S, + impl Iterator)> + 'b, + > where 'a: 'b, 'b: 'a, @@ -68,7 +97,7 @@ where StateAddition: Clone + 'b, ::IntoIter: 'b, NextStep: IntoIterator, - WalkerStep: Fn(T, &'a TypedGraph) -> SchemaResult + 'b + WalkerStep: Fn(T, &'a TypedGraph) -> SchemaResult + 'b, { GraphWalker { g: self.g, @@ -78,7 +107,7 @@ where |e| { Box::new(once(Err(e))) }, - |t| { + |t| { walker_step(t, self.g) .map_or_else::>>, _, _>( |e| Box::new(once(Err(e))), @@ -86,7 +115,7 @@ where ) } ) - .map(move |res| + .map(move |res| res.map_or_else( |e| (state.clone(), Err(e)), |(_, t)| (state.clone(), Ok(t)) @@ -99,10 +128,18 @@ where /// Moves the walker forward and adds more data to the state of the branch pub fn progress_with_state<'b, NewT, NextStep, StateAddition, WalkerStep, UpdateState>( - self, - walker_step: WalkerStep, - update_state: UpdateState - ) -> GraphWalker<'b, NewT, State, NK, EK, S, impl Iterator)> + 'b> + self, + walker_step: WalkerStep, + update_state: UpdateState, + ) -> GraphWalker< + 'b, + NewT, + State, + NK, + EK, + S, + impl Iterator)> + 'b, + > where 'a: 'b, 'b: 'a, @@ -111,7 +148,7 @@ where ::IntoIter: 'b, NextStep: IntoIterator, WalkerStep: Fn(T, &'a TypedGraph) -> SchemaResult + 'b, - UpdateState: Fn(State, StateAddition) -> State + 'b + Copy + UpdateState: Fn(State, StateAddition) -> State + 'b + Copy, { GraphWalker { g: self.g, @@ -121,7 +158,7 @@ where |e| { Box::new(once(Err(e))) }, - |t| { + |t| { walker_step(t, self.g) .map_or_else::>>, _, _>( |e| Box::new(once(Err(e))), @@ -129,7 +166,7 @@ where ) } ) - .map(move |res| + .map(move |res| res.map_or_else( |e| (state.clone(), Err(e)), |(new_state, t)| (update_state(state.clone(), new_state), Ok(t)) @@ -141,36 +178,24 @@ where } pub fn one(mut self) -> SchemaResult, NK, EK, S> { - self - .front - .next() - .map_or_else( - || Ok(None), - |(_state, res)| - res.map_or_else( - |e| Err(e), - |t| Ok(Some(t)) - ) - ) - + self.front.next().map_or_else( + || Ok(None), + |(_state, res)| res.map_or_else(|e| Err(e), |t| Ok(Some(t))), + ) } pub fn one_with_state(mut self) -> SchemaResult>, NK, EK, S> { - self.front - .next() - .map_or_else( - || Ok(None), - |(state, res)| - res.map_or_else( - |e| Err(e), - |t| Ok(Some(WalkerTarget { val: t, state })) - ) - ) + self.front.next().map_or_else( + || Ok(None), + |(state, res)| { + res.map_or_else(|e| Err(e), |t| Ok(Some(WalkerTarget { val: t, state }))) + }, + ) } - pub fn many(self) -> SchemaResult + pub fn many(self) -> SchemaResult where - TCollection: FromIterator + TCollection: FromIterator, { let mut results = Vec::new(); for (_, res) in self.front { @@ -180,9 +205,9 @@ where Ok(results.into_iter().collect()) } - pub fn many_with_state(self) -> SchemaResult + pub fn many_with_state(self) -> SchemaResult where - TStateCollection: FromIterator> + TStateCollection: FromIterator>, { let mut results = Vec::new(); for (state, res) in self.front { @@ -193,28 +218,44 @@ where } } -pub trait ToGraphWalker: Id +pub trait ToGraphWalker: Id where NK: Key, EK: Key, - S: SchemaExt + S: SchemaExt, { - fn to_walker<'a>(&'a self, g: &'a TypedGraph) -> SchemaResult)>>, NK, EK, S>; + fn to_walker<'a>( + &'a self, + g: &'a TypedGraph, + ) -> SchemaResult< + GraphWalker<'a, &Self, (), NK, EK, S, Once<((), SchemaResult<&'a Self, NK, EK, S>)>>, + NK, + EK, + S, + >; } -impl ToGraphWalker for T +impl ToGraphWalker for T where - T: Typed::Type> + Id, + T: Typed::Type> + Id, NK: Key, EK: Key, - S: SchemaExt, + S: SchemaExt, { - fn to_walker<'a>(&'a self, g: &'a TypedGraph) -> SchemaResult)>>, NK, EK, S> { + fn to_walker<'a>( + &'a self, + g: &'a TypedGraph, + ) -> SchemaResult< + GraphWalker<'a, &Self, (), NK, EK, S, Once<((), SchemaResult<&'a Self, NK, EK, S>)>>, + NK, + EK, + S, + > { // Make sure that the node has been inserted into the graph g.get_node(self.get_id())?; Ok(GraphWalker { g, - front: once(((), Ok(self))) - }) + front: once(((), Ok(self))), + }) } -} \ No newline at end of file +} diff --git a/src/graph/metadata.rs b/src/graph/metadata.rs index 4061a7a..451c176 100644 --- a/src/graph/metadata.rs +++ b/src/graph/metadata.rs @@ -1,9 +1,8 @@ +use crate::{EdgeKey, NodeKey}; use indexmap::IndexSet; use serde::{Deserialize, Serialize}; use std::fmt::Debug; -use std::ops::{DerefMut, Deref}; -use crate::{NodeKey, EdgeKey}; - +use std::ops::{Deref, DerefMut}; #[derive(Debug, Default, Clone, Serialize, Deserialize)] pub(crate) struct EdgeMetadata { @@ -50,9 +49,7 @@ impl DerefMut for NodeMetada { } } -impl EdgeMetadata { - -} +impl EdgeMetadata {} impl AsRef for EdgeMetadata { fn as_ref(&self) -> &E { @@ -71,4 +68,4 @@ impl DerefMut for EdgeMetadata { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.weight } -} \ No newline at end of file +} diff --git a/src/graph/migration/either.rs b/src/graph/migration/either.rs index c0a45b7..2434011 100644 --- a/src/graph/migration/either.rs +++ b/src/graph/migration/either.rs @@ -5,17 +5,17 @@ use crate::*; #[derive(Debug, Clone)] pub enum Either { Old(Old), - New(New) + New(New), } -impl From> for Option +impl From> for Option where - Option: From + Option: From, { fn from(value: Either) -> Self { match value { Either::Old(o) => o.into(), - Either::New(n) => Some(n) + Either::New(n) => Some(n), } } } @@ -39,8 +39,7 @@ impl, New: Id> Id for Either { } } -impl Copy for Either { -} +impl Copy for Either {} impl Typed for Either { type Type = Either<::Type, ::Type>; @@ -61,16 +60,16 @@ impl Display for Either { } } -impl PartialEq> for Either +impl PartialEq> for Either where Old: PartialEq, - New: PartialEq + New: PartialEq, { fn eq(&self, other: &Either) -> bool { match (self, other) { (Either::Old(old), Either::Old(ty)) => old == ty, (Either::New(new), Either::New(ty)) => new == ty, - _ => false + _ => false, } } } diff --git a/src/graph/migration/inbetween.rs b/src/graph/migration/inbetween.rs index f121ddb..b3d40da 100644 --- a/src/graph/migration/inbetween.rs +++ b/src/graph/migration/inbetween.rs @@ -9,7 +9,7 @@ pub struct InBetween { nk: PhantomData, ek: PhantomData, old: Old, - new: New + new: New, } impl InBetween { @@ -18,16 +18,16 @@ impl InBetween { nk: PhantomData, ek: PhantomData, old, - new + new, } - } + } } -impl SchemaExt for InBetween +impl SchemaExt for InBetween where - NK: Key, - EK: Key, - OldVersion: SchemaExt + MigrateSchema, + NK: Key, + EK: Key, + OldVersion: SchemaExt + MigrateSchema, NewVersion: SchemaExt, { type N = Either<>::N, >::N>; @@ -45,38 +45,29 @@ where } fn allow_edge( - &self, + &self, new_edge_count: usize, - edge_ty: ::Type, - source: ::Type, - target: ::Type + edge_ty: ::Type, + source: ::Type, + target: ::Type, ) -> Result<(), DisAllowedEdge> { - match (edge_ty, source, target) { // The edge is within the old graph - ( - Either::Old(edge_ty), - Either::Old(source), - Either::Old(target) - ) => self.old.allow_edge(new_edge_count, edge_ty, source, target), + (Either::Old(edge_ty), Either::Old(source), Either::Old(target)) => { + self.old.allow_edge(new_edge_count, edge_ty, source, target) + } // The edge is within the new graph - ( - Either::New(edge_ty), - Either::New(source), - Either::New(target) - ) => self.new.allow_edge(new_edge_count, edge_ty, source, target), + (Either::New(edge_ty), Either::New(source), Either::New(target)) => { + self.new.allow_edge(new_edge_count, edge_ty, source, target) + } // The edge is somewhere inbetween the two graphs - ( - edge_ty, - source, - target - ) => { + (edge_ty, source, target) => { // Only allow the edge if everything can be converted into the new graph let updated_content = ( self.update_edge_type(&self.new, edge_ty), self.update_node_type(&self.new, source), - self.update_node_type(&self.new, target) + self.update_node_type(&self.new, target), ); if let (Some(edge_ty), Some(source), Some(target)) = updated_content { self.new.allow_edge(new_edge_count, edge_ty, source, target) @@ -88,38 +79,55 @@ where } } -impl MigrateSchema for InBetween +impl MigrateSchema + for InBetween where - NK: Key, - EK: Key, + NK: Key, + EK: Key, OldVersion: SchemaExt + MigrateSchema, - NewVersion: SchemaExt + NewVersion: SchemaExt, { - fn update_edge(&self, new_schema: &NewVersion, edge: Self::E) -> Option<>::E> { + fn update_edge( + &self, + new_schema: &NewVersion, + edge: Self::E, + ) -> Option<>::E> { match edge { Either::New(e) => Some(e), - Either::Old(e) => self.old.update_edge(new_schema, e) + Either::Old(e) => self.old.update_edge(new_schema, e), } } - fn update_node(&self, new_schema: &NewVersion, node: Self::N) -> Option<>::N> { + fn update_node( + &self, + new_schema: &NewVersion, + node: Self::N, + ) -> Option<>::N> { match node { Either::New(n) => Some(n), - Either::Old(n) => self.old.update_node(new_schema, n) + Either::Old(n) => self.old.update_node(new_schema, n), } } - fn update_edge_type(&self, new_schema: &NewVersion, edge_type: ::Type) -> Option<<>::E as Typed>::Type> { + fn update_edge_type( + &self, + new_schema: &NewVersion, + edge_type: ::Type, + ) -> Option<<>::E as Typed>::Type> { match edge_type { Either::New(ty) => Some(ty), - Either::Old(ty) => self.old.update_edge_type(new_schema, ty) + Either::Old(ty) => self.old.update_edge_type(new_schema, ty), } } - fn update_node_type(&self, new_schema: &NewVersion, node_type: ::Type) -> Option<<>::N as Typed>::Type> { + fn update_node_type( + &self, + new_schema: &NewVersion, + node_type: ::Type, + ) -> Option<<>::N as Typed>::Type> { match node_type { Either::New(ty) => Some(ty), - Either::Old(ty) => self.old.update_node_type(new_schema, ty) + Either::Old(ty) => self.old.update_node_type(new_schema, ty), } } -} \ No newline at end of file +} diff --git a/src/graph/migration/migrate.rs b/src/graph/migration/migrate.rs index 9c3477e..b4cc0a3 100644 --- a/src/graph/migration/migrate.rs +++ b/src/graph/migration/migrate.rs @@ -5,18 +5,21 @@ use crate::*; pub type MigrationGraph = TypedGraph>; pub type MigrationResult = GenericTypedResult; -pub trait Migrationhandler +pub trait Migrationhandler where NK: Key, EK: Key, OldVersion: SchemaExt, NewVersion: SchemaExt, - OldVersion: MigrateSchema + OldVersion: MigrateSchema, { - fn update_data(&self, g: &mut MigrationGraph) -> SchemaResult<(), NK, EK, InBetween>; + fn update_data( + &self, + g: &mut MigrationGraph, + ) -> SchemaResult<(), NK, EK, InBetween>; } -pub trait MigrateSchema +pub trait MigrateSchema where NK: Key, EK: Key, @@ -24,21 +27,29 @@ where Self: SchemaExt, { /// Update a node from its old type to the new one - /// + /// /// Returning None indicates that there exists no equivalent in the new schema fn update_node(&self, new_schema: &NewVersion, node: Self::N) -> Option; /// Update an edge from its old type to the new one - /// + /// /// Returning None indicates that there exists no equivalent in the new schema fn update_edge(&self, new_schema: &NewVersion, edge: Self::E) -> Option; /// Update a node type from its old version to its new one - /// + /// /// Returning None indicates that there exists no equivalent in the new schema - fn update_node_type(&self, new_schema: &NewVersion, node_type: ::Type) -> Option<::Type>; + fn update_node_type( + &self, + new_schema: &NewVersion, + node_type: ::Type, + ) -> Option<::Type>; /// Update an edge type from its old version to its new one - /// + /// /// Returning None indicates that there exists no equivalent in the new schema - fn update_edge_type(&self, new_schema: &NewVersion, edge_type: ::Type) -> Option<::Type>; + fn update_edge_type( + &self, + new_schema: &NewVersion, + edge_type: ::Type, + ) -> Option<::Type>; } pub trait Migration: SchemaExt @@ -46,60 +57,62 @@ where NK: Key, EK: Key, NewVersion: SchemaExt + Clone, - Self: MigrateSchema + Clone + Self: MigrateSchema + Clone, { - type Handler: Migrationhandler; /// mirgate the data store in one schema to another - /// + /// /// Most of the time the default implementation is used as it uses an InBetween representation of the shemas to ensure type safety all throughout the migration process fn migrate( g: TypedGraph, handler: &Self::Handler, - new_schema: NewVersion + new_schema: NewVersion, ) -> GenericTypedResult, NK, EK> { // Setup migration enviroment let old_schema = g.get_schema().clone(); let old_name = old_schema.name(); let new_name = new_schema.name(); - let to_generic_error = |e: SchemaError>| + let to_generic_error = |e: SchemaError>| { e.map( - identity, - identity, + identity, + identity, |nt| match nt { Either::Old(nt) => format!("{}::{}", old_name, nt), Either::New(nt) => format!("{}::{}", new_name, nt), - }, + }, |et| match et { Either::Old(et) => format!("{}::{}", old_name, et), Either::New(et) => format!("{}::{}", new_name, et), - }, - ); - - let mut migration_g: MigrationGraph = g.update_schema( - InBetween::new(old_schema, new_schema.clone()), - |_, _, n| Some(Either::Old(n)), - |_, _, e| Some(Either::Old(e)), - ).map_err(to_generic_error)?; + }, + ) + }; + let mut migration_g: MigrationGraph = g + .update_schema( + InBetween::new(old_schema, new_schema.clone()), + |_, _, n| Some(Either::Old(n)), + |_, _, e| Some(Either::Old(e)), + ) + .map_err(to_generic_error)?; - handler.update_data(&mut migration_g).map_err(to_generic_error)?; + handler + .update_data(&mut migration_g) + .map_err(to_generic_error)?; // Finalize migration - let new_g = migration_g.update_schema( - new_schema, - |current_schema, new_schema, node| current_schema.update_node(&new_schema, node), - |current_schema, new_schema, edge| current_schema.update_edge(&new_schema, edge), - ) - // filter_map returns an error for the new schema - // So we have to convert it into an error for the joined schema - .map_err(|e| e - .map(identity, identity, Either::New, Either::New) - ) - // And then we can use the same formatter as for the other results - .map_err(to_generic_error)?; + let new_g = migration_g + .update_schema( + new_schema, + |current_schema, new_schema, node| current_schema.update_node(&new_schema, node), + |current_schema, new_schema, edge| current_schema.update_edge(&new_schema, edge), + ) + // filter_map returns an error for the new schema + // So we have to convert it into an error for the joined schema + .map_err(|e| e.map(identity, identity, Either::New, Either::New)) + // And then we can use the same formatter as for the other results + .map_err(to_generic_error)?; Ok(new_g) } @@ -109,9 +122,9 @@ pub trait DirectMigration: SchemaExt + Sized where NK: Key, EK: Key, - NewVersion: SchemaExt + NewVersion: SchemaExt, { fn migrate( - g: TypedGraph + g: TypedGraph, ) -> GenericTypedResult, NK, EK>; -} \ No newline at end of file +} diff --git a/src/graph/migration/migration_handler.rs b/src/graph/migration/migration_handler.rs index ccf6db5..d50ea92 100644 --- a/src/graph/migration/migration_handler.rs +++ b/src/graph/migration/migration_handler.rs @@ -3,7 +3,8 @@ use crate::*; /// The default migration handler does not alter any of the data in the graph pub struct DefaultMigrationHandler; -impl Migrationhandler for DefaultMigrationHandler +impl Migrationhandler + for DefaultMigrationHandler where NK: Key, EK: Key, @@ -11,10 +12,12 @@ where NewVersion: SchemaExt + Clone, // Everything should be convertable to the new schema - OldVersion: MigrateSchema + OldVersion: MigrateSchema, { - fn update_data(&self, _g: &mut MigrationGraph) -> SchemaResult<(), NK, EK, InBetween> { + fn update_data( + &self, + _g: &mut MigrationGraph, + ) -> SchemaResult<(), NK, EK, InBetween> { Ok(()) } - -} \ No newline at end of file +} diff --git a/src/graph/migration/mod.rs b/src/graph/migration/mod.rs index bc3d3eb..35ca349 100644 --- a/src/graph/migration/mod.rs +++ b/src/graph/migration/mod.rs @@ -1,9 +1,9 @@ -mod migrate; mod either; mod inbetween; +mod migrate; mod migration_handler; pub use either::*; pub use inbetween::*; +pub use migrate::*; pub use migration_handler::*; -pub use migrate::*; \ No newline at end of file diff --git a/src/graph/mod.rs b/src/graph/mod.rs index a817841..77b38ab 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -1,13 +1,13 @@ -mod typed_graph; +mod edge_ref; mod graph_traits; +mod graph_walker; mod metadata; mod migration; -mod graph_walker; -mod edge_ref; +mod typed_graph; pub use edge_ref::*; +pub use graph_traits::*; pub use graph_walker::*; +pub(crate) use metadata::*; pub use migration::*; pub use typed_graph::*; -pub use graph_traits::*; -pub(crate) use metadata::*; \ No newline at end of file diff --git a/src/graph/typed_graph.rs b/src/graph/typed_graph.rs index 0e106e5..bbc4c6d 100644 --- a/src/graph/typed_graph.rs +++ b/src/graph/typed_graph.rs @@ -1,23 +1,23 @@ use super::*; -use crate::{TypedError, SchemaResult, GenericTypedResult}; +use crate::{GenericTypedResult, SchemaResult, TypedError}; +use serde::{Deserialize, Serialize}; +use slotmap::{new_key_type, HopSlotMap}; use std::collections::HashMap; +use std::fmt::{self, Debug, Display}; +use std::hash::Hash; use std::marker::PhantomData; use std::ops::Deref; -use serde::{Deserialize, Serialize}; -use slotmap::{HopSlotMap, new_key_type}; -use std::fmt::{Debug, self, Display}; -use std::hash::Hash; #[derive(Copy, Clone, Debug)] pub enum Direction { Outgoing, - Incoming + Incoming, } #[derive(Copy, Clone, Debug, PartialEq)] pub enum InsertPosition { Before, - After + After, } new_key_type! { @@ -29,37 +29,37 @@ new_key_type! { pub struct TypedGraph> where NK: Key, - EK: Key + EK: Key, { /// Mapping from node ids to node keys node_lut: HashMap, /// Mapping from edge ids to edge keys edge_lut: HashMap, /// Contains the node weights and adjecency list - /// + /// /// Since the nodes stores its own id this can be used to convert node keys to node ids nodes: HopSlotMap>, /// Contains the edge weights, and edge endpoints - /// + /// /// Since the edges stores its own id this can be used to convert edge keys to edge ids edges: HopSlotMap>, - schema: S + schema: S, } -impl TypedGraph +impl TypedGraph where NK: Key, EK: Key, - S: SchemaExt + S: SchemaExt, { pub fn new(schema: S) -> Self { TypedGraph { node_lut: Default::default(), - edge_lut: Default::default(), - nodes: HopSlotMap::with_key(), - edges: HopSlotMap::with_key(), - schema: schema + edge_lut: Default::default(), + nodes: HopSlotMap::with_key(), + edges: HopSlotMap::with_key(), + schema: schema, } } @@ -76,27 +76,47 @@ where } fn get_node_key(&self, node_id: NK) -> SchemaResult { - self.node_lut.get(&node_id).copied().ok_or_else(|| TypedError::MissingNode(node_id)) + self.node_lut + .get(&node_id) + .copied() + .ok_or_else(|| TypedError::MissingNode(node_id)) } fn get_edge_key(&self, edge_id: EK) -> SchemaResult { - self.edge_lut.get(&edge_id).copied().ok_or_else(|| TypedError::MissingEdge(edge_id)) + self.edge_lut + .get(&edge_id) + .copied() + .ok_or_else(|| TypedError::MissingEdge(edge_id)) } fn get_node_internal(&self, node_key: NodeKey) -> SchemaResult<&NodeMetada, NK, EK, S> { - self.nodes.get(node_key).ok_or_else(|| TypedError::MissingNodeKey(node_key)) + self.nodes + .get(node_key) + .ok_or_else(|| TypedError::MissingNodeKey(node_key)) } fn get_edge_internal(&self, edge_key: EdgeKey) -> SchemaResult<&EdgeMetadata, NK, EK, S> { - self.edges.get(edge_key).ok_or_else(|| TypedError::MissingEdgeKey(edge_key)) + self.edges + .get(edge_key) + .ok_or_else(|| TypedError::MissingEdgeKey(edge_key)) } - fn get_node_mut_internal(&mut self, node_key: NodeKey) -> SchemaResult<&mut NodeMetada, NK, EK, S> { - self.nodes.get_mut(node_key).ok_or_else(|| TypedError::MissingNodeKey(node_key)) + fn get_node_mut_internal( + &mut self, + node_key: NodeKey, + ) -> SchemaResult<&mut NodeMetada, NK, EK, S> { + self.nodes + .get_mut(node_key) + .ok_or_else(|| TypedError::MissingNodeKey(node_key)) } - fn get_edge_mut_internal(&mut self, edge_key: EdgeKey) -> SchemaResult<&mut EdgeMetadata, NK, EK, S> { - self.edges.get_mut(edge_key).ok_or_else(|| TypedError::MissingEdgeKey(edge_key)) + fn get_edge_mut_internal( + &mut self, + edge_key: EdgeKey, + ) -> SchemaResult<&mut EdgeMetadata, NK, EK, S> { + self.edges + .get_mut(edge_key) + .ok_or_else(|| TypedError::MissingEdgeKey(edge_key)) } pub fn get_nodes(&self) -> impl Iterator + '_ { @@ -130,11 +150,13 @@ where } pub fn get_node(&self, node_id: NK) -> SchemaResult<&S::N, NK, EK, S> { - self.get_node_safe(node_id).ok_or_else(|| TypedError::MissingNode(node_id)) + self.get_node_safe(node_id) + .ok_or_else(|| TypedError::MissingNode(node_id)) } pub fn get_edge(&self, edge_id: EK) -> SchemaResult<&S::E, NK, EK, S> { - self.get_edge_safe(edge_id).ok_or_else(|| TypedError::MissingEdge(edge_id)) + self.get_edge_safe(edge_id) + .ok_or_else(|| TypedError::MissingEdge(edge_id)) } pub fn get_edge_full(&self, edge_id: EK) -> SchemaResult, NK, EK, S> { @@ -144,42 +166,44 @@ where weight: &edge.weight, source: self.nodes.get(edge.source).unwrap().get_id(), target: self.nodes.get(edge.target).unwrap().get_id(), - direction: Direction::Outgoing + direction: Direction::Outgoing, }) } pub fn get_node_mut(&mut self, node_id: NK) -> SchemaResult<&mut S::N, NK, EK, S> { - self.get_node_safe_mut(node_id).ok_or_else(|| TypedError::MissingNode(node_id)) + self.get_node_safe_mut(node_id) + .ok_or_else(|| TypedError::MissingNode(node_id)) } pub fn get_edge_mut(&mut self, edge_id: EK) -> SchemaResult<&mut S::E, NK, EK, S> { - self.get_edge_safe_mut(edge_id).ok_or_else(|| TypedError::MissingEdge(edge_id)) + self.get_edge_safe_mut(edge_id) + .ok_or_else(|| TypedError::MissingEdge(edge_id)) } - pub fn get_node_downcast(&self, node_id: NK) -> SchemaResult<&N, NK, EK, S> + pub fn get_node_downcast<'a, N>(&'a self, node_id: NK) -> SchemaResult where - S::N: Downcast + S::N: Downcast<'a, NK, EK, N, S>, { self.get_node(node_id).and_then(|n| n.downcast()) } - pub fn get_node_downcast_mut(&mut self, node_id: NK) -> SchemaResult<&mut N, NK, EK, S> + pub fn get_node_downcast_mut<'a, N>(&'a mut self, node_id: NK) -> SchemaResult where - S::N: DowncastMut + S::N: DowncastMut<'a, NK, EK, N, S>, { self.get_node_mut(node_id).and_then(|n| n.downcast_mut()) } - pub fn get_edge_downcast(&self, edge_id: EK) -> SchemaResult<&E, NK, EK, S> + pub fn get_edge_downcast<'a, E>(&'a self, edge_id: EK) -> SchemaResult where - S::E: Downcast + S::E: Downcast<'a, NK, EK, E, S>, { self.get_edge(edge_id).and_then(|e| e.downcast()) } - pub fn get_edge_downcast_mut(&mut self, edge_id: EK) -> SchemaResult<&mut E, NK, EK, S> + pub fn get_edge_downcast_mut<'a, E>(&'a mut self, edge_id: EK) -> SchemaResult where - S::E: DowncastMut + S::E: DowncastMut<'a, NK, EK, E, S>, { self.get_edge_mut(edge_id).and_then(|e| e.downcast_mut()) } @@ -193,12 +217,17 @@ where } /// Reorder the outgoing edge order by inserting the source edge and shifting all other edges - /// + /// /// This will fail if both edges do not have the same source node - /// + /// /// InsertPosition::Before the source edge will have the same index as the target edge currently has /// InsertPosition::After the source edge will be placed at target_idx + 1 instead - pub fn move_edge_order(&mut self, source_id: EK, target_id: EK, insert_position: InsertPosition) -> SchemaResult<(), NK, EK, S> { + pub fn move_edge_order( + &mut self, + source_id: EK, + target_id: EK, + insert_position: InsertPosition, + ) -> SchemaResult<(), NK, EK, S> { if source_id == target_id { return Ok(()); } @@ -222,24 +251,26 @@ where return Err(TypedError::InvalidEdgeMove(source_id, target_id)); } - let source_idx = node.outgoing_edges + let source_idx = node + .outgoing_edges .get_index_of(&source_key) .ok_or_else(|| TypedError::InvalidEdgeMove(source_id, target_id))?; // Then we find out where in the order we want the node to be places - let mut target_idx = node.outgoing_edges + let mut target_idx = node + .outgoing_edges .get_index_of(&target_key) .ok_or_else(|| TypedError::InvalidEdgeMove(source_id, target_id))?; // Figure out where to place the source relative to the target - // The position has to be adjusted as to not go out of bounds and + // The position has to be adjusted as to not go out of bounds and // play nicely with the behaviour of move_index match insert_position { InsertPosition::After => { if target_idx + 1 != node.outgoing_edges.len() && source_idx > target_idx { target_idx += 1; } - }, + } InsertPosition::Before => { if target_idx != 0 && source_idx < target_idx { target_idx -= 1; @@ -255,9 +286,9 @@ where /// Add a node or update an existing one /// Updating a node is only allowed if all the connected edges allows for the new type - pub fn add_node(&mut self, node: N) -> SchemaResult + pub fn add_node(&mut self, node: N) -> SchemaResult where - N: Into + N: Into, { let weight: S::N = node.into(); @@ -265,7 +296,7 @@ where let weight_type = weight.get_type(); let allowed = self.schema.allow_node(weight.get_type()); if let Err(e) = allowed { - return Err(TypedError::InvalidNodeType(weight_type, e)) + return Err(TypedError::InvalidNodeType(weight_type, e)); } let node_id = weight.get_id(); @@ -313,20 +344,24 @@ where if out_target_node.get_type() != target_node.get_type() { continue; } - + quantity += 1; } - + let allowed = self.schema.allow_edge( // Account for the new type adding a new edge quantity + 1, - weight_type.clone(), - source_node.get_type(), + weight_type.clone(), + source_node.get_type(), target_node.get_type(), - ); if let Err(e) = allowed { - return Err(TypedError::InvalidEdgeType(weight_type, source_node.get_type(), target_node.get_type(), e)) + return Err(TypedError::InvalidEdgeType( + weight_type, + source_node.get_type(), + target_node.get_type(), + e, + )); } } @@ -337,14 +372,13 @@ where // Just replace the node let node = self.get_node_mut_internal(node_key)?; node.weight = weight; - } } else { // Add the node to the graph let node_key = self.nodes.insert(NodeMetada { weight: weight, outgoing_edges: Default::default(), - incoming_edges: Default::default() + incoming_edges: Default::default(), }); self.node_lut.insert(node_id, node_key); } @@ -354,9 +388,9 @@ where /// Add an edge and if it already exists update the weight and enpoints of the edge /// The edge will preserve the order of endpoints that does not change - pub fn add_edge(&mut self, source: NK, target: NK, edge: E) -> SchemaResult + pub fn add_edge(&mut self, source: NK, target: NK, edge: E) -> SchemaResult where - E: Into + E: Into, { let weight: S::E = edge.into(); let edge_id = weight.get_id(); @@ -387,12 +421,17 @@ where let allowed = self.schema.allow_edge( quantity + 1, - weight_type.clone(), - source_node.get_type(), + weight_type.clone(), + source_node.get_type(), target_node.get_type(), ); if let Err(e) = allowed { - return Err(TypedError::InvalidEdgeType(weight_type, source_node.get_type(), target_node.get_type(), e)) + return Err(TypedError::InvalidEdgeType( + weight_type, + source_node.get_type(), + target_node.get_type(), + e, + )); } if let Some(edge_key) = self.edge_lut.get(&edge_id).copied() { @@ -432,37 +471,43 @@ where let full_weight = EdgeMetadata { weight, source: source_key, - target: target_key + target: target_key, }; let edge_key = self.edges.insert(full_weight); self.edge_lut.insert(edge_id, edge_key); - + // Add the edge to the source self.get_node_mut_internal(source_key)? .outgoing_edges .insert(edge_key); - + // Add the edge to the target self.get_node_mut_internal(target_key)? .incoming_edges .insert(edge_key); - } - Ok(edge_id) } /// Remove a node and all edges to and from it pub fn remove_node(&mut self, node_id: NK) -> SchemaResult { - let node_key = self.node_lut.remove(&node_id).ok_or_else(|| TypedError::NodeIdMissing(node_id))?; + let node_key = self + .node_lut + .remove(&node_id) + .ok_or_else(|| TypedError::NodeIdMissing(node_id))?; let node = self.nodes.remove(node_key).unwrap(); for edge_key in node.outgoing_edges { - let edge = self.edges.remove(edge_key).ok_or_else(|| TypedError::InvalidInternalState)?; + let edge = self + .edges + .remove(edge_key) + .ok_or_else(|| TypedError::InvalidInternalState)?; self.edge_lut.remove(&edge.weight.get_id()); if edge.target != node_key { - self.get_node_mut_internal(edge.target)?.incoming_edges.remove(&edge_key); + self.get_node_mut_internal(edge.target)? + .incoming_edges + .remove(&edge_key); } } @@ -473,10 +518,15 @@ where continue; } - let edge = self.edges.remove(edge_key).ok_or_else(|| TypedError::InvalidInternalState)?; + let edge = self + .edges + .remove(edge_key) + .ok_or_else(|| TypedError::InvalidInternalState)?; self.edge_lut.remove(&edge.weight.get_id()); if edge.source != node_key { - self.get_node_mut_internal(edge.source)?.outgoing_edges.shift_remove(&edge_key); + self.get_node_mut_internal(edge.source)? + .outgoing_edges + .shift_remove(&edge_key); } } @@ -485,19 +535,35 @@ where /// Remove an edge. pub fn remove_edge(&mut self, edge_id: EK) -> SchemaResult { - let edge_key = self.edge_lut.remove(&edge_id).ok_or_else(|| TypedError::EdgeIdMissing(edge_id))?; + let edge_key = self + .edge_lut + .remove(&edge_id) + .ok_or_else(|| TypedError::EdgeIdMissing(edge_id))?; // Remove the edge itself. - let edge = self.edges.remove(edge_key).ok_or_else(|| TypedError::InvalidInternalState)?; - self.get_node_mut_internal(edge.source)?.outgoing_edges.shift_remove(&edge_key); - self.get_node_mut_internal(edge.target)?.incoming_edges.remove(&edge_key); + let edge = self + .edges + .remove(edge_key) + .ok_or_else(|| TypedError::InvalidInternalState)?; + self.get_node_mut_internal(edge.source)? + .outgoing_edges + .shift_remove(&edge_key); + self.get_node_mut_internal(edge.target)? + .incoming_edges + .remove(&edge_key); Ok(edge.weight) } - /// Get all incoming edges - pub fn get_incoming<'a>(&'a self, node_id: NK) -> SchemaResult>, NK, EK, S> { - let node_key = *self.node_lut.get(&node_id).ok_or_else(|| TypedError::NodeIdMissing(node_id))?; + /// Get all incoming edges + pub fn get_incoming<'a>( + &'a self, + node_id: NK, + ) -> SchemaResult>, NK, EK, S> { + let node_key = *self + .node_lut + .get(&node_id) + .ok_or_else(|| TypedError::NodeIdMissing(node_id))?; Ok(self .get_node_internal(node_key)? .incoming_edges @@ -508,14 +574,19 @@ where weight: &edge.weight, source: self.nodes.get(edge.source).unwrap().get_id(), target: self.nodes.get(edge.target).unwrap().get_id(), - direction: Direction::Incoming + direction: Direction::Incoming, } - }) - ) + })) } - pub fn get_outgoing<'a>(&'a self, node_id: NK) -> SchemaResult>, NK, EK, S> { - let node_key = *self.node_lut.get(&node_id).ok_or_else(|| TypedError::NodeIdMissing(node_id))?; + pub fn get_outgoing<'a>( + &'a self, + node_id: NK, + ) -> SchemaResult>, NK, EK, S> { + let node_key = *self + .node_lut + .get(&node_id) + .ok_or_else(|| TypedError::NodeIdMissing(node_id))?; Ok(self .get_node_internal(node_key)? .outgoing_edges @@ -526,42 +597,57 @@ where weight: &edge.weight, source: self.nodes.get(edge.source).unwrap().get_id(), target: self.nodes.get(edge.target).unwrap().get_id(), - direction: Direction::Outgoing + direction: Direction::Outgoing, } - }) - ) + })) } - pub fn get_incoming_and_outgoing<'a>(&'a self, node_id: NK) -> SchemaResult>, NK, EK, S> { - self - .get_incoming(node_id) - .and_then(|inc| self - .get_outgoing(node_id) - .map(|out|inc.chain(out)) - ) + pub fn get_incoming_and_outgoing<'a>( + &'a self, + node_id: NK, + ) -> SchemaResult>, NK, EK, S> { + self.get_incoming(node_id) + .and_then(|inc| self.get_outgoing(node_id).map(|out| inc.chain(out))) } - pub fn get_outgoing_filter_edge<'a, F>(&'a self, node_id: NK, filter: F) -> SchemaResult>, NK, EK, S> + pub fn get_outgoing_filter<'a, F>( + &'a self, + node_id: NK, + filter: F, + ) -> SchemaResult>, NK, EK, S> where - F: Fn(&S::E) -> bool - { - Ok( - self - .get_outgoing(node_id)? - .filter(move |e| filter(&e.weight)) - ) + F: Fn(&S::E) -> bool, + { + Ok(self + .get_outgoing(node_id)? + .filter(move |e| filter(&e.weight))) } - - pub fn get_incoming_filter_edge<'a, F>(&'a self, node_id: NK, filter: F) -> SchemaResult>, NK, EK, S> + + pub fn get_incoming_filter<'a, F>( + &'a self, + node_id: NK, + filter: F, + ) -> SchemaResult>, NK, EK, S> where - F: Fn(&S::E) -> bool - { - Ok( - self - .get_incoming(node_id)? - .filter(move |e| filter(&e.weight)) - ) + F: Fn(&S::E) -> bool, + { + Ok(self + .get_incoming(node_id)? + .filter(move |e| filter(&e.weight))) + } + + pub fn get_incoming_and_outgoing_filter<'a, F>( + &'a self, + node_id: NK, + filter: F, + ) -> SchemaResult>, NK, EK, S> + where + F: Fn(&S::E) -> bool, + { + Ok(self + .get_incoming_and_outgoing(node_id)? + .filter(move |e| filter(&e.weight))) } pub fn nodes<'a>(&'a self) -> impl Iterator + 'a { @@ -573,15 +659,12 @@ where } pub fn edges_full<'a>(&'a self) -> impl Iterator> + 'a { - self - .edges - .values() - .map(|edge| EdgeRef { - weight: &edge.weight, - source: self.nodes.get(edge.source).unwrap().get_id(), - target: self.nodes.get(edge.target).unwrap().get_id(), - direction: Direction::Outgoing - }) + self.edges.values().map(|edge| EdgeRef { + weight: &edge.weight, + source: self.nodes.get(edge.source).unwrap().get_id(), + target: self.nodes.get(edge.target).unwrap().get_id(), + direction: Direction::Outgoing, + }) } pub fn node_ids<'a>(&'a self) -> impl Iterator + 'a { @@ -593,14 +676,18 @@ where } /// Apply a Migration to the current graph - pub fn migrate(self, new_schema: NS, handler: &S::Handler) -> GenericTypedResult, NK, EK> + pub fn migrate( + self, + new_schema: NS, + handler: &S::Handler, + ) -> GenericTypedResult, NK, EK> where S: Migration, NS: SchemaExt + Clone, { Migration::migrate(self, handler, new_schema) } - + /// Migrate directly from one version to another pub fn migrate_direct(self) -> GenericTypedResult, NK, EK> where @@ -612,14 +699,19 @@ where /// Convert the graph from one schema to another using two mapping functions /// The mapping functions are not allowed to change the id of any of the nodes only their data - /// - /// + /// + /// /// Returning None from a mapping function will delete the node from the resulting graph - /// - /// - /// When mapping to a schema with lower bounds on the number of edges allowed from a node. + /// + /// + /// When mapping to a schema with lower bounds on the number of edges allowed from a node. /// Edges higher in the outgoing edge order will be removed - pub fn update_schema(mut self, schema: NS, node_map: NF, edge_map: EF) -> SchemaResult, NK, EK, NS> + pub fn update_schema( + mut self, + schema: NS, + node_map: NF, + edge_map: EF, + ) -> SchemaResult, NK, EK, NS> where NS: SchemaExt, NF: Fn(&S, &NS, S::N) -> Option, @@ -627,12 +719,16 @@ where { let old_schema = self.schema; let mut new_graph = TypedGraph::new(schema); - + // Create a list of all the edges that stores them in outgoing order let mut edges = Vec::new(); for (_, node) in &self.nodes { for e in &node.outgoing_edges { - edges.push(self.edges.remove(*e).ok_or_else(|| TypedError::InvalidInternalState)?); + edges.push( + self.edges + .remove(*e) + .ok_or_else(|| TypedError::InvalidInternalState)?, + ); } } @@ -664,8 +760,12 @@ where return Err(TypedError::InconsistentEdgeIds(old_id, e.get_id())); } - let source_id = *node_id_lut.get(&edge.source).ok_or_else(|| TypedError::InvalidInternalState)?; - let target_id = *node_id_lut.get(&edge.target).ok_or_else(|| TypedError::InvalidInternalState)?; + let source_id = *node_id_lut + .get(&edge.source) + .ok_or_else(|| TypedError::InvalidInternalState)?; + let target_id = *node_id_lut + .get(&edge.target) + .ok_or_else(|| TypedError::InvalidInternalState)?; // Don't include the edge if the source or target has been removed if new_graph.has_node(source_id) && new_graph.has_node(target_id) { @@ -676,7 +776,7 @@ where // Since egdes are updated in outgoing order this will remove the last edges in the outgoing order Err(TypedError::InvalidEdgeType(_, _, _, DisAllowedEdge::ToMany)) => (), Err(e) => Err(e)?, - Ok(_) => () + Ok(_) => (), } } } @@ -686,54 +786,50 @@ where } } -impl Default for TypedGraph +impl Default for TypedGraph where NK: Key, EK: Key, - S: SchemaExt + Default + S: SchemaExt + Default, { fn default() -> Self { - TypedGraph { + TypedGraph { node_lut: Default::default(), edge_lut: Default::default(), - nodes: HopSlotMap::with_key(), - edges: HopSlotMap::with_key(), - schema: S::default() + nodes: HopSlotMap::with_key(), + edges: HopSlotMap::with_key(), + schema: S::default(), } } } -use serde::ser::*; -use serde::de::*; use serde::de::Error; +use serde::de::*; +use serde::ser::*; /// A reference to an edge with its source and target id #[derive(Serialize)] struct EdgeWriteDTO<'a, NK, E> { weight: &'a E, source: NK, - target: NK + target: NK, } // This is what #[derive(Serialize)] would generate. -impl Serialize for TypedGraph +impl Serialize for TypedGraph where NK: Key + Serialize, EK: Key + Serialize, N: Serialize + NodeExt, E: Serialize + EdgeExt, - S: SchemaExt + Serialize + S: SchemaExt + Serialize, { fn serialize(&self, serializer: Ser) -> Result where Ser: Serializer, { // Serialize the nodes as is - let nodes: Vec<_> = self - .nodes - .values() - .map(|n| &n.weight) - .collect(); + let nodes: Vec<_> = self.nodes.values().map(|n| &n.weight).collect(); // Wrap the edges in write dto let edges: Vec<_> = self @@ -749,7 +845,7 @@ where target: self.nodes.get(e.target).unwrap().get_id(), }) .collect(); - + // Serialize the graph as a map with 3 fields let mut s = serializer.serialize_map(Some(3))?; s.serialize_entry("schema", &self.schema)?; @@ -764,14 +860,14 @@ where struct EdgeReadDTO { weight: E, source: NK, - target: NK + target: NK, } /// A deserialize visitor that can generate a TypedGraph -/// -/// this contains all the generics used by the TypeGraph since they would otherwise be seen as not used +/// +/// this contains all the generics used by the TypeGraph since they would otherwise be seen as not used #[derive(Default)] -struct TypedGraphVisitor +struct TypedGraphVisitor where NK: Key, EK: Key, @@ -806,17 +902,20 @@ where where M: MapAccess<'de>, { - // Step 1: Deserialize the schema - let (schema_field, schema): (&'de str, S) = access.next_entry()?.ok_or_else(|| M::Error::missing_field("schema"))?; + let (schema_field, schema): (&'de str, S) = access + .next_entry()? + .ok_or_else(|| M::Error::missing_field("schema"))?; if schema_field != "schema" { return Err(M::Error::unknown_field(schema_field, &["schema"])); } - + let mut g = TypedGraph::new(schema); // Step 2: Deserialize the nodes - let (nodes_field, nodes): (&'de str, Vec) = access.next_entry()?.ok_or_else(|| M::Error::missing_field("nodes"))?; + let (nodes_field, nodes): (&'de str, Vec) = access + .next_entry()? + .ok_or_else(|| M::Error::missing_field("nodes"))?; if nodes_field != "nodes" { return Err(M::Error::unknown_field(nodes_field, &["nodes"])); } @@ -827,23 +926,25 @@ where } // Step 3: Deserialize the edges - let (edges_field, edges): (&'de str, Vec>) = access.next_entry()?.ok_or_else(|| M::Error::missing_field("edges"))?; + let (edges_field, edges): (&'de str, Vec>) = access + .next_entry()? + .ok_or_else(|| M::Error::missing_field("edges"))?; if edges_field != "edges" { return Err(M::Error::unknown_field(edges_field, &["edges"])); } // Check for id collisions and propper edge types for e in edges { - g.add_edge(e.source, e.target, e.weight).map_err(|e| M::Error::custom(e))?; + g.add_edge(e.source, e.target, e.weight) + .map_err(|e| M::Error::custom(e))?; } Ok(g) - } } /// Use the visitor to deserialize the TypedGraph -impl<'de, NK, EK, N, E, S> Deserialize<'de> for TypedGraph +impl<'de, NK, EK, N, E, S> Deserialize<'de> for TypedGraph where NK: Key + Display + Deserialize<'de>, EK: Key + Display + Deserialize<'de>, @@ -860,7 +961,7 @@ where ek: PhantomData, n: PhantomData, e: PhantomData, - s: PhantomData + s: PhantomData, }) } } @@ -873,7 +974,7 @@ fn graph_compose_test() -> crate::test::TestResult<()> { let config = CompleteGraph { nodes: 100, node_types: 5, - edge_types: 5 + edge_types: 5, }; let g = TestGraph::dummy(&config); @@ -894,11 +995,14 @@ fn graph_quantity_test() -> crate::test::TestResult<()> { const EDGE_TYPE1: usize = 0; const EDGE_TYPE2: usize = 2; - let s = TestSchema::new() - .endpoint_max_quantity(Some(vec![ + let s = TestSchema::new().endpoint_max_quantity(Some( + vec![ // Set a max capacity of 2 edges of type EDGE_TYPE1 from NODE_TYPE to NODE_TYPE1 - ((EDGE_TYPE1, NODE_TYPE, NODE_TYPE1), 2) - ].into_iter().collect())); + ((EDGE_TYPE1, NODE_TYPE, NODE_TYPE1), 2), + ] + .into_iter() + .collect(), + )); let mut g = TestGraph::new(s); let source_id = g.add_node((0, NODE_TYPE))?; @@ -930,7 +1034,7 @@ fn graph_quantity_test() -> crate::test::TestResult<()> { let source_id2 = g.add_node((2, NODE_TYPE2))?; let target_id2 = g.add_node((3, NODE_TYPE2))?; g.add_edge(target_id2, source_id2, (2, EDGE_TYPE2))?; - + let e = g.add_edge(source_id, target_id, (2, EDGE_TYPE1)); assert!(e.is_err(), "Updating completly different edge"); g.remove_edge(2)?; @@ -989,8 +1093,6 @@ fn edge_order() -> crate::test::TestResult<()> { let ids: Vec = ng.get_outgoing(a)?.map(|e| e.get_type()).collect(); assert_eq!(ids, &[0, 1, 2, 4, 3]); - - // Swap before self is identity let mut ng = g.clone(); ng.move_edge_order(2, 2, InsertPosition::Before)?; @@ -1003,7 +1105,6 @@ fn edge_order() -> crate::test::TestResult<()> { let ids: Vec = ng.get_outgoing(a)?.map(|e| e.get_type()).collect(); assert_eq!(ids, &[0, 1, 2, 3, 4]); - // Swap before when before let mut ng = g.clone(); ng.move_edge_order(1, 2, InsertPosition::Before)?; diff --git a/src/lib.rs b/src/lib.rs index 5ef7696..8b37d8f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,8 @@ +pub mod generic_graph; mod graph; -mod typed_error; #[cfg(any(test, bench))] pub mod test; -pub mod generic_graph; +mod typed_error; +pub use graph::*; pub use typed_error::*; -pub use graph::*; \ No newline at end of file diff --git a/src/test/mod.rs b/src/test/mod.rs index d41afd6..4728ec9 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -1,6 +1,6 @@ -mod test_graph; mod test_dummy; +mod test_graph; mod test_py; +pub use test_dummy::*; pub use test_graph::*; -pub use test_dummy::*; \ No newline at end of file diff --git a/src/test/test_dummy.rs b/src/test/test_dummy.rs index 9ee937d..b884766 100644 --- a/src/test/test_dummy.rs +++ b/src/test/test_dummy.rs @@ -1,35 +1,37 @@ +use super::TestGraph; use fake::*; use rand::Rng; -use super::TestGraph; /// Create a complete graph comprised of nodes and edges of random types pub struct CompleteGraph { pub nodes: usize, pub node_types: usize, - pub edge_types: usize + pub edge_types: usize, } impl Dummy for TestGraph { fn dummy_with_rng(config: &CompleteGraph, rng: &mut R) -> Self { let mut g = TestGraph::default(); for i in 0..config.nodes { - g.add_node((i, rng.gen_range(0..config.node_types))).unwrap(); + g.add_node((i, rng.gen_range(0..config.node_types))) + .unwrap(); } - + for x in 0..config.nodes { for y in 0..config.nodes { - if x == y { + if x == y { continue; } g.add_edge( - x, - y, - (x + y * config.nodes, rng.gen_range(0..config.edge_types)) - ).unwrap(); + x, + y, + (x + y * config.nodes, rng.gen_range(0..config.edge_types)), + ) + .unwrap(); } } g } -} \ No newline at end of file +} diff --git a/src/test/test_graph.rs b/src/test/test_graph.rs index 4504ad8..db34620 100644 --- a/src/test/test_graph.rs +++ b/src/test/test_graph.rs @@ -2,4 +2,4 @@ use crate::generic_graph::{GenericGraph, GenericResult, GenericSchema}; pub type TestSchema = GenericSchema; pub type TestGraph = GenericGraph; -pub type TestResult = GenericResult; \ No newline at end of file +pub type TestResult = GenericResult; diff --git a/src/test/test_py.rs b/src/test/test_py.rs index 12d9c2a..66af1de 100644 --- a/src/test/test_py.rs +++ b/src/test/test_py.rs @@ -1,14 +1,14 @@ use std::collections::HashSet; -use fake::{Dummy, Faker, Fake}; -use pyo3::*; +use fake::{Dummy, Fake, Faker}; use pyo3::types::PyModule; -use rand::seq::{SliceRandom, IteratorRandom}; -use serde::{Serialize, Deserialize}; +use pyo3::*; +use rand::seq::{IteratorRandom, SliceRandom}; +use serde::{Deserialize, Serialize}; use crate::Id; -use super::{TestGraph, TestSchema, TestResult}; +use super::{TestGraph, TestResult, TestSchema}; const TYPED_GRAPH_PY: &'static str = include_str!("rust_test.py"); @@ -16,7 +16,7 @@ const TYPED_GRAPH_PY: &'static str = include_str!("rust_test.py"); enum Action { AddNode { id: usize, - ty: usize + ty: usize, }, AddEdge { id: usize, @@ -28,8 +28,8 @@ enum Action { id: usize, }, RemoveEdge { - id: usize - } + id: usize, + }, } impl Action { @@ -37,13 +37,18 @@ impl Action { match self { Action::AddNode { id, ty } => { g.add_node((*id, *ty))?; - }, - Action::AddEdge { id, ty, source, target } => { + } + Action::AddEdge { + id, + ty, + source, + target, + } => { g.add_edge(*source, *target, (*id, *ty))?; - }, + } Action::RemoveNode { id } => { g.remove_node(*id)?; - }, + } Action::RemoveEdge { id } => { g.remove_edge(*id)?; } @@ -55,12 +60,11 @@ impl Action { struct TestProject { g: TestGraph, - actions: Vec + actions: Vec, } impl Dummy for TestProject { fn dummy_with_rng(_config: &Faker, rng: &mut R) -> Self { - // Create 5-10 node types let node_whitelist: Vec<_> = (0..rng.gen_range(5..10)).collect(); let node_whitelist_len = node_whitelist.len(); @@ -71,9 +75,9 @@ impl Dummy for TestProject { for edge_type in 0..rng.gen_range(5..10) { for _ in 0..rng.gen_range(0..max_edge_len) { edge_whitelist.insert(( - edge_type, - *node_whitelist.choose(rng).unwrap(), - *node_whitelist.choose(rng).unwrap() + edge_type, + *node_whitelist.choose(rng).unwrap(), + *node_whitelist.choose(rng).unwrap(), )); } } @@ -95,39 +99,62 @@ impl Dummy for TestProject { // First figure out which actions are possible let mut possible_actions: Vec = Vec::new(); - let add_node = Action::AddNode { id: next_node_id, ty: node_whitelist.choose(rng).unwrap().clone() }; + let add_node = Action::AddNode { + id: next_node_id, + ty: node_whitelist.choose(rng).unwrap().clone(), + }; possible_actions.push(add_node.clone()); possible_actions.push(add_node); if g.node_count() > 0 { let all_nodes = g.node_ids(); - let remove_node = Action::RemoveNode { id: all_nodes.choose(rng).unwrap().clone() }; + let remove_node = Action::RemoveNode { + id: all_nodes.choose(rng).unwrap().clone(), + }; possible_actions.push(remove_node); let current_node_types: HashSet = g.nodes().map(|n| n.1).collect(); let possible_edge_types: HashSet<&(usize, usize, usize)> = edge_whitelist .iter() - .filter(|(_, source, target)| current_node_types.contains(source) && current_node_types.contains(target)) + .filter(|(_, source, target)| { + current_node_types.contains(source) && current_node_types.contains(target) + }) .collect(); if let Some((edge_type, source, target)) = possible_edge_types.iter().choose(rng) { - let source_id = g.nodes().filter(|n| &n.1 == source).choose(rng).map(|n| n.0).unwrap(); - let target_id = g.nodes().filter(|n| &n.1 == target).choose(rng).map(|n| n.0).unwrap(); - - let add_edge = Action::AddEdge { id: next_edge_id, ty: *edge_type, source: source_id, target: target_id }; + let source_id = g + .nodes() + .filter(|n| &n.1 == source) + .choose(rng) + .map(|n| n.0) + .unwrap(); + let target_id = g + .nodes() + .filter(|n| &n.1 == target) + .choose(rng) + .map(|n| n.0) + .unwrap(); + + let add_edge = Action::AddEdge { + id: next_edge_id, + ty: *edge_type, + source: source_id, + target: target_id, + }; possible_actions.push(add_edge.clone()); possible_actions.push(add_edge); } } if g.edge_count() > 0 { - let remove_edge = Action::RemoveEdge { id: g.edge_ids().choose(rng).unwrap() }; + let remove_edge = Action::RemoveEdge { + id: g.edge_ids().choose(rng).unwrap(), + }; possible_actions.push(remove_edge) } // Finally pick one of the actions to do if let Some(action) = possible_actions.choose(rng) { - if matches!(action, Action::AddNode { .. }) { next_node_id += 1; } @@ -135,29 +162,26 @@ impl Dummy for TestProject { if matches!(action, Action::AddEdge { .. }) { next_edge_id += 1; } - + action.apply(&mut g).unwrap(); actions.push(action.clone()); } } - TestProject { - g, - actions - } + TestProject { g, actions } } } -fn run_py_test( - json_schema: String, - json_actions: String -) -> String { +fn run_py_test(json_schema: String, json_actions: String) -> String { Python::with_gil(|py| -> PyResult { let rust_test_mod: &PyModule = PyModule::from_code(py, TYPED_GRAPH_PY, "", "")?; - let json_py_graph = rust_test_mod.call_method("run", (json_schema, json_actions), None)?.extract()?; + let json_py_graph = rust_test_mod + .call_method("run", (json_schema, json_actions), None)? + .extract()?; Ok(json_py_graph) - }).unwrap() + }) + .unwrap() } #[test] @@ -173,11 +197,8 @@ fn test_typed_graph_py() { println!(); println!(); - let json_py_graph = run_py_test( - json_schema, - json_actions - ); - + let json_py_graph = run_py_test(json_schema, json_actions); + let py_graph: TestGraph = serde_json::from_str(&json_py_graph).unwrap(); prj.g.assert_eq(&py_graph).unwrap(); } @@ -196,16 +217,16 @@ fn run_single() { for action in actions { action.apply(&mut g).unwrap(); } - - let json_py_graph = run_py_test( - json_schema.to_string(), - json_actions.to_string() - ); + + let json_py_graph = run_py_test(json_schema.to_string(), json_actions.to_string()); let py_graph: TestGraph = serde_json::from_str(&json_py_graph).unwrap(); - let out: Vec = py_graph.get_outgoing(2).unwrap().map(|e| e.get_id()).collect(); + let out: Vec = py_graph + .get_outgoing(2) + .unwrap() + .map(|e| e.get_id()) + .collect(); dbg!(out); g.assert_eq(&py_graph).unwrap(); - } diff --git a/src/typed_error.rs b/src/typed_error.rs index 11f404e..16dd7aa 100644 --- a/src/typed_error.rs +++ b/src/typed_error.rs @@ -1,17 +1,17 @@ use std::fmt::Debug; use thiserror::Error; -use crate::{NodeKey, EdgeKey, SchemaExt, Typed, DisAllowedEdge, DisAllowedNode}; +use crate::{DisAllowedEdge, DisAllowedNode, EdgeKey, NodeKey, SchemaExt, Typed}; pub type TypedResult = Result>; pub type GenericTypedError = TypedError; pub type GenericTypedResult = Result>; pub type SchemaError = TypedError< - NK, - EK, - <>::N as Typed>::Type, - <>::E as Typed>::Type + NK, + EK, + <>::N as Typed>::Type, + <>::E as Typed>::Type, >; /// Helper type for errors. @@ -58,7 +58,7 @@ pub enum TypedError { #[error("Node id was changed from {0} to {1} which was not expected")] InconsistentNodeIds(NK, NK), - + #[error("Edge id was changed from {0} to {1} which was not expected")] InconsistentEdgeIds(EK, EK), @@ -78,12 +78,12 @@ pub enum TypedError { impl TypedError { pub fn map( - self, + self, nk_map: NKF, ek_map: EKF, nt_map: NTF, et_map: ETF, - ) -> TypedError + ) -> TypedError where NKF: Fn(NK) -> NK1, EKF: Fn(EK) -> EK1, @@ -99,12 +99,18 @@ impl TypedError { TypedError::EdgeIdMissing(a) => TypedError::EdgeIdMissing(ek_map(a)), TypedError::MissingNode(a) => TypedError::MissingNode(nk_map(a)), TypedError::MissingEdge(a) => TypedError::MissingEdge(ek_map(a)), - TypedError::InvalidEdgeType(a, b, c, e) => TypedError::InvalidEdgeType(et_map(a), nt_map(b), nt_map(c), e), + TypedError::InvalidEdgeType(a, b, c, e) => { + TypedError::InvalidEdgeType(et_map(a), nt_map(b), nt_map(c), e) + } TypedError::InvalidNodeType(a, e) => TypedError::InvalidNodeType(nt_map(a), e), TypedError::InvalidInternalState => TypedError::InvalidInternalState, TypedError::DownCastFailed(a, b) => TypedError::DownCastFailed(a, b), - TypedError::InconsistentNodeIds(a, b) => TypedError::InconsistentNodeIds(nk_map(a), nk_map(b)), - TypedError::InconsistentEdgeIds(a, b) => TypedError::InconsistentEdgeIds(ek_map(a), ek_map(b)), + TypedError::InconsistentNodeIds(a, b) => { + TypedError::InconsistentNodeIds(nk_map(a), nk_map(b)) + } + TypedError::InconsistentEdgeIds(a, b) => { + TypedError::InconsistentEdgeIds(ek_map(a), ek_map(b)) + } TypedError::InvalidEdgeMove(a, b) => TypedError::InvalidEdgeMove(ek_map(a), ek_map(b)), TypedError::MissingNodeKey(a) => TypedError::MissingNodeKey(a), TypedError::MissingEdgeKey(a) => TypedError::MissingEdgeKey(a), @@ -112,4 +118,4 @@ impl TypedError { TypedError::SerdeJsonError(a) => TypedError::SerdeJsonError(a), } } -} \ No newline at end of file +} diff --git a/typed_graph_py/typed_graph/__init__.py b/typed_graph_py/typed_graph/__init__.py index 41b1fa8..1d57e5f 100644 --- a/typed_graph_py/typed_graph/__init__.py +++ b/typed_graph_py/typed_graph/__init__.py @@ -1,4 +1,4 @@ -from typed_graph.dependency_traits import StrEnum, IntEnum, RustModel, RustRootModel +from typed_graph.dependency_traits import StrEnum, IntEnum, NestedEnum, RustModel, RustRootModel from typed_graph.typed_traits import SchemaExt, NodeExt, EdgeExt, TypeStatus, GraphData from typed_graph.typed_graph import TypedGraph from typed_graph.generic_graph import GenericGraph, GenericSchema, GenericWeight @@ -9,6 +9,7 @@ 'EdgeExt', 'StrEnum', 'IntEnum', + 'NestedEnum', 'TypeStatus', 'RustModel', 'RustRootModel', diff --git a/typed_graph_py/typed_graph/dependency_traits.py b/typed_graph_py/typed_graph/dependency_traits.py index 77a352b..1999c75 100644 --- a/typed_graph_py/typed_graph/dependency_traits.py +++ b/typed_graph_py/typed_graph/dependency_traits.py @@ -1,6 +1,8 @@ from enum import Enum, EnumMeta from pydantic import BaseModel, RootModel, model_serializer, model_validator from typing import Any, Callable +from pydantic._internal._model_construction import ModelMetaclass +import inspect class Enum_M(EnumMeta): def __new__(metacls, name: str, bases, classdict, **kwds): @@ -49,7 +51,7 @@ class vars: @model_serializer(mode = 'wrap') def _serialize( self, - default: Callable[['RustModel'], dict[str, Any]] + default: Callable [['RustModel'], dict[str, Any]] ) -> dict[str, Any] | Any: """ Serialize the model to a dict. @@ -107,4 +109,148 @@ def _deserialize( return ModelInstance RustModel = make_model(BaseModel) -RustRootModel = make_model(RootModel) \ No newline at end of file +RustRootModel = make_model(RootModel) + +class NestedEnumMeta(ModelMetaclass): + def __new__(metacls, name, bases, class_dct, *args, **kwargs): + """ + Create a new enum class with a number og varients as attributes + Each varient has their own class + that inherits all the base classes of the enum except for its pydantic model + """ + + # Retrieve varient annotations + annotations = None + for k, v in class_dct.items(): + if k == '__annotations__': + annotations = v + # Stop the varients from being made as fields in the enum base model + if '__annotations__' in class_dct: + del class_dct['__annotations__'] + + enum_class = super().__new__(metacls, name, bases, class_dct, *args, **kwargs) + + # Create a constructor on the enum that prevents it from being initialized + def __new__(self, *args, **kwarg): + raise Exception(f'Can\'t initialize enum type {name}') + setattr(enum_class, '__new__', __new__) + + # Find all bases clases that the varients should also inherit + varient_bases = [] + for enum_base in bases: + if enum_base.__name__ != 'NestedEnum' and not issubclass(enum_base, BaseModel): + varient_bases.append(enum_base) + + enum_varients = {} + + # Create varients if any are provided + if annotations: + for varient_name, varient_type in annotations.items(): + varient_class = NestedEnumMeta.create_varient(varient_name, varient_type, varient_bases, class_dct) + + setattr(enum_class, varient_name, varient_class) + enum_varients[varient_name] = varient_class + + setattr(enum_class, '_members', enum_varients) + return enum_class + + @staticmethod + def create_varient(varient_name, varient_type, varient_bases, class_dct, ): + varient_type_name = f"{class_dct['__qualname__']}.{varient_name}" + + if varient_type == str: + # Handle unit varients + class_bases = [RootModel, *varient_bases] + variation_class = ModelMetaclass.__new__( + ModelMetaclass, + varient_type_name, + (RootModel, ), + { + '__module__': class_dct['__module__'], + '__qualname__': varient_type_name, + '__annotations__': { + 'root': str, + }, + } + ) + + return variation_class(varient_name) + + elif isinstance(varient_type, dict): + # Handle struct varients + class_bases = [RustModel, *varient_bases] + + varient_dict = { + '__module__': class_dct['__module__'], + '__qualname__': varient_type_name, + '__annotations__': varient_type + } + + # pass information about generic along + if '__orig_bases__' in class_dct: + varient_dict['__orig_bases__'] = class_dct['__orig_bases__'] + + variation_class = ModelMetaclass.__new__( + ModelMetaclass, + varient_type_name, + tuple(class_bases), + varient_dict + ) + + return variation_class + else: + raise Exception(f"Unsupported varient type {varient_type} expected {str(str)} or {str(dict)}") + +class NestedEnum(BaseModel, metaclass=NestedEnumMeta): + + @model_validator(mode = 'wrap') + def _deserialize( + cls, + d: dict[str, Any] | Any, + default: Callable[[dict[str, Any]], 'RustModel'] + ) -> 'RustModel': + # Handle unit varients + if isinstance(d, str) and d in cls._members: + varient = cls._members[d] + if not inspect.isclass(varient): + return varient.model_validate(d) + + # If it is neither, then it must just be the enum + if not isinstance(d, dict): + return default(d) + + if len(d) != 1: + return default(d) + + # Handle dict varient + varient_name = next(iter(d.keys())) + if varient_name in cls._members: + varient = cls._members[varient_name] + if inspect.isclass(varient): + return varient.model_validate(d[varient_name]) + + return default(d) + + @classmethod + def __class_getitem__(cls, ty): + instance = super().__class_getitem__(ty) + + # We only populate _members if it is empty + # This is because Generic reuses the same class + if not hasattr(instance, '_members') or not instance._members : + instance._members = {} + + # We now need to propergate the generics from the enum to its varients + for name, _ in cls._members.items(): + + varient_instance = getattr(instance, name) + + # Do not propergate generics to unit enums + if inspect.isclass(varient_instance): + varient_instance = varient_instance[ty] + + # Update varients on instance class + setattr(instance, name, varient_instance) + instance._members[name] = varient_instance + + return instance \ No newline at end of file diff --git a/typed_graph_py/typed_graph/typed_graph.py b/typed_graph_py/typed_graph/typed_graph.py index 45d88da..e963cb7 100644 --- a/typed_graph_py/typed_graph/typed_graph.py +++ b/typed_graph_py/typed_graph/typed_graph.py @@ -438,19 +438,19 @@ def get_incoming_and_outgoing(self, node_id: NK) -> Iterator[EdgeRef[E, NK]]: """ return chain(self.get_outgoing(node_id), self.get_incoming(node_id)) - def get_incoming_filter(self, node_id: NK, f: Callable[[ET], bool]) -> Iterator[EdgeRef[E, NK]]: + def get_incoming_filter(self, node_id: NK, f: Callable[[E], bool]) -> Iterator[EdgeRef[E, NK]]: """ Filter the incoming edges based on edge type """ return filter(lambda e: f(e.weight), self.get_incoming(node_id)) - def get_outgoing_filter(self, node_id: NK, f: Callable[[ET], bool]) -> Iterator[EdgeRef[E, NK]]: + def get_outgoing_filter(self, node_id: NK, f: Callable[[E], bool]) -> Iterator[EdgeRef[E, NK]]: """ Filter the outgoing edges based on edge type """ return filter(lambda e: f(e.weight), self.get_outgoing(node_id)) - def get_incoming_and_outgoing_filter(self, node_id: NK, f: Callable[[ET], bool]) -> Iterator[EdgeRef[E, NK]]: + def get_incoming_and_outgoing_filter(self, node_id: NK, f: Callable[[E], bool]) -> Iterator[EdgeRef[E, NK]]: """ Filter incoming and outgoing edges based on edge type """