From 92a292a0aa8840c769c22420d21040fa00c095db Mon Sep 17 00:00:00 2001 From: David Rajaratnam Date: Sun, 3 Mar 2024 16:22:15 +1100 Subject: [PATCH] Add support for ampersand in join clause --- clorm/orm/query.py | 12 ++- .../join_with_ampersand.lp | 90 +++++++++++++++++++ runexamples.sh | 1 + tests/test_orm_query.py | 32 +++++++ 4 files changed, 134 insertions(+), 1 deletion(-) create mode 100644 examples/join_with_ampersand/join_with_ampersand.lp diff --git a/clorm/orm/query.py b/clorm/orm/query.py index 321fccf..a6b5c5e 100644 --- a/clorm/orm/query.py +++ b/clorm/orm/query.py @@ -1829,6 +1829,16 @@ def is_join_qcondition(cond): # ------------------------------------------------------------------------------ +def _process_join_with_andop(join_expressions): + tmp = [] + for jexp in join_expressions: + if isinstance(jexp, QCondition) and jexp.operator == operator.and_: + tmp.extend(_process_join_with_andop(jexp.args)) + else: + tmp.append(jexp) + return tmp + + def validate_join_expression(qconds, roots): jroots = set() # The set of all roots in the join clauses joins = [] # The list of joins @@ -1893,7 +1903,7 @@ def visit(r): ).format(p) ) - for qcond in qconds: + for qcond in _process_join_with_andop(qconds): if not is_join_qcondition(qcond): if not isinstance(qcond, QCondition): raise ValueError( diff --git a/examples/join_with_ampersand/join_with_ampersand.lp b/examples/join_with_ampersand/join_with_ampersand.lp new file mode 100644 index 0000000..838580c --- /dev/null +++ b/examples/join_with_ampersand/join_with_ampersand.lp @@ -0,0 +1,90 @@ +%---------------------------------------------------------------------------------- +% Domain encoding for a simple scheduling problem. Drivers need to make +% deliveries. Every driver has a fixed base cost and every delivery has a +% cost. We also need deliveries within a time limit. +% ---------------------------------------------------------------------------------- + + +time(1..4). + +1 { assignment(I, D, T) : driver(D,_), time(T) } 1 :- item(I,_). +:- assignment(I1, D, T), assignment(I2, D, T), I1 != I2. + +working_driver(D) :- assignment(_,D,_). + +#minimize { 1@2,D : working_driver(D) }. +#minimize { T@1,D : assignment(_,D,T) }. + + +#script(python) + +from clorm.clingo import Control +from clorm import Predicate, ConstantStr, FactBase +from clorm import ph1_ + + +#-------------------------------------------------------------------------- +# Define a data model - we only care about defining the input and output +# predicates. +#-------------------------------------------------------------------------- + +class Driver(Predicate): + driverid: ConstantStr + name: str + +class Item(Predicate): + itemid: ConstantStr + description: str + +class Assignment(Predicate): + itemid: ConstantStr + driverid: ConstantStr + time: int + +#-------------------------------------------------------------------------- +# main +#-------------------------------------------------------------------------- + +def main(ctrl_): + # For better integration with Clorm wrap the clingo.Control object with a + # clorm.clingo.Control object and pass the unifier list of predicates that + # are used to unify the symbols and predicates. + ctrl = Control(control_=ctrl_, unifier=[Driver,Item,Assignment]) + + # Dynamically generate the instance data + drivers = [ + Driver(driverid="dave", name="Dave X"), + Driver(driverid="morri", name="Morri Y"), + Driver(driverid="michael", name="Michael Z"), + ] + + items = [ Item(itemid=f"item{i}", description=f"Item {i}") for i in range(1,6) ] + instance = FactBase(drivers + items) + + # Add the instance data and ground the ASP program + ctrl.add_facts(instance) + ctrl.ground([("base",[])]) + + # Generate a solution - use a call back that saves the solution + solution=None + def on_model(model): + nonlocal solution + solution = model.facts(atoms=True) + + ctrl.solve(on_model=on_model) + if not solution: + raise ValueError("No solution found") + + # Do something with the solution - create a query so we can print out the + # assignments for each driver. + query=solution.query(Driver, Item, Assignment)\ + .join((Driver.driverid == Assignment.driverid) & + (Item.itemid == Assignment.itemid))\ + .group_by(Driver.name)\ + .order_by(Assignment.time)\ + .select(Item.description, Assignment.time) + for dname, assiter in query.all(): + print(f"Driver: {dname}:") + for idesc, atime in assiter: + print(f" {atime}: {idesc}") +#end. diff --git a/runexamples.sh b/runexamples.sh index 7215e3c..4c23b72 100755 --- a/runexamples.sh +++ b/runexamples.sh @@ -89,3 +89,4 @@ run_clingo examples/combine_fields/combine_fields.lp run_clingo examples/nested_list/nested_list.lp +run_clingo examples/join_with_ampersand/join_with_ampersand.lp diff --git a/tests/test_orm_query.py b/tests/test_orm_query.py index d4b772b..46d0b93 100644 --- a/tests/test_orm_query.py +++ b/tests/test_orm_query.py @@ -1304,6 +1304,38 @@ def test_nonapi_validate_join_expression(self): vje([F.anum == G.anum, X.anum == Y.anum], [F, G, X, Y, Z]) check_errmsg("Invalid join specification: missing joins", ctx) + # ------------------------------------------------------------------------------ + # Test validating a join expression connected by & + # ------------------------------------------------------------------------------ + def test_nonapi_validate_join_expression_with_ampersand(self): + F = path(self.F) + G = path(self.G) + FA = alias(F) + GA = alias(G) + SC = StandardComparator + vje = validate_join_expression + tmp1 = SC(operator.eq, [F.anum, G.anum]) + tmp2 = SC(operator.eq, [F.anum, GA.anum]) + tmp3 = SC(operator.eq, [G.anum, FA.anum]) + + joins = vje([(F.anum == G.anum) & (F.anum == GA.anum)], [F, G, GA]) + self.assertEqual([tmp1, tmp2], joins) + + joins = vje( + [(F.anum == G.anum) & (F.anum == GA.anum) & (G.anum == FA.anum)], [F, G, GA, FA] + ) + self.assertEqual([tmp1, tmp2, tmp3], joins) + + joins = vje( + [(F.anum == G.anum) & ((F.anum == GA.anum) & (G.anum == FA.anum))], [F, G, GA, FA] + ) + self.assertEqual([tmp1, tmp2, tmp3], joins) + + # Joining with a non-ampersand operator + with self.assertRaises(ValueError) as ctx: + vje([(F.anum == F.anum) | (F.anum == GA.anum)], [F, G, GA]) + check_errmsg("Invalid join ", ctx) + # ------------------------------------------------------------------------------ # Tests OrderBy, OrderByBlock, and related functions