Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/distance #13

Merged
merged 8 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nada_ai/linear_model/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, in_features: int, include_bias: bool = True) -> None:
include_bias (bool, optional): Whether or not to include a bias term. Defaults to True.
"""
self.coef = Parameter(in_features)
self.intercept = Parameter(1) if include_bias else None
self.intercept = Parameter() if include_bias else None

def forward(self, x: na.NadaArray) -> na.NadaArray:
"""
Expand Down
7 changes: 2 additions & 5 deletions nada_ai/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,14 @@ def forward(self, x: na.NadaArray) -> na.NadaArray:
"""
...

def __call__(self, x: na.NadaArray) -> na.NadaArray:
def __call__(self, *args, **kwargs) -> na.NadaArray:
"""
Proxy for forward pass.

Args:
x (na.NadaArray): Input array.

Returns:
na.NadaArray: Output array.
"""
return self.forward(x)
return self.forward(*args, **kwargs)

def __named_parameters(self, prefix: str) -> Iterator[Tuple[str, Parameter]]:
"""
Expand Down
3 changes: 2 additions & 1 deletion nada_ai/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from .linear import Linear
from .pooling import AvgPool2d
from .relu import ReLU
from .distance import DotProductSimilarity

__all__ = ["Conv2d", "Flatten", "Linear", "AvgPool2d", "ReLU"]
__all__ = ["Conv2d", "Flatten", "Linear", "AvgPool2d", "ReLU", "DotProductSimilarity"]
23 changes: 23 additions & 0 deletions nada_ai/nn/modules/distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Distance implementation"""

import nada_algebra as na
from nada_ai.nn.module import Module

__all__ = ["DotProductSimilarity"]


class DotProductSimilarity(Module):
"""Dot product similarity module"""

def forward(self, x_1: na.NadaArray, x_2: na.NadaArray) -> na.NadaArray:
"""
Forward pass logic.

Args:
x_1 (na.NadaArray): First input array.
x_2 (na.NadaArray): Second input array.

Returns:
na.NadaArray: Dot product between input arrays.
"""
return x_1 @ x_2.T
4 changes: 2 additions & 2 deletions nada_ai/nn/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
class Parameter(na.NadaArray):
"""Parameter class"""

def __init__(self, shape: ShapeLike) -> None:
def __init__(self, shape: ShapeLike = 1) -> None:
"""
Initializes light NadaArray wrapper.

Args:
shape (ShapeLike, optional): Parameter array shape.
shape (ShapeLike, optional): Parameter array shape. Defaults to 1.
"""
super().__init__(inner=np.empty(shape))

Expand Down
4 changes: 4 additions & 0 deletions tests/nada-tests/nada-project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ prime_size = 128
[[programs]]
path = "src/prophet.py"
prime_size = 128

[[programs]]
path = "src/distance.py"
prime_size = 128
19 changes: 19 additions & 0 deletions tests/nada-tests/src/distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from nada_dsl import *
from nada_ai.nn import DotProductSimilarity
import nada_algebra as na


def nada_main():
party = Party("party")

# 2 queries, each of size 3
queries = na.array((2, 3), party, "input_x", SecretInteger)
# 5 values, each also of size 3
values = na.array((5, 3), party, "input_y", SecretInteger)

dps = DotProductSimilarity()

similarities = dps(queries, values)
assert similarities.shape == (2, 5) # for each query, the similarity to each value

return similarities.output(party, "my_output")
68 changes: 68 additions & 0 deletions tests/nada-tests/tests/distance.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
---
program: distance
inputs:
secrets:
input_x_0_0:
SecretInteger: "1"
input_x_0_1:
SecretInteger: "2"
input_x_0_2:
SecretInteger: "3"
input_x_1_0:
SecretInteger: "4"
input_x_1_1:
SecretInteger: "5"
input_x_1_2:
SecretInteger: "6"
input_y_0_0:
SecretInteger: "1"
input_y_0_1:
SecretInteger: "2"
input_y_0_2:
SecretInteger: "3"
input_y_1_0:
SecretInteger: "4"
input_y_1_1:
SecretInteger: "5"
input_y_1_2:
SecretInteger: "6"
input_y_2_0:
SecretInteger: "7"
input_y_2_1:
SecretInteger: "8"
input_y_2_2:
SecretInteger: "9"
input_y_3_0:
SecretInteger: "10"
input_y_3_1:
SecretInteger: "11"
input_y_3_2:
SecretInteger: "12"
input_y_4_0:
SecretInteger: "13"
input_y_4_1:
SecretInteger: "14"
input_y_4_2:
SecretInteger: "15"
public_variables: {}
expected_outputs:
my_output_0_0:
SecretInteger: "14"
my_output_0_1:
SecretInteger: "32"
my_output_0_2:
SecretInteger: "50"
my_output_0_3:
SecretInteger: "68"
my_output_0_4:
SecretInteger: "86"
my_output_1_0:
SecretInteger: "32"
my_output_1_1:
SecretInteger: "77"
my_output_1_2:
SecretInteger: "122"
my_output_1_3:
SecretInteger: "167"
my_output_1_4:
SecretInteger: "212"
1 change: 1 addition & 0 deletions tests/test_all_nada.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"pool",
"linear_regression",
"end-to-end",
"distance",
"prophet",
]

Expand Down
Loading