Skip to content

Commit

Permalink
Merge pull request #154 from aai-institute/feature/deep-cat-operator
Browse files Browse the repository at this point in the history
Feature: Deep Cat Operator
  • Loading branch information
Samuel Burbulla authored Aug 19, 2024
2 parents 535c574 + b7c0b10 commit 9f08539
Show file tree
Hide file tree
Showing 7 changed files with 705 additions and 31 deletions.
6 changes: 0 additions & 6 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
with:
lfs: "true"
- name: Setup Python
uses: actions/setup-python@v5
with:
Expand All @@ -52,8 +50,6 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
with:
lfs: "true"
- name: Setup Python
uses: actions/setup-python@v5
with:
Expand All @@ -77,8 +73,6 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
with:
lfs: "true"
- name: Setup Python
uses: actions/setup-python@v5
with:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Add `Attention` base class, `MultiHeadAttention`, and `ScaledDotProductAttention` classes.
- Add `branch_network` and `trunk_network` arguments to `DeepONet` to allow for custom network architectures.
- Add `MaskedOperator` base class.
- Add `DeepCatOperator`.

## 0.1.0

Expand Down
523 changes: 498 additions & 25 deletions examples/meshes.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/continuiti/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .fno import FourierNeuralOperator
from .shape import OperatorShapes
from .cnn import ConvolutionalNeuralNetwork
from .dco import DeepCatOperator

