Skip to content

Commit

Permalink
Merge pull request #16 from aai-institute/feature/refactor-package
Browse files Browse the repository at this point in the history
Refactor model submodule.
  • Loading branch information
Samuel Burbulla authored Dec 7, 2023
2 parents 54a2581 + 076b031 commit 68779c5
Show file tree
Hide file tree
Showing 27 changed files with 1,209 additions and 787 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
test_*.png

# Translations
*.mo
Expand Down
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,6 @@ repos:
additional_dependencies: [flake8-bugbear]
args: ["--max-line-length=80"]

# Jupyter notebook cell output clearing
- repo: https://github.com/kynan/nbstripout
rev: 0.6.0
hooks:
- id: nbstripout

# Formatting yaml
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.1.0
Expand Down
109 changes: 82 additions & 27 deletions notebooks/sine.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ install_requires =
# wandb
# neptune-client
# mlflow
# nbstripout # remove output from jupyter notebooks
# comet-ml
tensorboard

Expand Down
130 changes: 130 additions & 0 deletions src/continuity/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
In Continuity, data is given by *observations*. Every observation is a set of
function evaluations, so-called *sensors*. Every data set is a set of
observations, evaluation coordinates and labels.
"""

import torch
from torch import Tensor
from numpy import ndarray
from typing import List, Tuple
from abc import abstractmethod


def get_device() -> torch.device:
"""Get torch device.
Returns:
Device.
"""
device = torch.device("cpu")

# if torch.backends.mps.is_available():
# device = torch.device("mps")

if torch.cuda.is_available():
device = torch.device("cuda")

return device


device = get_device()


def tensor(x):
"""Default conversion for tensors."""
return torch.tensor(x, device=device, dtype=torch.float32)


class Sensor:
"""
A sensor is a function evaluation.
Args:
x: spatial coordinate of shape (coordinate_dim)
u: function value of shape (num_channels)
"""

def __init__(self, x: ndarray, u: ndarray):
self.x = x
self.u = u

self.coordinate_dim = x.shape[0]
self.num_channels = u.shape[0]

def __str__(self) -> str:
return f"Sensor(x={self.x}, u={self.u})"


class Observation:
"""
An observation is a set of sensors.
Args:
sensors: List of sensors. Used to derive 'num_sensors', 'coordinate_dim' and 'num_channels'.
"""

def __init__(self, sensors: List[Sensor]):
self.sensors = sensors

self.num_sensors = len(sensors)
assert self.num_sensors > 0

self.coordinate_dim = self.sensors[0].coordinate_dim
self.num_channels = self.sensors[0].num_channels

# Check consistency across sensors
for sensor in self.sensors:
assert (
sensor.coordinate_dim == self.coordinate_dim
), "Inconsistent coordinate dimension."
assert (
sensor.num_channels == self.num_channels
), "Inconsistent number of channels."

def __str__(self) -> str:
s = "Observation(sensors=\n"
for sensor in self.sensors:
s += f" {sensor}, \n"
s += ")"
return s

def to_tensor(self) -> torch.Tensor:
"""Convert observation to tensor.
Returns:
Tensor of shape (num_sensors, coordinate_dim + num_channels)
"""
u = torch.zeros((self.num_sensors, self.coordinate_dim + self.num_channels))
for i, sensor in enumerate(self.sensors):
u[i] = torch.concat([tensor(sensor.x), tensor(sensor.u)])

# Move to device
u.to(device)

return u


class DataSet:
"""Data set base class."""

@abstractmethod
def __len__(self) -> int:
"""Return number of batches.
Returns:
Number of batches.
"""

@abstractmethod
def __getitem__(self, i: int) -> Tuple[Tensor, Tensor, Tensor]:
"""Return i-th batch as a tuple `(u, x, v)` with tensors for
observations `u`, coordinates `x` and labels `v`.
Args:
i: Index of batch.
Returns:
Batch tuple `(u, x, v)`.
"""
130 changes: 0 additions & 130 deletions src/continuity/data/dataset.py

This file was deleted.

Loading

0 comments on commit 68779c5

Please sign in to comment.