Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Thresholds #51

Merged
merged 12 commits into from
May 16, 2024
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,9 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

*env/

# Sphinx Documentation
/docs/source/_static/css/fonts/

130 changes: 130 additions & 0 deletions docs/group-chat-example.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Custom Thresholds Example

Here is an example that utilizes custom thresholds.

The following graph represents a network of People and a Text Message in their group chat.
<img src="../media/group_chat_graph.png"/>

In this case, we want to know when a text message has been viewed by all members of the group chat.

## Graph
First, lets create the group chat.

```python
import networkx as nx

# Create an empty graph
G = nx.Graph()

# Add nodes
nodes = ["TextMessage", "Zach", "Justin", "Michelle", "Amy"]
G.add_nodes_from(nodes)

# Add edges with attribute 'HaveAccess'
edges = [
("Zach", "TextMessage", {"HaveAccess": 1}),
("Justin", "TextMessage", {"HaveAccess": 1}),
("Michelle", "TextMessage", {"HaveAccess": 1}),
("Amy", "TextMessage", {"HaveAccess": 1})
]
G.add_edges_from(edges)

```

## Rules and Custom Thresholds
Considering that we only want a text message to be considered viewed by all if it has been viewed by everyone that can view it, we define the rule as follows:

```text
ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)
```

The `head` of the rule is `ViewedByAll(x)` and the body is `HaveAccess(x,y), Viewed(y)`. The head and body are separated by an arrow which means the rule will start evaluating from
timestep 0.

We add the rule into pyreason with:

```python
import pyreason as pr
from pyreason import Threshold

user_defined_thresholds = [
Threshold("greater_equal", ("number", "total"), 1),
Threshold("greater_equal", ("percent", "total"), 100),
]

pr.add_rule(pr.Rule('ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)', 'viewed_by_all_rule', user_defined_thresholds))
```
Where `viewed_by_all_rule` is the name of the rule. This helps to understand which rule/s are fired during reasoning later on.

The `user_defined_thresholds` are a list of custom thresholds of the format: (quantifier, quantifier_type, thresh) where:
- quantifier can be greater_equal, greater, less_equal, less, equal
- quantifier_type is a tuple where the first element can be either number or percent and the second element can be either total or available
- thresh represents the numerical threshold value to compare against

The custom thresholds are created corresponding to the two clauses (HaveAccess(x,y) and Viewed(y)) as below:
- ('greater_equal', ('number', 'total'), 1) (there needs to be at least one person who has access to TextMessage for the first clause to be satisfied)
- ('greater_equal', ('percent', 'total'), 100) (100% of people who have access to TextMessage need to view the message for second clause to be satisfied)

## Facts
The facts determine the initial conditions of elements in the graph. They can be specified from the graph attributes but in that
case they will be immutable later on. Adding PyReason facts gives us more flexibility.

In our case we want one person to view the TextMessage in a particular interval of timestep.
For example, we create facts stating:
- Zach and Justin view the TextMessage from at timestep 0
- Michelle views the TextMessage at timestep 1
- Amy views the TextMessage at timestep 2

We add the facts in PyReason as below:
```python
import pyreason as pr

pr.add_fact(pr.Fact("seen-fact-zach", "Zach", "Viewed", [1, 1], 0, 0, static=True))
pr.add_fact(pr.Fact("seen-fact-justin", "Justin", "Viewed", [1, 1], 0, 0, static=True))
pr.add_fact(pr.Fact("seen-fact-michelle", "Michelle", "Viewed", [1, 1], 1, 1, static=True))
pr.add_fact(pr.Fact("seen-fact-amy", "Amy", "Viewed", [1, 1], 2, 2, static=True))
```

This allows us to specify the component that has an initial condition, the initial condition itself in the form of bounds
as well as the start and end time of this condition.

## Running PyReason
Find the full code for this example [here](../tests/test_custom_thresholds.py)

The main line that runs the reasoning in that file is:
```python
interpretation = pr.reason(timesteps=3)
```
This specifies how many timesteps to run for.

## Expected Output
After running the python file, the expected output is:

```
TIMESTEP - 0
Empty DataFrame
Columns: [component, ViewedByAll]
Index: []

TIMESTEP - 1
Empty DataFrame
Columns: [component, ViewedByAll]
Index: []

TIMESTEP - 2
component ViewedByAll
0 TextMessage [1.0, 1.0]

TIMESTEP - 3
component ViewedByAll
0 TextMessage [1.0, 1.0]

```

1. For timestep 0, we set `Zach -> Viewed: [1,1]` and `Justin -> Viewed: [1,1]` in the facts
2. For timestep 1, Michelle views the TextMessage as stated in facts `Michelle -> Viewed: [1,1]`
3. For timestep 2, since Amy has just viewed the TextMessage, therefore `Amy -> Viewed: [1,1]`. As per the rule,
since all the people have viewed the TextMessage, the message is marked as ViewedByAll.


