Skip to content

Commit

Permalink
Train model v0.11.1
Browse files Browse the repository at this point in the history
  • Loading branch information
sdesabbata committed Nov 1, 2024
1 parent 3b2b2b2 commit eaeb0c9
Show file tree
Hide file tree
Showing 17 changed files with 82,330 additions and 14,649 deletions.
68,110 changes: 68,110 additions & 0 deletions analysis/gnnuf_exploratory_analysis_v0-11-emb_Leicester.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,15 @@
print("Run on CPU")

# Load model
model_name = "gnnuf_model_v0-8-1"
model_name = "gnnuf_model_v0-11-1"
model = GAE(
GINEEncoder(
SimpleSparseGINEEncoder(
in_channels=1,
edge_dim=1,
gine_mlp_channels=[
[1, 2, 4, 8],
[8, 16, 32, 64],
[64, 128, 256, 512]
],
out_channels=512
out_channels=8
)
)
out_channels_for_df = 512
out_channels_for_df = 8
model.load_state_dict(torch.load(this_repo_directory + "/models/" + model_name + ".pt", map_location=device))
model = model.to(device)
model.eval()
Expand All @@ -53,7 +48,7 @@
print(model_info_str)

# Load Leciester's graph
leicester = ox.io.load_graphml(bulk_storage_directory + "/osmnx/raw_excluded/leicester-1864.graphml")
leicester = ox.io.load_graphml(bulk_storage_directory + "/osmnx/raw/leicester-1864.graphml")

leicester_embs = {}
neighbourhood_min_nodes = 8
Expand Down
575 changes: 104 additions & 471 deletions code/gnnuf_models.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,12 @@
osmnx_loader_test = DataLoader(osmnx_dataset_test, batch_size=32, shuffle=True)

# Define the model
model_name = "gnnuf_model_v0-8-1"
model_name = "gnnuf_model_v0-11-1"
model = GAE(
GINEEncoder(
SimpleSparseGINEEncoder(
in_channels=1,
edge_dim=1,
gine_mlp_channels=[
[1, 2, 4, 8],
[8, 16, 32, 64],
[64, 128, 256, 512]
],
out_channels=512
out_channels=8
)
)
model = model.to(device)
Expand Down
14 changes: 9 additions & 5 deletions code/osmnx_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Base libraries
import os
from os.path import isfile, join
import glob
from datetime import datetime
import sys
import math
Expand Down Expand Up @@ -78,11 +79,14 @@ def download(self):
print(r"Please download the data manually")

def process(self):
graphml_file_names = [
join(self.root + "/raw", f)
for f in os.listdir(self.root + "/raw")
if f[-8:] == ".graphml"
if isfile(join(self.root + "/raw", f))]
# graphml_file_names = [
# join(self.root + "/raw", f)
# for f in os.listdir(self.root + "/raw")
# if f[-8:] == ".graphml"
# if isfile(join(self.root + "/raw", f))]

graphml_file_names = glob.glob(join(self.root + "/raw", '*.graphml'))
graphml_file_names = sorted(graphml_file_names, key=os.path.getsize)

neighbourhoods_list = []

Expand Down
Loading

0 comments on commit eaeb0c9

Please sign in to comment.