Skip to content

Commit

Permalink
Merge pull request #37 from aai-institute/join/feature-dataset
Browse files Browse the repository at this point in the history
Follow-up in join/feature dataset
  • Loading branch information
Samuel Burbulla authored Feb 14, 2024
2 parents f7f371e + cf3f83c commit 1e1e0bc
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 43 deletions.
5 changes: 2 additions & 3 deletions src/continuity/operators/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from torch import Tensor


class ResidualLayer(torch.nn.Module):
Expand All @@ -14,7 +13,7 @@ def __init__(self, width: int):
self.layer = torch.nn.Linear(width, width)
self.act = torch.nn.Tanh()

def forward(self, x: Tensor):
def forward(self, x: torch.Tensor):
"""Forward pass."""
return self.act(self.layer(x)) + x

Expand Down Expand Up @@ -73,7 +72,7 @@ def __init__(
kernel_depth,
)

def forward(self, x: Tensor, y: Tensor) -> Tensor:
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Compute kernel value.
Args:
Expand Down
14 changes: 5 additions & 9 deletions src/continuity/operators/deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
):
super().__init__()

self.dataset_shape = shapes
self.shapes = shapes

self.basis_functions = basis_functions
self.dot_dim = shapes.v.dim * basis_functions
Expand Down Expand Up @@ -74,25 +74,21 @@ def forward(

# flatten inputs for both trunk and branch network
u = u.flatten(1, -1)
assert u.shape[1:] == torch.Size(
[self.dataset_shape.u.num * self.dataset_shape.u.dim]
)
assert u.shape[1:] == torch.Size([self.shapes.u.num * self.shapes.u.dim])

y = y.flatten(0, 1)
assert u.shape[1:] == torch.Size(
[self.dataset_shape.u.num * self.dataset_shape.u.dim]
)
assert u.shape[1:] == torch.Size([self.shapes.u.num * self.shapes.u.dim])

# Pass through branch and trunk networks
b = self.branch(u)
t = self.trunk(y)

# dot product
b = b.reshape(-1, self.dataset_shape.v.dim, self.basis_functions)
b = b.reshape(-1, self.shapes.v.dim, self.basis_functions)
t = t.reshape(
b.size(0),
-1,
self.dataset_shape.v.dim,
self.shapes.v.dim,
self.basis_functions,
)
dot_prod = torch.einsum("abcd,acd->abc", t, b)
Expand Down
19 changes: 14 additions & 5 deletions src/continuity/operators/losses.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Loss functions."""

import torch
from torch import Tensor
from abc import abstractmethod
from typing import TYPE_CHECKING

Expand All @@ -14,8 +13,13 @@ class Loss:

@abstractmethod
def __call__(
self, op: "Operator", x: Tensor, u: Tensor, y: Tensor, v: Tensor
) -> Tensor:
self,
op: "Operator",
x: torch.Tensor,
u: torch.Tensor,
y: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
"""Evaluate loss.
Args:
Expand All @@ -34,8 +38,13 @@ def __init__(self):
self.mse = torch.nn.MSELoss()

def __call__(
self, op: "Operator", x: Tensor, u: Tensor, y: Tensor, v: Tensor
) -> Tensor:
self,
op: "Operator",
x: torch.Tensor,
u: torch.Tensor,
y: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
"""Evaluate MSE loss.
Args:
Expand Down
26 changes: 11 additions & 15 deletions src/continuity/operators/neuraloperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,36 +82,36 @@ class NeuralOperator(Operator):

def __init__(
self,
dataset_shape: DatasetShapes,
shapes: DatasetShapes,
depth: int = 1,
kernel_width: int = 32,
kernel_depth: int = 3,
):
super().__init__()

self.dataset_shape = dataset_shape
self.shapes = shapes

self.lifting = ContinuousConvolution(
NeuralNetworkKernel(kernel_width, kernel_depth),
dataset_shape.x.dim,
dataset_shape.u.dim,
shapes.x.dim,
shapes.u.dim,
)

self.hidden_layers = torch.nn.ModuleList(
[
ContinuousConvolution(
NeuralNetworkKernel(kernel_width, kernel_depth),
dataset_shape.x.dim,
dataset_shape.u.dim,
shapes.x.dim,
shapes.u.dim,
)
for _ in range(depth)
]
)

self.projection = ContinuousConvolution(
NeuralNetworkKernel(kernel_width, kernel_depth),
dataset_shape.x.dim,
dataset_shape.u.dim,
shapes.x.dim,
shapes.u.dim,
)

def forward(
Expand All @@ -129,22 +129,18 @@ def forward(
"""
# Lifting layer (we use x as evaluation coordinates for now)
v = self.lifting(x, u, x)
assert v.shape[1:] == torch.Size(
[self.dataset_shape.x.num, self.dataset_shape.u.dim]
)
assert v.shape[1:] == torch.Size([self.shapes.x.num, self.shapes.u.dim])

# Hidden layers
for layer in self.hidden_layers:
# Layer operation (with residual connection)
v = layer(x, v, x) + v
assert v.shape[1:] == torch.Size(
[self.dataset_shape.x.num, self.dataset_shape.u.dim]
)
assert v.shape[1:] == torch.Size([self.shapes.x.num, self.shapes.u.dim])

# Activation
v = torch.tanh(v)

# Projection layer
w = self.projection(x, v, y)
assert w.shape[1:] == torch.Size([y.size(1), self.dataset_shape.u.dim])
assert w.shape[1:] == torch.Size([y.size(1), self.shapes.u.dim])
return w
3 changes: 1 addition & 2 deletions src/continuity/operators/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def compile(
"""Compile operator.
Args:
verbose: Print number of model parameters to stdout.
optimizer: Torch-like optimizer.
loss_fn: Loss function taking (x, u, y, v). Defaults to MSELoss.
verbose: Print number of model parameters to stdout.
"""
self.optimizer = optimizer
self.loss_fn = loss_fn or MSELoss()
Expand All @@ -62,7 +62,6 @@ def fit(
"""Fit operator to data set.
Args:
batch_size: Batch size.
dataset: Data set.
epochs: Number of epochs.
callbacks: List of callbacks.
Expand Down
20 changes: 15 additions & 5 deletions src/continuity/pde/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Every PDE is implemented using a physics-informed loss function.
"""

from torch import Tensor
import torch
from abc import abstractmethod

from continuity.operators.operator import Operator
Expand All @@ -17,8 +17,13 @@ class PDE:

@abstractmethod
def __call__(
self, op: Operator, x: Tensor, u: Tensor, y: Tensor, v: Tensor
) -> Tensor:
self,
op: Operator,
x: torch.Tensor,
u: torch.Tensor,
y: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
"""Computes PDE loss."""


Expand All @@ -33,8 +38,13 @@ def __init__(self, pde: PDE):
self.pde = pde

def __call__(
self, op: Operator, x: Tensor, u: Tensor, y: Tensor, v: Tensor
) -> Tensor:
self,
op: Operator,
x: torch.Tensor,
u: torch.Tensor,
y: torch.Tensor,
_: torch.Tensor,
) -> torch.Tensor:
"""Evaluate loss.
Args:
Expand Down
5 changes: 2 additions & 3 deletions src/continuity/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@

import torch
import numpy as np
from torch import Tensor
from typing import Optional
from matplotlib.axis import Axis
import matplotlib.pyplot as plt
from continuity.operators import Operator


def plot(x: Tensor, u: Tensor, ax: Optional[Axis] = None):
def plot(x: torch.Tensor, u: torch.Tensor, ax: Optional[Axis] = None):
"""Plots a function $u(x)$.
Currently only supports coordinate dimensions of $d = 1,2$.
Expand Down Expand Up @@ -44,7 +43,7 @@ def plot(x: Tensor, u: Tensor, ax: Optional[Axis] = None):


def plot_evaluation(
operator: Operator, x: Tensor, u: Tensor, ax: Optional[Axis] = None
operator: Operator, x: torch.Tensor, u: torch.Tensor, ax: Optional[Axis] = None
):
"""Plots the mapped function `operator(observation)` evaluated on a $[-1, 1]^d$ grid.
Expand Down
2 changes: 1 addition & 1 deletion tests/operators/test_neuraloperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_neuraloperator():

# Operator
operator = NeuralOperator(
dataset_shape=dataset.shapes,
shapes=dataset.shapes,
depth=1,
kernel_width=32,
kernel_depth=3,
Expand Down

0 comments on commit 1e1e0bc

Please sign in to comment.