We also output two CSV files detailing all the events that took place during reasoning (one for nodes, one for edges)
Binary file added media/group_chat_graph.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions pyreason/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
add_fact(Fact('popular-fact', 'Mary', 'popular', [1, 1], 0, 2))
reason(timesteps=2)

reset()
reset_rules()

# Update cache status
cache_status['initialized'] = True
with open(cache_status_path, 'w') as file:
Expand Down
1 change: 1 addition & 0 deletions pyreason/pyreason.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pyreason.scripts.numba_wrapper.numba_types.rule_type as rule
from pyreason.scripts.facts.fact import Fact
from pyreason.scripts.rules.rule import Rule
from pyreason.scripts.threshold.threshold import Threshold
import pyreason.scripts.numba_wrapper.numba_types.fact_node_type as fact_node
import pyreason.scripts.numba_wrapper.numba_types.fact_edge_type as fact_edge
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
Expand Down
10 changes: 5 additions & 5 deletions pyreason/scripts/rules/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ class Rule:
Example text:
`'pred1(x,y) : [0.2, 1] <- pred2(a, b) : [1,1], pred3(b, c)'`

1. It is not possible to specify thresholds. Threshold is greater than or equal to 1 by default
2. It is not possible to have weights for different clauses. Weights are 1 by default with bias 0
TODO: Add threshold class where we can pass this as a parameter
1. It is not possible to have weights for different clauses. Weights are 1 by default with bias 0
TODO: Add weights as a parameter
"""
def __init__(self, rule_text: str, name: str, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False):
def __init__(self, rule_text: str, name: str, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False, custom_thresholds=None):
"""
:param rule_text: The rule in text format
:param name: The name of the rule. This will appear in the rule trace
:param infer_edges: Whether to infer new edges after edge rule fires
:param set_static: Whether to set the atom in the head as static if the rule fires. The bounds will no longer change
:param immediate_rule: Whether the rule is immediate. Immediate rules check for more applicable rules immediately after being applied
"""
self.rule = rule_parser.parse_rule(rule_text, name, infer_edges, set_static, immediate_rule)
if custom_thresholds is None:
custom_thresholds = []
self.rule = rule_parser.parse_rule(rule_text, name, custom_thresholds, infer_edges, set_static, immediate_rule)
Empty file.
41 changes: 41 additions & 0 deletions pyreason/scripts/threshold/threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
class Threshold:
"""
A class representing a threshold for a clause in a rule.

Attributes:
quantifier (str): The comparison operator, e.g., 'greater_equal', 'less', etc.
quantifier_type (tuple): A tuple indicating the type of quantifier, e.g., ('number', 'total').
thresh (int): The numerical threshold value to compare against.

Methods:
to_tuple(): Converts the Threshold instance into a tuple compatible with numba types.
"""

def __init__(self, quantifier, quantifier_type, thresh):
"""
Initializes a Threshold instance.

Args:
quantifier (str): The comparison operator for the threshold.
quantifier_type (tuple): The type of quantifier ('number' or 'percent', 'total' or 'available').
thresh (int): The numerical value for the threshold.
"""

if quantifier not in ("greater_equal", "greater", "less_equal", "less", "equal"):
raise ValueError("Invalid quantifier")

if quantifier_type[0] not in ("number", "percent") or quantifier_type[1] not in ("total", "available"):
raise ValueError("Invalid quantifier type")

self.quantifier = quantifier
self.quantifier_type = quantifier_type
self.thresh = thresh

def to_tuple(self):
"""
Converts the Threshold instance into a tuple compatible with numba types.

