diff --git a/.gitignore b/.gitignore index de239947..4d45f685 100755 --- a/.gitignore +++ b/.gitignore @@ -160,4 +160,9 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ + +*env/ + +# Sphinx Documentation /docs/source/_static/css/fonts/ + diff --git a/docs/group-chat-example.md b/docs/group-chat-example.md new file mode 100755 index 00000000..c2aaa0ce --- /dev/null +++ b/docs/group-chat-example.md @@ -0,0 +1,130 @@ +# Custom Thresholds Example + +Here is an example that utilizes custom thresholds. + +The following graph represents a network of People and a Text Message in their group chat. + + +In this case, we want to know when a text message has been viewed by all members of the group chat. + +## Graph +First, lets create the group chat. + +```python +import networkx as nx + +# Create an empty graph +G = nx.Graph() + +# Add nodes +nodes = ["TextMessage", "Zach", "Justin", "Michelle", "Amy"] +G.add_nodes_from(nodes) + +# Add edges with attribute 'HaveAccess' +edges = [ + ("Zach", "TextMessage", {"HaveAccess": 1}), + ("Justin", "TextMessage", {"HaveAccess": 1}), + ("Michelle", "TextMessage", {"HaveAccess": 1}), + ("Amy", "TextMessage", {"HaveAccess": 1}) +] +G.add_edges_from(edges) + +``` + +## Rules and Custom Thresholds +Considering that we only want a text message to be considered viewed by all if it has been viewed by everyone that can view it, we define the rule as follows: + +```text +ViewedByAll(x) <- HaveAccess(x,y), Viewed(y) +``` + +The `head` of the rule is `ViewedByAll(x)` and the body is `HaveAccess(x,y), Viewed(y)`. The head and body are separated by an arrow which means the rule will start evaluating from +timestep 0. + +We add the rule into pyreason with: + +```python +import pyreason as pr +from pyreason import Threshold + +user_defined_thresholds = [ + Threshold("greater_equal", ("number", "total"), 1), + Threshold("greater_equal", ("percent", "total"), 100), +] + +pr.add_rule(pr.Rule('ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)', 'viewed_by_all_rule', user_defined_thresholds)) +``` +Where `viewed_by_all_rule` is the name of the rule. This helps to understand which rule/s are fired during reasoning later on. + +The `user_defined_thresholds` are a list of custom thresholds of the format: (quantifier, quantifier_type, thresh) where: +- quantifier can be greater_equal, greater, less_equal, less, equal +- quantifier_type is a tuple where the first element can be either number or percent and the second element can be either total or available +- thresh represents the numerical threshold value to compare against + +The custom thresholds are created corresponding to the two clauses (HaveAccess(x,y) and Viewed(y)) as below: +- ('greater_equal', ('number', 'total'), 1) (there needs to be at least one person who has access to TextMessage for the first clause to be satisfied) +- ('greater_equal', ('percent', 'total'), 100) (100% of people who have access to TextMessage need to view the message for second clause to be satisfied) + +## Facts +The facts determine the initial conditions of elements in the graph. They can be specified from the graph attributes but in that +case they will be immutable later on. Adding PyReason facts gives us more flexibility. + +In our case we want one person to view the TextMessage in a particular interval of timestep. +For example, we create facts stating: +- Zach and Justin view the TextMessage from at timestep 0 +- Michelle views the TextMessage at timestep 1 +- Amy views the TextMessage at timestep 2 + +We add the facts in PyReason as below: +```python +import pyreason as pr + +pr.add_fact(pr.Fact("seen-fact-zach", "Zach", "Viewed", [1, 1], 0, 0, static=True)) +pr.add_fact(pr.Fact("seen-fact-justin", "Justin", "Viewed", [1, 1], 0, 0, static=True)) +pr.add_fact(pr.Fact("seen-fact-michelle", "Michelle", "Viewed", [1, 1], 1, 1, static=True)) +pr.add_fact(pr.Fact("seen-fact-amy", "Amy", "Viewed", [1, 1], 2, 2, static=True)) +``` + +This allows us to specify the component that has an initial condition, the initial condition itself in the form of bounds +as well as the start and end time of this condition. + +## Running PyReason +Find the full code for this example [here](../tests/test_custom_thresholds.py) + +The main line that runs the reasoning in that file is: +```python +interpretation = pr.reason(timesteps=3) +``` +This specifies how many timesteps to run for. + +## Expected Output +After running the python file, the expected output is: + +``` +TIMESTEP - 0 +Empty DataFrame +Columns: [component, ViewedByAll] +Index: [] + +TIMESTEP - 1 +Empty DataFrame +Columns: [component, ViewedByAll] +Index: [] + +TIMESTEP - 2 + component ViewedByAll +0 TextMessage [1.0, 1.0] + +TIMESTEP - 3 + component ViewedByAll +0 TextMessage [1.0, 1.0] + +``` + +1. For timestep 0, we set `Zach -> Viewed: [1,1]` and `Justin -> Viewed: [1,1]` in the facts +2. For timestep 1, Michelle views the TextMessage as stated in facts `Michelle -> Viewed: [1,1]` +3. For timestep 2, since Amy has just viewed the TextMessage, therefore `Amy -> Viewed: [1,1]`. As per the rule, +since all the people have viewed the TextMessage, the message is marked as ViewedByAll. + + +We also output two CSV files detailing all the events that took place during reasoning (one for nodes, one for edges) diff --git a/media/group_chat_graph.png b/media/group_chat_graph.png new file mode 100644 index 00000000..dc9afac6 Binary files /dev/null and b/media/group_chat_graph.png differ diff --git a/pyreason/__init__.py b/pyreason/__init__.py index 56e7df41..85f8319a 100755 --- a/pyreason/__init__.py +++ b/pyreason/__init__.py @@ -23,6 +23,9 @@ add_fact(Fact('popular-fact', 'Mary', 'popular', [1, 1], 0, 2)) reason(timesteps=2) + reset() + reset_rules() + # Update cache status cache_status['initialized'] = True with open(cache_status_path, 'w') as file: diff --git a/pyreason/pyreason.py b/pyreason/pyreason.py index 94ac1573..48ae672f 100755 --- a/pyreason/pyreason.py +++ b/pyreason/pyreason.py @@ -17,6 +17,7 @@ import pyreason.scripts.numba_wrapper.numba_types.rule_type as rule from pyreason.scripts.facts.fact import Fact from pyreason.scripts.rules.rule import Rule +from pyreason.scripts.threshold.threshold import Threshold import pyreason.scripts.numba_wrapper.numba_types.fact_node_type as fact_node import pyreason.scripts.numba_wrapper.numba_types.fact_edge_type as fact_edge import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval diff --git a/pyreason/scripts/rules/rule.py b/pyreason/scripts/rules/rule.py index 97cca58d..73824c4d 100755 --- a/pyreason/scripts/rules/rule.py +++ b/pyreason/scripts/rules/rule.py @@ -6,12 +6,10 @@ class Rule: Example text: `'pred1(x,y) : [0.2, 1] <- pred2(a, b) : [1,1], pred3(b, c)'` - 1. It is not possible to specify thresholds. Threshold is greater than or equal to 1 by default - 2. It is not possible to have weights for different clauses. Weights are 1 by default with bias 0 - TODO: Add threshold class where we can pass this as a parameter + 1. It is not possible to have weights for different clauses. Weights are 1 by default with bias 0 TODO: Add weights as a parameter """ - def __init__(self, rule_text: str, name: str, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False): + def __init__(self, rule_text: str, name: str, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False, custom_thresholds=None): """ :param rule_text: The rule in text format :param name: The name of the rule. This will appear in the rule trace @@ -19,4 +17,6 @@ def __init__(self, rule_text: str, name: str, infer_edges: bool = False, set_sta :param set_static: Whether to set the atom in the head as static if the rule fires. The bounds will no longer change :param immediate_rule: Whether the rule is immediate. Immediate rules check for more applicable rules immediately after being applied """ - self.rule = rule_parser.parse_rule(rule_text, name, infer_edges, set_static, immediate_rule) + if custom_thresholds is None: + custom_thresholds = [] + self.rule = rule_parser.parse_rule(rule_text, name, custom_thresholds, infer_edges, set_static, immediate_rule) diff --git a/pyreason/scripts/threshold/__init__.py b/pyreason/scripts/threshold/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pyreason/scripts/threshold/threshold.py b/pyreason/scripts/threshold/threshold.py new file mode 100644 index 00000000..39722631 --- /dev/null +++ b/pyreason/scripts/threshold/threshold.py @@ -0,0 +1,41 @@ +class Threshold: + """ + A class representing a threshold for a clause in a rule. + + Attributes: + quantifier (str): The comparison operator, e.g., 'greater_equal', 'less', etc. + quantifier_type (tuple): A tuple indicating the type of quantifier, e.g., ('number', 'total'). + thresh (int): The numerical threshold value to compare against. + + Methods: + to_tuple(): Converts the Threshold instance into a tuple compatible with numba types. + """ + + def __init__(self, quantifier, quantifier_type, thresh): + """ + Initializes a Threshold instance. + + Args: + quantifier (str): The comparison operator for the threshold. + quantifier_type (tuple): The type of quantifier ('number' or 'percent', 'total' or 'available'). + thresh (int): The numerical value for the threshold. + """ + + if quantifier not in ("greater_equal", "greater", "less_equal", "less", "equal"): + raise ValueError("Invalid quantifier") + + if quantifier_type[0] not in ("number", "percent") or quantifier_type[1] not in ("total", "available"): + raise ValueError("Invalid quantifier type") + + self.quantifier = quantifier + self.quantifier_type = quantifier_type + self.thresh = thresh + + def to_tuple(self): + """ + Converts the Threshold instance into a tuple compatible with numba types. + + Returns: + tuple: A tuple representation of the Threshold instance. + """ + return (self.quantifier, self.quantifier_type, self.thresh) \ No newline at end of file diff --git a/pyreason/scripts/utils/rule_parser.py b/pyreason/scripts/utils/rule_parser.py index fc2db4bb..36bdede0 100644 --- a/pyreason/scripts/utils/rule_parser.py +++ b/pyreason/scripts/utils/rule_parser.py @@ -6,7 +6,7 @@ import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval -def parse_rule(rule_text: str, name: str, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False) -> rule.Rule: +def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False) -> rule.Rule: # First remove all spaces from line r = rule_text.replace(' ', '') @@ -152,7 +152,23 @@ def parse_rule(rule_text: str, name: str, infer_edges: bool = False, set_static: # Array to store clauses for nodes: node/edge, [subset]/[subset1, subset2], label, interval, operator clauses = numba.typed.List.empty_list(numba.types.Tuple((numba.types.string, label.label_type, numba.types.ListType(numba.types.string), interval.interval_type, numba.types.string))) - # Loop though clauses + # gather count of clauses for threshold validation + num_clauses = len(body_clauses) + + if custom_thresholds and (len(custom_thresholds) != num_clauses): + raise Exception('The length of custom thresholds {} is not equal to number of clauses {}' + .format(len(custom_thresholds), num_clauses)) + + # If no custom thresholds provided, use defaults + # otherwise loop through user-defined thresholds and convert to numba compatible format + if not custom_thresholds: + for _ in range(num_clauses): + thresholds.append(('greater_equal', ('number', 'total'), 1.0)) + else: + for threshold in custom_thresholds: + thresholds.append(threshold.to_tuple()) + + # # Loop though clauses for body_clause, predicate, variables, bounds in zip(body_clauses, body_predicates, body_variables, body_bounds): # Neigh criteria clause_type = 'node' if len(variables) == 1 else 'edge' @@ -165,12 +181,6 @@ def parse_rule(rule_text: str, name: str, infer_edges: bool = False, set_static: bnd = interval.closed(bounds[0], bounds[1]) clauses.append((clause_type, l, subset, bnd, op)) - # Threshold. - quantifier = 'greater_equal' - quantifier_type = ('number', 'total') - thresh = 1 - thresholds.append((quantifier, quantifier_type, thresh)) - # Assert that there are two variables in the head of the rule if we infer edges # Add edges between head variables if necessary if infer_edges: diff --git a/tests/group_chat_graph.graphml b/tests/group_chat_graph.graphml new file mode 100644 index 00000000..7c76e29b --- /dev/null +++ b/tests/group_chat_graph.graphml @@ -0,0 +1,26 @@ + + + + + + + + + + + + + 1 + + + 1 + + + + 1 + + + 1 + + + diff --git a/tests/test_custom_thresholds.py b/tests/test_custom_thresholds.py new file mode 100644 index 00000000..b982bf36 --- /dev/null +++ b/tests/test_custom_thresholds.py @@ -0,0 +1,62 @@ +# Test if the simple program works with thresholds defined +import pyreason as pr +from pyreason import Threshold + + +def test_custom_thresholds(): + # Reset PyReason + pr.reset() + pr.reset_rules() + + # Modify the paths based on where you've stored the files we made above + graph_path = "./tests/group_chat_graph.graphml" + + # Modify pyreason settings to make verbose and to save the rule trace to a file + pr.settings.verbose = True # Print info to screen + + # Load all the files into pyreason + pr.load_graphml(graph_path) + + # add custom thresholds + user_defined_thresholds = [ + Threshold("greater_equal", ("number", "total"), 1), + Threshold("greater_equal", ("percent", "total"), 100), + ] + + pr.add_rule( + pr.Rule( + "ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)", + "viewed_by_all_rule", + custom_thresholds=user_defined_thresholds, + ) + ) + + pr.add_fact(pr.Fact("seen-fact-zach", "Zach", "Viewed", [1, 1], 0, 3)) + pr.add_fact(pr.Fact("seen-fact-justin", "Justin", "Viewed", [1, 1], 0, 3)) + pr.add_fact(pr.Fact("seen-fact-michelle", "Michelle", "Viewed", [1, 1], 1, 3)) + pr.add_fact(pr.Fact("seen-fact-amy", "Amy", "Viewed", [1, 1], 2, 3)) + + # Run the program for three timesteps to see the diffusion take place + interpretation = pr.reason(timesteps=3) + + # Display the changes in the interpretation for each timestep + dataframes = pr.filter_and_sort_nodes(interpretation, ["ViewedByAll"]) + for t, df in enumerate(dataframes): + print(f"TIMESTEP - {t}") + print(df) + print() + + assert ( + len(dataframes[0]) == 0 + ), "At t=0 the TextMessage should not have been ViewedByAll" + assert ( + len(dataframes[2]) == 1 + ), "At t=2 the TextMessage should have been ViewedByAll" + + # TextMessage should be ViewedByAll in t=2 + assert "TextMessage" in dataframes[2]["component"].values and dataframes[2].iloc[ + 0 + ].ViewedByAll == [ + 1, + 1, + ], "TextMessage should have ViewedByAll bounds [1,1] for t=2 timesteps" diff --git a/tests/test_hello_world.py b/tests/test_hello_world.py index c2213458..c932daff 100644 --- a/tests/test_hello_world.py +++ b/tests/test_hello_world.py @@ -3,6 +3,10 @@ def test_hello_world(): + # Reset PyReason + pr.reset() + pr.reset_rules() + # Modify the paths based on where you've stored the files we made above graph_path = './tests/friends_graph.graphml' diff --git a/tests/test_hello_world_parallel.py b/tests/test_hello_world_parallel.py index cd16111c..1b7ee03c 100644 --- a/tests/test_hello_world_parallel.py +++ b/tests/test_hello_world_parallel.py @@ -3,6 +3,10 @@ def test_hello_world_parallel(): + # Reset PyReason + pr.reset() + pr.reset_rules() + # Modify the paths based on where you've stored the files we made above graph_path = './tests/friends_graph.graphml'