-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/prophet #11
Feature/prophet #11
Changes from 12 commits
3f05e19
0217be8
eb99afa
65b433e
e76f593
bc5eee6
688e0b0
a771eed
eee449a
50e0853
8e43b30
1ef4145
5b73d51
56add95
83d7206
6726a7a
efab681
c60da03
7610e74
868c9d4
8cc6489
3dca259
d7c205b
4a3ce9e
b1b64f0
56ca1b8
1911508
074cde1
3830420
295a3ce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,5 +3,5 @@ version = "0.1.0" | |
authors = [""] | ||
|
||
[[programs]] | ||
path = "src/main.py" | ||
path = "src/complex_model.py" | ||
prime_size = 128 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "complex_model" | ||
name = "neural_net" | ||
version = "0.1.0" | ||
authors = [""] | ||
|
||
[[programs]] | ||
path = "src/main.py" | ||
path = "src/neural_net.py" | ||
prime_size = 128 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -167,7 +167,13 @@ def forward(self, x: na.NadaArray) -> na.NadaArray: | |
) | ||
|
||
# Sort & rescale the obtained results by the quantization scale (here: 16) | ||
outputs = [result[1] / 2**16 for result in sorted(result.items())] | ||
outputs = [ | ||
result[1] / 2**16 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as above: |
||
for result in sorted( | ||
result.items(), | ||
key=lambda x: int(x[0].replace("my_output", "").replace("_", "")), | ||
) | ||
] | ||
|
||
print(f"🖥️ The result is {outputs}") | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
name = "time_series" | ||
version = "0.1.0" | ||
authors = [""] | ||
|
||
[[programs]] | ||
path = "src/time_series.py" | ||
prime_size = 128 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import asyncio | ||
import py_nillion_client as nillion | ||
import os | ||
import sys | ||
import time | ||
import numpy as np | ||
import nada_algebra as na | ||
import pandas as pd | ||
from nada_ai import ProphetClient | ||
from prophet import Prophet | ||
from dotenv import load_dotenv | ||
|
||
# Add the parent directory to the system path to import modules from it | ||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) | ||
|
||
# Import helper functions for creating nillion client and getting keys | ||
from neural_net.network.helpers.nillion_client_helper import create_nillion_client | ||
from neural_net.network.helpers.nillion_keypath_helper import ( | ||
getUserKeyFromFile, | ||
getNodeKeyFromFile, | ||
) | ||
import nada_algebra.client as na_client | ||
|
||
# Load environment variables from a .env file | ||
load_dotenv() | ||
|
||
|
||
# Decorator function to measure and log the execution time of asynchronous functions | ||
def async_timer(file_path): | ||
def decorator(func): | ||
async def wrapper(*args, **kwargs): | ||
start_time = time.time() | ||
result = await func(*args, **kwargs) | ||
end_time = time.time() | ||
elapsed_time = end_time - start_time | ||
|
||
# Log the execution time to a file | ||
with open(file_path, "a") as file: | ||
file.write(f"{elapsed_time:.6f},\n") | ||
return result | ||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
# Asynchronous function to store a program on the nillion client | ||
@async_timer("bench/store_program.txt") | ||
async def store_program(client, user_id, cluster_id, program_name, program_mir_path): | ||
action_id = await client.store_program(cluster_id, program_name, program_mir_path) | ||
program_id = f"{user_id}/{program_name}" | ||
print("Stored program. action_id:", action_id) | ||
print("Stored program_id:", program_id) | ||
return program_id | ||
|
||
|
||
# Asynchronous function to store secrets on the nillion client | ||
@async_timer("bench/store_secrets.txt") | ||
async def store_secrets(client, cluster_id, program_id, party_id, party_name, secrets): | ||
secret_bindings = nillion.ProgramBindings(program_id) | ||
secret_bindings.add_input_party(party_name, party_id) | ||
|
||
# Store the secret for the specified party | ||
store_id = await client.store_secrets(cluster_id, secret_bindings, secrets, None) | ||
return store_id | ||
|
||
|
||
# Asynchronous function to perform computation on the nillion client | ||
@async_timer("bench/compute.txt") | ||
async def compute( | ||
client, cluster_id, compute_bindings, store_ids, computation_time_secrets | ||
): | ||
compute_id = await client.compute( | ||
cluster_id, | ||
compute_bindings, | ||
store_ids, | ||
computation_time_secrets, | ||
nillion.PublicVariables({}), | ||
) | ||
|
||
# Monitor and print the computation result | ||
print(f"The computation was sent to the network. compute_id: {compute_id}") | ||
while True: | ||
compute_event = await client.next_compute_event() | ||
if isinstance(compute_event, nillion.ComputeFinishedEvent): | ||
print(f"✅ Compute complete for compute_id {compute_event.uuid}") | ||
return compute_event.result.value | ||
|
||
|
||
# Main asynchronous function to coordinate the process | ||
async def main(): | ||
cluster_id = os.getenv("NILLION_CLUSTER_ID") | ||
userkey = getUserKeyFromFile(os.getenv("NILLION_USERKEY_PATH_PARTY_1")) | ||
nodekey = getNodeKeyFromFile(os.getenv("NILLION_NODEKEY_PATH_PARTY_1")) | ||
client = create_nillion_client(userkey, nodekey) | ||
party_id = client.party_id | ||
user_id = client.user_id | ||
party_names = na_client.parties(2) | ||
program_name = "main" | ||
program_mir_path = f"./target/{program_name}.nada.bin" | ||
|
||
if not os.path.exists("bench"): | ||
os.mkdir("bench") | ||
|
||
na.set_log_scale(50) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jcabrero this is indeed stupidly high - reason is that some prophet parameters are extremely small (ie 1e-13) which means that they get rounded to zero - which is not yet supported. |
||
|
||
# Store the program | ||
program_id = await store_program( | ||
client, user_id, cluster_id, program_name, program_mir_path | ||
) | ||
|
||
# Train prophet model | ||
model = Prophet() | ||
|
||
ds = pd.date_range("2024-05-01", "2024-05-17").tolist() | ||
y = np.arange(1, 18).tolist() | ||
|
||
fit_model = model.fit(df=pd.DataFrame({"ds": ds, "y": y})) | ||
|
||
print("Model params are:", fit_model.params) | ||
print("Number of detected changepoints:", fit_model.n_changepoints) | ||
|
||
# Create and store model secrets via ModelClient | ||
model_client = ProphetClient(fit_model) | ||
model_secrets = nillion.Secrets( | ||
model_client.export_state_as_secrets("my_prophet", na.SecretRational) | ||
) | ||
|
||
model_store_id = await store_secrets( | ||
client, cluster_id, program_id, party_id, party_names[0], model_secrets | ||
) | ||
|
||
# Store inputs to perform inference for | ||
future_df = fit_model.make_future_dataframe(periods=3) | ||
inference_ds = fit_model.setup_dataframe(future_df.copy()) | ||
|
||
my_input = {} | ||
my_input.update( | ||
na_client.array(inference_ds["floor"].to_numpy(), "floor", na.SecretRational) | ||
) | ||
my_input.update( | ||
na_client.array(inference_ds["t"].to_numpy(), "t", na.SecretRational) | ||
) | ||
|
||
input_secrets = nillion.Secrets(my_input) | ||
|
||
data_store_id = await store_secrets( | ||
client, cluster_id, program_id, party_id, party_names[1], input_secrets | ||
) | ||
|
||
# Set up the compute bindings for the parties | ||
compute_bindings = nillion.ProgramBindings(program_id) | ||
[ | ||
compute_bindings.add_input_party(party_name, party_id) | ||
for party_name in party_names | ||
] | ||
compute_bindings.add_output_party(party_names[1], party_id) | ||
|
||
print(f"Computing using program {program_id}") | ||
print(f"Use secret store_id: {model_store_id} {data_store_id}") | ||
|
||
# Perform the computation and return the result | ||
result = await compute( | ||
client, | ||
cluster_id, | ||
compute_bindings, | ||
[model_store_id, data_store_id], | ||
nillion.Secrets({}), | ||
) | ||
|
||
# Sort & rescale the obtained results by the quantization scale | ||
outputs = [ | ||
result[1] / 2**50 | ||
for result in sorted( | ||
result.items(), | ||
key=lambda x: int(x[0].replace("my_output", "").replace("_", "")), | ||
) | ||
] | ||
|
||
print(f"🖥️ The result is {outputs}") | ||
|
||
expected = fit_model.predict(inference_ds)["yhat"].to_numpy() | ||
print(f"🖥️ VS expected plain-text result {expected}") | ||
return result | ||
|
||
|
||
# Run the main function if the script is executed directly | ||
if __name__ == "__main__": | ||
asyncio.run(main()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import os | ||
import py_nillion_client as nillion | ||
from helpers.nillion_payments_helper import create_payments_config | ||
|
||
|
||
def create_nillion_client(userkey, nodekey): | ||
bootnodes = [os.getenv("NILLION_BOOTNODE_MULTIADDRESS")] | ||
payments_config = create_payments_config() | ||
|
||
return nillion.NillionClient( | ||
nodekey, bootnodes, nillion.ConnectionMode.relay(), userkey, payments_config | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import os | ||
import py_nillion_client as nillion | ||
|
||
|
||
def getUserKeyFromFile(userkey_filepath): | ||
return nillion.UserKey.from_file(userkey_filepath) | ||
|
||
|
||
def getNodeKeyFromFile(nodekey_filepath): | ||
return nillion.NodeKey.from_file(nodekey_filepath) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import os | ||
import py_nillion_client as nillion | ||
|
||
|
||
def create_payments_config(): | ||
return nillion.PaymentsConfig( | ||
os.getenv("NILLION_BLOCKCHAIN_RPC_ENDPOINT"), | ||
os.getenv("NILLION_WALLET_PRIVATE_KEY"), | ||
int(os.getenv("NILLION_CHAIN_ID")), | ||
os.getenv("NILLION_PAYMENTS_SC_ADDRESS"), | ||
os.getenv("NILLION_BLINDING_FACTORS_MANAGER_SC_ADDRESS"), | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import nada_algebra as na | ||
import numpy as np | ||
from nada_ai.time_series import Prophet | ||
|
||
|
||
def nada_main(): | ||
na.set_log_scale(50) | ||
|
||
# Step 1: We use Nada Algebra wrapper to create "Party0" and "Party1" | ||
parties = na.parties(2) | ||
|
||
# Step 2: Instantiate model object | ||
my_prophet = Prophet( | ||
n_changepoints=12, # NOTE: this is a learned hyperparameter | ||
yearly_seasonality=False, | ||
weekly_seasonality=True, | ||
daily_seasonality=False, | ||
) | ||
|
||
# Step 3: Load model weights from Nillion network by passing model name (acts as ID) | ||
# In this examples Party0 provides the model and Party1 runs inference | ||
my_prophet.load_state_from_network("my_prophet", parties[0], na.SecretRational) | ||
|
||
# Step 4: Load input data to be used for inference (provided by Party1) | ||
dates = np.arange(np.datetime64("2024-05-01"), np.datetime64("2024-05-21")) | ||
|
||
floor = na.array((20,), parties[1], "floor", na.SecretRational) | ||
t = na.array((20,), parties[1], "t", na.SecretRational) | ||
|
||
# Step 5: Compute inference | ||
# Note: completely equivalent to `my_model.forward(...)` or `model.predict(...)` | ||
result = my_prophet(dates, floor, t) | ||
|
||
# Step 6: We can use result.output() to produce the output for Party1 and variable name "my_output" | ||
return result.output(parties[1], "my_output") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably that is not the way we want to do it, but (if you consider it relevant) we can put here
na_client.float_from_rational(result[1])
.