Skip to content

Commit

Permalink
[DAPHNE-#775] Support for unary minus in DaphneLib
Browse files Browse the repository at this point in the history
- DaphneDSL recently supported the additive inverse operator, but DaphneLib didn't support it yet.
- Python has the __neg__ method to override the additive inverse operator
- Added script-level test cases
  • Loading branch information
ldirry authored and corepointer committed Aug 8, 2024
1 parent a1268b2 commit bed9bcd
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/api/python/daphne/operator/nodes/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ def __setitem__(self, key, value):
consumer.update_node_in_input_list(new_node, self)
self.__dict__ = Matrix(new_node.daphne_context, None, [new_node, value, row_index, column_index], left_brackets=True).__dict__

def __neg__(self) -> 'OperationNode':
return Matrix(self.daphne_context, 'minus', [self])

def sum(self, axis: int = None) -> 'OperationNode':
"""Calculate sum of matrix.
:param axis: can be 0 or 1 to do either row or column sums
Expand Down
3 changes: 3 additions & 0 deletions src/api/python/daphne/operator/nodes/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def __ne__(self, other) -> 'Scalar':

def __rne__(self, other) -> 'Scalar':
return Scalar(self.daphne_context, '!=', [other, self])

def __neg__(self) -> 'Scalar':
return Scalar(self.daphne_context, 'minus', [self])

def abs(self) -> 'Scalar':
return Scalar(self.daphne_context, 'abs', [self])
Expand Down
2 changes: 2 additions & 0 deletions src/parser/daphnedsl/DaphneDSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,8 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f
// Arithmetic/general math
// --------------------------------------------------------------------

if(func == "minus")
return createUnaryOp<EwMinusOp>(loc, func, args);
if(func == "abs")
return createUnaryOp<EwAbsOp>(loc, func, args);
if(func == "sign")
Expand Down
5 changes: 5 additions & 0 deletions test/api/python/matrix_ewunary.daphne
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

print(-[1]);
print(-[0]);
print(-[-3.3]);

print(abs([1]));
print(abs([0]));
print(abs([-3.3]));
Expand All @@ -22,6 +26,7 @@ print(sign([-3.3]));

m = [0.99];

print(-m);
print(exp(m));
print(ln(m));
print(sqrt(m));
Expand Down
5 changes: 5 additions & 0 deletions test/api/python/matrix_ewunary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

dc = DaphneContext()

(-dc.fill(1, 1, 1)).print().compute()
(-dc.fill(0, 1, 1)).print().compute()
(-dc.fill(-3.3, 1, 1)).print().compute()

dc.fill(1, 1, 1).abs().print().compute()
dc.fill(0, 1, 1).abs().print().compute()
dc.fill(-3.3, 1, 1).abs().print().compute()
Expand All @@ -27,6 +31,7 @@

m = dc.fill(0.99, 1, 1)

(-m).print().compute()
m.exp().print().compute()
m.ln().print().compute()
m.sqrt().print().compute()
Expand Down
5 changes: 5 additions & 0 deletions test/api/python/scalar_ewunary.daphne
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

print(-1);
print(-0);
print(--3.3);

print(abs(1));
print(abs(0));
print(abs(-3.3));
Expand All @@ -22,6 +26,7 @@ print(sign(-3.3));

s = 0.99;

print(-s);
print(exp(s));
print(ln(s));
print(sqrt(s));
Expand Down
5 changes: 5 additions & 0 deletions test/api/python/scalar_ewunary.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
# TODO Currently, we cannot simply construct a DaphneLib scalar from a Python scalar.
# Thus, we use a work-around here by taking the sum of a 1x1 matrix with the desired value.

(-dc.fill(1, 1, 1)).sum().print().compute()
(-dc.fill(0, 1, 1)).sum().print().compute()
(-dc.fill(-3.3, 1, 1)).sum().print().compute()

dc.fill(1, 1, 1).sum().abs().print().compute()
dc.fill(0, 1, 1).sum().abs().print().compute()
dc.fill(-3.3, 1, 1).sum().abs().print().compute()
Expand All @@ -30,6 +34,7 @@

s = dc.fill(0.99, 1, 1)

(-s.sum()).print().compute()
s.sum().exp().print().compute()
s.sum().ln().print().compute()
s.sum().sqrt().print().compute()
Expand Down

0 comments on commit bed9bcd

Please sign in to comment.