diff --git a/.gitignore b/.gitignore
index de239947..4d45f685 100755
--- a/.gitignore
+++ b/.gitignore
@@ -160,4 +160,9 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
+
+*env/
+
+# Sphinx Documentation
/docs/source/_static/css/fonts/
+
diff --git a/docs/group-chat-example.md b/docs/group-chat-example.md
new file mode 100755
index 00000000..c2aaa0ce
--- /dev/null
+++ b/docs/group-chat-example.md
@@ -0,0 +1,130 @@
+# Custom Thresholds Example
+
+Here is an example that utilizes custom thresholds.
+
+The following graph represents a network of People and a Text Message in their group chat.
+
+
+In this case, we want to know when a text message has been viewed by all members of the group chat.
+
+## Graph
+First, lets create the group chat.
+
+```python
+import networkx as nx
+
+# Create an empty graph
+G = nx.Graph()
+
+# Add nodes
+nodes = ["TextMessage", "Zach", "Justin", "Michelle", "Amy"]
+G.add_nodes_from(nodes)
+
+# Add edges with attribute 'HaveAccess'
+edges = [
+ ("Zach", "TextMessage", {"HaveAccess": 1}),
+ ("Justin", "TextMessage", {"HaveAccess": 1}),
+ ("Michelle", "TextMessage", {"HaveAccess": 1}),
+ ("Amy", "TextMessage", {"HaveAccess": 1})
+]
+G.add_edges_from(edges)
+
+```
+
+## Rules and Custom Thresholds
+Considering that we only want a text message to be considered viewed by all if it has been viewed by everyone that can view it, we define the rule as follows:
+
+```text
+ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)
+```
+
+The `head` of the rule is `ViewedByAll(x)` and the body is `HaveAccess(x,y), Viewed(y)`. The head and body are separated by an arrow which means the rule will start evaluating from
+timestep 0.
+
+We add the rule into pyreason with:
+
+```python
+import pyreason as pr
+from pyreason import Threshold
+
+user_defined_thresholds = [
+ Threshold("greater_equal", ("number", "total"), 1),
+ Threshold("greater_equal", ("percent", "total"), 100),
+]
+
+pr.add_rule(pr.Rule('ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)', 'viewed_by_all_rule', user_defined_thresholds))
+```
+Where `viewed_by_all_rule` is the name of the rule. This helps to understand which rule/s are fired during reasoning later on.
+
+The `user_defined_thresholds` are a list of custom thresholds of the format: (quantifier, quantifier_type, thresh) where:
+- quantifier can be greater_equal, greater, less_equal, less, equal
+- quantifier_type is a tuple where the first element can be either number or percent and the second element can be either total or available
+- thresh represents the numerical threshold value to compare against
+
+The custom thresholds are created corresponding to the two clauses (HaveAccess(x,y) and Viewed(y)) as below:
+- ('greater_equal', ('number', 'total'), 1) (there needs to be at least one person who has access to TextMessage for the first clause to be satisfied)
+- ('greater_equal', ('percent', 'total'), 100) (100% of people who have access to TextMessage need to view the message for second clause to be satisfied)
+
+## Facts
+The facts determine the initial conditions of elements in the graph. They can be specified from the graph attributes but in that
+case they will be immutable later on. Adding PyReason facts gives us more flexibility.
+
+In our case we want one person to view the TextMessage in a particular interval of timestep.
+For example, we create facts stating:
+- Zach and Justin view the TextMessage from at timestep 0
+- Michelle views the TextMessage at timestep 1
+- Amy views the TextMessage at timestep 2
+
+We add the facts in PyReason as below:
+```python
+import pyreason as pr
+
+pr.add_fact(pr.Fact("seen-fact-zach", "Zach", "Viewed", [1, 1], 0, 0, static=True))
+pr.add_fact(pr.Fact("seen-fact-justin", "Justin", "Viewed", [1, 1], 0, 0, static=True))
+pr.add_fact(pr.Fact("seen-fact-michelle", "Michelle", "Viewed", [1, 1], 1, 1, static=True))
+pr.add_fact(pr.Fact("seen-fact-amy", "Amy", "Viewed", [1, 1], 2, 2, static=True))
+```
+
+This allows us to specify the component that has an initial condition, the initial condition itself in the form of bounds
+as well as the start and end time of this condition.
+
+## Running PyReason
+Find the full code for this example [here](../tests/test_custom_thresholds.py)
+
+The main line that runs the reasoning in that file is:
+```python
+interpretation = pr.reason(timesteps=3)
+```
+This specifies how many timesteps to run for.
+
+## Expected Output
+After running the python file, the expected output is:
+
+```
+TIMESTEP - 0
+Empty DataFrame
+Columns: [component, ViewedByAll]
+Index: []
+
+TIMESTEP - 1
+Empty DataFrame
+Columns: [component, ViewedByAll]
+Index: []
+
+TIMESTEP - 2
+ component ViewedByAll
+0 TextMessage [1.0, 1.0]
+
+TIMESTEP - 3
+ component ViewedByAll
+0 TextMessage [1.0, 1.0]
+
+```
+
+1. For timestep 0, we set `Zach -> Viewed: [1,1]` and `Justin -> Viewed: [1,1]` in the facts
+2. For timestep 1, Michelle views the TextMessage as stated in facts `Michelle -> Viewed: [1,1]`
+3. For timestep 2, since Amy has just viewed the TextMessage, therefore `Amy -> Viewed: [1,1]`. As per the rule,
+since all the people have viewed the TextMessage, the message is marked as ViewedByAll.
+
+
+We also output two CSV files detailing all the events that took place during reasoning (one for nodes, one for edges)
diff --git a/media/group_chat_graph.png b/media/group_chat_graph.png
new file mode 100644
index 00000000..dc9afac6
Binary files /dev/null and b/media/group_chat_graph.png differ
diff --git a/pyreason/__init__.py b/pyreason/__init__.py
index 56e7df41..85f8319a 100755
--- a/pyreason/__init__.py
+++ b/pyreason/__init__.py
@@ -23,6 +23,9 @@
add_fact(Fact('popular-fact', 'Mary', 'popular', [1, 1], 0, 2))
reason(timesteps=2)
+ reset()
+ reset_rules()
+
# Update cache status
cache_status['initialized'] = True
with open(cache_status_path, 'w') as file:
diff --git a/pyreason/pyreason.py b/pyreason/pyreason.py
index 94ac1573..48ae672f 100755
--- a/pyreason/pyreason.py
+++ b/pyreason/pyreason.py
@@ -17,6 +17,7 @@
import pyreason.scripts.numba_wrapper.numba_types.rule_type as rule
from pyreason.scripts.facts.fact import Fact
from pyreason.scripts.rules.rule import Rule
+from pyreason.scripts.threshold.threshold import Threshold
import pyreason.scripts.numba_wrapper.numba_types.fact_node_type as fact_node
import pyreason.scripts.numba_wrapper.numba_types.fact_edge_type as fact_edge
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
diff --git a/pyreason/scripts/rules/rule.py b/pyreason/scripts/rules/rule.py
index 97cca58d..73824c4d 100755
--- a/pyreason/scripts/rules/rule.py
+++ b/pyreason/scripts/rules/rule.py
@@ -6,12 +6,10 @@ class Rule:
Example text:
`'pred1(x,y) : [0.2, 1] <- pred2(a, b) : [1,1], pred3(b, c)'`
- 1. It is not possible to specify thresholds. Threshold is greater than or equal to 1 by default
- 2. It is not possible to have weights for different clauses. Weights are 1 by default with bias 0
- TODO: Add threshold class where we can pass this as a parameter
+ 1. It is not possible to have weights for different clauses. Weights are 1 by default with bias 0
TODO: Add weights as a parameter
"""
- def __init__(self, rule_text: str, name: str, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False):
+ def __init__(self, rule_text: str, name: str, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False, custom_thresholds=None):
"""
:param rule_text: The rule in text format
:param name: The name of the rule. This will appear in the rule trace
@@ -19,4 +17,6 @@ def __init__(self, rule_text: str, name: str, infer_edges: bool = False, set_sta
:param set_static: Whether to set the atom in the head as static if the rule fires. The bounds will no longer change
:param immediate_rule: Whether the rule is immediate. Immediate rules check for more applicable rules immediately after being applied
"""
- self.rule = rule_parser.parse_rule(rule_text, name, infer_edges, set_static, immediate_rule)
+ if custom_thresholds is None:
+ custom_thresholds = []
+ self.rule = rule_parser.parse_rule(rule_text, name, custom_thresholds, infer_edges, set_static, immediate_rule)
diff --git a/pyreason/scripts/threshold/__init__.py b/pyreason/scripts/threshold/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/pyreason/scripts/threshold/threshold.py b/pyreason/scripts/threshold/threshold.py
new file mode 100644
index 00000000..39722631
--- /dev/null
+++ b/pyreason/scripts/threshold/threshold.py
@@ -0,0 +1,41 @@
+class Threshold:
+ """
+ A class representing a threshold for a clause in a rule.
+
+ Attributes:
+ quantifier (str): The comparison operator, e.g., 'greater_equal', 'less', etc.
+ quantifier_type (tuple): A tuple indicating the type of quantifier, e.g., ('number', 'total').
+ thresh (int): The numerical threshold value to compare against.
+
+ Methods:
+ to_tuple(): Converts the Threshold instance into a tuple compatible with numba types.
+ """
+
+ def __init__(self, quantifier, quantifier_type, thresh):
+ """
+ Initializes a Threshold instance.
+
+ Args:
+ quantifier (str): The comparison operator for the threshold.
+ quantifier_type (tuple): The type of quantifier ('number' or 'percent', 'total' or 'available').
+ thresh (int): The numerical value for the threshold.
+ """
+
+ if quantifier not in ("greater_equal", "greater", "less_equal", "less", "equal"):
+ raise ValueError("Invalid quantifier")
+
+ if quantifier_type[0] not in ("number", "percent") or quantifier_type[1] not in ("total", "available"):
+ raise ValueError("Invalid quantifier type")
+
+ self.quantifier = quantifier
+ self.quantifier_type = quantifier_type
+ self.thresh = thresh
+
+ def to_tuple(self):
+ """
+ Converts the Threshold instance into a tuple compatible with numba types.
+
+ Returns:
+ tuple: A tuple representation of the Threshold instance.
+ """
+ return (self.quantifier, self.quantifier_type, self.thresh)
\ No newline at end of file
diff --git a/pyreason/scripts/utils/rule_parser.py b/pyreason/scripts/utils/rule_parser.py
index fc2db4bb..36bdede0 100644
--- a/pyreason/scripts/utils/rule_parser.py
+++ b/pyreason/scripts/utils/rule_parser.py
@@ -6,7 +6,7 @@
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
-def parse_rule(rule_text: str, name: str, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False) -> rule.Rule:
+def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False) -> rule.Rule:
# First remove all spaces from line
r = rule_text.replace(' ', '')
@@ -152,7 +152,23 @@ def parse_rule(rule_text: str, name: str, infer_edges: bool = False, set_static:
# Array to store clauses for nodes: node/edge, [subset]/[subset1, subset2], label, interval, operator
clauses = numba.typed.List.empty_list(numba.types.Tuple((numba.types.string, label.label_type, numba.types.ListType(numba.types.string), interval.interval_type, numba.types.string)))
- # Loop though clauses
+ # gather count of clauses for threshold validation
+ num_clauses = len(body_clauses)
+
+ if custom_thresholds and (len(custom_thresholds) != num_clauses):
+ raise Exception('The length of custom thresholds {} is not equal to number of clauses {}'
+ .format(len(custom_thresholds), num_clauses))
+
+ # If no custom thresholds provided, use defaults
+ # otherwise loop through user-defined thresholds and convert to numba compatible format
+ if not custom_thresholds:
+ for _ in range(num_clauses):
+ thresholds.append(('greater_equal', ('number', 'total'), 1.0))
+ else:
+ for threshold in custom_thresholds:
+ thresholds.append(threshold.to_tuple())
+
+ # # Loop though clauses
for body_clause, predicate, variables, bounds in zip(body_clauses, body_predicates, body_variables, body_bounds):
# Neigh criteria
clause_type = 'node' if len(variables) == 1 else 'edge'
@@ -165,12 +181,6 @@ def parse_rule(rule_text: str, name: str, infer_edges: bool = False, set_static:
bnd = interval.closed(bounds[0], bounds[1])
clauses.append((clause_type, l, subset, bnd, op))
- # Threshold.
- quantifier = 'greater_equal'
- quantifier_type = ('number', 'total')
- thresh = 1
- thresholds.append((quantifier, quantifier_type, thresh))
-
# Assert that there are two variables in the head of the rule if we infer edges
# Add edges between head variables if necessary
if infer_edges:
diff --git a/tests/group_chat_graph.graphml b/tests/group_chat_graph.graphml
new file mode 100644
index 00000000..7c76e29b
--- /dev/null
+++ b/tests/group_chat_graph.graphml
@@ -0,0 +1,26 @@
+
+
+
+
+
+
+
+
+
+
+
+
+ 1
+
+
+ 1
+
+
+
+ 1
+
+
+ 1
+
+
+
diff --git a/tests/test_custom_thresholds.py b/tests/test_custom_thresholds.py
new file mode 100644
index 00000000..b982bf36
--- /dev/null
+++ b/tests/test_custom_thresholds.py
@@ -0,0 +1,62 @@
+# Test if the simple program works with thresholds defined
+import pyreason as pr
+from pyreason import Threshold
+
+
+def test_custom_thresholds():
+ # Reset PyReason
+ pr.reset()
+ pr.reset_rules()
+
+ # Modify the paths based on where you've stored the files we made above
+ graph_path = "./tests/group_chat_graph.graphml"
+
+ # Modify pyreason settings to make verbose and to save the rule trace to a file
+ pr.settings.verbose = True # Print info to screen
+
+ # Load all the files into pyreason
+ pr.load_graphml(graph_path)
+
+ # add custom thresholds
+ user_defined_thresholds = [
+ Threshold("greater_equal", ("number", "total"), 1),
+ Threshold("greater_equal", ("percent", "total"), 100),
+ ]
+
+ pr.add_rule(
+ pr.Rule(
+ "ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)",
+ "viewed_by_all_rule",
+ custom_thresholds=user_defined_thresholds,
+ )
+ )
+
+ pr.add_fact(pr.Fact("seen-fact-zach", "Zach", "Viewed", [1, 1], 0, 3))
+ pr.add_fact(pr.Fact("seen-fact-justin", "Justin", "Viewed", [1, 1], 0, 3))
+ pr.add_fact(pr.Fact("seen-fact-michelle", "Michelle", "Viewed", [1, 1], 1, 3))
+ pr.add_fact(pr.Fact("seen-fact-amy", "Amy", "Viewed", [1, 1], 2, 3))
+
+ # Run the program for three timesteps to see the diffusion take place
+ interpretation = pr.reason(timesteps=3)
+
+ # Display the changes in the interpretation for each timestep
+ dataframes = pr.filter_and_sort_nodes(interpretation, ["ViewedByAll"])
+ for t, df in enumerate(dataframes):
+ print(f"TIMESTEP - {t}")
+ print(df)
+ print()
+
+ assert (
+ len(dataframes[0]) == 0
+ ), "At t=0 the TextMessage should not have been ViewedByAll"
+ assert (
+ len(dataframes[2]) == 1
+ ), "At t=2 the TextMessage should have been ViewedByAll"
+
+ # TextMessage should be ViewedByAll in t=2
+ assert "TextMessage" in dataframes[2]["component"].values and dataframes[2].iloc[
+ 0
+ ].ViewedByAll == [
+ 1,
+ 1,
+ ], "TextMessage should have ViewedByAll bounds [1,1] for t=2 timesteps"
diff --git a/tests/test_hello_world.py b/tests/test_hello_world.py
index c2213458..c932daff 100644
--- a/tests/test_hello_world.py
+++ b/tests/test_hello_world.py
@@ -3,6 +3,10 @@
def test_hello_world():
+ # Reset PyReason
+ pr.reset()
+ pr.reset_rules()
+
# Modify the paths based on where you've stored the files we made above
graph_path = './tests/friends_graph.graphml'
diff --git a/tests/test_hello_world_parallel.py b/tests/test_hello_world_parallel.py
index cd16111c..1b7ee03c 100644
--- a/tests/test_hello_world_parallel.py
+++ b/tests/test_hello_world_parallel.py
@@ -3,6 +3,10 @@
def test_hello_world_parallel():
+ # Reset PyReason
+ pr.reset()
+ pr.reset_rules()
+
# Modify the paths based on where you've stored the files we made above
graph_path = './tests/friends_graph.graphml'