Skip to content

Commit

Permalink
added query feature and prevented nodes/edges from being added to gra…
Browse files Browse the repository at this point in the history
…ph when not necessary due to ground rules
  • Loading branch information
dyumanaditya committed Aug 4, 2024
1 parent ce61648 commit aae4943
Showing 1 changed file with 83 additions and 5 deletions.
88 changes: 83 additions & 5 deletions pyreason/scripts/interpretation/interpretation.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,69 @@ def get_dict(self):

return interpretations

def query(self, query, return_bool=True):
"""
This function is used to query the graph after reasoning
:param query: The query string of for `pred(node)` or `pred(edge)` or `pred(node) : [l, u]`
:param return_bool: If True, returns boolean of query, else the bounds associated with it
:return: bool, or bounds
"""
# Parse the query
query = query.replace(' ', '')

if ':' in query:
pred_comp, bounds = query.split(':')
l, u = bounds.split(',')
l, u = float(l), float(u)
else:
if query[0] == '~':
pred_comp = query[1:]
l, u = 0, 0
else:
pred_comp = query
l, u = 1, 1

bnd = interval.closed(l, u)

# Split predicate and component
idx = pred_comp.find('(')
pred = label.Label(pred_comp[:idx])
component = pred_comp[idx + 1:-1]

if ',' in component:
component = tuple(component.split(','))
comp_type = 'edge'
else:
comp_type = 'node'

# Check if the component exists
if comp_type == 'node':
if component not in self.nodes:
return False if return_bool else (0, 0)
else:
if component not in self.edges:
return False if return_bool else (0, 0)

# Check if the predicate exists
if comp_type == 'node':
if pred not in self.interpretations_node[component].world:
return False if return_bool else (0, 0)
else:
if pred not in self.interpretations_edge[component].world:
return False if return_bool else (0, 0)

# Check if the bounds are satisfied
if comp_type == 'node':
if self.interpretations_node[component].world[pred] in bnd:
return True if return_bool else (self.interpretations_node[component].world[pred].lower, self.interpretations_node[component].world[pred].upper)
else:
return False if return_bool else (0, 0)
else:
if self.interpretations_edge[component].world[pred] in bnd:
return True if return_bool else (self.interpretations_edge[component].world[pred].lower, self.interpretations_edge[component].world[pred].upper)
else:
return False if return_bool else (0, 0)


@numba.njit(cache=True)
def _ground_rule(rule, interpretations_node, interpretations_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_atoms):
Expand Down Expand Up @@ -798,11 +861,12 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, nodes, edges,

# If there is no grounding for head_var_1, we treat it as a ground atom and add it to the graph
head_var_1_in_nodes = head_var_1 in nodes
add_head_var_node_to_graph = False
if allow_ground_atoms and head_var_1_in_nodes:
groundings[head_var_1] = numba.typed.List([head_var_1])
elif head_var_1 not in groundings:
if not head_var_1_in_nodes:
_add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node)
add_head_var_node_to_graph = True
groundings[head_var_1] = numba.typed.List([head_var_1])

for head_grounding in groundings[head_var_1]:
Expand Down Expand Up @@ -874,6 +938,10 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, nodes, edges,
# Comparison clause (we do not handle for now)
pass

# Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules)
if add_head_var_node_to_graph:
_add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node)

# For each grounding add a rule to be applied
applicable_rules_node.append((head_grounding, annotations, qualified_nodes, qualified_edges, edges_to_be_added))

Expand All @@ -884,23 +952,26 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, nodes, edges,
# If there is no grounding for head_var_1 or head_var_2, we treat it as a ground atom and add it to the graph
head_var_1_in_nodes = head_var_1 in nodes
head_var_2_in_nodes = head_var_2 in nodes
add_head_var_1_node_to_graph = False
add_head_var_2_node_to_graph = False
add_head_edge_to_graph = False
if allow_ground_atoms and head_var_1_in_nodes:
groundings[head_var_1] = numba.typed.List([head_var_1])
if allow_ground_atoms and head_var_2_in_nodes:
groundings[head_var_2] = numba.typed.List([head_var_2])

if head_var_1 not in groundings:
if not head_var_1_in_nodes:
_add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node)
add_head_var_1_node_to_graph = True
groundings[head_var_1] = numba.typed.List([head_var_1])
if head_var_2 not in groundings:
if not head_var_2_in_nodes:
_add_node(head_var_2, neighbors, reverse_neighbors, nodes, interpretations_node)
add_head_var_2_node_to_graph = True
groundings[head_var_2] = numba.typed.List([head_var_2])

# Artificially connect the head variables with an edge if both of them were not in the graph
if not head_var_1_in_nodes and not head_var_2_in_nodes:
_add_edge(head_var_1, head_var_2, neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge)
add_head_edge_to_graph = True

head_var_1_groundings = groundings[head_var_1]
head_var_2_groundings = groundings[head_var_2]
Expand All @@ -922,7 +993,6 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, nodes, edges,

# Loop through the head variable groundings
for valid_e in valid_edge_groundings:
satisfaction = True
head_var_1_grounding, head_var_2_grounding = valid_e[0], valid_e[1]
qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type))
qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type))
Expand Down Expand Up @@ -1055,6 +1125,14 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, nodes, edges,
a.append(interpretations_edge[qe].world[clause_label])
annotations.append(a)

# Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules)
if add_head_var_1_node_to_graph and head_var_1_grounding == head_var_1:
_add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node)
if add_head_var_2_node_to_graph and head_var_2_grounding == head_var_2:
_add_node(head_var_2, neighbors, reverse_neighbors, nodes, interpretations_node)
if add_head_edge_to_graph and (head_var_1, head_var_2) == (head_var_1_grounding, head_var_2_grounding):
_add_edge(head_var_1, head_var_2, neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge)

# For each grounding combination add a rule to be applied
# Only if all the clauses have valid groundings
# if satisfaction:
Expand Down

0 comments on commit aae4943

Please sign in to comment.