Skip to content

Commit

Permalink
Merge pull request #13 from NillionNetwork/feature/distance
Browse files Browse the repository at this point in the history
Adds dot product distance function
  • Loading branch information
mathias-nillion authored Jun 12, 2024
2 parents 2a3cb50 + 12f5d78 commit 23ca09a
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 9 deletions.
2 changes: 1 addition & 1 deletion nada_ai/linear_model/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, in_features: int, include_bias: bool = True) -> None:
include_bias (bool, optional): Whether or not to include a bias term. Defaults to True.
"""
self.coef = Parameter(in_features)
self.intercept = Parameter(1) if include_bias else None
self.intercept = Parameter() if include_bias else None

def forward(self, x: na.NadaArray) -> na.NadaArray:
"""
Expand Down
7 changes: 2 additions & 5 deletions nada_ai/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,14 @@ def forward(self, x: na.NadaArray) -> na.NadaArray:
"""
...

def __call__(self, x: na.NadaArray) -> na.NadaArray:
def __call__(self, *args, **kwargs) -> na.NadaArray:
"""
Proxy for forward pass.
Args:
x (na.NadaArray): Input array.
Returns:
na.NadaArray: Output array.
"""
return self.forward(x)
return self.forward(*args, **kwargs)

def __named_parameters(self, prefix: str) -> Iterator[Tuple[str, Parameter]]:
"""
Expand Down
3 changes: 2 additions & 1 deletion nada_ai/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from .linear import Linear
from .pooling import AvgPool2d
from .relu import ReLU
from .distance import DotProductSimilarity

__all__ = ["Conv2d", "Flatten", "Linear", "AvgPool2d", "ReLU"]
__all__ = ["Conv2d", "Flatten", "Linear", "AvgPool2d", "ReLU", "DotProductSimilarity"]
23 changes: 23 additions & 0 deletions nada_ai/nn/modules/distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Distance implementation"""

import nada_algebra as na
from nada_ai.nn.module import Module

__all__ = ["DotProductSimilarity"]


class DotProductSimilarity(Module):
"""Dot product similarity module"""

def forward(self, x_1: na.NadaArray, x_2: na.NadaArray) -> na.NadaArray:
"""
Forward pass logic.
Args:
x_1 (na.NadaArray): First input array.
x_2 (na.NadaArray): Second input array.
Returns:
na.NadaArray: Dot product between input arrays.
"""
return x_1 @ x_2.T
4 changes: 2 additions & 2 deletions nada_ai/nn/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
class Parameter(na.NadaArray):
"""Parameter class"""

def __init__(self, shape: ShapeLike) -> None:
def __init__(self, shape: ShapeLike = 1) -> None:
"""
Initializes light NadaArray wrapper.
Args:
shape (ShapeLike, optional): Parameter array shape.
shape (ShapeLike, optional): Parameter array shape. Defaults to 1.
"""
super().__init__(inner=np.empty(shape))

Expand Down
4 changes: 4 additions & 0 deletions tests/nada-tests/nada-project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ prime_size = 128
[[programs]]
path = "src/prophet.py"
prime_size = 128

[[programs]]
path = "src/distance.py"
prime_size = 128
19 changes: 19 additions & 0 deletions tests/nada-tests/src/distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from nada_dsl import *
from nada_ai.nn import DotProductSimilarity
import nada_algebra as na


def nada_main():
party = Party("party")

# 2 queries, each of size 3
queries = na.array((2, 3), party, "input_x", SecretInteger)
# 5 values, each also of size 3
values = na.array((5, 3), party, "input_y", SecretInteger)

dps = DotProductSimilarity()

similarities = dps(queries, values)
assert similarities.shape == (2, 5) # for each query, the similarity to each value

return similarities.output(party, "my_output")
68 changes: 68 additions & 0 deletions tests/nada-tests/tests/distance.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
---
program: distance
inputs:
secrets:
input_x_0_0:
SecretInteger: "1"
input_x_0_1:
SecretInteger: "2"
input_x_0_2:
SecretInteger: "3"
input_x_1_0:
SecretInteger: "4"
input_x_1_1:
SecretInteger: "5"
input_x_1_2:
SecretInteger: "6"
input_y_0_0:
SecretInteger: "1"
input_y_0_1:
SecretInteger: "2"
input_y_0_2:
SecretInteger: "3"
input_y_1_0:
SecretInteger: "4"
input_y_1_1:
SecretInteger: "5"
input_y_1_2:
SecretInteger: "6"
input_y_2_0:
SecretInteger: "7"
input_y_2_1:
SecretInteger: "8"
input_y_2_2:
SecretInteger: "9"
input_y_3_0:
SecretInteger: "10"
input_y_3_1:
SecretInteger: "11"
input_y_3_2:
SecretInteger: "12"
input_y_4_0:
SecretInteger: "13"
input_y_4_1:
SecretInteger: "14"
input_y_4_2:
SecretInteger: "15"
public_variables: {}
expected_outputs:
my_output_0_0:
SecretInteger: "14"
my_output_0_1:
SecretInteger: "32"
my_output_0_2:
SecretInteger: "50"
my_output_0_3:
SecretInteger: "68"
my_output_0_4:
SecretInteger: "86"
my_output_1_0:
SecretInteger: "32"
my_output_1_1:
SecretInteger: "77"
my_output_1_2:
SecretInteger: "122"
my_output_1_3:
SecretInteger: "167"
my_output_1_4:
SecretInteger: "212"
1 change: 1 addition & 0 deletions tests/test_all_nada.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"pool",
"linear_regression",
"end-to-end",
"distance",
"prophet",
]

Expand Down

0 comments on commit 23ca09a

Please sign in to comment.