__all__ = [
"Operator",
Expand All @@ -29,4 +30,5 @@
"DeepNeuralOperator",
"FourierNeuralOperator",
"ConvolutionalNeuralNetwork",
"DeepCatOperator",
]
150 changes: 150 additions & 0 deletions src/continuiti/operators/dco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
`continuiti.operators.dco`
The DeepCatOperator (DCO) architecture.
"""

import torch
import torch.nn as nn
from typing import Optional
from math import ceil, prod
from .operator import Operator, OperatorShapes
from continuiti.networks import DeepResidualNetwork


class DeepCatOperator(Operator):
"""Deep Cat Operator.
This class implements the DeepCatOperator, a neural operator inspired by the DeepONet.
It consists of three main parts:
1. **Branch Network**: Processes the sensor inputs (`u`).
2. **Trunk Network**: Processes the evaluation locations (`y`).
3. **Cat Network**: Combines the outputs from the Branch- and Trunk-Network to produce the final output.
The architecture has the following structure:
┌─────────────────────┐ ┌────────────────────┐
│ *Branch Network* │ │ *Trunk Network* │
│ Input (u) │ │ Input (y) │
│ Output (b) │ │ Output (t) │
└─────────────────┬───┘ └──┬─────────────────┘
┌ ─ ─ ─ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ─ ─ ─ ┐
│ *Concatenation* │
│ Input (b, t) │
│ Output (c) │
│ branch_cat_ratio = b.numel() / cat_net_width │
└ ─ ─ ─ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘
┌─────────┴────────┐
│ *Cat Network* │
│ Input (c) │
│ Output (v) │
└──────────────────┘
This allows the operator to integrate evaluation locations earlier, while ensuring that both the sensor inputs and
the evaluation location contribute in a predictable form to the flow of information. Directly stacking both the
sensors and evaluation location can lead to an imbalance in the number of features in the neural operator. The
arg `branch_cat_ratio` dictates how this fraction is set (defaults to 50/50). The cat-network does not require the
neural operator to learn good basis functions with the trunk network only. The information from the input space and
the evaluation locations can be taken into account early, allowing for better abstraction.
Args:
shapes: Operator shapes.
branch_width: Width of the branch net (deep residual network). Defaults to 32.
branch_depth: Depth of the branch net (deep residual network). Defaults to 4.
trunk_width: Width of the trunk net (deep residual network). Defaults to 32.
trunk_depth: Depth of the trunk net (deep residual network). Defaults to 4.
branch_cat_ratio: Ratio indicating which fraction of the concatenated tensor originates from the branch net.
Controls flow of information into branch- and trunk-net. Defaults to 0.5.
cat_net_width: Width of the cat net (deep residual network). Defaults to 32.
cat_net_depth: Depth of the cat net (deep residual network). Defaults to 4.
act: Activation function. Defaults to Tanh.
device: Device.
"""

def __init__(
self,
shapes: OperatorShapes,
branch_width: int = 32,
branch_depth: int = 4,
trunk_width: int = 32,
trunk_depth: int = 4,
branch_cat_ratio: float = 0.5,
cat_net_width: int = 32,
cat_net_depth: int = 4,
act: Optional[nn.Module] = None,
device: Optional[torch.device] = None,
):
super().__init__(shapes=shapes, device=device)

if act is None:
act = nn.Tanh()

assert (
0.0 < branch_cat_ratio < 1.0
), f"Ratio has to be in (0, 1), but found {branch_cat_ratio}"
branch_out_width = ceil(cat_net_width * branch_cat_ratio)
assert (
branch_out_width != cat_net_width
), f"Input cat ratio {branch_cat_ratio} results in eval net width equal zero."

input_in_width = prod(shapes.u.size) * shapes.u.dim
self.branch_net = DeepResidualNetwork(
input_size=input_in_width,
output_size=branch_out_width,
width=branch_width,
depth=branch_depth,
act=act,
device=device,
)

eval_out_width = cat_net_width - branch_out_width
self.trunk_net = DeepResidualNetwork(
input_size=shapes.y.dim,
output_size=eval_out_width,
width=trunk_width,
depth=trunk_depth,
act=act,
device=device,
)

self.cat_act = act # no activation before first and after last layer

self.cat_net = DeepResidualNetwork(
input_size=cat_net_width,
output_size=shapes.v.dim,
width=cat_net_width,
depth=cat_net_depth,
act=act,
device=device,
)

def forward(
self, _: torch.Tensor, u: torch.Tensor, y: torch.Tensor
) -> torch.Tensor:
"""Forward pass through the operator.
Args:
_: Tensor containing sensor locations. Ignored.
u: Tensor containing values of sensors of shape (batch_size, u_dim, num_sensors...).
y: Tensor containing evaluation locations of shape (batch_size, y_dim, num_evaluations...).
Returns:
Tensor of predicted evaluation values of shape (batch_size, v_dim, num_evaluations...).
"""
ipt = torch.flatten(u, start_dim=1)
ipt = self.branch_net(ipt)

y_num = y.shape[2:]
eval = y.flatten(start_dim=2).transpose(1, -1)
eval = self.trunk_net(eval)

ipt = ipt.unsqueeze(1).expand(-1, eval.size(1), -1)
cat = torch.cat([ipt, eval], dim=-1)
out = self.cat_act(cat)
out = self.cat_net(out)

return out.reshape(-1, self.shapes.v.dim, *y_num)
5 changes: 5 additions & 0 deletions tests/operators/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,8 @@ def random_shape_operator_datasets() -> List[OperatorDataset]:
)

return datasets


@pytest.fixture(scope="session")
def random_operator_dataset(random_shape_operator_datasets) -> OperatorDataset:
return random_shape_operator_datasets[0]
49 changes: 49 additions & 0 deletions tests/operators/test_dco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from typing import List

from continuiti.operators import DeepCatOperator
from continuiti.benchmarks.sine import SineBenchmark
from continuiti.trainer import Trainer
from continuiti.operators.losses import MSELoss
from continuiti.networks import DeepResidualNetwork

from .util import get_shape_mismatches


@pytest.fixture(scope="module")
def dcos(random_shape_operator_datasets) -> List[DeepCatOperator]:
return [
DeepCatOperator(dataset.shapes) for dataset in random_shape_operator_datasets
]


class TestDeepCatOperator:
def test_can_initialize(self, random_operator_dataset):
operator = DeepCatOperator(random_operator_dataset.shapes)

assert isinstance(operator, DeepCatOperator)

def test_can_initialize_default_networks(self, random_operator_dataset):
operator = DeepCatOperator(shapes=random_operator_dataset.shapes)

assert isinstance(operator.branch_net, DeepResidualNetwork)
assert isinstance(operator.trunk_net, DeepResidualNetwork)
assert isinstance(operator.cat_net, DeepResidualNetwork)

def test_forward_shapes_correct(self, dcos, random_shape_operator_datasets):
assert get_shape_mismatches(dcos, random_shape_operator_datasets) == []

@pytest.mark.slow
def test_can_overfit(self):
# Data set
dataset = SineBenchmark(n_train=1).train_dataset

# Operator
operator = DeepCatOperator(dataset.shapes)

# Train
Trainer(operator).fit(dataset, tol=1e-2)

# Check solution
x, u, y, v = dataset.x, dataset.u, dataset.y, dataset.v
assert MSELoss()(operator, x, u, y, v) < 1e-2

0 comments on commit 9f08539

Please sign in to comment.