Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add numpy compatible dot funciton and override __matmul__ operator #899

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions python/src/nnabla/_arithmetic_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,26 @@ cdef object pow(object x, object y, object z):
return F.r_pow_scalar(y, x)
else:
return x ** y


cdef object matmul(object x, object y):
"""
Matlix multiplication

Implements the matmul operator expression ``x @ y``.
When both of ``x`` and ``y`` are either :obj:`~nnabla.Variable` or
:obj:`~nnabla.NdArray`, :func:`~nnabla.functions.affine`` is
internally called.

Args:
x (~nnabla.Variable or ~nnabla.NdArray): Left operand. It must be 2-dimensional.
y (~nnabla.Variable or ~nnabla.NdArray): Right operand. It must be 2-dimensional.

Returns: :class:`~nnabla.Variable` or :class:`~nnabla.NdArray`.

"""
import nnabla.functions as F
assert x.ndim == 2 and y.ndim == 2, "Both of x and y must be matrices."
assert isinstance(x, (NdArray, Variable))
assert isinstance(y, (NdArray, Variable))
return F.affine(x, y)
3 changes: 3 additions & 0 deletions python/src/nnabla/_variable.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,9 @@ cdef class Variable:
def __pow__(x, y, z):
return AOP.pow(x, y, z)

def __matmul__(x, y):
return AOP.matmul(x, y)

def __iadd__(self, x):
import nnabla.functions as F
if isinstance(x, (NdArray, Variable)):
Expand Down
3 changes: 2 additions & 1 deletion python/src/nnabla/functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2017 Sony Corporation. All Rights Reserved.
# Copyright (c) 2017-2021 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,7 @@
import nnabla as nn
import numpy as np
from .normalization_functions import *
from .numpy_compat_functions import *


def sum(x, axis=None, keepdims=False):
Expand Down
53 changes: 53 additions & 0 deletions python/src/nnabla/numpy_compat_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2021 Sony Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np


def dot(a, b, out=None):
'''
A compatible operation with ``numpy.dot``.

Note:
Any operation between nnabla's Variable/NdArray and numpy array is not supported.

Args:
a (Variable, NdArray or scalar): Left input array.
b (Variable, NdArray or scalar): Right input array.
out: Not supported so far.

Returns:
~nnabla.Variable: N-D array.

'''
import nnabla as nn
import nnabla.fucntions as F
assert out is None, "The `out` option is not supported."

def _chk(x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please consider making this a descriptive name, such as _is_nnabla_data_type.

return isinstance(x, (nn.NdArray, nn.Variable))

if _chk(a) and _chk(b):
if a.ndim == 1 and b.ndim == 1:
return return F.sum(a * b)
if a.ndim == 2 and b.ndim >= 2:
return F.affine(a, b)
if a.ndim == 0 or b.ndim == 0:
return a * b
if a.ndim > 2 and b.ndim == 1:
h = F.affine(x, F.reshape(y, (-1, 1)), base_axis=x.ndim - 1)
return F.reshape(h, h.shape[:-1])
raise ValueError(f'Undefined configuration: a.ndim={a.ndim}, b.ndim:{b.ndim}')

return x * y
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a runtime error. Function inputs are a and b and no variables x and y are defined at this point.