Skip to content

Commit

Permalink
first code update
Browse files Browse the repository at this point in the history
  • Loading branch information
mv-lab committed Jul 7, 2023
1 parent 5a631d7 commit 2253d71
Show file tree
Hide file tree
Showing 11 changed files with 599 additions and 10 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Custom

*.ipynb
!nilut-multiblend.ipynb

dataset/*.png

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
58 changes: 48 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,29 +1,67 @@
# NILUT: Conditional Neural Implicit 3D Lookup Tables for Image Enhancement
# [NILUT: Conditional Neural Implicit 3D Lookup Tables for Image Enhancement](https://arxiv.org/abs/2306.11920)

[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2306.11920)
[<a href=""><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="colab demo"></a>]()
[<a href="https://www.kaggle.com/code/jesucristo/super-resolution-demo-swin2sr-official/"><img src="https://upload.wikimedia.org/wikipedia/commons/7/7c/Kaggle_logo.png?20140912155123" alt="kaggle demo" width=50></a>]()

[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)]()

[Marcos V. Conde](https://scholar.google.com/citations?user=NtB1kjYAAAAJ&hl=en), [Javier Vazquez-Corral](https://scholar.google.com/citations?user=gjnuPMoAAAAJ&hl=en), [Michael S. Brown](https://scholar.google.com/citations?hl=en&user=Gv1QGSMAAAAJ), [Radu Timofte](https://scholar.google.com/citations?user=u3MwH5kAAAAJ&hl=en)


**TL;DR** NILUT uses neural representations for controllable photorealistic image enhancement.
**TL;DR** NILUT uses neural representations for controllable photorealistic image enhancement. 🚀 Demo Tutorial and pretrained models available.


<img src="media/nilut-intro.gif" alt="NILUT" width="800">

----

3D lookup tables (3D LUTs) are a key component for image enhancement. Modern image signal processors (ISPs) have dedicated support for these as part of the camera rendering pipeline. Cameras typically provide multiple options for picture styles, where each style is usually obtained by applying a unique handcrafted 3D LUT. Current approaches for learning and applying 3D LUTs are notably fast, yet not so memory-efficient, as storing multiple 3D LUTs is required. For this reason and other implementation limitations, their use on mobile devices is less popular.

In this work, we propose a Neural Implicit LUT (NILUT), an implicitly defined continuous 3D color transformation parameterized by a neural network. We show that NILUTs are capable of accurately emulating real 3D LUTs. Moreover, a NILUT can be extended to incorporate multiple styles into a single network with the ability to blend styles implicitly. Our novel approach is memory-efficient, controllable and can complement previous methods, including learned ISPs.


**Topics** Image Enhancement, Image Editing, Color Manipulation, Tone Mapping, Presets

***Website and repo in progress.*** **See also [AISP](https://github.com/mv-lab/AISP)** for image signal processing code and papers.

----

<br>
**Pre-trained** sample models are available at `models/`. We provide `nilutx3style.pt` a NILUT that encodes three 3D LUT styles (1,3,4) with high accuracy.

<img src="nilut-intro.gif" alt="NILUT" width="800">
**Demo Tutorial** in [nilut-multiblend.ipynb](nilut-multiblend.ipynb) we provide a simple tutorial on how to use NILUT for multi-style image enhancement and blending. The corresponding training code will be released soon.

<br>
**Dataset** The folder `dataset/` includes 100 images from the Adobe MIT 5K Dataset. The images were processed using professional 3D LUTs on Adobe Lightroom. The structure of the dataset is:

----
```
dataset/
├── 001_blend.png
├── 001_LUT01.png
├── 001_LUT02.png
├── 001_LUT03.png
├── 001_LUT04.png
├── 001_LUT05.png
├── 001_LUT08.png
├── 001_LUT10.png
└── 001.png
...
```

3D lookup tables (3D LUTs) are a key component for image enhancement. Modern image signal processors (ISPs) have dedicated support for these as part of the camera rendering pipeline. Cameras typically provide multiple options for picture styles, where each style is usually obtained by applying a unique handcrafted 3D LUT. Current approaches for learning and applying 3D LUTs are notably fast, yet not so memory-efficient, as storing multiple 3D LUTs is required. For this reason and other implementation limitations, their use on mobile devices is less popular.
where `001.png` is the input unprocessed image, `001_LUTXX.png` is the result of applying each corresponding LUT and `001_blend.png` is the example target for evaluating sytle-blending (in the example the blending is between styles 1,3, and 4 with equal weights 0.33).
The complete dataset includes 100 images `aaa.png` and their enhanced variants for each 3D LUT.

In this work, we propose a Neural Implicit LUT (NILUT), an implicitly defined continuous 3D color transformation parameterized by a neural network. We show that NILUTs are capable of accurately emulating real 3D LUTs. Moreover, a NILUT can be extended to incorporate multiple styles into a single network with the ability to blend styles implicitly. Our novel approach is memory-efficient, controllable and can complement previous methods, including learned ISPs.

----

**Contact** marcos.conde[at]uni-wuerzburg.de
Hope you like it 🤗 If you find this interesting/insightful/inspirational or you use it, do not forget to acknowledge our work:

```
@article{conde2023nilut,
title={NILUT: Conditional Neural Implicit 3D Lookup Tables for Image Enhancement},
author={Conde, Marcos V and Vazquez-Corral, Javier and Brown, Michael S and Timofte, Radu},
journal={arXiv preprint arXiv:2306.11920},
year={2023}
}
```

**Contact** marcos.conde[at]uni-wuerzburg.de

70 changes: 70 additions & 0 deletions dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import numpy as np

from utils import load_img, np_psnr


class EvalMultiLUTBlending (Dataset):
"""
Dataloader to load the input image <inp_img> and the reference target images <list_out_imgs>.
The order of the target images must be: ground-truth 3D LUT outputs (the first <nluts> elements in the list), following by gt blending results.
We will load each reference, and include the corresponding style vector a sinput to the network
Example:
test_images = EvalMultiLUTFitting('./DatasetLUTs_100images/001.png',
['./DatasetLUTs_100images/001_LUT01.png',
'./DatasetLUTs_100images/001_LUT03.png',
'./DatasetLUTs_100images/001_LUT04.png',
'./DatasetLUTs_100images/001_blend.png'], nluts=3)
test_dataloader = DataLoader(test_images, batch_size=1, pin_memory=True, num_workers=0)
"""

def __init__(self, inp_img, list_out_img, nluts):
super().__init__()

self.inp_imgs = load_img(inp_img)
self.out_imgs = []
self.error = []
self.shape = self.inp_imgs.shape
self.nluts = nluts

for fout in list_out_img:
lut = load_img(fout)
assert self.inp_imgs.shape == lut.shape
assert (self.inp_imgs.max() <= 1) and (lut.max() <= 1)
self.out_imgs.append(lut)
self.error.append(np_psnr(self.inp_imgs,lut))
del lut

self.references = len(list_out_img)

def __len__(self):
return self.references

def __getitem__(self, idx):
if idx > self.references: raise IndexError

style_vector = np.zeros(self.nluts).astype(np.float32)

if idx < self.nluts:
style_vector[idx] = 1.
else:
style_vector = np.array([0.33, 0.33, 0.33]).astype(np.float32)

# Convert images to pytorch tensors
img = torch.from_numpy(self.inp_imgs)
lut = torch.from_numpy(self.out_imgs[idx])

img = img.reshape((img.shape[0]*img.shape[1],3)) # [hw, 3]
lut = lut.reshape((lut.shape[0]*lut.shape[1],3)) # [hw, 3]

style_vector = torch.from_numpy(style_vector)
style_vector_re = style_vector.repeat(img.shape[0]).view(img.shape[0],self.nluts)

img = torch.cat([img,style_vector_re], dim=-1)

return img, lut, style_vector
Empty file added dataset/.gitkeep
Empty file.
Binary file added media/cnilut.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added media/header.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes
Binary file added models/nilutx3style.pt
Binary file not shown.
369 changes: 369 additions & 0 deletions nilut-multiblend.ipynb

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
imageio==2.30.0
matplotlib==3.7.1
numpy==1.24.3
opencv-python==4.7.0.72
Pillow==9.4.0
scikit-image==0.20.0
scipy==1.10.1
torch==2.0.1
torchaudio==2.0.2
torchvision==0.15.2
95 changes: 95 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
NILUT: Conditional Neural Implicit 3D Lookup Tables for Image Enhancement
Utils for training and ploting
"""

import torch
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gc
import time
from skimage import io, color


# Timing utilities

def start_timer():
global start_time
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.synchronize()
start_time = time.time()

def end_timer_and_print(local_msg):
torch.cuda.synchronize()
end_time = time.time()
print("\n" + local_msg)
print("Total execution time = {:.3f} sec".format(end_time - start_time))
print("Max memory used by tensors = {} bytes".format(torch.cuda.max_memory_allocated()))

def clean_mem():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

# Model

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)


# Load/save and plot images

def load_img (filename, norm=True,):

img = np.array(Image.open(filename))
if norm:
img = img / 255.
img = img.astype(np.float32)
return img

def save_rgb (img, filename):
if np.max(img) <= 1:
img = img * 255

img = img.astype(np.float32)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

cv2.imwrite(filename, img)

def plot_all (images, figsize=(20,10), axis='off'):
fig = plt.figure(figsize=figsize, dpi=80)
nplots = len(images)
for i in range(nplots):
plt.subplot(1,nplots,i+1)
plt.axis(axis)
plt.imshow(images[i])

plt.show()

# Metrics

def np_psnr(y_true, y_pred):
mse = np.mean((y_true - y_pred) ** 2)
if(mse == 0): return np.inf
return 20 * np.log10(1 / np.sqrt(mse))

def pt_psnr (y_true, y_pred):
mse = torch.mean((y_true - y_pred) ** 2)
return 20 * torch.log10(1 / torch.sqrt(mse))

def deltae_dist (y_true, y_pred):
"""
Calcultae DeltaE discance in the LAB color space.
Images must numpy arrays.
"""

gt_lab = color.rgb2lab((y_true*255).astype('uint8'))
out_lab = color.rgb2lab((y_pred*255).astype('uint8'))
l2_lab = ((gt_lab - out_lab)**2).mean()
l2_lab = np.sqrt(((gt_lab - out_lab)**2).sum(axis=-1)).mean()
return l2_lab

0 comments on commit 2253d71

Please sign in to comment.