Skip to content

Commit

Permalink
Add support for ampersand in join clause
Browse files Browse the repository at this point in the history
  • Loading branch information
daveraja committed Mar 3, 2024
1 parent 3754ac0 commit 92a292a
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 1 deletion.
12 changes: 11 additions & 1 deletion clorm/orm/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1829,6 +1829,16 @@ def is_join_qcondition(cond):
# ------------------------------------------------------------------------------


def _process_join_with_andop(join_expressions):
tmp = []
for jexp in join_expressions:
if isinstance(jexp, QCondition) and jexp.operator == operator.and_:
tmp.extend(_process_join_with_andop(jexp.args))
else:
tmp.append(jexp)
return tmp


def validate_join_expression(qconds, roots):
jroots = set() # The set of all roots in the join clauses
joins = [] # The list of joins
Expand Down Expand Up @@ -1893,7 +1903,7 @@ def visit(r):
).format(p)
)

for qcond in qconds:
for qcond in _process_join_with_andop(qconds):
if not is_join_qcondition(qcond):
if not isinstance(qcond, QCondition):
raise ValueError(
Expand Down
90 changes: 90 additions & 0 deletions examples/join_with_ampersand/join_with_ampersand.lp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
%----------------------------------------------------------------------------------
% Domain encoding for a simple scheduling problem. Drivers need to make
% deliveries. Every driver has a fixed base cost and every delivery has a
% cost. We also need deliveries within a time limit.
% ----------------------------------------------------------------------------------


time(1..4).

1 { assignment(I, D, T) : driver(D,_), time(T) } 1 :- item(I,_).
:- assignment(I1, D, T), assignment(I2, D, T), I1 != I2.

working_driver(D) :- assignment(_,D,_).

#minimize { 1@2,D : working_driver(D) }.
#minimize { T@1,D : assignment(_,D,T) }.


#script(python)

from clorm.clingo import Control
from clorm import Predicate, ConstantStr, FactBase
from clorm import ph1_


#--------------------------------------------------------------------------
# Define a data model - we only care about defining the input and output
# predicates.
#--------------------------------------------------------------------------

class Driver(Predicate):
driverid: ConstantStr
name: str

class Item(Predicate):
itemid: ConstantStr
description: str

class Assignment(Predicate):
itemid: ConstantStr
driverid: ConstantStr
time: int

#--------------------------------------------------------------------------
# main
#--------------------------------------------------------------------------

def main(ctrl_):
# For better integration with Clorm wrap the clingo.Control object with a
# clorm.clingo.Control object and pass the unifier list of predicates that
# are used to unify the symbols and predicates.
ctrl = Control(control_=ctrl_, unifier=[Driver,Item,Assignment])

# Dynamically generate the instance data
drivers = [
Driver(driverid="dave", name="Dave X"),
Driver(driverid="morri", name="Morri Y"),
Driver(driverid="michael", name="Michael Z"),
]

items = [ Item(itemid=f"item{i}", description=f"Item {i}") for i in range(1,6) ]
instance = FactBase(drivers + items)

# Add the instance data and ground the ASP program
ctrl.add_facts(instance)
ctrl.ground([("base",[])])

# Generate a solution - use a call back that saves the solution
solution=None
def on_model(model):
nonlocal solution
solution = model.facts(atoms=True)

ctrl.solve(on_model=on_model)
if not solution:
raise ValueError("No solution found")

# Do something with the solution - create a query so we can print out the
# assignments for each driver.
query=solution.query(Driver, Item, Assignment)\
.join((Driver.driverid == Assignment.driverid) &
(Item.itemid == Assignment.itemid))\
.group_by(Driver.name)\
.order_by(Assignment.time)\
.select(Item.description, Assignment.time)
for dname, assiter in query.all():
print(f"Driver: {dname}:")
for idesc, atime in assiter:
print(f" {atime}: {idesc}")
#end.
1 change: 1 addition & 0 deletions runexamples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,4 @@ run_clingo examples/combine_fields/combine_fields.lp

run_clingo examples/nested_list/nested_list.lp

run_clingo examples/join_with_ampersand/join_with_ampersand.lp
32 changes: 32 additions & 0 deletions tests/test_orm_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,38 @@ def test_nonapi_validate_join_expression(self):
vje([F.anum == G.anum, X.anum == Y.anum], [F, G, X, Y, Z])
check_errmsg("Invalid join specification: missing joins", ctx)

# ------------------------------------------------------------------------------
# Test validating a join expression connected by &
# ------------------------------------------------------------------------------
def test_nonapi_validate_join_expression_with_ampersand(self):
F = path(self.F)
G = path(self.G)
FA = alias(F)
GA = alias(G)
SC = StandardComparator
vje = validate_join_expression
tmp1 = SC(operator.eq, [F.anum, G.anum])
tmp2 = SC(operator.eq, [F.anum, GA.anum])
tmp3 = SC(operator.eq, [G.anum, FA.anum])

joins = vje([(F.anum == G.anum) & (F.anum == GA.anum)], [F, G, GA])
self.assertEqual([tmp1, tmp2], joins)

joins = vje(
[(F.anum == G.anum) & (F.anum == GA.anum) & (G.anum == FA.anum)], [F, G, GA, FA]
)
self.assertEqual([tmp1, tmp2, tmp3], joins)

joins = vje(
[(F.anum == G.anum) & ((F.anum == GA.anum) & (G.anum == FA.anum))], [F, G, GA, FA]
)
self.assertEqual([tmp1, tmp2, tmp3], joins)

# Joining with a non-ampersand operator
with self.assertRaises(ValueError) as ctx:
vje([(F.anum == F.anum) | (F.anum == GA.anum)], [F, G, GA])
check_errmsg("Invalid join ", ctx)


# ------------------------------------------------------------------------------
# Tests OrderBy, OrderByBlock, and related functions
Expand Down

0 comments on commit 92a292a

Please sign in to comment.