Returns:
tuple: A tuple representation of the Threshold instance.
"""
return (self.quantifier, self.quantifier_type, self.thresh)
26 changes: 18 additions & 8 deletions pyreason/scripts/utils/rule_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval


def parse_rule(rule_text: str, name: str, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False) -> rule.Rule:
def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False) -> rule.Rule:
# First remove all spaces from line
r = rule_text.replace(' ', '')

Expand Down Expand Up @@ -152,7 +152,23 @@ def parse_rule(rule_text: str, name: str, infer_edges: bool = False, set_static:
# Array to store clauses for nodes: node/edge, [subset]/[subset1, subset2], label, interval, operator
clauses = numba.typed.List.empty_list(numba.types.Tuple((numba.types.string, label.label_type, numba.types.ListType(numba.types.string), interval.interval_type, numba.types.string)))

# Loop though clauses
# gather count of clauses for threshold validation
num_clauses = len(body_clauses)

if custom_thresholds and (len(custom_thresholds) != num_clauses):
raise Exception('The length of custom thresholds {} is not equal to number of clauses {}'
.format(len(custom_thresholds), num_clauses))

# If no custom thresholds provided, use defaults
# otherwise loop through user-defined thresholds and convert to numba compatible format
if not custom_thresholds:
for _ in range(num_clauses):
thresholds.append(('greater_equal', ('number', 'total'), 1.0))
else:
for threshold in custom_thresholds:
thresholds.append(threshold.to_tuple())

# # Loop though clauses
for body_clause, predicate, variables, bounds in zip(body_clauses, body_predicates, body_variables, body_bounds):
# Neigh criteria
clause_type = 'node' if len(variables) == 1 else 'edge'
Expand All @@ -165,12 +181,6 @@ def parse_rule(rule_text: str, name: str, infer_edges: bool = False, set_static:
bnd = interval.closed(bounds[0], bounds[1])
clauses.append((clause_type, l, subset, bnd, op))

# Threshold.
quantifier = 'greater_equal'
quantifier_type = ('number', 'total')
thresh = 1
thresholds.append((quantifier, quantifier_type, thresh))

# Assert that there are two variables in the head of the rule if we infer edges
# Add edges between head variables if necessary
if infer_edges:
Expand Down
26 changes: 26 additions & 0 deletions tests/group_chat_graph.graphml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<?xml version='1.0' encoding='utf-8'?>
<graphml xmlns="http://graphml.graphdrawing.org/xmlns" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd">

<key id="HaveAccess" for="edge" attr.name="HaveAccess" attr.type="long" />
<graph edgedefault="undirected">
<node id="TextMessage" />
<node id="Zach" />
<node id="Justin" />
<node id="Michelle" />
<node id="Amy" />

<edge source="Zach" target="TextMessage">
<data key="HaveAccess">1</data>
</edge>
<edge source="Justin" target="TextMessage">
<data key="HaveAccess">1</data>
</edge>

<edge source="Amy" target="TextMessage">
<data key="HaveAccess">1</data>
</edge>
<edge source="Michelle" target="TextMessage">
<data key="HaveAccess">1</data>
</edge>
</graph>
</graphml>
62 changes: 62 additions & 0 deletions tests/test_custom_thresholds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Test if the simple program works with thresholds defined
import pyreason as pr
from pyreason import Threshold


def test_custom_thresholds():
# Reset PyReason
pr.reset()
pr.reset_rules()

# Modify the paths based on where you've stored the files we made above
graph_path = "./tests/group_chat_graph.graphml"

# Modify pyreason settings to make verbose and to save the rule trace to a file
pr.settings.verbose = True # Print info to screen

# Load all the files into pyreason
pr.load_graphml(graph_path)

# add custom thresholds
user_defined_thresholds = [
Threshold("greater_equal", ("number", "total"), 1),
Threshold("greater_equal", ("percent", "total"), 100),
]

pr.add_rule(
pr.Rule(
"ViewedByAll(x) <- HaveAccess(x,y), Viewed(y)",
"viewed_by_all_rule",
custom_thresholds=user_defined_thresholds,
)
)

pr.add_fact(pr.Fact("seen-fact-zach", "Zach", "Viewed", [1, 1], 0, 3))
pr.add_fact(pr.Fact("seen-fact-justin", "Justin", "Viewed", [1, 1], 0, 3))
pr.add_fact(pr.Fact("seen-fact-michelle", "Michelle", "Viewed", [1, 1], 1, 3))
pr.add_fact(pr.Fact("seen-fact-amy", "Amy", "Viewed", [1, 1], 2, 3))

# Run the program for three timesteps to see the diffusion take place
interpretation = pr.reason(timesteps=3)

# Display the changes in the interpretation for each timestep
dataframes = pr.filter_and_sort_nodes(interpretation, ["ViewedByAll"])
for t, df in enumerate(dataframes):
print(f"TIMESTEP - {t}")
print(df)
print()

assert (
len(dataframes[0]) == 0
), "At t=0 the TextMessage should not have been ViewedByAll"
assert (
len(dataframes[2]) == 1
), "At t=2 the TextMessage should have been ViewedByAll"

# TextMessage should be ViewedByAll in t=2
assert "TextMessage" in dataframes[2]["component"].values and dataframes[2].iloc[
0
].ViewedByAll == [
1,
1,
], "TextMessage should have ViewedByAll bounds [1,1] for t=2 timesteps"
4 changes: 4 additions & 0 deletions tests/test_hello_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@


def test_hello_world():
# Reset PyReason
pr.reset()
pr.reset_rules()

# Modify the paths based on where you've stored the files we made above
graph_path = './tests/friends_graph.graphml'

Expand Down
4 changes: 4 additions & 0 deletions tests/test_hello_world_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@


def test_hello_world_parallel():
# Reset PyReason
pr.reset()
pr.reset_rules()

# Modify the paths based on where you've stored the files we made above
graph_path = './tests/friends_graph.graphml'

Expand Down
Loading