Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Join: feature dataset #35

Merged
merged 17 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 121 additions & 75 deletions notebooks/basics.ipynb

Large diffs are not rendered by default.

156 changes: 102 additions & 54 deletions notebooks/physicsinformed.ipynb

Large diffs are not rendered by default.

96 changes: 61 additions & 35 deletions notebooks/selfsupervised.ipynb

Large diffs are not rendered by default.

121 changes: 74 additions & 47 deletions notebooks/superresolution.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions src/continuity/benchmarks/sine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Sine benchmark."""

from continuity.benchmarks import Benchmark
from continuity.data import split
from continuity.data.datasets import Sine
from continuity.data import Sine, split
from continuity.operators.losses import Loss, MSELoss
from torch.utils.data import Dataset

Expand Down
25 changes: 18 additions & 7 deletions src/continuity/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@
import os
import torch

from .dataset import OperatorDataset, SelfSupervisedOperatorDataset
from .shape import DatasetShapes
from .sine import Sine
from .flame import Flame, FlameDataLoader

__all__ = [
"OperatorDataset",
"SelfSupervisedOperatorDataset",
"DatasetShapes",
"Sine",
"Flame",
"FlameDataLoader",
"device",
"split",
]


def get_device() -> torch.device:
"""Get torch device.
Expand All @@ -34,11 +50,6 @@ def get_device() -> torch.device:
device = get_device()


def tensor(x):
"""Default conversion for tensors."""
return torch.tensor(x, dtype=torch.float32)


def split(dataset, split=0.5, seed=None):
"""
Split data set into two parts.
Expand Down Expand Up @@ -70,7 +81,7 @@ def dataset_loss(dataset, operator, loss_fn):
loss = 0.0

for x, u, y, v in dataset:
batch_size = x.shape[0]
loss += loss_fn(operator, x, u, y, v) / batch_size
x, u, y, v = x.unsqueeze(0), u.unsqueeze(0), y.unsqueeze(0), v.unsqueeze(0)
loss += loss_fn(operator, x, u, y, v)

return loss
161 changes: 161 additions & 0 deletions src/continuity/data/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""
`continuity.data`

Data sets in Continuity.
Every data set is a list of `(x, u, y, v)` tuples.
"""

import torch
import torch.utils.data as td
from typing import Tuple

from .shape import DatasetShapes, TensorShape


class OperatorDataset(td.Dataset):
"""A dataset for operator training.

In operator training, at least one function is mapped onto a second one. To fulfill the properties discretization
invariance, domain independence and learn operators with physics-based loss access to at least four different
discretized spaces is necessary. One on which the input is sampled (x), the input function sampled on these points
(u), the discretization of the output space (y), and the output of the operator (v) sampled on these points. Not
all loss functions and/or operators need access to all of these attributes.

Args:
x: Tensor of shape (#observations, #sensors, x-dim) with sensor positions.
u: Tensor of shape (#observations, #sensors, u-dim) with evaluations of the input functions at sensor positions.
y: Tensor of shape (#observations, #evaluations, y-dim) with evaluation positions.
v: Tensor of shape (#observations, #evaluations, v-dim) with ground truth operator mappings.

Attributes:
shapes (dataclass): Shape of all tensors.
transform (dict): Transformations for each tensor.
"""

def __init__(
self,
x: torch.Tensor,
u: torch.Tensor,
y: torch.Tensor,
v: torch.Tensor,
x_transform=None,
u_transform=None,
y_transform=None,
v_transform=None,
):
assert x.ndim == u.ndim == y.ndim == v.ndim == 3, "Wrong number of dimensions."
assert (
x.size(0) == u.size(0) == y.size(0) == v.size(0)
), "Inconsistent number of observations."
assert x.size(1) == u.size(1), "Inconsistent number of sensors."
assert y.size(1) == v.size(1), "Inconsistent number of evaluations."

super().__init__()

self.x = x
self.u = u
self.y = y
self.v = v

# used to initialize architectures
self.shapes = DatasetShapes(
num_observations=int(x.size(0)),
x=TensorShape(*x.size()[1:]),
u=TensorShape(*u.size()[1:]),
y=TensorShape(*y.size()[1:]),
v=TensorShape(*v.size()[1:]),
)

self.transform = {
dim: tf
for dim, tf in [
("x", x_transform),
("u", u_transform),
("y", y_transform),
("v", v_transform),
]
if tf is not None
}

def __len__(self) -> int:
"""Return the number of samples.

Returns:
number of samples in the entire set.
"""
return self.shapes.num_observations

def __getitem__(
self, idx
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Retrieves the input-output pair at the specified index and applies transformations.

Parameters:
- idx: The index of the sample to retrieve.

Returns:
A tuple containing the three input tensors and the output tensor for the given index.
"""
sample = {
"x": self.x[idx],
"u": self.u[idx],
"y": self.y[idx],
"v": self.v[idx],
}

# transform
for dim, val in sample.items():
if dim in self.transform:
sample[dim] = self.transform[dim](val)

return sample["x"], sample["u"], sample["y"], sample["v"]


class SelfSupervisedOperatorDataset(OperatorDataset):
"""
A `SelfSupervisedOperatorDataset` is a data set that contains data for self-supervised learning.
Every data point is created by taking one sensor as a label.

Every observation consists of tuples `(x, u, y, v)`, where `x` contains the sensor
positions, `u` the sensor values, and `y = x_i` and `v = u_i` are
the label's coordinate its value for all `i`.

Args:
x: Sensor positions of shape (num_observations, num_sensors, coordinate_dim)
u: Sensor values of shape (num_observations, num_sensors, num_channels)
"""

def __init__(self, x: torch.Tensor, u: torch.Tensor):
self.num_observations = u.shape[0]
JakobEliasWagner marked this conversation as resolved.
Show resolved Hide resolved
self.num_sensors = u.shape[1]
self.coordinate_dim = x.shape[-1]
self.num_channels = u.shape[-1]

# Check consistency across observations
for i in range(self.num_observations):
assert (
x[i].shape[-1] == self.coordinate_dim
), "Inconsistent coordinate dimension."
assert (
u[i].shape[-1] == self.num_channels
), "Inconsistent number of channels."

xs, us, ys, vs = [], [], [], []

for i in range(self.num_observations):
# Add one data point for every sensor
for j in range(self.num_sensors):
y = x[i][j].unsqueeze(0)
v = u[i][j].unsqueeze(0)

xs.append(x[i])
us.append(u[i])
ys.append(y)
vs.append(v)

xs = torch.stack(xs)
us = torch.stack(us)
ys = torch.stack(ys)
vs = torch.stack(vs)

super().__init__(xs, us, ys, vs)
143 changes: 0 additions & 143 deletions src/continuity/data/datasets.py

This file was deleted.

Loading
Loading