-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/lab-v2/pyreason
- Loading branch information
Showing
13 changed files
with
299 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Custom Thresholds Example | ||
|
||
Here is an example that utilizes custom thresholds. | ||
|
||
The following graph represents a network of People and a Text Message in their group chat. | ||
<img src="../media/group_chat_graph.png"/> | ||
|
||
In this case, we want to know when a text message has been viewed by all members of the group chat. | ||
|
||
## Graph | ||
First, lets create the group chat. | ||
|
||
```python | ||
import networkx as nx | ||
|
||
# Create an empty graph | ||
G = nx.Graph() | ||
|
||
# Add nodes | ||
nodes = ["TextMessage", "Zach", "Justin", "Michelle", "Amy"] | ||
G.add_nodes_from(nodes) | ||
|
||
# Add edges with attribute 'HaveAccess' | ||
edges = [ | ||
("Zach", "TextMessage", {"HaveAccess": 1}), | ||
("Justin", "TextMessage", {"HaveAccess": 1}), | ||
("Michelle", "TextMessage", {"HaveAccess": 1}), | ||
("Amy", "TextMessage", {"HaveAccess": 1}) | ||
] | ||
G.add_edges_from(edges) | ||
|
||
``` | ||
|
||
## Rules and Custom Thresholds | ||
Considering that we only want a text message to be considered viewed by all if it has been viewed by everyone that can view it, we define the rule as follows: | ||
|
||
```text | ||
ViewedByAll(x) <- HaveAccess(x,y), Viewed(y) | ||
``` | ||
|
||
The `head` of the rule is `ViewedByAll(x)` and the body is `HaveAccess(x,y), Viewed(y)`. The head and body are separated by an arrow which means the rule will start evaluating from | ||
timestep 0. | ||
|
||
We add the rule into pyreason with: | ||
|
||
```python | ||
import pyreason as pr | ||
from pyreason import Threshold | ||
|
||
user_defined_thresholds = [ | ||
Threshold("greater_equal", ("number", "total"), 1), | ||
Threshold("greater_equal", ("percent", "total"), 100), | ||
] | ||
|
||
pr.add_rule(pr.Rule('ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)', 'viewed_by_all_rule', user_defined_thresholds)) | ||
``` | ||
Where `viewed_by_all_rule` is the name of the rule. This helps to understand which rule/s are fired during reasoning later on. | ||
|
||
The `user_defined_thresholds` are a list of custom thresholds of the format: (quantifier, quantifier_type, thresh) where: | ||
- quantifier can be greater_equal, greater, less_equal, less, equal | ||
- quantifier_type is a tuple where the first element can be either number or percent and the second element can be either total or available | ||
- thresh represents the numerical threshold value to compare against | ||
|
||
The custom thresholds are created corresponding to the two clauses (HaveAccess(x,y) and Viewed(y)) as below: | ||
- ('greater_equal', ('number', 'total'), 1) (there needs to be at least one person who has access to TextMessage for the first clause to be satisfied) | ||
- ('greater_equal', ('percent', 'total'), 100) (100% of people who have access to TextMessage need to view the message for second clause to be satisfied) | ||
|
||
## Facts | ||
The facts determine the initial conditions of elements in the graph. They can be specified from the graph attributes but in that | ||
case they will be immutable later on. Adding PyReason facts gives us more flexibility. | ||
|
||
In our case we want one person to view the TextMessage in a particular interval of timestep. | ||
For example, we create facts stating: | ||
- Zach and Justin view the TextMessage from at timestep 0 | ||
- Michelle views the TextMessage at timestep 1 | ||
- Amy views the TextMessage at timestep 2 | ||
|
||
We add the facts in PyReason as below: | ||
```python | ||
import pyreason as pr | ||
|
||
pr.add_fact(pr.Fact("seen-fact-zach", "Zach", "Viewed", [1, 1], 0, 0, static=True)) | ||
pr.add_fact(pr.Fact("seen-fact-justin", "Justin", "Viewed", [1, 1], 0, 0, static=True)) | ||
pr.add_fact(pr.Fact("seen-fact-michelle", "Michelle", "Viewed", [1, 1], 1, 1, static=True)) | ||
pr.add_fact(pr.Fact("seen-fact-amy", "Amy", "Viewed", [1, 1], 2, 2, static=True)) | ||
``` | ||
|
||
This allows us to specify the component that has an initial condition, the initial condition itself in the form of bounds | ||
as well as the start and end time of this condition. | ||
|
||
## Running PyReason | ||
Find the full code for this example [here](../tests/test_custom_thresholds.py) | ||
|
||
The main line that runs the reasoning in that file is: | ||
```python | ||
interpretation = pr.reason(timesteps=3) | ||
``` | ||
This specifies how many timesteps to run for. | ||
|
||
## Expected Output | ||
After running the python file, the expected output is: | ||
|
||
``` | ||
TIMESTEP - 0 | ||
Empty DataFrame | ||
Columns: [component, ViewedByAll] | ||
Index: [] | ||
TIMESTEP - 1 | ||
Empty DataFrame | ||
Columns: [component, ViewedByAll] | ||
Index: [] | ||
TIMESTEP - 2 | ||
component ViewedByAll | ||
0 TextMessage [1.0, 1.0] | ||
TIMESTEP - 3 | ||
component ViewedByAll | ||
0 TextMessage [1.0, 1.0] | ||
``` | ||
|
||
1. For timestep 0, we set `Zach -> Viewed: [1,1]` and `Justin -> Viewed: [1,1]` in the facts | ||
2. For timestep 1, Michelle views the TextMessage as stated in facts `Michelle -> Viewed: [1,1]` | ||
3. For timestep 2, since Amy has just viewed the TextMessage, therefore `Amy -> Viewed: [1,1]`. As per the rule, | ||
since all the people have viewed the TextMessage, the message is marked as ViewedByAll. | ||
|
||
|
||
We also output two CSV files detailing all the events that took place during reasoning (one for nodes, one for edges) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
class Threshold: | ||
""" | ||
A class representing a threshold for a clause in a rule. | ||
Attributes: | ||
quantifier (str): The comparison operator, e.g., 'greater_equal', 'less', etc. | ||
quantifier_type (tuple): A tuple indicating the type of quantifier, e.g., ('number', 'total'). | ||
thresh (int): The numerical threshold value to compare against. | ||
Methods: | ||
to_tuple(): Converts the Threshold instance into a tuple compatible with numba types. | ||
""" | ||
|
||
def __init__(self, quantifier, quantifier_type, thresh): | ||
""" | ||
Initializes a Threshold instance. | ||
Args: | ||
quantifier (str): The comparison operator for the threshold. | ||
quantifier_type (tuple): The type of quantifier ('number' or 'percent', 'total' or 'available'). | ||
thresh (int): The numerical value for the threshold. | ||
""" | ||
|
||
if quantifier not in ("greater_equal", "greater", "less_equal", "less", "equal"): | ||
raise ValueError("Invalid quantifier") | ||
|
||
if quantifier_type[0] not in ("number", "percent") or quantifier_type[1] not in ("total", "available"): | ||
raise ValueError("Invalid quantifier type") | ||
|
||
self.quantifier = quantifier | ||
self.quantifier_type = quantifier_type | ||
self.thresh = thresh | ||
|
||
def to_tuple(self): | ||
""" | ||
Converts the Threshold instance into a tuple compatible with numba types. | ||
Returns: | ||
tuple: A tuple representation of the Threshold instance. | ||
""" | ||
return (self.quantifier, self.quantifier_type, self.thresh) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
<?xml version='1.0' encoding='utf-8'?> | ||
<graphml xmlns="http://graphml.graphdrawing.org/xmlns" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd"> | ||
|
||
<key id="HaveAccess" for="edge" attr.name="HaveAccess" attr.type="long" /> | ||
<graph edgedefault="undirected"> | ||
<node id="TextMessage" /> | ||
<node id="Zach" /> | ||
<node id="Justin" /> | ||
<node id="Michelle" /> | ||
<node id="Amy" /> | ||
|
||
<edge source="Zach" target="TextMessage"> | ||
<data key="HaveAccess">1</data> | ||
</edge> | ||
<edge source="Justin" target="TextMessage"> | ||
<data key="HaveAccess">1</data> | ||
</edge> | ||
|
||
<edge source="Amy" target="TextMessage"> | ||
<data key="HaveAccess">1</data> | ||
</edge> | ||
<edge source="Michelle" target="TextMessage"> | ||
<data key="HaveAccess">1</data> | ||
</edge> | ||
</graph> | ||
</graphml> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Test if the simple program works with thresholds defined | ||
import pyreason as pr | ||
from pyreason import Threshold | ||
|
||
|
||
def test_custom_thresholds(): | ||
# Reset PyReason | ||
pr.reset() | ||
pr.reset_rules() | ||
|
||
# Modify the paths based on where you've stored the files we made above | ||
graph_path = "./tests/group_chat_graph.graphml" | ||
|
||
# Modify pyreason settings to make verbose and to save the rule trace to a file | ||
pr.settings.verbose = True # Print info to screen | ||
|
||
# Load all the files into pyreason | ||
pr.load_graphml(graph_path) | ||
|
||
# add custom thresholds | ||
user_defined_thresholds = [ | ||
Threshold("greater_equal", ("number", "total"), 1), | ||
Threshold("greater_equal", ("percent", "total"), 100), | ||
] | ||
|
||
pr.add_rule( | ||
pr.Rule( | ||
"ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)", | ||
"viewed_by_all_rule", | ||
custom_thresholds=user_defined_thresholds, | ||
) | ||
) | ||
|
||
pr.add_fact(pr.Fact("seen-fact-zach", "Zach", "Viewed", [1, 1], 0, 3)) | ||
pr.add_fact(pr.Fact("seen-fact-justin", "Justin", "Viewed", [1, 1], 0, 3)) | ||
pr.add_fact(pr.Fact("seen-fact-michelle", "Michelle", "Viewed", [1, 1], 1, 3)) | ||
pr.add_fact(pr.Fact("seen-fact-amy", "Amy", "Viewed", [1, 1], 2, 3)) | ||
|
||
# Run the program for three timesteps to see the diffusion take place | ||
interpretation = pr.reason(timesteps=3) | ||
|
||
# Display the changes in the interpretation for each timestep | ||
dataframes = pr.filter_and_sort_nodes(interpretation, ["ViewedByAll"]) | ||
for t, df in enumerate(dataframes): | ||
print(f"TIMESTEP - {t}") | ||
print(df) | ||
print() | ||
|
||
assert ( | ||
len(dataframes[0]) == 0 | ||
), "At t=0 the TextMessage should not have been ViewedByAll" | ||
assert ( | ||
len(dataframes[2]) == 1 | ||
), "At t=2 the TextMessage should have been ViewedByAll" | ||
|
||
# TextMessage should be ViewedByAll in t=2 | ||
assert "TextMessage" in dataframes[2]["component"].values and dataframes[2].iloc[ | ||
0 | ||
].ViewedByAll == [ | ||
1, | ||
1, | ||
], "TextMessage should have ViewedByAll bounds [1,1] for t=2 timesteps" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters