diff --git a/jax3d/projects/mobilenerf/README.md b/jax3d/projects/mobilenerf/README.md new file mode 100644 index 0000000..cc8d99d --- /dev/null +++ b/jax3d/projects/mobilenerf/README.md @@ -0,0 +1,74 @@ +# MobileNeRF + +This repository contains the source code for the paper MobileNeRF: Exploiting the Polygon Rasterization Pipeline for Efficient Neural Field Rendering on Mobile Architectures. + +This code is created by [Zhiqin Chen](https://czq142857.github.io/) when he was a student researcher at Google. + +*Please note that this is not an officially supported Google product.* + + +## Installation + +You will need 8 v100 GPUs to successfully train the model. + +We recommend using [Anaconda](https://www.anaconda.com/products/individual) to set up the environment. Clone the repo, go to the mobilenerf folder, and run the following commands: + +``` +conda create --name mobilenerf python=3.6.13; conda activate mobilenerf +conda install pip; pip install --upgrade pip +pip install -r requirements.txt +``` + +Please make sure that your jax supports GPU. You might need to re-install jax by following the [jax installation guide](https://github.com/google/jax#installation). + + + +## Data + +Please download the datasets from the [NeRF official Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). +Please download and unzip `nerf_synthetic.zip` and `nerf_llff_data.zip`. + +**(TODO: how to download unbounded scenes from Mip-NeRF 360?)** + +## Training + +The training code is in three .py files, corresponding to the three training stages: 1. continuous opacity, 2. binarization and supersampling, 3. extracting meshes and textures. + +First, modify the parameters in all .py files: +``` +scene_type = "synthetic" +object_name = "chair" +scene_dir = "datasets/nerf_synthetic/"+object_name +``` +*scene_type* can be synthetic, forwardfacing, or real360. *object_name* is the name of the scene to be trained; the available names are listed in the code. *scene_dir* is the folder holding the training data. + +Afterwards, run the three .py files consecutively +``` +python stage1.py +python stage2.py +python stage3.py +``` +The intermediate weights will be saved in folder *weights* and the intermediate outputs (sample testing images, loss curves) will be written to folder *samples*. The output meshes+textures will be saved in folder *obj_phone*. + +Note: the stage-1 training could occasionally fail for unknown reasons (especially on the bicycle scene); switch to a different set of GPUs could solve this. + +Note: For unbounded 360 degree scenes, ```stage3.py``` will only extract meshes of the center unit cube. To extract the entire scene, use ```python stage3_with_box.py```. + +It takes 8 hours to train the first stage, 12-16 hours to train the second, and 1-3 hours to run the third. + +## Running the viewer + +The viewer code is provided in this repo, as three .html files for three types of scenes. + +You can set up a local server on your machine, e.g., +``` +cd folder_containing_the_html +python -m http.server +``` +Then open +``` +localhost:8000/view_synthetic.html?obj=chair +``` +Note that you should put the meshes+textures of the chair model in a folder *chair_phone*. The folder should be in the same directory as the html file. + +Please allow some time for the scenes to load. Use left mouse button to rotate, right mouse button to pan (especially for forward-facing scenes), and scroll wheel to zoom. On phones, Use you fingers to rotate or pan or zoom. Resize the window (or landscape<->portrait your phone) to show the resolution. \ No newline at end of file diff --git a/jax3d/projects/mobilenerf/requirements.txt b/jax3d/projects/mobilenerf/requirements.txt new file mode 100644 index 0000000..ec789f7 --- /dev/null +++ b/jax3d/projects/mobilenerf/requirements.txt @@ -0,0 +1,7 @@ +numpy>=1.16.4 +jax>=0.2.6 +jaxlib>=0.1.69 +flax>=0.2.2 +opencv-python>=4.4.0 +Pillow>=7.2.0 +matplotlib>=3.3.4 diff --git a/jax3d/projects/mobilenerf/stage1.py b/jax3d/projects/mobilenerf/stage1.py new file mode 100644 index 0000000..a9e402c --- /dev/null +++ b/jax3d/projects/mobilenerf/stage1.py @@ -0,0 +1,1696 @@ +# Copyright 2022 The jax3d Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +scene_type = "synthetic" +object_name = "chair" +scene_dir = "datasets/nerf_synthetic/"+object_name + +# synthetic +# chair drums ficus hotdog lego materials mic ship + +# forwardfacing +# fern flower fortress horns leaves orchids room trex + +# real360 +# bicycle flowerbed gardenvase stump treehill +# fulllivingroom kitchencounter kitchenlego officebonsai + +#%% -------------------------------------------------------------------------------- +# ## General imports +#%% +import copy +import gc +import json +import os +import numpy +import cv2 +from tqdm import tqdm +import pickle +import jax +import jax.numpy as np +from jax import random +import flax +import flax.linen as nn +import functools +import math +from typing import Sequence, Callable +import time +import matplotlib.pyplot as plt +from PIL import Image +from multiprocessing.pool import ThreadPool + +print(jax.local_devices()) +if len(jax.local_devices())!=8: + print("ERROR: need 8 v100 GPUs") + 1/0 +weights_dir = "weights" +samples_dir = "samples" +if not os.path.exists(weights_dir): + os.makedirs(weights_dir) +if not os.path.exists(samples_dir): + os.makedirs(samples_dir) +def write_floatpoint_image(name,img): + img = numpy.clip(numpy.array(img)*255,0,255).astype(numpy.uint8) + cv2.imwrite(name,img[:,:,::-1]) +#%% -------------------------------------------------------------------------------- +# ## Load the dataset +#%% +# """ Load dataset """ + +if scene_type=="synthetic": + white_bkgd = True +elif scene_type=="forwardfacing": + white_bkgd = False +elif scene_type=="real360": + white_bkgd = False + + +#https://github.com/google-research/google-research/blob/master/snerg/nerf/datasets.py + + +if scene_type=="synthetic": + + def load_blender(data_dir, split): + with open( + os.path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp: + meta = json.load(fp) + + cams = [] + paths = [] + for i in range(len(meta["frames"])): + frame = meta["frames"][i] + cams.append(np.array(frame["transform_matrix"], dtype=np.float32)) + + fname = os.path.join(data_dir, frame["file_path"] + ".png") + paths.append(fname) + + def image_read_fn(fname): + with open(fname, "rb") as imgin: + image = np.array(Image.open(imgin), dtype=np.float32) / 255. + return image + with ThreadPool() as pool: + images = pool.map(image_read_fn, paths) + pool.close() + pool.join() + + images = np.stack(images, axis=0) + if white_bkgd: + images = (images[..., :3] * images[..., -1:] + (1. - images[..., -1:])) + else: + images = images[..., :3] * images[..., -1:] + + h, w = images.shape[1:3] + camera_angle_x = float(meta["camera_angle_x"]) + focal = .5 * w / np.tan(.5 * camera_angle_x) + + hwf = np.array([h, w, focal], dtype=np.float32) + poses = np.stack(cams, axis=0) + return {'images' : images, 'c2w' : poses, 'hwf' : hwf} + + data = {'train' : load_blender(scene_dir, 'train'), + 'test' : load_blender(scene_dir, 'test')} + + splits = ['train', 'test'] + for s in splits: + print(s) + for k in data[s]: + print(f' {k}: {data[s][k].shape}') + + images, poses, hwf = data['train']['images'], data['train']['c2w'], data['train']['hwf'] + write_floatpoint_image(samples_dir+"/training_image_sample.png",images[0]) + + for i in range(3): + plt.figure() + plt.scatter(poses[:,i,3], poses[:,(i+1)%3,3]) + plt.axis('equal') + plt.savefig(samples_dir+"/training_camera"+str(i)+".png") + +elif scene_type=="forwardfacing" or scene_type=="real360": + + import numpy as np #temporarily use numpy as np, then switch back to jax.numpy + import jax.numpy as jnp + + def _viewmatrix(z, up, pos): + """Construct lookat view matrix.""" + vec2 = _normalize(z) + vec1_avg = up + vec0 = _normalize(np.cross(vec1_avg, vec2)) + vec1 = _normalize(np.cross(vec2, vec0)) + m = np.stack([vec0, vec1, vec2, pos], 1) + return m + + def _normalize(x): + """Normalization helper function.""" + return x / np.linalg.norm(x) + + def _poses_avg(poses): + """Average poses according to the original NeRF code.""" + hwf = poses[0, :3, -1:] + center = poses[:, :3, 3].mean(0) + vec2 = _normalize(poses[:, :3, 2].sum(0)) + up = poses[:, :3, 1].sum(0) + c2w = np.concatenate([_viewmatrix(vec2, up, center), hwf], 1) + return c2w + + def _recenter_poses(poses): + """Recenter poses according to the original NeRF code.""" + poses_ = poses.copy() + bottom = np.reshape([0, 0, 0, 1.], [1, 4]) + c2w = _poses_avg(poses) + c2w = np.concatenate([c2w[:3, :4], bottom], -2) + bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) + poses = np.concatenate([poses[:, :3, :4], bottom], -2) + poses = np.linalg.inv(c2w) @ poses + poses_[:, :3, :4] = poses[:, :3, :4] + poses = poses_ + return poses + + def _transform_poses_pca(poses): + """Transforms poses so principal components lie on XYZ axes.""" + poses_ = poses.copy() + t = poses[:, :3, 3] + t_mean = t.mean(axis=0) + t = t - t_mean + + eigval, eigvec = np.linalg.eig(t.T @ t) + # Sort eigenvectors in order of largest to smallest eigenvalue. + inds = np.argsort(eigval)[::-1] + eigvec = eigvec[:, inds] + rot = eigvec.T + if np.linalg.det(rot) < 0: + rot = np.diag(np.array([1, 1, -1])) @ rot + + transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) + bottom = np.broadcast_to([0, 0, 0, 1.], poses[..., :1, :4].shape) + pad_poses = np.concatenate([poses[..., :3, :4], bottom], axis=-2) + poses_recentered = transform @ pad_poses + poses_recentered = poses_recentered[..., :3, :4] + transform = np.concatenate([transform, np.eye(4)[3:]], axis=0) + + # Flip coordinate system if z component of y-axis is negative + if poses_recentered.mean(axis=0)[2, 1] < 0: + poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered + transform = np.diag(np.array([1, -1, -1, 1])) @ transform + + # Just make sure it's it in the [-1, 1]^3 cube + scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3])) + poses_recentered[:, :3, 3] *= scale_factor + transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform + + poses_[:, :3, :4] = poses_recentered[:, :3, :4] + poses_recentered = poses_ + return poses_recentered, transform + + def load_LLFF(data_dir, split, factor = 4, llffhold = 8): + # Load images. + imgdir_suffix = "" + if factor > 0: + imgdir_suffix = "_{}".format(factor) + imgdir = os.path.join(data_dir, "images" + imgdir_suffix) + if not os.path.exists(imgdir): + raise ValueError("Image folder {} doesn't exist.".format(imgdir)) + imgfiles = [ + os.path.join(imgdir, f) + for f in sorted(os.listdir(imgdir)) + if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") + ] + def image_read_fn(fname): + with open(fname, "rb") as imgin: + image = np.array(Image.open(imgin), dtype=np.float32) / 255. + return image + with ThreadPool() as pool: + images = pool.map(image_read_fn, imgfiles) + pool.close() + pool.join() + images = np.stack(images, axis=-1) + + # Load poses and bds. + with open(os.path.join(data_dir, "poses_bounds.npy"), + "rb") as fp: + poses_arr = np.load(fp) + poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) + bds = poses_arr[:, -2:].transpose([1, 0]) + if poses.shape[-1] != images.shape[-1]: + raise RuntimeError("Mismatch between imgs {} and poses {}".format( + images.shape[-1], poses.shape[-1])) + + # Update poses according to downsampling. + poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1]) + poses[2, 4, :] = poses[2, 4, :] * 1. / factor + + # Correct rotation matrix ordering and move variable dim to axis 0. + poses = np.concatenate( + [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) + poses = np.moveaxis(poses, -1, 0).astype(np.float32) + images = np.moveaxis(images, -1, 0) + bds = np.moveaxis(bds, -1, 0).astype(np.float32) + + + if scene_type=="real360": + # Rotate/scale poses to align ground with xy plane and fit to unit cube. + poses, _ = _transform_poses_pca(poses) + else: + # Rescale according to a default bd factor. + scale = 1. / (bds.min() * .75) + poses[:, :3, 3] *= scale + bds *= scale + # Recenter poses + poses = _recenter_poses(poses) + + # Select the split. + i_test = np.arange(images.shape[0])[::llffhold] + i_train = np.array( + [i for i in np.arange(int(images.shape[0])) if i not in i_test]) + if split == "train": + indices = i_train + else: + indices = i_test + images = images[indices] + poses = poses[indices] + + camtoworlds = poses[:, :3, :4] + focal = poses[0, -1, -1] + h, w = images.shape[1:3] + + hwf = np.array([h, w, focal], dtype=np.float32) + + return {'images' : jnp.array(images), 'c2w' : jnp.array(camtoworlds), 'hwf' : jnp.array(hwf)} + + data = {'train' : load_LLFF(scene_dir, 'train'), + 'test' : load_LLFF(scene_dir, 'test')} + + splits = ['train', 'test'] + for s in splits: + print(s) + for k in data[s]: + print(f' {k}: {data[s][k].shape}') + + images, poses, hwf = data['train']['images'], data['train']['c2w'], data['train']['hwf'] + write_floatpoint_image(samples_dir+"/training_image_sample.png",images[0]) + + for i in range(3): + plt.figure() + plt.scatter(poses[:,i,3], poses[:,(i+1)%3,3]) + plt.axis('equal') + plt.savefig(samples_dir+"/training_camera"+str(i)+".png") + + bg_color = jnp.mean(images) + + import jax.numpy as np +#%% -------------------------------------------------------------------------------- +# ## Helper functions +#%% +adam_kwargs = { + 'beta1': 0.9, + 'beta2': 0.999, + 'eps': 1e-15, +} + +n_device = jax.local_device_count() + +rng = random.PRNGKey(1) + + + +# General math functions. + +def matmul(a, b): + """jnp.matmul defaults to bfloat16, but this helper function doesn't.""" + return np.matmul(a, b, precision=jax.lax.Precision.HIGHEST) + +def normalize(x): + """Normalization helper function.""" + return x / np.linalg.norm(x, axis=-1, keepdims=True) + +def sinusoidal_encoding(position, minimum_frequency_power, + maximum_frequency_power,include_identity = False): + # Compute the sinusoidal encoding components + frequency = 2.0**np.arange(minimum_frequency_power, maximum_frequency_power) + angle = position[..., None, :] * frequency[:, None] + encoding = np.sin(np.stack([angle, angle + 0.5 * np.pi], axis=-2)) + # Flatten encoding dimensions + encoding = encoding.reshape(*position.shape[:-1], -1) + # Add identity component + if include_identity: + encoding = np.concatenate([position, encoding], axis=-1) + return encoding + +# Pose/ray math. + +def generate_rays(pixel_coords, pix2cam, cam2world): + """Generate camera rays from pixel coordinates and poses.""" + homog = np.ones_like(pixel_coords[..., :1]) + pixel_dirs = np.concatenate([pixel_coords + .5, homog], axis=-1)[..., None] + cam_dirs = matmul(pix2cam, pixel_dirs) + ray_dirs = matmul(cam2world[..., :3, :3], cam_dirs)[..., 0] + ray_origins = np.broadcast_to(cam2world[..., :3, 3], ray_dirs.shape) + + #f = 1./pix2cam[0,0] + #w = -2. * f * pix2cam[0,2] + #h = 2. * f * pix2cam[1,2] + + return ray_origins, ray_dirs + +def pix2cam_matrix(height, width, focal): + """Inverse intrinsic matrix for a pinhole camera.""" + return np.array([ + [1./focal, 0, -.5 * width / focal], + [0, -1./focal, .5 * height / focal], + [0, 0, -1.], + ]) + +def camera_ray_batch(cam2world, hwf): + """Generate rays for a pinhole camera with given extrinsic and intrinsic.""" + height, width = int(hwf[0]), int(hwf[1]) + pix2cam = pix2cam_matrix(*hwf) + pixel_coords = np.stack(np.meshgrid(np.arange(width), np.arange(height)), axis=-1) + return generate_rays(pixel_coords, pix2cam, cam2world) + +def random_ray_batch(rng, batch_size, data): + """Generate a random batch of ray data.""" + keys = random.split(rng, 3) + cam_ind = random.randint(keys[0], [batch_size], 0, data['c2w'].shape[0]) + y_ind = random.randint(keys[1], [batch_size], 0, data['images'].shape[1]) + x_ind = random.randint(keys[2], [batch_size], 0, data['images'].shape[2]) + pixel_coords = np.stack([x_ind, y_ind], axis=-1) + pix2cam = pix2cam_matrix(*data['hwf']) + cam2world = data['c2w'][cam_ind, :3, :4] + rays = generate_rays(pixel_coords, pix2cam, cam2world) + pixels = data['images'][cam_ind, y_ind, x_ind] + return rays, pixels + + +# Learning rate helpers. + +def log_lerp(t, v0, v1): + """Interpolate log-linearly from `v0` (t=0) to `v1` (t=1).""" + if v0 <= 0 or v1 <= 0: + raise ValueError(f'Interpolants {v0} and {v1} must be positive.') + lv0 = np.log(v0) + lv1 = np.log(v1) + return np.exp(np.clip(t, 0, 1) * (lv1 - lv0) + lv0) + +def lr_fn(step, max_steps, lr0, lr1, lr_delay_steps=20000, lr_delay_mult=0.1): + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)) + else: + delay_rate = 1. + return delay_rate * log_lerp(step / max_steps, lr0, lr1) + +#%% -------------------------------------------------------------------------------- +# ## Plane parameters and setup +#%% +#scene scales + +if scene_type=="synthetic": + scene_grid_scale = 1.2 + if "hotdog" in scene_dir or "mic" in scene_dir or "ship" in scene_dir: + scene_grid_scale = 1.5 + grid_min = np.array([-1, -1, -1]) * scene_grid_scale + grid_max = np.array([ 1, 1, 1]) * scene_grid_scale + point_grid_size = 128 + + def get_taper_coord(p): + return p + def inverse_taper_coord(p): + return p + +elif scene_type=="forwardfacing": + scene_grid_taper = 1.25 + scene_grid_zstart = 25.0 + scene_grid_zend = 1.0 + scene_grid_scale = 0.7 + grid_min = np.array([-scene_grid_scale, -scene_grid_scale, 0]) + grid_max = np.array([ scene_grid_scale, scene_grid_scale, 1]) + point_grid_size = 128 + + def get_taper_coord(p): + pz = np.maximum(-p[..., 2:3],1e-10) + px = p[..., 0:1]/(pz*scene_grid_taper) + py = p[..., 1:2]/(pz*scene_grid_taper) + pz = (np.log(pz) - np.log(scene_grid_zend))/(np.log(scene_grid_zstart) - np.log(scene_grid_zend)) + return np.concatenate([px,py,pz],axis=-1) + def inverse_taper_coord(p): + pz = np.exp( p[..., 2:3] * \ + (np.log(scene_grid_zstart) - np.log(scene_grid_zend)) + \ + np.log(scene_grid_zend) ) + px = p[..., 0:1]*(pz*scene_grid_taper) + py = p[..., 1:2]*(pz*scene_grid_taper) + pz = -pz + return np.concatenate([px,py,pz],axis=-1) + +elif scene_type=="real360": + scene_grid_zmax = 16.0 + if object_name == "gardenvase": + scene_grid_zmax = 9.0 + grid_min = np.array([-1, -1, -1]) + grid_max = np.array([ 1, 1, 1]) + point_grid_size = 128 + + def get_taper_coord(p): + return p + def inverse_taper_coord(p): + return p + + #approximate solution of e^x = ax+b + #(np.exp( x ) + (x-1)) / x = scene_grid_zmax + #np.exp( x ) - scene_grid_zmax*x + (x-1) = 0 + scene_grid_zcc = -1 + for i in range(10000): + j = numpy.log(scene_grid_zmax)+i/1000.0 + if numpy.exp(j) - scene_grid_zmax*j + (j-1) >0: + scene_grid_zcc = j + break + if scene_grid_zcc<0: + print("ERROR: approximate solution of e^x = ax+b failed") + 1/0 + + + +grid_dtype = np.float32 + +#plane parameter grid +point_grid = np.zeros( + (point_grid_size, point_grid_size, point_grid_size, 3), + dtype=grid_dtype) +acc_grid = np.zeros( + (point_grid_size, point_grid_size, point_grid_size), + dtype=grid_dtype) +point_grid_diff_lr_scale = 16.0/point_grid_size + + + +def get_acc_grid_masks(taper_positions, acc_grid): + grid_positions = (taper_positions - grid_min) * \ + (point_grid_size / (grid_max - grid_min) ) + grid_masks = (grid_positions[..., 0]>=1) & (grid_positions[..., 0]=1) & (grid_positions[..., 1]=1) & (grid_positions[..., 2]=1) & (grid_positions[..., 0]=1) & (grid_positions[..., 1]=1) & (grid_positions[..., 2]=1) & (grid_positions[..., 0]=1) & (grid_positions[..., 1]=1) & (grid_positions[..., 2]=1) & (grid_positions[..., 0]=1) & (grid_positions[..., 1]=1) & (grid_positions[..., 2]=0) & (b>=0) & (c>=0) & np.logical_not(denominator_mask) + return a,b,c,mask + + + + +cell_size_x = (grid_max[0] - grid_min[0])/point_grid_size +half_cell_size_x = cell_size_x/2 +neg_half_cell_size_x = -half_cell_size_x +cell_size_y = (grid_max[1] - grid_min[1])/point_grid_size +half_cell_size_y = cell_size_y/2 +neg_half_cell_size_y = -half_cell_size_y +cell_size_z = (grid_max[2] - grid_min[2])/point_grid_size +half_cell_size_z = cell_size_z/2 +neg_half_cell_size_z = -half_cell_size_z + +def get_inside_cell_mask(P,ooxyz): + P_ = get_taper_coord(P) - ooxyz + return (P_[..., 0]>=neg_half_cell_size_x) \ + & (P_[..., 0]=neg_half_cell_size_y) \ + & (P_[..., 1]=neg_half_cell_size_z) \ + & (P_[..., 2]tx_n,tx_p,tx_n) + ty = np.where(ty_p>ty_n,ty_p,ty_n) + + tx_py = oy + dy * tx + ty_px = ox + dx * ty + t = np.where(np.abs(tx_py) 0: + net = np.concatenate([net, inputs], axis=-1) + + net = dense_layer(self.out_dim)(net) + + return net + +# Set up the MLPs for color and density. +class MLP(nn.Module): + features: Sequence[int] + + @nn.compact + def __call__(self, x): + for feat in self.features[:-1]: + x = nn.relu(nn.Dense(feat)(x)) + x = nn.Dense(self.features[-1])(x) + return x + + +density_model = RadianceField(1) +feature_model = RadianceField(num_bottleneck_features) +color_model = MLP([16,16,3]) + +# These are the variables we will be optimizing during trianing. +model_vars = [point_grid, acc_grid, + density_model.init( + jax.random.PRNGKey(0), + np.zeros([1, 3])), + feature_model.init( + jax.random.PRNGKey(0), + np.zeros([1, 3])), + color_model.init( + jax.random.PRNGKey(0), + np.zeros([1, 3+num_bottleneck_features])), + ] + +#avoid bugs +point_grid = None +acc_grid = None +#%% -------------------------------------------------------------------------------- +# ## Main rendering functions +#%% +def compute_volumetric_rendering_weights_with_alpha(alpha): + density_exp = 1. - alpha + density_exp_shifted = np.concatenate([np.ones_like(density_exp[..., :1]), + density_exp[..., :-1]], axis=-1) + trans = np.cumprod(density_exp_shifted, axis=-1) + weights = alpha * trans + return weights + +def render_rays(rays, vars, keep_num, threshold, wbgcolor, rng): + + #---------- ray-plane intersection points + grid_indices, grid_masks = gridcell_from_rays(rays, vars[1], keep_num, threshold) + + pts, grid_masks, points, fake_t = compute_undc_intersection(vars[0], + grid_indices, grid_masks, rays, keep_num) + + if scene_type=="forwardfacing": + fake_t = compute_t_forwardfacing(pts,grid_masks) + elif scene_type=="real360": + skybox_positions, skybox_masks = compute_box_intersection(rays) + pts = np.concatenate([pts,skybox_positions], axis=-2) + grid_masks = np.concatenate([grid_masks,skybox_masks], axis=-1) + pts, grid_masks, fake_t = sort_and_compute_t_real360(pts,grid_masks) + + # Now use the MLP to compute density and features + mlp_alpha = density_model.apply(vars[-3], pts) + mlp_alpha = jax.nn.sigmoid(mlp_alpha[..., 0]-8) + mlp_alpha = mlp_alpha * grid_masks + + weights = compute_volumetric_rendering_weights_with_alpha(mlp_alpha) + acc = np.sum(weights, axis=-1) + + mlp_alpha_b = mlp_alpha + jax.lax.stop_gradient( + np.clip((mlp_alpha>0.5).astype(mlp_alpha.dtype), 0.00001,0.99999) - mlp_alpha) + weights_b = compute_volumetric_rendering_weights_with_alpha(mlp_alpha_b) + acc_b = np.sum(weights_b, axis=-1) + + # ... as well as view-dependent colors. + dirs = normalize(rays[1]) + dirs = np.broadcast_to(dirs[..., None, :], pts.shape) + + #previous: (features+dirs)->MLP->(RGB) + mlp_features = jax.nn.sigmoid(feature_model.apply(vars[-2], pts)) + features_dirs_enc = np.concatenate([mlp_features, dirs], axis=-1) + colors = jax.nn.sigmoid(color_model.apply(vars[-1], features_dirs_enc)) + + rgb = np.sum(weights[..., None] * colors, axis=-2) + rgb_b = np.sum(weights_b[..., None] * colors, axis=-2) + + # Composite onto the background color. + if white_bkgd: + rgb = rgb + (1. - acc[..., None]) + rgb_b = rgb_b + (1. - acc_b[..., None]) + else: + bgc = random.randint(rng, [1], 0, 2).astype(bg_color.dtype) * wbgcolor + \ + bg_color * (1-wbgcolor) + rgb = rgb + (1. - acc[..., None]) * bgc + rgb_b = rgb_b + (1. - acc_b[..., None]) * bgc + + #get acc_grid_masks to update acc_grid + acc_grid_masks = get_acc_grid_masks(pts, vars[1]) + acc_grid_masks = acc_grid_masks*grid_masks + + return rgb, acc, rgb_b, acc_b, mlp_alpha, weights, points, fake_t, acc_grid_masks +#%% -------------------------------------------------------------------------------- +# ## Set up pmap'd rendering for test time evaluation. +#%% +test_batch_size = 1024*n_device +test_keep_num = point_grid_size*3//4 +test_threshold = 0.1 +test_wbgcolor = 0.0 + + +render_test_p = jax.pmap(lambda rays, vars: render_rays( + rays, vars, test_keep_num, test_threshold, test_wbgcolor, rng), + in_axes=(0, None)) + + +def render_test(rays, vars): + sh = rays[0].shape + rays = [x.reshape((jax.local_device_count(), -1) + sh[1:]) for x in rays] + out = render_test_p(rays, vars) + out = [numpy.reshape(numpy.array(x),sh[:-1]+(-1,)) for x in out] + return out + +def render_loop(rays, vars, chunk): + sh = list(rays[0].shape[:-1]) + rays = [x.reshape([-1, 3]) for x in rays] + l = rays[0].shape[0] + n = jax.local_device_count() + p = ((l - 1) // n + 1) * n - l + rays = [np.pad(x, ((0,p),(0,0))) for x in rays] + outs = [render_test([x[i:i+chunk] for x in rays], vars) + for i in range(0, rays[0].shape[0], chunk)] + outs = [np.reshape( + np.concatenate([z[i] for z in outs])[:l], sh + [-1]) for i in range(4)] + return outs + +# Make sure that everything works, by rendering an image from the test set + +if scene_type=="synthetic": + selected_test_index = 97 + preview_image_height = 800 + +elif scene_type=="forwardfacing": + selected_test_index = 0 + preview_image_height = 756//2 + +elif scene_type=="real360": + selected_test_index = 0 + preview_image_height = 840//2 + +rays = camera_ray_batch( + data['test']['c2w'][selected_test_index], data['test']['hwf']) +gt = data['test']['images'][selected_test_index] +out = render_loop(rays, model_vars, test_batch_size) +rgb = out[0] +acc = out[1] +rgb_b = out[2] +acc_b = out[3] +write_floatpoint_image(samples_dir+"/s1_"+str(0)+"_rgb.png",rgb) +write_floatpoint_image(samples_dir+"/s1_"+str(0)+"_rgb_binarized.png",rgb_b) +write_floatpoint_image(samples_dir+"/s1_"+str(0)+"_gt.png",gt) +write_floatpoint_image(samples_dir+"/s1_"+str(0)+"_acc.png",acc) +write_floatpoint_image(samples_dir+"/s1_"+str(0)+"_acc_binarized.png",acc_b) +#%% -------------------------------------------------------------------------------- +# ## Training loop +#%% + +def lossfun_distortion(x, w): + """Compute iint w_i w_j |x_i - x_j| d_i d_j.""" + # The loss incurred between all pairs of intervals. + dux = np.abs(x[..., :, None] - x[..., None, :]) + losses_cross = np.sum(w * np.sum(w[..., None, :] * dux, axis=-1), axis=-1) + + # The loss incurred within each individual interval with itself. + losses_self = np.sum((w[..., 1:]**2 + w[..., :-1]**2) * \ + (x[..., 1:] - x[..., :-1]), axis=-1) / 6 + + return losses_cross + losses_self + +def compute_TV(acc_grid): + dx = acc_grid[:-1,:,:] - acc_grid[1:,:,:] + dy = acc_grid[:,:-1,:] - acc_grid[:,1:,:] + dz = acc_grid[:,:,:-1] - acc_grid[:,:,1:] + TV = np.mean(np.square(dx))+np.mean(np.square(dy))+np.mean(np.square(dz)) + return TV + +def train_step(state, rng, traindata, lr, wdistortion, wbinary, wbgcolor, batch_size, keep_num, threshold): + key, rng = random.split(rng) + rays, pixels = random_ray_batch( + key, batch_size // n_device, traindata) + + def loss_fn(vars): + rgb_est, _, rgb_est_b, _, mlp_alpha, weights, points, fake_t, acc_grid_masks = render_rays( + rays, vars, keep_num, threshold, wbgcolor, rng) + + loss_color_l2 = np.mean(np.square(rgb_est - pixels)) + #loss_color_l2_b = np.mean(np.square(rgb_est_b - pixels)) + + loss_acc = np.mean(np.maximum(jax.lax.stop_gradient(weights) - acc_grid_masks,0)) + loss_acc += np.mean(np.abs(vars[1])) *1e-5 + loss_acc += compute_TV(vars[1]) *1e-5 + + loss_distortion = np.mean(lossfun_distortion(fake_t, weights)) *wdistortion + + point_loss = np.abs(points) + point_loss_out = point_loss *1000.0 + point_loss_in = point_loss *0.01 + point_mask = point_loss<(grid_max - grid_min)/point_grid_size/2 + point_loss = np.mean(np.where(point_mask, point_loss_in, point_loss_out)) + + return loss_color_l2 + loss_distortion + loss_acc + point_loss, loss_color_l2 + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (total_loss, color_loss_l2), grad = grad_fn(state.target) + total_loss = jax.lax.pmean(total_loss, axis_name='batch') + color_loss_l2 = jax.lax.pmean(color_loss_l2, axis_name='batch') + + grad = jax.lax.pmean(grad, axis_name='batch') + state = state.apply_gradient(grad, learning_rate=lr) + + return state, color_loss_l2 + +train_pstep = jax.pmap(train_step, axis_name='batch', + in_axes=(0, 0, 0, None, None, None, None, None, None, None), + static_broadcasted_argnums = (7,8,)) +traindata_p = flax.jax_utils.replicate(data['train']) +state = flax.optim.Adam(**adam_kwargs).create(model_vars) + +step_init = state.state.step +state = flax.jax_utils.replicate(state) +print(f'starting at {step_init}') + +# Training loop +psnrs = [] +iters = [] +psnrs_test = [] +iters_test = [] +t_total = 0.0 +t_last = 0.0 +i_last = step_init + +training_iters = 200000 +train_iters_cont = 300000 +if scene_type=="real360": + training_iters = 300000 + +print("Training") +for i in tqdm(range(step_init, training_iters + 1)): + t = time.time() + + lr = lr_fn(i,train_iters_cont, 1e-3, 1e-5) + wbgcolor = min(1.0, float(i)/50000) + wbinary = 0.0 + + if scene_type=="synthetic": + wdistortion = 0.0 + elif scene_type=="forwardfacing": + wdistortion = 0.0 if i<10000 else 0.01 + elif scene_type=="real360": + wdistortion = 0.0 if i<10000 else 0.001 + + if i<=50000: + batch_size = test_batch_size//4 + keep_num = test_keep_num*4 + threshold = -100000.0 + elif i<=100000: + batch_size = test_batch_size//2 + keep_num = test_keep_num*2 + threshold = test_threshold + else: + batch_size = test_batch_size + keep_num = test_keep_num + threshold = test_threshold + + rng, key1, key2 = random.split(rng, 3) + key2 = random.split(key2, n_device) + state, color_loss_l2 = train_pstep( + state, key2, traindata_p, + lr, + wdistortion, + wbinary, + wbgcolor, + batch_size, + keep_num, + threshold + ) + + psnrs.append(-10. * np.log10(color_loss_l2[0])) + iters.append(i) + + if i > 0: + t_total += time.time() - t + + # Logging + if (i % 10000 == 0) and i > 0: + gc.collect() + + unreplicated_state = flax.jax_utils.unreplicate(state) + pickle.dump(unreplicated_state, open(weights_dir+"/s1_"+"tmp_state"+str(i)+".pkl", "wb")) + + print('Current iteration %d, elapsed training time: %d min %d sec.' + % (i, t_total // 60, int(t_total) % 60)) + + print('Batch size: %d' % batch_size) + print('Keep num: %d' % keep_num) + t_elapsed = t_total - t_last + i_elapsed = i - i_last + t_last = t_total + i_last = i + print("Speed:") + print(' %0.3f secs per iter.' % (t_elapsed / i_elapsed)) + print(' %0.3f iters per sec.' % (i_elapsed / t_elapsed)) + + vars = unreplicated_state.target + rays = camera_ray_batch( + data['test']['c2w'][selected_test_index], data['test']['hwf']) + gt = data['test']['images'][selected_test_index] + out = render_loop(rays, vars, test_batch_size) + rgb = out[0] + acc = out[1] + rgb_b = out[2] + acc_b = out[3] + psnrs_test.append(-10 * np.log10(np.mean(np.square(rgb - gt)))) + iters_test.append(i) + + print("PSNR:") + print(' Training running average: %0.3f' % np.mean(np.array(psnrs[-200:]))) + print(' Selected test image: %0.3f' % psnrs_test[-1]) + + plt.figure() + plt.title(i) + plt.plot(iters, psnrs) + plt.plot(iters_test, psnrs_test) + p = np.array(psnrs) + plt.ylim(np.min(p) - .5, np.max(p) + .5) + plt.legend() + plt.savefig(samples_dir+"/s1_"+str(i)+"_loss.png") + + write_floatpoint_image(samples_dir+"/s1_"+str(i)+"_rgb.png",rgb) + write_floatpoint_image(samples_dir+"/s1_"+str(i)+"_rgb_binarized.png",rgb_b) + write_floatpoint_image(samples_dir+"/s1_"+str(i)+"_gt.png",gt) + write_floatpoint_image(samples_dir+"/s1_"+str(i)+"_acc.png",acc) + write_floatpoint_image(samples_dir+"/s1_"+str(i)+"_acc_binarized.png",acc_b) + +#%% +#%% -------------------------------------------------------------------------------- +# ## Run test-set evaluation +#%% +gc.collect() + +render_poses = data['test']['c2w'][:len(data['test']['images'])] +frames = [] +framemasks = [] +print("Testing") +for p in tqdm(render_poses): + out = render_loop(camera_ray_batch(p, hwf), vars, test_batch_size) + frames.append(out[0]) + framemasks.append(out[1]) +psnrs_test = [-10 * np.log10(np.mean(np.square(rgb - gt))) for (rgb, gt) in zip(frames, data['test']['images'])] +print("Test set average PSNR: %f" % np.array(psnrs_test).mean()) + +#%% +import jax.numpy as jnp +import jax.scipy as jsp + +#copied from SNeRG +def compute_ssim(img0, + img1, + max_val, + filter_size=11, + filter_sigma=1.5, + k1=0.01, + k2=0.03, + return_map=False): + """Computes SSIM from two images. + This function was modeled after tf.image.ssim, and should produce comparable + output. + Args: + img0: array. An image of size [..., width, height, num_channels]. + img1: array. An image of size [..., width, height, num_channels]. + max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. + filter_size: int >= 1. Window size. + filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. + k1: float > 0. One of the SSIM dampening parameters. + k2: float > 0. One of the SSIM dampening parameters. + return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned + Returns: + Each image's mean SSIM, or a tensor of individual values if `return_map`. + """ + # Construct a 1D Gaussian blur filter. + hw = filter_size // 2 + shift = (2 * hw - filter_size + 1) / 2 + f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma)**2 + filt = jnp.exp(-0.5 * f_i) + filt /= jnp.sum(filt) + + # Blur in x and y (faster than the 2D convolution). + filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid") + filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid") + + # Vmap the blurs to the tensor size, and then compose them. + num_dims = len(img0.shape) + map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1]) + for d in map_axes: + filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d) + filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d) + filt_fn = lambda z: filt_fn1(filt_fn2(z)) + + mu0 = filt_fn(img0) + mu1 = filt_fn(img1) + mu00 = mu0 * mu0 + mu11 = mu1 * mu1 + mu01 = mu0 * mu1 + sigma00 = filt_fn(img0**2) - mu00 + sigma11 = filt_fn(img1**2) - mu11 + sigma01 = filt_fn(img0 * img1) - mu01 + + # Clip the variances and covariances to valid values. + # Variance must be non-negative: + sigma00 = jnp.maximum(0., sigma00) + sigma11 = jnp.maximum(0., sigma11) + sigma01 = jnp.sign(sigma01) * jnp.minimum( + jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01)) + + c1 = (k1 * max_val)**2 + c2 = (k2 * max_val)**2 + numer = (2 * mu01 + c1) * (2 * sigma01 + c2) + denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) + ssim_map = numer / denom + ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims))) + return ssim_map if return_map else ssim + +# Compiling to the CPU because it's faster and more accurate. +ssim_fn = jax.jit( + functools.partial(compute_ssim, max_val=1.), backend="cpu") + +ssim_values = [] +for i in range(len(data['test']['images'])): + ssim = ssim_fn(frames[i], data['test']['images'][i]) + ssim_values.append(float(ssim)) + +print("Test set average SSIM: %f" % np.array(ssim_values).mean()) +#%% +#%% -------------------------------------------------------------------------------- +# ## Save weights +#%% +pickle.dump(vars, open(weights_dir+"/"+"weights_stage1.pkl", "wb")) diff --git a/jax3d/projects/mobilenerf/stage2.py b/jax3d/projects/mobilenerf/stage2.py new file mode 100644 index 0000000..3a59aa4 --- /dev/null +++ b/jax3d/projects/mobilenerf/stage2.py @@ -0,0 +1,2023 @@ +# Copyright 2022 The jax3d Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +scene_type = "synthetic" +object_name = "chair" +scene_dir = "datasets/nerf_synthetic/"+object_name + +# synthetic +# chair drums ficus hotdog lego materials mic ship + +# forwardfacing +# fern flower fortress horns leaves orchids room trex + +# real360 +# bicycle flowerbed gardenvase stump treehill +# fulllivingroom kitchencounter kitchenlego officebonsai + +#%% -------------------------------------------------------------------------------- +# # Binarize - optimize density and color +#%% -------------------------------------------------------------------------------- +# ## General imports +#%% +import copy +import gc +import json +import os +import numpy +import cv2 +from tqdm import tqdm +import pickle +import jax +import jax.numpy as np +from jax import random +import flax +import flax.linen as nn +import functools +import math +from typing import Sequence, Callable +import time +import matplotlib.pyplot as plt +from PIL import Image +from multiprocessing.pool import ThreadPool + +print(jax.local_devices()) +if len(jax.local_devices())!=8: + print("ERROR: need 8 v100 GPUs") + 1/0 +weights_dir = "weights" +samples_dir = "samples" +if not os.path.exists(weights_dir): + os.makedirs(weights_dir) +if not os.path.exists(samples_dir): + os.makedirs(samples_dir) +def write_floatpoint_image(name,img): + img = numpy.clip(numpy.array(img)*255,0,255).astype(numpy.uint8) + cv2.imwrite(name,img[:,:,::-1]) +#%% -------------------------------------------------------------------------------- +# ## Load the dataset. +#%% +# """ Load dataset """ + +if scene_type=="synthetic": + white_bkgd = True +elif scene_type=="forwardfacing": + white_bkgd = False +elif scene_type=="real360": + white_bkgd = False + + +#https://github.com/google-research/google-research/blob/master/snerg/nerf/datasets.py + + +if scene_type=="synthetic": + + def load_blender(data_dir, split): + with open( + os.path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp: + meta = json.load(fp) + + cams = [] + paths = [] + for i in range(len(meta["frames"])): + frame = meta["frames"][i] + cams.append(np.array(frame["transform_matrix"], dtype=np.float32)) + + fname = os.path.join(data_dir, frame["file_path"] + ".png") + paths.append(fname) + + def image_read_fn(fname): + with open(fname, "rb") as imgin: + image = np.array(Image.open(imgin), dtype=np.float32) / 255. + return image + with ThreadPool() as pool: + images = pool.map(image_read_fn, paths) + pool.close() + pool.join() + + images = np.stack(images, axis=0) + if white_bkgd: + images = (images[..., :3] * images[..., -1:] + (1. - images[..., -1:])) + else: + images = images[..., :3] * images[..., -1:] + + h, w = images.shape[1:3] + camera_angle_x = float(meta["camera_angle_x"]) + focal = .5 * w / np.tan(.5 * camera_angle_x) + + hwf = np.array([h, w, focal], dtype=np.float32) + poses = np.stack(cams, axis=0) + return {'images' : images, 'c2w' : poses, 'hwf' : hwf} + + data = {'train' : load_blender(scene_dir, 'train'), + 'test' : load_blender(scene_dir, 'test')} + + splits = ['train', 'test'] + for s in splits: + print(s) + for k in data[s]: + print(f' {k}: {data[s][k].shape}') + + images, poses, hwf = data['train']['images'], data['train']['c2w'], data['train']['hwf'] + write_floatpoint_image(samples_dir+"/training_image_sample.png",images[0]) + + for i in range(3): + plt.figure() + plt.scatter(poses[:,i,3], poses[:,(i+1)%3,3]) + plt.axis('equal') + plt.savefig(samples_dir+"/training_camera"+str(i)+".png") + +elif scene_type=="forwardfacing" or scene_type=="real360": + + import numpy as np #temporarily use numpy as np, then switch back to jax.numpy + import jax.numpy as jnp + + def _viewmatrix(z, up, pos): + """Construct lookat view matrix.""" + vec2 = _normalize(z) + vec1_avg = up + vec0 = _normalize(np.cross(vec1_avg, vec2)) + vec1 = _normalize(np.cross(vec2, vec0)) + m = np.stack([vec0, vec1, vec2, pos], 1) + return m + + def _normalize(x): + """Normalization helper function.""" + return x / np.linalg.norm(x) + + def _poses_avg(poses): + """Average poses according to the original NeRF code.""" + hwf = poses[0, :3, -1:] + center = poses[:, :3, 3].mean(0) + vec2 = _normalize(poses[:, :3, 2].sum(0)) + up = poses[:, :3, 1].sum(0) + c2w = np.concatenate([_viewmatrix(vec2, up, center), hwf], 1) + return c2w + + def _recenter_poses(poses): + """Recenter poses according to the original NeRF code.""" + poses_ = poses.copy() + bottom = np.reshape([0, 0, 0, 1.], [1, 4]) + c2w = _poses_avg(poses) + c2w = np.concatenate([c2w[:3, :4], bottom], -2) + bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) + poses = np.concatenate([poses[:, :3, :4], bottom], -2) + poses = np.linalg.inv(c2w) @ poses + poses_[:, :3, :4] = poses[:, :3, :4] + poses = poses_ + return poses + + def _transform_poses_pca(poses): + """Transforms poses so principal components lie on XYZ axes.""" + poses_ = poses.copy() + t = poses[:, :3, 3] + t_mean = t.mean(axis=0) + t = t - t_mean + + eigval, eigvec = np.linalg.eig(t.T @ t) + # Sort eigenvectors in order of largest to smallest eigenvalue. + inds = np.argsort(eigval)[::-1] + eigvec = eigvec[:, inds] + rot = eigvec.T + if np.linalg.det(rot) < 0: + rot = np.diag(np.array([1, 1, -1])) @ rot + + transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) + bottom = np.broadcast_to([0, 0, 0, 1.], poses[..., :1, :4].shape) + pad_poses = np.concatenate([poses[..., :3, :4], bottom], axis=-2) + poses_recentered = transform @ pad_poses + poses_recentered = poses_recentered[..., :3, :4] + transform = np.concatenate([transform, np.eye(4)[3:]], axis=0) + + # Flip coordinate system if z component of y-axis is negative + if poses_recentered.mean(axis=0)[2, 1] < 0: + poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered + transform = np.diag(np.array([1, -1, -1, 1])) @ transform + + # Just make sure it's it in the [-1, 1]^3 cube + scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3])) + poses_recentered[:, :3, 3] *= scale_factor + transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform + + poses_[:, :3, :4] = poses_recentered[:, :3, :4] + poses_recentered = poses_ + return poses_recentered, transform + + def load_LLFF(data_dir, split, factor = 4, llffhold = 8): + # Load images. + imgdir_suffix = "" + if factor > 0: + imgdir_suffix = "_{}".format(factor) + imgdir = os.path.join(data_dir, "images" + imgdir_suffix) + if not os.path.exists(imgdir): + raise ValueError("Image folder {} doesn't exist.".format(imgdir)) + imgfiles = [ + os.path.join(imgdir, f) + for f in sorted(os.listdir(imgdir)) + if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") + ] + def image_read_fn(fname): + with open(fname, "rb") as imgin: + image = np.array(Image.open(imgin), dtype=np.float32) / 255. + return image + with ThreadPool() as pool: + images = pool.map(image_read_fn, imgfiles) + pool.close() + pool.join() + images = np.stack(images, axis=-1) + + # Load poses and bds. + with open(os.path.join(data_dir, "poses_bounds.npy"), + "rb") as fp: + poses_arr = np.load(fp) + poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) + bds = poses_arr[:, -2:].transpose([1, 0]) + if poses.shape[-1] != images.shape[-1]: + raise RuntimeError("Mismatch between imgs {} and poses {}".format( + images.shape[-1], poses.shape[-1])) + + # Update poses according to downsampling. + poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1]) + poses[2, 4, :] = poses[2, 4, :] * 1. / factor + + # Correct rotation matrix ordering and move variable dim to axis 0. + poses = np.concatenate( + [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) + poses = np.moveaxis(poses, -1, 0).astype(np.float32) + images = np.moveaxis(images, -1, 0) + bds = np.moveaxis(bds, -1, 0).astype(np.float32) + + + if scene_type=="real360": + # Rotate/scale poses to align ground with xy plane and fit to unit cube. + poses, _ = _transform_poses_pca(poses) + else: + # Rescale according to a default bd factor. + scale = 1. / (bds.min() * .75) + poses[:, :3, 3] *= scale + bds *= scale + # Recenter poses + poses = _recenter_poses(poses) + + # Select the split. + i_test = np.arange(images.shape[0])[::llffhold] + i_train = np.array( + [i for i in np.arange(int(images.shape[0])) if i not in i_test]) + if split == "train": + indices = i_train + else: + indices = i_test + images = images[indices] + poses = poses[indices] + + camtoworlds = poses[:, :3, :4] + focal = poses[0, -1, -1] + h, w = images.shape[1:3] + + hwf = np.array([h, w, focal], dtype=np.float32) + + return {'images' : jnp.array(images), 'c2w' : jnp.array(camtoworlds), 'hwf' : jnp.array(hwf)} + + data = {'train' : load_LLFF(scene_dir, 'train'), + 'test' : load_LLFF(scene_dir, 'test')} + + splits = ['train', 'test'] + for s in splits: + print(s) + for k in data[s]: + print(f' {k}: {data[s][k].shape}') + + images, poses, hwf = data['train']['images'], data['train']['c2w'], data['train']['hwf'] + write_floatpoint_image(samples_dir+"/training_image_sample.png",images[0]) + + for i in range(3): + plt.figure() + plt.scatter(poses[:,i,3], poses[:,(i+1)%3,3]) + plt.axis('equal') + plt.savefig(samples_dir+"/training_camera"+str(i)+".png") + + bg_color = jnp.mean(images) + + import jax.numpy as np +#%% -------------------------------------------------------------------------------- +# ## Helper functions +#%% +adam_kwargs = { + 'beta1': 0.9, + 'beta2': 0.999, + 'eps': 1e-15, +} + +n_device = jax.local_device_count() + +rng = random.PRNGKey(1) + + + +# General math functions. + +def matmul(a, b): + """jnp.matmul defaults to bfloat16, but this helper function doesn't.""" + return np.matmul(a, b, precision=jax.lax.Precision.HIGHEST) + +def normalize(x): + """Normalization helper function.""" + return x / np.linalg.norm(x, axis=-1, keepdims=True) + +def sinusoidal_encoding(position, minimum_frequency_power, + maximum_frequency_power,include_identity = False): + # Compute the sinusoidal encoding components + frequency = 2.0**np.arange(minimum_frequency_power, maximum_frequency_power) + angle = position[..., None, :] * frequency[:, None] + encoding = np.sin(np.stack([angle, angle + 0.5 * np.pi], axis=-2)) + # Flatten encoding dimensions + encoding = encoding.reshape(*position.shape[:-1], -1) + # Add identity component + if include_identity: + encoding = np.concatenate([position, encoding], axis=-1) + return encoding + +# Pose/ray math. + +def generate_rays(pixel_coords, pix2cam, cam2world): + """Generate camera rays from pixel coordinates and poses.""" + homog = np.ones_like(pixel_coords[..., :1]) + pixel_dirs = np.concatenate([pixel_coords + .5, homog], axis=-1)[..., None] + cam_dirs = matmul(pix2cam, pixel_dirs) + ray_dirs = matmul(cam2world[..., :3, :3], cam_dirs)[..., 0] + ray_origins = np.broadcast_to(cam2world[..., :3, 3], ray_dirs.shape) + + #f = 1./pix2cam[0,0] + #w = -2. * f * pix2cam[0,2] + #h = 2. * f * pix2cam[1,2] + + return ray_origins, ray_dirs + +def pix2cam_matrix(height, width, focal): + """Inverse intrinsic matrix for a pinhole camera.""" + return np.array([ + [1./focal, 0, -.5 * width / focal], + [0, -1./focal, .5 * height / focal], + [0, 0, -1.], + ]) + +def camera_ray_batch_xxxxx_original(cam2world, hwf): + """Generate rays for a pinhole camera with given extrinsic and intrinsic.""" + height, width = int(hwf[0]), int(hwf[1]) + pix2cam = pix2cam_matrix(*hwf) + pixel_coords = np.stack(np.meshgrid(np.arange(width), np.arange(height)), axis=-1) + return generate_rays(pixel_coords, pix2cam, cam2world) + +def camera_ray_batch(cam2world, hwf): ### antialiasing by supersampling + """Generate rays for a pinhole camera with given extrinsic and intrinsic.""" + height, width = int(hwf[0]), int(hwf[1]) + pix2cam = pix2cam_matrix(*hwf) + x_ind, y_ind = np.meshgrid(np.arange(width), np.arange(height)) + pixel_coords = np.stack([x_ind-0.25, y_ind-0.25, x_ind+0.25, y_ind-0.25, + x_ind-0.25, y_ind+0.25, x_ind+0.25, y_ind+0.25], axis=-1) + pixel_coords = np.reshape(pixel_coords, [height,width,4,2]) + + return generate_rays(pixel_coords, pix2cam, cam2world) + +def random_ray_batch_xxxxx_original(rng, batch_size, data): + """Generate a random batch of ray data.""" + keys = random.split(rng, 3) + cam_ind = random.randint(keys[0], [batch_size], 0, data['c2w'].shape[0]) + y_ind = random.randint(keys[1], [batch_size], 0, data['images'].shape[1]) + x_ind = random.randint(keys[2], [batch_size], 0, data['images'].shape[2]) + pixel_coords = np.stack([x_ind, y_ind], axis=-1) + pix2cam = pix2cam_matrix(*data['hwf']) + cam2world = data['c2w'][cam_ind, :3, :4] + rays = generate_rays(pixel_coords, pix2cam, cam2world) + pixels = data['images'][cam_ind, y_ind, x_ind] + return rays, pixels + +def random_ray_batch(rng, batch_size, data): ### antialiasing by supersampling + """Generate a random batch of ray data.""" + keys = random.split(rng, 3) + cam_ind = random.randint(keys[0], [batch_size], 0, data['c2w'].shape[0]) + y_ind = random.randint(keys[1], [batch_size], 0, data['images'].shape[1]) + y_ind_f = y_ind.astype(np.float32) + x_ind = random.randint(keys[2], [batch_size], 0, data['images'].shape[2]) + x_ind_f = x_ind.astype(np.float32) + pixel_coords = np.stack([x_ind_f-0.25, y_ind_f-0.25, x_ind_f+0.25, y_ind_f-0.25, + x_ind_f-0.25, y_ind_f+0.25, x_ind_f+0.25, y_ind_f+0.25], axis=-1) + pixel_coords = np.reshape(pixel_coords, [batch_size,4,2]) + pix2cam = pix2cam_matrix(*data['hwf']) + cam_ind_x4 = np.tile(cam_ind[..., None], [1,4]) + cam_ind_x4 = np.reshape(cam_ind_x4, [-1]) + cam2world = data['c2w'][cam_ind_x4, :3, :4] + cam2world = np.reshape(cam2world, [batch_size,4,3,4]) + rays = generate_rays(pixel_coords, pix2cam, cam2world) + pixels = data['images'][cam_ind, y_ind, x_ind] + return rays, pixels + + +# Learning rate helpers. + +def log_lerp(t, v0, v1): + """Interpolate log-linearly from `v0` (t=0) to `v1` (t=1).""" + if v0 <= 0 or v1 <= 0: + raise ValueError(f'Interpolants {v0} and {v1} must be positive.') + lv0 = np.log(v0) + lv1 = np.log(v1) + return np.exp(np.clip(t, 0, 1) * (lv1 - lv0) + lv0) + +def lr_fn(step, max_steps, lr0, lr1, lr_delay_steps=20000, lr_delay_mult=0.1): + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)) + else: + delay_rate = 1. + return delay_rate * log_lerp(step / max_steps, lr0, lr1) + +#%% -------------------------------------------------------------------------------- +# ## Plane parameters and setup +#%% +#scene scales + +if scene_type=="synthetic": + scene_grid_scale = 1.2 + if "hotdog" in scene_dir or "mic" in scene_dir or "ship" in scene_dir: + scene_grid_scale = 1.5 + grid_min = np.array([-1, -1, -1]) * scene_grid_scale + grid_max = np.array([ 1, 1, 1]) * scene_grid_scale + point_grid_size = 128 + + def get_taper_coord(p): + return p + def inverse_taper_coord(p): + return p + +elif scene_type=="forwardfacing": + scene_grid_taper = 1.25 + scene_grid_zstart = 25.0 + scene_grid_zend = 1.0 + scene_grid_scale = 0.7 + grid_min = np.array([-scene_grid_scale, -scene_grid_scale, 0]) + grid_max = np.array([ scene_grid_scale, scene_grid_scale, 1]) + point_grid_size = 128 + + def get_taper_coord(p): + pz = np.maximum(-p[..., 2:3],1e-10) + px = p[..., 0:1]/(pz*scene_grid_taper) + py = p[..., 1:2]/(pz*scene_grid_taper) + pz = (np.log(pz) - np.log(scene_grid_zend))/(np.log(scene_grid_zstart) - np.log(scene_grid_zend)) + return np.concatenate([px,py,pz],axis=-1) + def inverse_taper_coord(p): + pz = np.exp( p[..., 2:3] * \ + (np.log(scene_grid_zstart) - np.log(scene_grid_zend)) + \ + np.log(scene_grid_zend) ) + px = p[..., 0:1]*(pz*scene_grid_taper) + py = p[..., 1:2]*(pz*scene_grid_taper) + pz = -pz + return np.concatenate([px,py,pz],axis=-1) + +elif scene_type=="real360": + scene_grid_zmax = 16.0 + if object_name == "gardenvase": + scene_grid_zmax = 9.0 + grid_min = np.array([-1, -1, -1]) + grid_max = np.array([ 1, 1, 1]) + point_grid_size = 128 + + def get_taper_coord(p): + return p + def inverse_taper_coord(p): + return p + + #approximate solution of e^x = ax+b + #(np.exp( x ) + (x-1)) / x = scene_grid_zmax + #np.exp( x ) - scene_grid_zmax*x + (x-1) = 0 + scene_grid_zcc = -1 + for i in range(10000): + j = numpy.log(scene_grid_zmax)+i/1000.0 + if numpy.exp(j) - scene_grid_zmax*j + (j-1) >0: + scene_grid_zcc = j + break + if scene_grid_zcc<0: + print("ERROR: approximate solution of e^x = ax+b failed") + 1/0 + + + +grid_dtype = np.float32 + +#plane parameter grid +point_grid = np.zeros( + (point_grid_size, point_grid_size, point_grid_size, 3), + dtype=grid_dtype) +acc_grid = np.zeros( + (point_grid_size, point_grid_size, point_grid_size), + dtype=grid_dtype) +point_grid_diff_lr_scale = 16.0/point_grid_size + + + +def get_acc_grid_masks(taper_positions, acc_grid): + grid_positions = (taper_positions - grid_min) * \ + (point_grid_size / (grid_max - grid_min) ) + grid_masks = (grid_positions[..., 0]>=1) & (grid_positions[..., 0]=1) & (grid_positions[..., 1]=1) & (grid_positions[..., 2]=1) & (grid_positions[..., 0]=1) & (grid_positions[..., 1]=1) & (grid_positions[..., 2]=1) & (grid_positions[..., 0]=1) & (grid_positions[..., 1]=1) & (grid_positions[..., 2]=1) & (grid_positions[..., 0]=1) & (grid_positions[..., 1]=1) & (grid_positions[..., 2]=0) & (b>=0) & (c>=0) & np.logical_not(denominator_mask) + return a,b,c,mask + + + + +cell_size_x = (grid_max[0] - grid_min[0])/point_grid_size +half_cell_size_x = cell_size_x/2 +neg_half_cell_size_x = -half_cell_size_x +cell_size_y = (grid_max[1] - grid_min[1])/point_grid_size +half_cell_size_y = cell_size_y/2 +neg_half_cell_size_y = -half_cell_size_y +cell_size_z = (grid_max[2] - grid_min[2])/point_grid_size +half_cell_size_z = cell_size_z/2 +neg_half_cell_size_z = -half_cell_size_z + +def get_inside_cell_mask(P,ooxyz): + P_ = get_taper_coord(P) - ooxyz + return (P_[..., 0]>=neg_half_cell_size_x) \ + & (P_[..., 0]=neg_half_cell_size_y) \ + & (P_[..., 1]=neg_half_cell_size_z) \ + & (P_[..., 2]tx_n,tx_p,tx_n) + ty = np.where(ty_p>ty_n,ty_p,ty_n) + + tx_py = oy + dy * tx + ty_px = ox + dx * ty + t = np.where(np.abs(tx_py) 0: + net = np.concatenate([net, inputs], axis=-1) + + net = dense_layer(self.out_dim)(net) + + return net + +# Set up the MLPs for color and density. +class MLP(nn.Module): + features: Sequence[int] + + @nn.compact + def __call__(self, x): + for feat in self.features[:-1]: + x = nn.relu(nn.Dense(feat)(x)) + x = nn.Dense(self.features[-1])(x) + return x + + +density_model = RadianceField(1) +feature_model = RadianceField(num_bottleneck_features) +color_model = MLP([16,16,3]) + +# These are the variables we will be optimizing during trianing. +model_vars = [point_grid, acc_grid, + density_model.init( + jax.random.PRNGKey(0), + np.zeros([1, 3])), + feature_model.init( + jax.random.PRNGKey(0), + np.zeros([1, 3])), + color_model.init( + jax.random.PRNGKey(0), + np.zeros([1, 3+num_bottleneck_features])), + ] + +#avoid bugs +point_grid = None +acc_grid = None +#%% -------------------------------------------------------------------------------- +# ## Load weights +#%% +vars = pickle.load(open(weights_dir+"/"+"weights_stage1.pkl", "rb")) +model_vars = vars +#%% -------------------------------------------------------------------------------- +# ## Main rendering functions +#%% +def compute_volumetric_rendering_weights_with_alpha(alpha): + density_exp = 1. - alpha + density_exp_shifted = np.concatenate([np.ones_like(density_exp[..., :1]), + density_exp[..., :-1]], axis=-1) + trans = np.cumprod(density_exp_shifted, axis=-1) + weights = alpha * trans + return weights + +def render_rays(rays, vars, keep_num, threshold, wbgcolor, rng): ### antialiasing by supersampling + + #---------- ray-plane intersection points + grid_indices, grid_masks = gridcell_from_rays(rays, vars[1], keep_num, threshold) + + pts, grid_masks, points, fake_t = compute_undc_intersection(vars[0], + grid_indices, grid_masks, rays, keep_num) + + if scene_type=="forwardfacing": + fake_t = compute_t_forwardfacing(pts,grid_masks) + elif scene_type=="real360": + skybox_positions, skybox_masks = compute_box_intersection(rays) + pts = np.concatenate([pts,skybox_positions], axis=-2) + grid_masks = np.concatenate([grid_masks,skybox_masks], axis=-1) + pts, grid_masks, fake_t = sort_and_compute_t_real360(pts,grid_masks) + + # Now use the MLP to compute density and features + mlp_alpha = density_model.apply(vars[-3], pts) + mlp_alpha = jax.nn.sigmoid(mlp_alpha[..., 0]-8) + mlp_alpha = mlp_alpha * grid_masks + + weights = compute_volumetric_rendering_weights_with_alpha(mlp_alpha) #[N,4,P] + acc = np.sum(weights, axis=-1) #[N,4] + acc = np.mean(acc, axis=-1) #[N] + + mlp_alpha_b = mlp_alpha + jax.lax.stop_gradient( + np.clip((mlp_alpha>0.5).astype(mlp_alpha.dtype), 0.00001,0.99999) - mlp_alpha) + weights_b = compute_volumetric_rendering_weights_with_alpha(mlp_alpha_b) #[N,4,P] + acc_b = np.sum(weights_b, axis=-1) #[N,4] + acc_b = np.mean(acc_b, axis=-1) #[N] + + #deferred features + mlp_features_ = jax.nn.sigmoid(feature_model.apply(vars[-2], pts)) #[N,4,P,C] + mlp_features = np.sum(weights[..., None] * mlp_features_, axis=-2) #[N,4,C] + mlp_features = np.mean(mlp_features, axis=-2) #[N,C] + mlp_features_b = np.sum(weights_b[..., None] * mlp_features_, axis=-2) #[N,4,C] + mlp_features_b = np.mean(mlp_features_b, axis=-2) #[N,C] + + + # ... as well as view-dependent colors. + dirs = normalize(rays[1]) #[N,4,3] + dirs = np.mean(dirs, axis=-2) #[N,3] + features_dirs_enc = np.concatenate([mlp_features, dirs], axis=-1) #[N,C+3] + features_dirs_enc_b = np.concatenate([mlp_features_b, dirs], axis=-1) #[N,C+3] + rgb = jax.nn.sigmoid(color_model.apply(vars[-1], features_dirs_enc)) + rgb_b = jax.nn.sigmoid(color_model.apply(vars[-1], features_dirs_enc_b)) + + # Composite onto the background color. + if white_bkgd: + rgb = rgb * acc[..., None] + (1. - acc[..., None]) + rgb_b = rgb_b * acc_b[..., None] + (1. - acc_b[..., None]) + else: + bgc = random.randint(rng, [1], 0, 2).astype(bg_color.dtype) * wbgcolor + \ + bg_color * (1-wbgcolor) + rgb = rgb * acc[..., None] + (1. - acc[..., None]) * bgc + rgb_b = rgb_b * acc_b[..., None] + (1. - acc_b[..., None]) * bgc + + #get acc_grid_masks to update acc_grid + acc_grid_masks = get_acc_grid_masks(pts, vars[1]) + acc_grid_masks = acc_grid_masks*grid_masks + + return rgb, acc, rgb_b, acc_b, mlp_alpha, weights, points, fake_t, acc_grid_masks +#%% -------------------------------------------------------------------------------- +# ## Set up pmap'd rendering for test time evaluation. +#%% +test_batch_size = 256*n_device +test_keep_num = point_grid_size*3//4 +test_threshold = 0.1 +test_wbgcolor = 0.0 + + +render_test_p = jax.pmap(lambda rays, vars: render_rays( + rays, vars, test_keep_num, test_threshold, test_wbgcolor, rng), + in_axes=(0, None)) + +import numpy + +def render_test(rays, vars): ### antialiasing by supersampling + sh = rays[0].shape + rays = [x.reshape((jax.local_device_count(), -1) + sh[1:]) for x in rays] + out = render_test_p(rays, vars) + out = [numpy.reshape(numpy.array(out[i]),sh[:-2]+(-1,)) for i in range(4)] + return out + +def render_loop(rays, vars, chunk): ### antialiasing by supersampling + sh = list(rays[0].shape[:-2]) + rays = [x.reshape([-1, 4, 3]) for x in rays] + l = rays[0].shape[0] + n = jax.local_device_count() + p = ((l - 1) // n + 1) * n - l + rays = [np.pad(x, ((0,p),(0,0),(0,0))) for x in rays] + outs = [render_test([x[i:i+chunk] for x in rays], vars) + for i in range(0, rays[0].shape[0], chunk)] + outs = [np.reshape( + np.concatenate([z[i] for z in outs])[:l], sh + [-1]) for i in range(4)] + return outs + +# Make sure that everything works, by rendering an image from the test set + +if scene_type=="synthetic": + selected_test_index = 97 + preview_image_height = 800 + +elif scene_type=="forwardfacing": + selected_test_index = 0 + preview_image_height = 756//2 + +elif scene_type=="real360": + selected_test_index = 0 + preview_image_height = 840//2 + +rays = camera_ray_batch( + data['test']['c2w'][selected_test_index], data['test']['hwf']) +gt = data['test']['images'][selected_test_index] +out = render_loop(rays, model_vars, test_batch_size) +rgb = out[0] +acc = out[1] +rgb_b = out[2] +acc_b = out[3] +write_floatpoint_image(samples_dir+"/s2_0_"+str(0)+"_rgb.png",rgb) +write_floatpoint_image(samples_dir+"/s2_0_"+str(0)+"_rgb_binarized.png",rgb_b) +write_floatpoint_image(samples_dir+"/s2_0_"+str(0)+"_gt.png",gt) +write_floatpoint_image(samples_dir+"/s2_0_"+str(0)+"_acc.png",acc) +write_floatpoint_image(samples_dir+"/s2_0_"+str(0)+"_acc_binarized.png",acc_b) +#%% -------------------------------------------------------------------------------- +# ## Training loop +#%% + +def lossfun_distortion(x, w): + """Compute iint w_i w_j |x_i - x_j| d_i d_j.""" + # The loss incurred between all pairs of intervals. + dux = np.abs(x[..., :, None] - x[..., None, :]) + losses_cross = np.sum(w * np.sum(w[..., None, :] * dux, axis=-1), axis=-1) + + # The loss incurred within each individual interval with itself. + losses_self = np.sum((w[..., 1:]**2 + w[..., :-1]**2) * \ + (x[..., 1:] - x[..., :-1]), axis=-1) / 6 + + return losses_cross + losses_self + +def compute_TV(acc_grid): + dx = acc_grid[:-1,:,:] - acc_grid[1:,:,:] + dy = acc_grid[:,:-1,:] - acc_grid[:,1:,:] + dz = acc_grid[:,:,:-1] - acc_grid[:,:,1:] + TV = np.mean(np.square(dx))+np.mean(np.square(dy))+np.mean(np.square(dz)) + return TV + +def train_step(state, rng, traindata, lr, wdistortion, wbinary, wbgcolor, batch_size, keep_num, threshold): + key, rng = random.split(rng) + rays, pixels = random_ray_batch( + key, batch_size // n_device, traindata) + + def loss_fn(vars): + rgb_est, _, rgb_est_b, _, mlp_alpha, weights, points, fake_t, acc_grid_masks = render_rays( + rays, vars, keep_num, threshold, wbgcolor, rng) + + loss_color_l2_ = np.mean(np.square(rgb_est - pixels)) + loss_color_l2 = loss_color_l2_ * (1-wbinary) + loss_color_l2_b_ = np.mean(np.square(rgb_est_b - pixels)) + loss_color_l2_b = loss_color_l2_b_ * wbinary + + loss_acc = np.mean(np.maximum(jax.lax.stop_gradient(weights) - acc_grid_masks,0)) + loss_acc += np.mean(np.abs(vars[1])) *1e-5 + loss_acc += compute_TV(vars[1]) *1e-5 + + loss_distortion = np.mean(lossfun_distortion(fake_t, weights)) *wdistortion + + point_loss = np.abs(points) + point_loss_out = point_loss *1000.0 + point_loss_in = point_loss *0.01 + point_mask = point_loss<(grid_max - grid_min)/point_grid_size/2 + point_loss = np.mean(np.where(point_mask, point_loss_in, point_loss_out)) + + return loss_color_l2 + loss_color_l2_b + loss_distortion + loss_acc + point_loss, loss_color_l2_b_ + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (total_loss, color_loss_l2), grad = grad_fn(state.target) + total_loss = jax.lax.pmean(total_loss, axis_name='batch') + color_loss_l2 = jax.lax.pmean(color_loss_l2, axis_name='batch') + + grad = jax.lax.pmean(grad, axis_name='batch') + state = state.apply_gradient(grad, learning_rate=lr) + + return state, color_loss_l2 + +train_pstep = jax.pmap(train_step, axis_name='batch', + in_axes=(0, 0, 0, None, None, None, None, None, None, None), + static_broadcasted_argnums = (7,8,)) +traindata_p = flax.jax_utils.replicate(data['train']) +state = flax.optim.Adam(**adam_kwargs).create(model_vars) + +step_init = state.state.step +state = flax.jax_utils.replicate(state) +print(f'starting at {step_init}') + +# Training loop +psnrs = [] +iters = [] +psnrs_test = [] +psnrs_b_test = [] +iters_test = [] +t_total = 0.0 +t_last = 0.0 +i_last = step_init + +training_iters = 400000 +train_psnr_max = 0.0 + + +print("Training") +for i in tqdm(range(step_init, training_iters + 1)): + t = time.time() + + batch_size = test_batch_size + keep_num = test_keep_num + threshold = test_threshold + lr = 1e-5 + wbinary = float(i)/training_iters + wbgcolor = 1.0 + + if scene_type=="synthetic": + wdistortion = 0.0 + elif scene_type=="forwardfacing": + wdistortion = 0.01 + elif scene_type=="real360": + wdistortion = 0.001 + + rng, key1, key2 = random.split(rng, 3) + key2 = random.split(key2, n_device) + state, color_loss_l2 = train_pstep( + state, key2, traindata_p, + lr, + wdistortion, + wbinary, + wbgcolor, + batch_size, + keep_num, + threshold + ) + + psnrs.append(-10. * np.log10(color_loss_l2[0])) + iters.append(i) + + if i > 0: + t_total += time.time() - t + + # Logging + if (i % 10000 == 0) and i > 0: + this_train_psnr = np.mean(np.array(psnrs[-5000:])) + + #stop when iteration>200000 and the training psnr drops + if i>200000 and this_train_psnr<=train_psnr_max-0.001: + unreplicated_state = pickle.load(open(weights_dir+"/s2_0_"+"tmp_state"+str(i-10000)+".pkl", "rb")) + vars = unreplicated_state.target + break + + train_psnr_max = max(this_train_psnr,train_psnr_max) + gc.collect() + + unreplicated_state = flax.jax_utils.unreplicate(state) + pickle.dump(unreplicated_state, open(weights_dir+"/s2_0_"+"tmp_state"+str(i)+".pkl", "wb")) + + print('Current iteration %d, elapsed training time: %d min %d sec.' + % (i, t_total // 60, int(t_total) % 60)) + + print('Batch size: %d' % batch_size) + print('Keep num: %d' % keep_num) + t_elapsed = t_total - t_last + i_elapsed = i - i_last + t_last = t_total + i_last = i + print("Speed:") + print(' %0.3f secs per iter.' % (t_elapsed / i_elapsed)) + print(' %0.3f iters per sec.' % (i_elapsed / t_elapsed)) + + vars = unreplicated_state.target + rays = camera_ray_batch( + data['test']['c2w'][selected_test_index], data['test']['hwf']) + gt = data['test']['images'][selected_test_index] + out = render_loop(rays, vars, test_batch_size) + rgb = out[0] + acc = out[1] + rgb_b = out[2] + acc_b = out[3] + psnrs_test.append(-10 * np.log10(np.mean(np.square(rgb - gt)))) + psnrs_b_test.append(-10 * np.log10(np.mean(np.square(rgb_b - gt)))) + iters_test.append(i) + + print("PSNR:") + print(' Training running average: %0.3f' % this_train_psnr) + print(' Test average: %0.3f' % psnrs_test[-1]) + print(' Test binary average: %0.3f' % psnrs_b_test[-1]) + + plt.figure() + plt.title(i) + plt.plot(iters, psnrs) + plt.plot(iters_test, psnrs_test) + plt.plot(iters_test, psnrs_b_test) + p = np.array(psnrs) + plt.ylim(np.min(p) - .5, np.max(p) + .5) + plt.legend() + plt.savefig(samples_dir+"/s2_0_"+str(i)+"_loss.png") + + write_floatpoint_image(samples_dir+"/s2_0_"+str(i)+"_rgb.png",rgb) + write_floatpoint_image(samples_dir+"/s2_0_"+str(i)+"_rgb_binarized.png",rgb_b) + write_floatpoint_image(samples_dir+"/s2_0_"+str(i)+"_gt.png",gt) + write_floatpoint_image(samples_dir+"/s2_0_"+str(i)+"_acc.png",acc) + write_floatpoint_image(samples_dir+"/s2_0_"+str(i)+"_acc_binarized.png",acc_b) + +#%% +#%% -------------------------------------------------------------------------------- +# ## Run test-set evaluation +#%% +gc.collect() + +render_poses = data['test']['c2w'][:len(data['test']['images'])] +frames = [] +framemasks = [] +print("Testing") +for p in tqdm(render_poses): + out = render_loop(camera_ray_batch(p, hwf), vars, test_batch_size) + frames.append(out[2]) + framemasks.append(out[3]) +psnrs_test = [-10 * np.log10(np.mean(np.square(rgb - gt))) for (rgb, gt) in zip(frames, data['test']['images'])] +print("Test set average PSNR: %f" % np.array(psnrs_test).mean()) + +#%% +import jax.numpy as jnp +import jax.scipy as jsp + +def compute_ssim(img0, + img1, + max_val, + filter_size=11, + filter_sigma=1.5, + k1=0.01, + k2=0.03, + return_map=False): + """Computes SSIM from two images. + This function was modeled after tf.image.ssim, and should produce comparable + output. + Args: + img0: array. An image of size [..., width, height, num_channels]. + img1: array. An image of size [..., width, height, num_channels]. + max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. + filter_size: int >= 1. Window size. + filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. + k1: float > 0. One of the SSIM dampening parameters. + k2: float > 0. One of the SSIM dampening parameters. + return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned + Returns: + Each image's mean SSIM, or a tensor of individual values if `return_map`. + """ + # Construct a 1D Gaussian blur filter. + hw = filter_size // 2 + shift = (2 * hw - filter_size + 1) / 2 + f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma)**2 + filt = jnp.exp(-0.5 * f_i) + filt /= jnp.sum(filt) + + # Blur in x and y (faster than the 2D convolution). + filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid") + filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid") + + # Vmap the blurs to the tensor size, and then compose them. + num_dims = len(img0.shape) + map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1]) + for d in map_axes: + filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d) + filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d) + filt_fn = lambda z: filt_fn1(filt_fn2(z)) + + mu0 = filt_fn(img0) + mu1 = filt_fn(img1) + mu00 = mu0 * mu0 + mu11 = mu1 * mu1 + mu01 = mu0 * mu1 + sigma00 = filt_fn(img0**2) - mu00 + sigma11 = filt_fn(img1**2) - mu11 + sigma01 = filt_fn(img0 * img1) - mu01 + + # Clip the variances and covariances to valid values. + # Variance must be non-negative: + sigma00 = jnp.maximum(0., sigma00) + sigma11 = jnp.maximum(0., sigma11) + sigma01 = jnp.sign(sigma01) * jnp.minimum( + jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01)) + + c1 = (k1 * max_val)**2 + c2 = (k2 * max_val)**2 + numer = (2 * mu01 + c1) * (2 * sigma01 + c2) + denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) + ssim_map = numer / denom + ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims))) + return ssim_map if return_map else ssim + +# Compiling to the CPU because it's faster and more accurate. +ssim_fn = jax.jit( + functools.partial(compute_ssim, max_val=1.), backend="cpu") + +ssim_values = [] +for i in range(len(data['test']['images'])): + ssim = ssim_fn(frames[i], data['test']['images'][i]) + ssim_values.append(float(ssim)) + +print("Test set average SSIM: %f" % np.array(ssim_values).mean()) +#%% +#%% -------------------------------------------------------------------------------- +# ## Save weights +#%% +pickle.dump(vars, open(weights_dir+"/"+"weights_stage2_0.pkl", "wb")) + +#%% -------------------------------------------------------------------------------- +# # Fix density and finetune color +#%% -------------------------------------------------------------------------------- +# ## Load weights +#%% +vars = pickle.load(open(weights_dir+"/"+"weights_stage2_0.pkl", "rb")) +model_vars = vars +#%% -------------------------------------------------------------------------------- +# ## Main rendering functions +#%% +def render_rays(rays, vars, keep_num, threshold, wbgcolor, rng): ### antialiasing by supersampling + + #---------- ray-plane intersection points + grid_indices, grid_masks = gridcell_from_rays(rays, vars[1], keep_num, threshold) + + pts, grid_masks, points, fake_t = compute_undc_intersection(vars[0], + grid_indices, grid_masks, rays, keep_num) + + if scene_type=="forwardfacing": + fake_t = compute_t_forwardfacing(pts,grid_masks) + elif scene_type=="real360": + skybox_positions, skybox_masks = compute_box_intersection(rays) + pts = np.concatenate([pts,skybox_positions], axis=-2) + grid_masks = np.concatenate([grid_masks,skybox_masks], axis=-1) + pts, grid_masks, fake_t = sort_and_compute_t_real360(pts,grid_masks) + + pts = jax.lax.stop_gradient(pts) + + # Now use the MLP to compute density and features + mlp_alpha = density_model.apply(vars[-3], pts) + mlp_alpha = jax.nn.sigmoid(mlp_alpha[..., 0]-8) + mlp_alpha = mlp_alpha * grid_masks + + mlp_alpha_b = (mlp_alpha>0.5).astype(mlp_alpha.dtype) #no gradient to density + weights_b = compute_volumetric_rendering_weights_with_alpha(mlp_alpha_b) #[N,4,P] + acc_b = np.sum(weights_b, axis=-1) #[N,4] + acc_b = np.mean(acc_b, axis=-1) #[N] + + #deferred features + mlp_features_ = jax.nn.sigmoid(feature_model.apply(vars[-2], pts)) #[N,4,P,C] + mlp_features_b = np.sum(weights_b[..., None] * mlp_features_, axis=-2) #[N,4,C] + mlp_features_b = np.mean(mlp_features_b, axis=-2) #[N,C] + + + # ... as well as view-dependent colors. + dirs = normalize(rays[1]) #[N,4,3] + dirs = np.mean(dirs, axis=-2) #[N,3] + features_dirs_enc_b = np.concatenate([mlp_features_b, dirs], axis=-1) #[N,C+3] + rgb_b = jax.nn.sigmoid(color_model.apply(vars[-1], features_dirs_enc_b)) + + # Composite onto the background color. + if white_bkgd: + rgb_b = rgb_b * acc_b[..., None] + (1. - acc_b[..., None]) + else: + bgc = random.randint(rng, [1], 0, 2).astype(bg_color.dtype) * wbgcolor + \ + bg_color * (1-wbgcolor) + rgb_b = rgb_b * acc_b[..., None] + (1. - acc_b[..., None]) * bgc + + return rgb_b, acc_b +#%% -------------------------------------------------------------------------------- +# ## Set up pmap'd rendering for test time evaluation. +#%% +test_batch_size = 256*n_device +test_keep_num = point_grid_size*3//4 +test_threshold = 0.1 +test_wbgcolor = 0.0 + + +render_test_p = jax.pmap(lambda rays, vars: render_rays( + rays, vars, test_keep_num, test_threshold, test_wbgcolor, rng), + in_axes=(0, None)) + +import numpy + +def render_test(rays, vars): ### antialiasing by supersampling + sh = rays[0].shape + rays = [x.reshape((jax.local_device_count(), -1) + sh[1:]) for x in rays] + out = render_test_p(rays, vars) + out = [numpy.reshape(numpy.array(out[i]),sh[:-2]+(-1,)) for i in range(2)] + return out + +def render_loop(rays, vars, chunk): ### antialiasing by supersampling + sh = list(rays[0].shape[:-2]) + rays = [x.reshape([-1, 4, 3]) for x in rays] + l = rays[0].shape[0] + n = jax.local_device_count() + p = ((l - 1) // n + 1) * n - l + rays = [np.pad(x, ((0,p),(0,0),(0,0))) for x in rays] + outs = [render_test([x[i:i+chunk] for x in rays], vars) + for i in range(0, rays[0].shape[0], chunk)] + outs = [np.reshape( + np.concatenate([z[i] for z in outs])[:l], sh + [-1]) for i in range(2)] + return outs + +# Make sure that everything works, by rendering an image from the test set + +if scene_type=="synthetic": + selected_test_index = 97 + preview_image_height = 800 + +elif scene_type=="forwardfacing": + selected_test_index = 0 + preview_image_height = 756//2 + +elif scene_type=="real360": + selected_test_index = 0 + preview_image_height = 840//2 + +rays = camera_ray_batch( + data['test']['c2w'][selected_test_index], data['test']['hwf']) +gt = data['test']['images'][selected_test_index] +out = render_loop(rays, model_vars, test_batch_size) +rgb_b = out[0] +acc_b = out[1] +write_floatpoint_image(samples_dir+"/s2_1_"+str(0)+"_rgb_binarized.png",rgb_b) +write_floatpoint_image(samples_dir+"/s2_1_"+str(0)+"_gt.png",gt) +write_floatpoint_image(samples_dir+"/s2_1_"+str(0)+"_acc_binarized.png",acc_b) +#%% -------------------------------------------------------------------------------- +# ## Training loop +#%% +def train_step(state, rng, traindata, lr, wdistortion, wbinary, wbgcolor, batch_size, keep_num, threshold): + key, rng = random.split(rng) + rays, pixels = random_ray_batch( + key, batch_size // n_device, traindata) + + def loss_fn(vars): + rgb_est_b, _ = render_rays( + rays, vars, keep_num, threshold, wbgcolor, rng) + + loss_color_l2_b = np.mean(np.square(rgb_est_b - pixels)) + + return loss_color_l2_b, loss_color_l2_b + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (total_loss, color_loss_l2), grad = grad_fn(state.target) + total_loss = jax.lax.pmean(total_loss, axis_name='batch') + color_loss_l2 = jax.lax.pmean(color_loss_l2, axis_name='batch') + + grad = jax.lax.pmean(grad, axis_name='batch') + state = state.apply_gradient(grad, learning_rate=lr) + + return state, color_loss_l2 + +train_pstep = jax.pmap(train_step, axis_name='batch', + in_axes=(0, 0, 0, None, None, None, None, None, None, None), + static_broadcasted_argnums = (7,8,)) +traindata_p = flax.jax_utils.replicate(data['train']) +state = flax.optim.Adam(**adam_kwargs).create(model_vars) + +step_init = state.state.step +state = flax.jax_utils.replicate(state) +print(f'starting at {step_init}') + +# Training loop +psnrs = [] +iters = [] +psnrs_test = [] +psnrs_b_test = [] +iters_test = [] +t_total = 0.0 +t_last = 0.0 +i_last = step_init + +training_iters = 100000 + +print("Training") +for i in tqdm(range(step_init, training_iters + 1)): + t = time.time() + + batch_size = test_batch_size + keep_num = test_keep_num + threshold = test_threshold + lr = 1e-5 + wbinary = 0.5 + wbgcolor = 1.0 + + if scene_type=="synthetic": + wdistortion = 0.0 + elif scene_type=="forwardfacing": + wdistortion = 0.01 + elif scene_type=="real360": + wdistortion = 0.001 + + rng, key1, key2 = random.split(rng, 3) + key2 = random.split(key2, n_device) + state, color_loss_l2 = train_pstep( + state, key2, traindata_p, + lr, + wdistortion, + wbinary, + wbgcolor, + batch_size, + keep_num, + threshold + ) + + psnrs.append(-10. * np.log10(color_loss_l2[0])) + iters.append(i) + + if i > 0: + t_total += time.time() - t + + # Logging + if (i % 10000 == 0) and i > 0: + this_train_psnr = np.mean(np.array(psnrs[-5000:])) + gc.collect() + + unreplicated_state = flax.jax_utils.unreplicate(state) + pickle.dump(unreplicated_state, open(weights_dir+"/s2_1_"+"tmp_state"+str(i)+".pkl", "wb")) + + print('Current iteration %d, elapsed training time: %d min %d sec.' + % (i, t_total // 60, int(t_total) % 60)) + + print('Batch size: %d' % batch_size) + print('Keep num: %d' % keep_num) + t_elapsed = t_total - t_last + i_elapsed = i - i_last + t_last = t_total + i_last = i + print("Speed:") + print(' %0.3f secs per iter.' % (t_elapsed / i_elapsed)) + print(' %0.3f iters per sec.' % (i_elapsed / t_elapsed)) + + vars = unreplicated_state.target + rays = camera_ray_batch( + data['test']['c2w'][selected_test_index], data['test']['hwf']) + gt = data['test']['images'][selected_test_index] + out = render_loop(rays, vars, test_batch_size) + rgb_b = out[0] + acc_b = out[1] + psnrs_b_test.append(-10 * np.log10(np.mean(np.square(rgb_b - gt)))) + iters_test.append(i) + + print("PSNR:") + print(' Training running average: %0.3f' % this_train_psnr) + print(' Test binary average: %0.3f' % psnrs_b_test[-1]) + + plt.figure() + plt.title(i) + plt.plot(iters, psnrs) + plt.plot(iters_test, psnrs_b_test) + p = np.array(psnrs) + plt.ylim(np.min(p) - .5, np.max(p) + .5) + plt.legend() + plt.savefig(samples_dir+"/s2_1_"+str(i)+"_loss.png") + + write_floatpoint_image(samples_dir+"/s2_1_"+str(i)+"_rgb_binarized.png",rgb_b) + write_floatpoint_image(samples_dir+"/s2_1_"+str(i)+"_gt.png",gt) + write_floatpoint_image(samples_dir+"/s2_1_"+str(i)+"_acc_binarized.png",acc_b) + +#%% -------------------------------------------------------------------------------- +# ## Run test-set evaluation +#%% +gc.collect() + +render_poses = data['test']['c2w'][:len(data['test']['images'])] +frames = [] +framemasks = [] +print("Testing") +for p in tqdm(render_poses): + out = render_loop(camera_ray_batch(p, hwf), vars, test_batch_size) + frames.append(out[0]) + framemasks.append(out[1]) +psnrs_test = [-10 * np.log10(np.mean(np.square(rgb - gt))) for (rgb, gt) in zip(frames, data['test']['images'])] +print("Test set average PSNR: %f" % np.array(psnrs_test).mean()) + +#%% +ssim_values = [] +for i in range(len(data['test']['images'])): + ssim = ssim_fn(frames[i], data['test']['images'][i]) + ssim_values.append(float(ssim)) + +print("Test set average SSIM: %f" % np.array(ssim_values).mean()) +#%% +#%% -------------------------------------------------------------------------------- +# ## Save weights +#%% +pickle.dump(vars, open(weights_dir+"/"+"weights_stage2_1.pkl", "wb")) diff --git a/jax3d/projects/mobilenerf/stage3.py b/jax3d/projects/mobilenerf/stage3.py new file mode 100644 index 0000000..828cf70 --- /dev/null +++ b/jax3d/projects/mobilenerf/stage3.py @@ -0,0 +1,2599 @@ +# Copyright 2022 The jax3d Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +scene_type = "synthetic" +object_name = "chair" +scene_dir = "datasets/nerf_synthetic/"+object_name + +# synthetic +# chair drums ficus hotdog lego materials mic ship + +# forwardfacing +# fern flower fortress horns leaves orchids room trex + +# real360 +# bicycle flowerbed gardenvase stump treehill +# fulllivingroom kitchencounter kitchenlego officebonsai + +#%% -------------------------------------------------------------------------------- +# ## General imports +#%% +import copy +import gc +import json +import os +import numpy +import cv2 +from tqdm import tqdm +import pickle +import jax +import jax.numpy as np +from jax import random +import flax +import flax.linen as nn +import functools +import math +from typing import Sequence, Callable +import time +import matplotlib.pyplot as plt +from PIL import Image +from multiprocessing.pool import ThreadPool + +print(jax.local_devices()) +if len(jax.local_devices())!=8: + print("ERROR: need 8 v100 GPUs") + 1/0 +weights_dir = "weights" +samples_dir = "samples" +if not os.path.exists(weights_dir): + os.makedirs(weights_dir) +if not os.path.exists(samples_dir): + os.makedirs(samples_dir) +def write_floatpoint_image(name,img): + img = numpy.clip(numpy.array(img)*255,0,255).astype(numpy.uint8) + cv2.imwrite(name,img[:,:,::-1]) +#%% -------------------------------------------------------------------------------- +# ## Load the dataset. +#%% +# """ Load dataset """ + +if scene_type=="synthetic": + white_bkgd = True +elif scene_type=="forwardfacing": + white_bkgd = False +elif scene_type=="real360": + white_bkgd = False + + +#https://github.com/google-research/google-research/blob/master/snerg/nerf/datasets.py + + +if scene_type=="synthetic": + + def load_blender(data_dir, split): + with open( + os.path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp: + meta = json.load(fp) + + cams = [] + paths = [] + for i in range(len(meta["frames"])): + frame = meta["frames"][i] + cams.append(np.array(frame["transform_matrix"], dtype=np.float32)) + + fname = os.path.join(data_dir, frame["file_path"] + ".png") + paths.append(fname) + + def image_read_fn(fname): + with open(fname, "rb") as imgin: + image = np.array(Image.open(imgin), dtype=np.float32) / 255. + return image + with ThreadPool() as pool: + images = pool.map(image_read_fn, paths) + pool.close() + pool.join() + + images = np.stack(images, axis=0) + if white_bkgd: + images = (images[..., :3] * images[..., -1:] + (1. - images[..., -1:])) + else: + images = images[..., :3] * images[..., -1:] + + h, w = images.shape[1:3] + camera_angle_x = float(meta["camera_angle_x"]) + focal = .5 * w / np.tan(.5 * camera_angle_x) + + hwf = np.array([h, w, focal], dtype=np.float32) + poses = np.stack(cams, axis=0) + return {'images' : images, 'c2w' : poses, 'hwf' : hwf} + + data = {'train' : load_blender(scene_dir, 'train'), + 'test' : load_blender(scene_dir, 'test')} + + splits = ['train', 'test'] + for s in splits: + print(s) + for k in data[s]: + print(f' {k}: {data[s][k].shape}') + + images, poses, hwf = data['train']['images'], data['train']['c2w'], data['train']['hwf'] + write_floatpoint_image(samples_dir+"/training_image_sample.png",images[0]) + + for i in range(3): + plt.figure() + plt.scatter(poses[:,i,3], poses[:,(i+1)%3,3]) + plt.axis('equal') + plt.savefig(samples_dir+"/training_camera"+str(i)+".png") + +elif scene_type=="forwardfacing" or scene_type=="real360": + + import numpy as np #temporarily use numpy as np, then switch back to jax.numpy + import jax.numpy as jnp + + def _viewmatrix(z, up, pos): + """Construct lookat view matrix.""" + vec2 = _normalize(z) + vec1_avg = up + vec0 = _normalize(np.cross(vec1_avg, vec2)) + vec1 = _normalize(np.cross(vec2, vec0)) + m = np.stack([vec0, vec1, vec2, pos], 1) + return m + + def _normalize(x): + """Normalization helper function.""" + return x / np.linalg.norm(x) + + def _poses_avg(poses): + """Average poses according to the original NeRF code.""" + hwf = poses[0, :3, -1:] + center = poses[:, :3, 3].mean(0) + vec2 = _normalize(poses[:, :3, 2].sum(0)) + up = poses[:, :3, 1].sum(0) + c2w = np.concatenate([_viewmatrix(vec2, up, center), hwf], 1) + return c2w + + def _recenter_poses(poses): + """Recenter poses according to the original NeRF code.""" + poses_ = poses.copy() + bottom = np.reshape([0, 0, 0, 1.], [1, 4]) + c2w = _poses_avg(poses) + c2w = np.concatenate([c2w[:3, :4], bottom], -2) + bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) + poses = np.concatenate([poses[:, :3, :4], bottom], -2) + poses = np.linalg.inv(c2w) @ poses + poses_[:, :3, :4] = poses[:, :3, :4] + poses = poses_ + return poses + + def _transform_poses_pca(poses): + """Transforms poses so principal components lie on XYZ axes.""" + poses_ = poses.copy() + t = poses[:, :3, 3] + t_mean = t.mean(axis=0) + t = t - t_mean + + eigval, eigvec = np.linalg.eig(t.T @ t) + # Sort eigenvectors in order of largest to smallest eigenvalue. + inds = np.argsort(eigval)[::-1] + eigvec = eigvec[:, inds] + rot = eigvec.T + if np.linalg.det(rot) < 0: + rot = np.diag(np.array([1, 1, -1])) @ rot + + transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) + bottom = np.broadcast_to([0, 0, 0, 1.], poses[..., :1, :4].shape) + pad_poses = np.concatenate([poses[..., :3, :4], bottom], axis=-2) + poses_recentered = transform @ pad_poses + poses_recentered = poses_recentered[..., :3, :4] + transform = np.concatenate([transform, np.eye(4)[3:]], axis=0) + + # Flip coordinate system if z component of y-axis is negative + if poses_recentered.mean(axis=0)[2, 1] < 0: + poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered + transform = np.diag(np.array([1, -1, -1, 1])) @ transform + + # Just make sure it's it in the [-1, 1]^3 cube + scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3])) + poses_recentered[:, :3, 3] *= scale_factor + transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform + + poses_[:, :3, :4] = poses_recentered[:, :3, :4] + poses_recentered = poses_ + return poses_recentered, transform + + def load_LLFF(data_dir, split, factor = 4, llffhold = 8): + # Load images. + imgdir_suffix = "" + if factor > 0: + imgdir_suffix = "_{}".format(factor) + imgdir = os.path.join(data_dir, "images" + imgdir_suffix) + if not os.path.exists(imgdir): + raise ValueError("Image folder {} doesn't exist.".format(imgdir)) + imgfiles = [ + os.path.join(imgdir, f) + for f in sorted(os.listdir(imgdir)) + if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") + ] + def image_read_fn(fname): + with open(fname, "rb") as imgin: + image = np.array(Image.open(imgin), dtype=np.float32) / 255. + return image + with ThreadPool() as pool: + images = pool.map(image_read_fn, imgfiles) + pool.close() + pool.join() + images = np.stack(images, axis=-1) + + # Load poses and bds. + with open(os.path.join(data_dir, "poses_bounds.npy"), + "rb") as fp: + poses_arr = np.load(fp) + poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) + bds = poses_arr[:, -2:].transpose([1, 0]) + if poses.shape[-1] != images.shape[-1]: + raise RuntimeError("Mismatch between imgs {} and poses {}".format( + images.shape[-1], poses.shape[-1])) + + # Update poses according to downsampling. + poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1]) + poses[2, 4, :] = poses[2, 4, :] * 1. / factor + + # Correct rotation matrix ordering and move variable dim to axis 0. + poses = np.concatenate( + [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) + poses = np.moveaxis(poses, -1, 0).astype(np.float32) + images = np.moveaxis(images, -1, 0) + bds = np.moveaxis(bds, -1, 0).astype(np.float32) + + + if scene_type=="real360": + # Rotate/scale poses to align ground with xy plane and fit to unit cube. + poses, _ = _transform_poses_pca(poses) + else: + # Rescale according to a default bd factor. + scale = 1. / (bds.min() * .75) + poses[:, :3, 3] *= scale + bds *= scale + # Recenter poses + poses = _recenter_poses(poses) + + # Select the split. + i_test = np.arange(images.shape[0])[::llffhold] + i_train = np.array( + [i for i in np.arange(int(images.shape[0])) if i not in i_test]) + if split == "train": + indices = i_train + else: + indices = i_test + images = images[indices] + poses = poses[indices] + + camtoworlds = poses[:, :3, :4] + focal = poses[0, -1, -1] + h, w = images.shape[1:3] + + hwf = np.array([h, w, focal], dtype=np.float32) + + return {'images' : jnp.array(images), 'c2w' : jnp.array(camtoworlds), 'hwf' : jnp.array(hwf)} + + data = {'train' : load_LLFF(scene_dir, 'train'), + 'test' : load_LLFF(scene_dir, 'test')} + + splits = ['train', 'test'] + for s in splits: + print(s) + for k in data[s]: + print(f' {k}: {data[s][k].shape}') + + images, poses, hwf = data['train']['images'], data['train']['c2w'], data['train']['hwf'] + write_floatpoint_image(samples_dir+"/training_image_sample.png",images[0]) + + for i in range(3): + plt.figure() + plt.scatter(poses[:,i,3], poses[:,(i+1)%3,3]) + plt.axis('equal') + plt.savefig(samples_dir+"/training_camera"+str(i)+".png") + + bg_color = jnp.mean(images) + + import jax.numpy as np +#%% -------------------------------------------------------------------------------- +# ## Helper functions +#%% +adam_kwargs = { + 'beta1': 0.9, + 'beta2': 0.999, + 'eps': 1e-15, +} + +n_device = jax.local_device_count() + +rng = random.PRNGKey(1) + + + +# General math functions. + +def matmul(a, b): + """jnp.matmul defaults to bfloat16, but this helper function doesn't.""" + return np.matmul(a, b, precision=jax.lax.Precision.HIGHEST) + +def normalize(x): + """Normalization helper function.""" + return x / np.linalg.norm(x, axis=-1, keepdims=True) + +def sinusoidal_encoding(position, minimum_frequency_power, + maximum_frequency_power,include_identity = False): + # Compute the sinusoidal encoding components + frequency = 2.0**np.arange(minimum_frequency_power, maximum_frequency_power) + angle = position[..., None, :] * frequency[:, None] + encoding = np.sin(np.stack([angle, angle + 0.5 * np.pi], axis=-2)) + # Flatten encoding dimensions + encoding = encoding.reshape(*position.shape[:-1], -1) + # Add identity component + if include_identity: + encoding = np.concatenate([position, encoding], axis=-1) + return encoding + +# Pose/ray math. + +def generate_rays(pixel_coords, pix2cam, cam2world): + """Generate camera rays from pixel coordinates and poses.""" + homog = np.ones_like(pixel_coords[..., :1]) + pixel_dirs = np.concatenate([pixel_coords + .5, homog], axis=-1)[..., None] + cam_dirs = matmul(pix2cam, pixel_dirs) + ray_dirs = matmul(cam2world[..., :3, :3], cam_dirs)[..., 0] + ray_origins = np.broadcast_to(cam2world[..., :3, 3], ray_dirs.shape) + + #f = 1./pix2cam[0,0] + #w = -2. * f * pix2cam[0,2] + #h = 2. * f * pix2cam[1,2] + + return ray_origins, ray_dirs + +def pix2cam_matrix(height, width, focal): + """Inverse intrinsic matrix for a pinhole camera.""" + return np.array([ + [1./focal, 0, -.5 * width / focal], + [0, -1./focal, .5 * height / focal], + [0, 0, -1.], + ]) + +def camera_ray_batch_xxxxx_original(cam2world, hwf): + """Generate rays for a pinhole camera with given extrinsic and intrinsic.""" + height, width = int(hwf[0]), int(hwf[1]) + pix2cam = pix2cam_matrix(*hwf) + pixel_coords = np.stack(np.meshgrid(np.arange(width), np.arange(height)), axis=-1) + return generate_rays(pixel_coords, pix2cam, cam2world) + +def camera_ray_batch(cam2world, hwf): ### antialiasing by supersampling + """Generate rays for a pinhole camera with given extrinsic and intrinsic.""" + height, width = int(hwf[0]), int(hwf[1]) + pix2cam = pix2cam_matrix(*hwf) + x_ind, y_ind = np.meshgrid(np.arange(width), np.arange(height)) + pixel_coords = np.stack([x_ind-0.25, y_ind-0.25, x_ind+0.25, y_ind-0.25, + x_ind-0.25, y_ind+0.25, x_ind+0.25, y_ind+0.25], axis=-1) + pixel_coords = np.reshape(pixel_coords, [height,width,4,2]) + + return generate_rays(pixel_coords, pix2cam, cam2world) + +def random_ray_batch_xxxxx_original(rng, batch_size, data): + """Generate a random batch of ray data.""" + keys = random.split(rng, 3) + cam_ind = random.randint(keys[0], [batch_size], 0, data['c2w'].shape[0]) + y_ind = random.randint(keys[1], [batch_size], 0, data['images'].shape[1]) + x_ind = random.randint(keys[2], [batch_size], 0, data['images'].shape[2]) + pixel_coords = np.stack([x_ind, y_ind], axis=-1) + pix2cam = pix2cam_matrix(*data['hwf']) + cam2world = data['c2w'][cam_ind, :3, :4] + rays = generate_rays(pixel_coords, pix2cam, cam2world) + pixels = data['images'][cam_ind, y_ind, x_ind] + return rays, pixels + +def random_ray_batch(rng, batch_size, data): ### antialiasing by supersampling + """Generate a random batch of ray data.""" + keys = random.split(rng, 3) + cam_ind = random.randint(keys[0], [batch_size], 0, data['c2w'].shape[0]) + y_ind = random.randint(keys[1], [batch_size], 0, data['images'].shape[1]) + y_ind_f = y_ind.astype(np.float32) + x_ind = random.randint(keys[2], [batch_size], 0, data['images'].shape[2]) + x_ind_f = x_ind.astype(np.float32) + pixel_coords = np.stack([x_ind_f-0.25, y_ind_f-0.25, x_ind_f+0.25, y_ind_f-0.25, + x_ind_f-0.25, y_ind_f+0.25, x_ind_f+0.25, y_ind_f+0.25], axis=-1) + pixel_coords = np.reshape(pixel_coords, [batch_size,4,2]) + pix2cam = pix2cam_matrix(*data['hwf']) + cam_ind_x4 = np.tile(cam_ind[..., None], [1,4]) + cam_ind_x4 = np.reshape(cam_ind_x4, [-1]) + cam2world = data['c2w'][cam_ind_x4, :3, :4] + cam2world = np.reshape(cam2world, [batch_size,4,3,4]) + rays = generate_rays(pixel_coords, pix2cam, cam2world) + pixels = data['images'][cam_ind, y_ind, x_ind] + return rays, pixels + + +# Learning rate helpers. + +def log_lerp(t, v0, v1): + """Interpolate log-linearly from `v0` (t=0) to `v1` (t=1).""" + if v0 <= 0 or v1 <= 0: + raise ValueError(f'Interpolants {v0} and {v1} must be positive.') + lv0 = np.log(v0) + lv1 = np.log(v1) + return np.exp(np.clip(t, 0, 1) * (lv1 - lv0) + lv0) + +def lr_fn(step, max_steps, lr0, lr1, lr_delay_steps=20000, lr_delay_mult=0.1): + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)) + else: + delay_rate = 1. + return delay_rate * log_lerp(step / max_steps, lr0, lr1) + +#%% -------------------------------------------------------------------------------- +# ## Plane parameters and setup +#%% +#scene scales + +if scene_type=="synthetic": + scene_grid_scale = 1.2 + if "hotdog" in scene_dir or "mic" in scene_dir or "ship" in scene_dir: + scene_grid_scale = 1.5 + grid_min = np.array([-1, -1, -1]) * scene_grid_scale + grid_max = np.array([ 1, 1, 1]) * scene_grid_scale + point_grid_size = 128 + + def get_taper_coord(p): + return p + def inverse_taper_coord(p): + return p + +elif scene_type=="forwardfacing": + scene_grid_taper = 1.25 + scene_grid_zstart = 25.0 + scene_grid_zend = 1.0 + scene_grid_scale = 0.7 + grid_min = np.array([-scene_grid_scale, -scene_grid_scale, 0]) + grid_max = np.array([ scene_grid_scale, scene_grid_scale, 1]) + point_grid_size = 128 + + def get_taper_coord(p): + pz = np.maximum(-p[..., 2:3],1e-10) + px = p[..., 0:1]/(pz*scene_grid_taper) + py = p[..., 1:2]/(pz*scene_grid_taper) + pz = (np.log(pz) - np.log(scene_grid_zend))/(np.log(scene_grid_zstart) - np.log(scene_grid_zend)) + return np.concatenate([px,py,pz],axis=-1) + def inverse_taper_coord(p): + pz = np.exp( p[..., 2:3] * \ + (np.log(scene_grid_zstart) - np.log(scene_grid_zend)) + \ + np.log(scene_grid_zend) ) + px = p[..., 0:1]*(pz*scene_grid_taper) + py = p[..., 1:2]*(pz*scene_grid_taper) + pz = -pz + return np.concatenate([px,py,pz],axis=-1) + +elif scene_type=="real360": + scene_grid_zmax = 16.0 + if object_name == "gardenvase": + scene_grid_zmax = 9.0 + grid_min = np.array([-1, -1, -1]) + grid_max = np.array([ 1, 1, 1]) + point_grid_size = 128 + + def get_taper_coord(p): + return p + def inverse_taper_coord(p): + return p + + #approximate solution of e^x = ax+b + #(np.exp( x ) + (x-1)) / x = scene_grid_zmax + #np.exp( x ) - scene_grid_zmax*x + (x-1) = 0 + scene_grid_zcc = -1 + for i in range(10000): + j = numpy.log(scene_grid_zmax)+i/1000.0 + if numpy.exp(j) - scene_grid_zmax*j + (j-1) >0: + scene_grid_zcc = j + break + if scene_grid_zcc<0: + print("ERROR: approximate solution of e^x = ax+b failed") + 1/0 + + + +grid_dtype = np.float32 + +#plane parameter grid +point_grid = np.zeros( + (point_grid_size, point_grid_size, point_grid_size, 3), + dtype=grid_dtype) +acc_grid = np.zeros( + (point_grid_size, point_grid_size, point_grid_size), + dtype=grid_dtype) +point_grid_diff_lr_scale = 16.0/point_grid_size + + + +def get_acc_grid_masks(taper_positions, acc_grid): + grid_positions = (taper_positions - grid_min) * \ + (point_grid_size / (grid_max - grid_min) ) + grid_masks = (grid_positions[..., 0]>=1) & (grid_positions[..., 0]=1) & (grid_positions[..., 1]=1) & (grid_positions[..., 2] 0: + net = np.concatenate([net, inputs], axis=-1) + + net = dense_layer(self.out_dim)(net) + + return net + +# Set up the MLPs for color and density. +class MLP(nn.Module): + features: Sequence[int] + + @nn.compact + def __call__(self, x): + for feat in self.features[:-1]: + x = nn.relu(nn.Dense(feat)(x)) + x = nn.Dense(self.features[-1])(x) + return x + + +density_model = RadianceField(1) +feature_model = RadianceField(num_bottleneck_features) +color_model = MLP([16,16,3]) + +# These are the variables we will be optimizing during trianing. +model_vars = [point_grid, acc_grid, + density_model.init( + jax.random.PRNGKey(0), + np.zeros([1, 3])), + feature_model.init( + jax.random.PRNGKey(0), + np.zeros([1, 3])), + color_model.init( + jax.random.PRNGKey(0), + np.zeros([1, 3+num_bottleneck_features])), + ] + +#avoid bugs +point_grid = None +acc_grid = None +#%% -------------------------------------------------------------------------------- +# ## Load weights +#%% +vars = pickle.load(open(weights_dir+"/"+"weights_stage2_1.pkl", "rb")) +model_vars = vars +#%% -------------------------------------------------------------------------------- +# ## Get mesh +#%% + +#%% +#extract mesh vertices +layer_num = point_grid_size + +v_grid = numpy.zeros([layer_num+1,layer_num+1,layer_num+1,3], numpy.float32) +v_grid[:-1,:-1,:-1] = numpy.array(vars[0])*point_grid_diff_lr_scale +#%% +#get UV coordinates + +if scene_type=="synthetic": + texture_size = 1024*2 + batch_num = 8*8*8 +elif scene_type=="forwardfacing": + texture_size = 1024*2 + batch_num = 8*8*8 +elif scene_type=="real360": + texture_size = 1024*2 + batch_num = 8*8*8 + +test_threshold = 0.1 + + +out_feat_num = num_bottleneck_features//4 + +quad_size = texture_size//layer_num +assert quad_size*layer_num == texture_size +#pre-compute weights for each quad +# 0 - 1 x +# | \ | +# 2 - 3 +# y +quad_weights = numpy.zeros([quad_size,quad_size,4],numpy.float32) +for i in range(quad_size): + for j in range(quad_size): + x = (i)/quad_size + y = (j)/quad_size + if x>y: + quad_weights[i,j,0] = 1-x + quad_weights[i,j,1] = x-y + quad_weights[i,j,2] = 0 + quad_weights[i,j,3] = y + else: + quad_weights[i,j,0] = 1-y + quad_weights[i,j,1] = 0 + quad_weights[i,j,2] = y-x + quad_weights[i,j,3] = x +quad_weights = numpy.reshape(quad_weights,[quad_size*quad_size,4]) +quad_weights = numpy.transpose(quad_weights, (1,0)) #[4,quad_size*quad_size] + +grid_max_numpy = numpy.array(grid_max,numpy.float32) +grid_min_numpy = numpy.array(grid_min,numpy.float32) + +i_grid = numpy.zeros([layer_num,layer_num,layer_num],numpy.int32) +j_grid = numpy.zeros([layer_num,layer_num,layer_num],numpy.int32) +k_grid = numpy.zeros([layer_num,layer_num,layer_num],numpy.int32) + +i_grid[:,:,:] = numpy.reshape(numpy.arange(layer_num),[-1,1,1]) +j_grid[:,:,:] = numpy.reshape(numpy.arange(layer_num),[1,-1,1]) +k_grid[:,:,:] = numpy.reshape(numpy.arange(layer_num),[1,1,-1]) + + + +def get_density_color(pts, vars): + #redefine net + + acc_grid_masks = get_acc_grid_masks(pts, vars[1]) + + # Now use the MLP to compute density and features + mlp_alpha = density_model.apply(vars[-3], pts) + mlp_alpha = jax.nn.sigmoid(mlp_alpha[..., 0]-8) + mlp_alpha = mlp_alpha * (acc_grid_masks>=test_threshold) + mlp_alpha = (mlp_alpha>0.5).astype(np.uint8) + + #previous: (features+dirs)->MLP->(RGB) + mlp_features = jax.nn.sigmoid(feature_model.apply(vars[-2], pts)) + #discretize + mlp_features_ = np.round(mlp_features*255).astype(np.uint8) + mlp_features_0 = np.clip(mlp_features_[...,0:1],1,255)*mlp_alpha[..., None] + mlp_features_1 = mlp_features_[...,1:]*mlp_alpha[..., None] + mlp_features_ = np.concatenate([mlp_features_0,mlp_features_1],axis=-1) + + return mlp_features_ + +get_density_color_p = jax.pmap(lambda pts, vars: get_density_color(pts,vars), + in_axes=(0, None)) + + + +def get_feature_png(feat): + h,w,c = feat.shape + #deal with opencv BGR->RGB + if c%4!=0: + print("ERROR: c%4!=0") + 1/0 + out = [] + for i in range(out_feat_num): + ff = numpy.zeros([h,w,4],numpy.uint8) + ff[...,0] = feat[..., i*4+2] + ff[...,1] = feat[..., i*4+1] + ff[...,2] = feat[..., i*4+0] + ff[...,3] = feat[..., i*4+3] + out.append(ff) + return out + + + + + +##### z planes + +x,y,z = j_grid,k_grid,i_grid +p0 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = j_grid+1,k_grid,i_grid +p1 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = j_grid,k_grid+1,i_grid +p2 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = j_grid+1,k_grid+1,i_grid +p3 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +p0123 = numpy.stack([p0,p1,p2,p3],axis=-1) #[M,N,K,3,4] +p0123 = p0123 @ quad_weights #[M,N,K,3,quad_size*quad_size] +p0123 = numpy.reshape(p0123, [layer_num,layer_num,layer_num,3,quad_size,quad_size]) #[M,N,K,3,quad_size,quad_size] +p0123 = numpy.transpose(p0123, (0,1,4,2,5,3)) #[M,N,quad_size,K,quad_size,3] +#positions_z = numpy.reshape(numpy.ascontiguousarray(p0123), [layer_num,layer_num*quad_size,layer_num*quad_size,3]) +positions_z = numpy.reshape(numpy.ascontiguousarray(p0123), [-1,3]) + +p0 = None +p1 = None +p2 = None +p3 = None +p0123 = None + +total_len = len(positions_z) +batch_len = total_len//batch_num +coarse_feature_z = numpy.zeros([total_len,num_bottleneck_features],numpy.uint8) +for i in range(batch_num): + t0 = numpy.reshape(positions_z[i*batch_len:(i+1)*batch_len], [n_device,-1,3]) + t0 = get_density_color_p(t0,vars) + coarse_feature_z[i*batch_len:(i+1)*batch_len] = numpy.reshape(t0,[-1,num_bottleneck_features]) +coarse_feature_z = numpy.reshape(coarse_feature_z,[layer_num,texture_size,texture_size,num_bottleneck_features]) +coarse_feature_z[:,-quad_size:,:] = 0 +coarse_feature_z[:,:,-quad_size:] = 0 + +positions_z = None + +buffer_z = [] +for i in range(layer_num): + if not numpy.any(coarse_feature_z[i,:,:,0]>0): + buffer_z.append(None) + continue + feats = get_feature_png(coarse_feature_z[i]) + buffer_z.append(feats) + +coarse_feature_z = None + + + +##### x planes + +x,y,z = i_grid,j_grid,k_grid +p0 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = i_grid,j_grid+1,k_grid +p1 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = i_grid,j_grid,k_grid+1 +p2 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = i_grid,j_grid+1,k_grid+1 +p3 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +p0123 = numpy.stack([p0,p1,p2,p3],axis=-1) #[M,N,K,3,4] +p0123 = p0123 @ quad_weights #[M,N,K,3,quad_size*quad_size] +p0123 = numpy.reshape(p0123, [layer_num,layer_num,layer_num,3,quad_size,quad_size]) #[M,N,K,3,quad_size,quad_size] +p0123 = numpy.transpose(p0123, (0,1,4,2,5,3)) #[M,N,quad_size,K,quad_size,3] +#positions_x = numpy.reshape(numpy.ascontiguousarray(p0123), [layer_num,layer_num*quad_size,layer_num*quad_size,3]) +positions_x = numpy.reshape(numpy.ascontiguousarray(p0123), [-1,3]) + +p0 = None +p1 = None +p2 = None +p3 = None +p0123 = None + +total_len = len(positions_x) +batch_len = total_len//batch_num +coarse_feature_x = numpy.zeros([total_len,num_bottleneck_features],numpy.uint8) +for i in range(batch_num): + t0 = numpy.reshape(positions_x[i*batch_len:(i+1)*batch_len], [n_device,-1,3]) + t0 = get_density_color_p(t0,vars) + coarse_feature_x[i*batch_len:(i+1)*batch_len] = numpy.reshape(t0,[-1,num_bottleneck_features]) +coarse_feature_x = numpy.reshape(coarse_feature_x,[layer_num,texture_size,texture_size,num_bottleneck_features]) +coarse_feature_x[:,-quad_size:,:] = 0 +coarse_feature_x[:,:,-quad_size:] = 0 + +positions_x = None + +buffer_x = [] +for i in range(layer_num): + if not numpy.any(coarse_feature_x[i,:,:,0]>0): + buffer_x.append(None) + continue + feats = get_feature_png(coarse_feature_x[i]) + buffer_x.append(feats) + +coarse_feature_x = None + + + +##### y planes + +x,y,z = j_grid,i_grid,k_grid +p0 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = j_grid+1,i_grid,k_grid +p1 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = j_grid,i_grid,k_grid+1 +p2 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = j_grid+1,i_grid,k_grid+1 +p3 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +p0123 = numpy.stack([p0,p1,p2,p3],axis=-1) #[M,N,K,3,4] +p0123 = p0123 @ quad_weights #[M,N,K,3,quad_size*quad_size] +p0123 = numpy.reshape(p0123, [layer_num,layer_num,layer_num,3,quad_size,quad_size]) #[M,N,K,3,quad_size,quad_size] +p0123 = numpy.transpose(p0123, (0,1,4,2,5,3)) #[M,N,quad_size,K,quad_size,3] +#positions_y = numpy.reshape(numpy.ascontiguousarray(p0123), [layer_num,layer_num*quad_size,layer_num*quad_size,3]) +positions_y = numpy.reshape(numpy.ascontiguousarray(p0123), [-1,3]) + +p0 = None +p1 = None +p2 = None +p3 = None +p0123 = None + +total_len = len(positions_y) +batch_len = total_len//batch_num +coarse_feature_y = numpy.zeros([total_len,num_bottleneck_features],numpy.uint8) +for i in range(batch_num): + t0 = numpy.reshape(positions_y[i*batch_len:(i+1)*batch_len], [n_device,-1,3]) + t0 = get_density_color_p(t0,vars) + coarse_feature_y[i*batch_len:(i+1)*batch_len] = numpy.reshape(t0,[-1,num_bottleneck_features]) +coarse_feature_y = numpy.reshape(coarse_feature_y,[layer_num,texture_size,texture_size,num_bottleneck_features]) +coarse_feature_y[:,-quad_size:,:] = 0 +coarse_feature_y[:,:,-quad_size:] = 0 + +positions_y = None + +buffer_y = [] +for i in range(layer_num): + if not numpy.any(coarse_feature_y[i,:,:,0]>0): + buffer_y.append(None) + continue + feats = get_feature_png(coarse_feature_y[i]) + buffer_y.append(feats) + +coarse_feature_y = None + +#%% +write_floatpoint_image(samples_dir+"/s3_slice_sample.png",buffer_z[layer_num//2][0]/255.0) +#%% +out_img_size = 1024*20 +out_img = [] +for i in range(out_feat_num): + out_img.append(numpy.zeros([out_img_size,out_img_size,4], numpy.uint8)) +out_cell_num = 0 +out_cell_size = quad_size+1 +out_img_h = out_img_size//out_cell_size +out_img_w = out_img_size//out_cell_size + + + +if scene_type=="synthetic": + def inverse_taper_coord_numpy(p): + return p + +elif scene_type=="forwardfacing": + def inverse_taper_coord_numpy(p): + pz = numpy.exp( p[..., 2:3] * \ + (numpy.log(scene_grid_zstart) - numpy.log(scene_grid_zend)) + \ + numpy.log(scene_grid_zend) ) + px = p[..., 0:1]*(pz*scene_grid_taper) + py = p[..., 1:2]*(pz*scene_grid_taper) + pz = -pz + return numpy.concatenate([px,py,pz],axis=-1) + +elif scene_type=="real360": + def inverse_taper_coord_numpy(p): + return p + + + +def write_patch_to_png(out_img,out_cell_num,out_img_w,j,k,feats): + py = out_cell_num//out_img_w + px = out_cell_num%out_img_w + + osy = j*quad_size + oey = j*quad_size+out_cell_size + tsy = py*out_cell_size + tey = py*out_cell_size+out_cell_size + osx = k*quad_size + oex = k*quad_size+out_cell_size + tsx = px*out_cell_size + tex = px*out_cell_size+out_cell_size + + for i in range(out_feat_num): + out_img[i][tsy:tey,tsx:tex] = feats[i][osy:oey,osx:oex] + +def get_png_uv(out_cell_num,out_img_w,out_img_size): + py = out_cell_num//out_img_w + px = out_cell_num%out_img_w + + uv0 = numpy.array([py*out_cell_size+0.5, px*out_cell_size+0.5],numpy.float32)/out_img_size + uv1 = numpy.array([(py+1)*out_cell_size-0.5, px*out_cell_size+0.5],numpy.float32)/out_img_size + uv2 = numpy.array([py*out_cell_size+0.5, (px+1)*out_cell_size-0.5],numpy.float32)/out_img_size + uv3 = numpy.array([(py+1)*out_cell_size-0.5, (px+1)*out_cell_size-0.5],numpy.float32)/out_img_size + + return uv0,uv1,uv2,uv3 + + +#for eval +point_UV_grid = numpy.zeros([point_grid_size,point_grid_size,point_grid_size,3,4,2], numpy.float32) + +#mesh vertices +bag_of_v = [] + + +#synthetic and real360 +#up is z- +#order: z-,x+,y+ +if scene_type=="synthetic" or scene_type=="real360": + for k in range(layer_num-1,-1,-1): + for i in range(layer_num): + for j in range(layer_num): + + # z plane + if not(k==0 or k==layer_num-1 or i==layer_num-1 or j==layer_num-1): + feats = buffer_z[k] + if feats is not None and numpy.max(feats[0][i*quad_size:(i+1)*quad_size+1,j*quad_size:(j+1)*quad_size+1,2])>0: + + write_patch_to_png(out_img,out_cell_num,out_img_w,i,j,feats) + uv0,uv1,uv2,uv3 = get_png_uv(out_cell_num,out_img_w,out_img_size) + out_cell_num += 1 + + p0 = v_grid[i,j,k] + (numpy.array([i,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p1 = v_grid[i+1,j,k] + (numpy.array([i+1,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p2 = v_grid[i,j+1,k] + (numpy.array([i,j+1,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p3 = v_grid[i+1,j+1,k] + (numpy.array([i+1,j+1,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + + p0 = inverse_taper_coord_numpy(p0) + p1 = inverse_taper_coord_numpy(p1) + p2 = inverse_taper_coord_numpy(p2) + p3 = inverse_taper_coord_numpy(p3) + + point_UV_grid[i,j,k,2,0] = uv0 + point_UV_grid[i,j,k,2,1] = uv1 + point_UV_grid[i,j,k,2,2] = uv2 + point_UV_grid[i,j,k,2,3] = uv3 + + bag_of_v.append([p0,p1,p2,p3]) + + # x plane + if not(i==0 or i==layer_num-1 or j==layer_num-1 or k==layer_num-1): + feats = buffer_x[i] + if feats is not None and numpy.max(feats[0][j*quad_size:(j+1)*quad_size+1,k*quad_size:(k+1)*quad_size+1,2])>0: + + write_patch_to_png(out_img,out_cell_num,out_img_w,j,k,feats) + uv0,uv1,uv2,uv3 = get_png_uv(out_cell_num,out_img_w,out_img_size) + out_cell_num += 1 + + p0 = v_grid[i,j,k] + (numpy.array([i,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p1 = v_grid[i,j+1,k] + (numpy.array([i,j+1,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p2 = v_grid[i,j,k+1] + (numpy.array([i,j,k+1],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p3 = v_grid[i,j+1,k+1] + (numpy.array([i,j+1,k+1],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + + p0 = inverse_taper_coord_numpy(p0) + p1 = inverse_taper_coord_numpy(p1) + p2 = inverse_taper_coord_numpy(p2) + p3 = inverse_taper_coord_numpy(p3) + + point_UV_grid[i,j,k,0,0] = uv0 + point_UV_grid[i,j,k,0,1] = uv1 + point_UV_grid[i,j,k,0,2] = uv2 + point_UV_grid[i,j,k,0,3] = uv3 + + bag_of_v.append([p0,p1,p2,p3]) + + # y plane + if not(j==0 or j==layer_num-1 or i==layer_num-1 or k==layer_num-1): + feats = buffer_y[j] + if feats is not None and numpy.max(feats[0][i*quad_size:(i+1)*quad_size+1,k*quad_size:(k+1)*quad_size+1,2])>0: + + write_patch_to_png(out_img,out_cell_num,out_img_w,i,k,feats) + uv0,uv1,uv2,uv3 = get_png_uv(out_cell_num,out_img_w,out_img_size) + out_cell_num += 1 + + p0 = v_grid[i,j,k] + (numpy.array([i,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p1 = v_grid[i+1,j,k] + (numpy.array([i+1,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p2 = v_grid[i,j,k+1] + (numpy.array([i,j,k+1],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p3 = v_grid[i+1,j,k+1] + (numpy.array([i+1,j,k+1],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + + p0 = inverse_taper_coord_numpy(p0) + p1 = inverse_taper_coord_numpy(p1) + p2 = inverse_taper_coord_numpy(p2) + p3 = inverse_taper_coord_numpy(p3) + + point_UV_grid[i,j,k,1,0] = uv0 + point_UV_grid[i,j,k,1,1] = uv1 + point_UV_grid[i,j,k,1,2] = uv2 + point_UV_grid[i,j,k,1,3] = uv3 + + bag_of_v.append([p0,p1,p2,p3]) + + + + + +#forwardfacing +#front is z- +#order: z+,x+,y+ +elif scene_type=="forwardfacing": + for k in range(layer_num): + for i in range(layer_num): + for j in range(layer_num): + + # z plane + if not(k==0 or k==layer_num-1 or i==layer_num-1 or j==layer_num-1): + feats = buffer_z[k] + if feats is not None and numpy.max(feats[0][i*quad_size:(i+1)*quad_size+1,j*quad_size:(j+1)*quad_size+1,2])>0: + + write_patch_to_png(out_img,out_cell_num,out_img_w,i,j,feats) + uv0,uv1,uv2,uv3 = get_png_uv(out_cell_num,out_img_w,out_img_size) + out_cell_num += 1 + + p0 = v_grid[i,j,k] + (numpy.array([i,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p1 = v_grid[i+1,j,k] + (numpy.array([i+1,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p2 = v_grid[i,j+1,k] + (numpy.array([i,j+1,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p3 = v_grid[i+1,j+1,k] + (numpy.array([i+1,j+1,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + + p0 = inverse_taper_coord_numpy(p0) + p1 = inverse_taper_coord_numpy(p1) + p2 = inverse_taper_coord_numpy(p2) + p3 = inverse_taper_coord_numpy(p3) + + point_UV_grid[i,j,k,2,0] = uv0 + point_UV_grid[i,j,k,2,1] = uv1 + point_UV_grid[i,j,k,2,2] = uv2 + point_UV_grid[i,j,k,2,3] = uv3 + + bag_of_v.append([p0,p1,p2,p3]) + + # x plane + if not(i==0 or i==layer_num-1 or j==layer_num-1 or k==layer_num-1): + feats = buffer_x[i] + if feats is not None and numpy.max(feats[0][j*quad_size:(j+1)*quad_size+1,k*quad_size:(k+1)*quad_size+1,2])>0: + + write_patch_to_png(out_img,out_cell_num,out_img_w,j,k,feats) + uv0,uv1,uv2,uv3 = get_png_uv(out_cell_num,out_img_w,out_img_size) + out_cell_num += 1 + + p0 = v_grid[i,j,k] + (numpy.array([i,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p1 = v_grid[i,j+1,k] + (numpy.array([i,j+1,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p2 = v_grid[i,j,k+1] + (numpy.array([i,j,k+1],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p3 = v_grid[i,j+1,k+1] + (numpy.array([i,j+1,k+1],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + + p0 = inverse_taper_coord_numpy(p0) + p1 = inverse_taper_coord_numpy(p1) + p2 = inverse_taper_coord_numpy(p2) + p3 = inverse_taper_coord_numpy(p3) + + point_UV_grid[i,j,k,0,0] = uv0 + point_UV_grid[i,j,k,0,1] = uv1 + point_UV_grid[i,j,k,0,2] = uv2 + point_UV_grid[i,j,k,0,3] = uv3 + + bag_of_v.append([p0,p1,p2,p3]) + + # y plane + if not(j==0 or j==layer_num-1 or i==layer_num-1 or k==layer_num-1): + feats = buffer_y[j] + if feats is not None and numpy.max(feats[0][i*quad_size:(i+1)*quad_size+1,k*quad_size:(k+1)*quad_size+1,2])>0: + + write_patch_to_png(out_img,out_cell_num,out_img_w,i,k,feats) + uv0,uv1,uv2,uv3 = get_png_uv(out_cell_num,out_img_w,out_img_size) + out_cell_num += 1 + + p0 = v_grid[i,j,k] + (numpy.array([i,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p1 = v_grid[i+1,j,k] + (numpy.array([i+1,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p2 = v_grid[i,j,k+1] + (numpy.array([i,j,k+1],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p3 = v_grid[i+1,j,k+1] + (numpy.array([i+1,j,k+1],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + + p0 = inverse_taper_coord_numpy(p0) + p1 = inverse_taper_coord_numpy(p1) + p2 = inverse_taper_coord_numpy(p2) + p3 = inverse_taper_coord_numpy(p3) + + point_UV_grid[i,j,k,1,0] = uv0 + point_UV_grid[i,j,k,1,1] = uv1 + point_UV_grid[i,j,k,1,2] = uv2 + point_UV_grid[i,j,k,1,3] = uv3 + + bag_of_v.append([p0,p1,p2,p3]) + + + +print("Number of quad faces:", out_cell_num) + +buffer_x = None +buffer_y = None +buffer_z = None +#%% -------------------------------------------------------------------------------- +# ## Main rendering functions +#%% + +#compute ray-gridcell intersections + +if scene_type=="synthetic": + + def gridcell_from_rays(rays): + ray_origins = rays[0] + ray_directions = rays[1] + + dtype = ray_origins.dtype + batch_shape = ray_origins.shape[:-1] + small_step = 1e-5 + epsilon = 1e-5 + + ox = ray_origins[..., 0:1] + oy = ray_origins[..., 1:2] + oz = ray_origins[..., 2:3] + + dx = ray_directions[..., 0:1] + dy = ray_directions[..., 1:2] + dz = ray_directions[..., 2:3] + + dxm = (np.abs(dx)=1) & (grid_positions[..., 0]=1) & (grid_positions[..., 1]=1) & (grid_positions[..., 2]=1) & (grid_positions[..., 0]=1) & (grid_positions[..., 1]=1) & (grid_positions[..., 2]=1) & (grid_positions[..., 0]=1) & (grid_positions[..., 1]=1) & (grid_positions[..., 2]=0) & (b>=0) & (c>=0) & np.logical_not(denominator_mask) + return a,b,c,mask + + + + +cell_size_x = (grid_max[0] - grid_min[0])/point_grid_size +half_cell_size_x = cell_size_x/2 +neg_half_cell_size_x = -half_cell_size_x +cell_size_y = (grid_max[1] - grid_min[1])/point_grid_size +half_cell_size_y = cell_size_y/2 +neg_half_cell_size_y = -half_cell_size_y +cell_size_z = (grid_max[2] - grid_min[2])/point_grid_size +half_cell_size_z = cell_size_z/2 +neg_half_cell_size_z = -half_cell_size_z + +def get_inside_cell_mask(P,ooxyz): + P_ = get_taper_coord(P) - ooxyz + return (P_[..., 0]>=neg_half_cell_size_x) \ + & (P_[..., 0]=neg_half_cell_size_y) \ + & (P_[..., 1]=neg_half_cell_size_z) \ + & (P_[..., 2] 0.5) #[N,4] + + ind = np.argmax(weights_b, axis=-1, keepdims=True) #[N,4,1] + selected_uv = np.take_along_axis(world_uv, ind[..., None], axis=-2) #[N,4,1,2] + selected_uv = selected_uv[..., 0,:] * acc_b[..., None] #[N,4,2] + + return acc_b, selected_uv + +def render_rays_get_color(rays, vars, mlp_features_b, acc_b): + + mlp_features_b = mlp_features_b.astype(np.float32)/255 #[N,4,C] + mlp_features_b = mlp_features_b * acc_b[..., None] #[N,4,C] + mlp_features_b = np.mean(mlp_features_b, axis=-2) #[N,C] + + acc_b = np.mean(acc_b.astype(np.float32), axis=-1) #[N] + + # ... as well as view-dependent colors. + dirs = normalize(rays[1]) #[N,4,3] + dirs = np.mean(dirs, axis=-2) #[N,3] + features_dirs_enc_b = np.concatenate([mlp_features_b, dirs], axis=-1) #[N,C+3] + rgb_b = jax.nn.sigmoid(color_model.apply(vars[-1], features_dirs_enc_b)) + + # Composite onto the background color. + if white_bkgd: + rgb_b = rgb_b * acc_b[..., None] + (1. - acc_b[..., None]) + else: + bgc = bg_color + rgb_b = rgb_b * acc_b[..., None] + (1. - acc_b[..., None]) * bgc + + return rgb_b, acc_b + +#%% -------------------------------------------------------------------------------- +# ## Set up pmap'd rendering for test time evaluation. +#%% +#for eval +texture_alpha = numpy.zeros([out_img_size,out_img_size,1], numpy.uint8) +texture_features = numpy.zeros([out_img_size,out_img_size,8], numpy.uint8) + +texture_alpha[:,:,0] = (out_img[0][:,:,2]>0) + +texture_features[:,:,0:3] = out_img[0][:,:,2::-1] +texture_features[:,:,3] = out_img[0][:,:,3] +texture_features[:,:,4:7] = out_img[1][:,:,2::-1] +texture_features[:,:,7] = out_img[1][:,:,3] +#%% +test_batch_size = 4096*n_device + +render_rays_get_uv_p = jax.pmap(lambda rays, vars, uv, alp: render_rays_get_uv( + rays, vars, uv, alp), + in_axes=(0, None, None, None)) + +render_rays_get_color_p = jax.pmap(lambda rays, vars, mlp_features_b, acc_b: render_rays_get_color( + rays, vars, mlp_features_b, acc_b), + in_axes=(0, None, 0, 0)) + + +def render_test(rays, vars, uv, alp, feat): + sh = rays[0].shape + rays = [x.reshape((jax.local_device_count(), -1) + sh[1:]) for x in rays] + acc_b, selected_uv = render_rays_get_uv_p(rays, vars, uv, alp) + + #deferred features + selected_uv = numpy.array(selected_uv) + mlp_features_b = feat[selected_uv[...,0],selected_uv[...,1]] + + rgb_b, acc_b = render_rays_get_color_p(rays, vars, mlp_features_b, acc_b) + + out = [rgb_b, acc_b, selected_uv] + out = [numpy.reshape(numpy.array(out[i]),sh[:-2]+(-1,)) for i in range(3)] + return out + +def render_loop(rays, vars, uv, alp, feat, chunk): + sh = list(rays[0].shape[:-2]) + rays = [x.reshape([-1, 4, 3]) for x in rays] + l = rays[0].shape[0] + n = jax.local_device_count() + p = ((l - 1) // n + 1) * n - l + rays = [np.pad(x, ((0,p),(0,0),(0,0))) for x in rays] + outs = [render_test([x[i:i+chunk] for x in rays], vars, uv, alp, feat) + for i in range(0, rays[0].shape[0], chunk)] + outs = [np.reshape( + np.concatenate([z[i] for z in outs])[:l], sh + [-1]) for i in range(3)] + return outs + +# Make sure that everything works, by rendering an image from the test set + +if scene_type=="synthetic": + selected_test_index = 97 + preview_image_height = 800 + +elif scene_type=="forwardfacing": + selected_test_index = 0 + preview_image_height = 756//2 + +elif scene_type=="real360": + selected_test_index = 0 + preview_image_height = 840//2 + +rays = camera_ray_batch( + data['test']['c2w'][selected_test_index], data['test']['hwf']) +gt = data['test']['images'][selected_test_index] +out = render_loop(rays, model_vars, point_UV_grid, texture_alpha, texture_features, test_batch_size) +rgb = out[0] +acc = out[1] +write_floatpoint_image(samples_dir+"/s3_"+str(0)+"_rgb_discretized.png",rgb) +write_floatpoint_image(samples_dir+"/s3_"+str(0)+"_gt.png",gt) +write_floatpoint_image(samples_dir+"/s3_"+str(0)+"_acc_discretized.png",acc) +#%% -------------------------------------------------------------------------------- +# ## Remove invisible triangles +#%% +gc.collect() + +render_poses = data['train']['c2w'] +texture_mask = numpy.zeros([out_img_size,out_img_size], numpy.uint8) +print("Removing invisible triangles") +for p in tqdm(render_poses): + out = render_loop(camera_ray_batch(p, hwf), vars, point_UV_grid, texture_alpha, texture_features, test_batch_size) + uv = np.reshape(out[2],[-1,2]) + texture_mask[uv[:,0],uv[:,1]] = 1 +#%% +#additional views +if scene_type=="synthetic": + def generate_spherical_poses(poses): + rad = np.sqrt(np.mean(np.sum(np.square(poses[:, :3, 3]), -1))) + centroid = np.mean(poses[:, :3, 3], 0) + pmax = np.max(poses[:, :3, 3], 0) + pmin = np.min(poses[:, :3, 3], 0) + zh0 = centroid[2] + zh1 = centroid[2]*0.6 + pmax[2]*0.4 + zh2 = centroid[2]*0.2 + pmax[2]*0.8 + zh3 = centroid[2]*0.6 + pmin[2]*0.4 + zh4 = centroid[2]*0.2 + pmin[2]*0.8 + new_poses = [] + + for zh in [zh0,zh1,zh2,zh3,zh4]: + radcircle = np.sqrt(rad**2 - zh**2) + for th in np.linspace(0., 2. * np.pi, 60): + camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) + up = np.array([0, 0, -1.]) + vec2 = normalize(camorigin) + vec0 = normalize(np.cross(vec2, up)) + vec1 = normalize(np.cross(vec2, vec0)) + pos = camorigin + p = np.stack([vec0, vec1, vec2, pos], 1) + new_poses.append(p) + + render_poses = np.stack(new_poses, 0)[:, :3, :4] + return render_poses + + poses = data['train']['c2w'] + additional_poses = generate_spherical_poses(poses) + + print("Removing invisible triangles") + for p in tqdm(additional_poses): + out = render_loop(camera_ray_batch(p, hwf), vars, point_UV_grid, texture_alpha, texture_features, test_batch_size) + uv = np.reshape(out[2],[-1,2]) + texture_mask[uv[:,0],uv[:,1]] = 1 +#%% +#mask invisible triangles for eval + +#count visible quads +num_visible_quads = 0 + +quad_t1_mask = numpy.zeros([out_cell_size,out_cell_size],numpy.uint8) +quad_t2_mask = numpy.zeros([out_cell_size,out_cell_size],numpy.uint8) +for i in range(out_cell_size): + for j in range(out_cell_size): + if i>=j: + quad_t1_mask[i,j] = 1 + if i<=j: + quad_t2_mask[i,j] = 1 + +def check_triangle_visible(mask,out_cell_num): + py = out_cell_num//out_img_w + px = out_cell_num%out_img_w + + tsy = py*out_cell_size + tey = py*out_cell_size+out_cell_size + tsx = px*out_cell_size + tex = px*out_cell_size+out_cell_size + + quad_m = mask[tsy:tey,tsx:tex] + t1_visible = numpy.any(quad_m*quad_t1_mask) + t2_visible = numpy.any(quad_m*quad_t2_mask) + + return (t1_visible or t2_visible), t1_visible, t2_visible + +def mask_triangle_invisible(mask,out_cell_num,imga): + py = out_cell_num//out_img_w + px = out_cell_num%out_img_w + + tsy = py*out_cell_size + tey = py*out_cell_size+out_cell_size + tsx = px*out_cell_size + tex = px*out_cell_size+out_cell_size + + quad_m = mask[tsy:tey,tsx:tex] + t1_visible = numpy.any(quad_m*quad_t1_mask) + t2_visible = numpy.any(quad_m*quad_t2_mask) + + if not (t1_visible or t2_visible): + imga[tsy:tey,tsx:tex] = 0 + + elif not t1_visible: + imga[tsy:tey,tsx:tex] = imga[tsy:tey,tsx:tex]*quad_t2_mask[:,:,None] + + elif not t2_visible: + imga[tsy:tey,tsx:tex] = imga[tsy:tey,tsx:tex]*quad_t1_mask[:,:,None] + + return (t1_visible or t2_visible), t1_visible, t2_visible + + +for i in range(out_cell_num): + quad_visible, t1_visible, t2_visible = mask_triangle_invisible(texture_mask, i, texture_alpha) + if quad_visible: + num_visible_quads += 1 + +print("Number of quad faces:", num_visible_quads) + +#%% -------------------------------------------------------------------------------- +# ## Eval +#%% +gc.collect() + +render_poses = data['test']['c2w'][:len(data['test']['images'])] +frames = [] +framemasks = [] +print("Testing") +for p in tqdm(render_poses): + out = render_loop(camera_ray_batch(p, hwf), vars, point_UV_grid, texture_alpha, texture_features, test_batch_size) + frames.append(out[0]) + framemasks.append(out[1]) +psnrs_test = [-10 * np.log10(np.mean(np.square(rgb - gt))) for (rgb, gt) in zip(frames, data['test']['images'])] +print("Test set average PSNR: %f" % np.array(psnrs_test).mean()) + +#%% +import jax.numpy as jnp +import jax.scipy as jsp + +def compute_ssim(img0, + img1, + max_val, + filter_size=11, + filter_sigma=1.5, + k1=0.01, + k2=0.03, + return_map=False): + """Computes SSIM from two images. + This function was modeled after tf.image.ssim, and should produce comparable + output. + Args: + img0: array. An image of size [..., width, height, num_channels]. + img1: array. An image of size [..., width, height, num_channels]. + max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. + filter_size: int >= 1. Window size. + filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. + k1: float > 0. One of the SSIM dampening parameters. + k2: float > 0. One of the SSIM dampening parameters. + return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned + Returns: + Each image's mean SSIM, or a tensor of individual values if `return_map`. + """ + # Construct a 1D Gaussian blur filter. + hw = filter_size // 2 + shift = (2 * hw - filter_size + 1) / 2 + f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma)**2 + filt = jnp.exp(-0.5 * f_i) + filt /= jnp.sum(filt) + + # Blur in x and y (faster than the 2D convolution). + filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid") + filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid") + + # Vmap the blurs to the tensor size, and then compose them. + num_dims = len(img0.shape) + map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1]) + for d in map_axes: + filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d) + filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d) + filt_fn = lambda z: filt_fn1(filt_fn2(z)) + + mu0 = filt_fn(img0) + mu1 = filt_fn(img1) + mu00 = mu0 * mu0 + mu11 = mu1 * mu1 + mu01 = mu0 * mu1 + sigma00 = filt_fn(img0**2) - mu00 + sigma11 = filt_fn(img1**2) - mu11 + sigma01 = filt_fn(img0 * img1) - mu01 + + # Clip the variances and covariances to valid values. + # Variance must be non-negative: + sigma00 = jnp.maximum(0., sigma00) + sigma11 = jnp.maximum(0., sigma11) + sigma01 = jnp.sign(sigma01) * jnp.minimum( + jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01)) + + c1 = (k1 * max_val)**2 + c2 = (k2 * max_val)**2 + numer = (2 * mu01 + c1) * (2 * sigma01 + c2) + denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) + ssim_map = numer / denom + ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims))) + return ssim_map if return_map else ssim + +# Compiling to the CPU because it's faster and more accurate. +ssim_fn = jax.jit( + functools.partial(compute_ssim, max_val=1.), backend="cpu") + +ssim_values = [] +for i in range(len(data['test']['images'])): + ssim = ssim_fn(frames[i], data['test']['images'][i]) + ssim_values.append(float(ssim)) + +print("Test set average SSIM: %f" % np.array(ssim_values).mean()) +#%% -------------------------------------------------------------------------------- +# ## Write mesh +#%% + +#use texture_mask to decide keep or drop + +new_img_sizes = [ + [1024,1024], + [2048,1024], + [2048,2048], + [4096,2048], + [4096,4096], + [8192,4096], + [8192,8192], + [16384,8192], + [16384,16384], +] + +fit_flag = False +for i in range(len(new_img_sizes)): + new_img_size_w,new_img_size_h = new_img_sizes[i] + new_img_size_ratio = new_img_size_w/new_img_size_h + new_img_h = new_img_size_h//out_cell_size + new_img_w = new_img_size_w//out_cell_size + if num_visible_quads<=new_img_h*new_img_w: + fit_flag = True + break + +if fit_flag: + print("Texture image size:", new_img_size_w,new_img_size_h) +else: + print("Texture image too small", new_img_size_w,new_img_size_h) + 1/0 + + +new_img = [] +for i in range(out_feat_num): + new_img.append(numpy.zeros([new_img_size_h,new_img_size_w,4], numpy.uint8)) +new_cell_num = 0 + + +def copy_patch_to_png(out_img,out_cell_num,new_img,new_cell_num): + py = out_cell_num//out_img_w + px = out_cell_num%out_img_w + + ny = new_cell_num//new_img_w + nx = new_cell_num%new_img_w + + tsy = py*out_cell_size + tey = py*out_cell_size+out_cell_size + tsx = px*out_cell_size + tex = px*out_cell_size+out_cell_size + nsy = ny*out_cell_size + ney = ny*out_cell_size+out_cell_size + nsx = nx*out_cell_size + nex = nx*out_cell_size+out_cell_size + + for i in range(out_feat_num): + new_img[i][nsy:ney,nsx:nex] = out_img[i][tsy:tey,tsx:tex] + + return True + + + +#write mesh + +obj_save_dir = "obj" +if not os.path.exists(obj_save_dir): + os.makedirs(obj_save_dir) + +obj_f = open(obj_save_dir+"/shape.obj",'w') + +vcount = 0 + +for i in range(out_cell_num): + quad_visible, t1_visible, t2_visible = check_triangle_visible(texture_mask, i) + if quad_visible: + copy_patch_to_png(out_img,i,new_img,new_cell_num) + p0,p1,p2,p3 = bag_of_v[i] + uv0,uv1,uv2,uv3 = get_png_uv(new_cell_num,new_img_w,new_img_size_w) + new_cell_num += 1 + + if scene_type=="synthetic" or scene_type=="real360": + obj_f.write("v %.6f %.6f %.6f\n" % (p0[0],p0[2],-p0[1])) + obj_f.write("v %.6f %.6f %.6f\n" % (p1[0],p1[2],-p1[1])) + obj_f.write("v %.6f %.6f %.6f\n" % (p2[0],p2[2],-p2[1])) + obj_f.write("v %.6f %.6f %.6f\n" % (p3[0],p3[2],-p3[1])) + elif scene_type=="forwardfacing": + obj_f.write("v %.6f %.6f %.6f\n" % (p0[0],p0[1],p0[2])) + obj_f.write("v %.6f %.6f %.6f\n" % (p1[0],p1[1],p1[2])) + obj_f.write("v %.6f %.6f %.6f\n" % (p2[0],p2[1],p2[2])) + obj_f.write("v %.6f %.6f %.6f\n" % (p3[0],p3[1],p3[2])) + + obj_f.write("vt %.6f %.6f\n" % (uv0[1],1-uv0[0]*new_img_size_ratio)) + obj_f.write("vt %.6f %.6f\n" % (uv1[1],1-uv1[0]*new_img_size_ratio)) + obj_f.write("vt %.6f %.6f\n" % (uv2[1],1-uv2[0]*new_img_size_ratio)) + obj_f.write("vt %.6f %.6f\n" % (uv3[1],1-uv3[0]*new_img_size_ratio)) + if t1_visible: + obj_f.write("f %d/%d %d/%d %d/%d\n" % (vcount+1,vcount+1,vcount+2,vcount+2,vcount+4,vcount+4)) + if t2_visible: + obj_f.write("f %d/%d %d/%d %d/%d\n" % (vcount+1,vcount+1,vcount+4,vcount+4,vcount+3,vcount+3)) + vcount += 4 + +for j in range(out_feat_num): + cv2.imwrite(obj_save_dir+"/shape.pngfeat"+str(j)+".png", new_img[j], [cv2.IMWRITE_PNG_COMPRESSION, 9]) +obj_f.close() + + +#%% +#export weights for the MLP +mlp_params = {} + +mlp_params['0_weights'] = vars[-1]['params']['Dense_0']['kernel'].tolist() +mlp_params['1_weights'] = vars[-1]['params']['Dense_1']['kernel'].tolist() +mlp_params['2_weights'] = vars[-1]['params']['Dense_2']['kernel'].tolist() +mlp_params['0_bias'] = vars[-1]['params']['Dense_0']['bias'].tolist() +mlp_params['1_bias'] = vars[-1]['params']['Dense_1']['bias'].tolist() +mlp_params['2_bias'] = vars[-1]['params']['Dense_2']['bias'].tolist() + +scene_params_path = obj_save_dir+'/mlp.json' +with open(scene_params_path, 'wb') as f: + f.write(json.dumps(mlp_params).encode('utf-8')) +#%% -------------------------------------------------------------------------------- +# ## Split the large texture image into images of size 4096 +#%% +import numpy as np + +target_dir = obj_save_dir+"_phone" + +texture_size = 4096 +patchsize = 17 +texture_patch_size = texture_size//patchsize + +if not os.path.exists(target_dir): + os.makedirs(target_dir) + + +source_obj_dir = obj_save_dir+"/shape.obj" +source_png0_dir = obj_save_dir+"/shape.pngfeat0.png" +source_png1_dir = obj_save_dir+"/shape.pngfeat1.png" + +source_png0 = cv2.imread(source_png0_dir,cv2.IMREAD_UNCHANGED) +source_png1 = cv2.imread(source_png1_dir,cv2.IMREAD_UNCHANGED) + +img_h,img_w,_ = source_png0.shape + + + +num_splits = 0 #this is a counter + + +fin = open(source_obj_dir,'r') +lines = fin.readlines() +fin.close() + + + +current_img_idx = 0 +current_img0 = np.zeros([texture_size,texture_size,4],np.uint8) +current_img1 = np.zeros([texture_size,texture_size,4],np.uint8) +current_quad_count = 0 +current_obj = open(target_dir+"/shape"+str(current_img_idx)+".obj",'w') +current_v_count = 0 +current_v_offset = 0 + +#v-vt-f cycle + +for i in range(len(lines)): + line = lines[i].split() + if len(line)==0: + continue + + elif line[0] == 'v': + current_obj.write(lines[i]) + current_v_count += 1 + + elif line[0] == 'vt': + if lines[i-1].split()[0] == "v": + + line = lines[i].split() + x0 = float(line[1]) + y0 = 1-float(line[2]) + + line = lines[i+1].split() + x1 = float(line[1]) + y1 = 1-float(line[2]) + + line = lines[i+2].split() + x2 = float(line[1]) + y2 = 1-float(line[2]) + + line = lines[i+3].split() + x3 = float(line[1]) + y3 = 1-float(line[2]) + + xc = (x0+x1+x2+x3)*img_w/4 + yc = (y0+y1+y2+y3)*img_h/4 + + old_cell_x = int(xc/patchsize) + old_cell_y = int(yc/patchsize) + + new_cell_x = current_quad_count%texture_patch_size + new_cell_y = current_quad_count//texture_patch_size + current_quad_count += 1 + + #copy patch + + tsy = old_cell_y*patchsize + tey = old_cell_y*patchsize+patchsize + tsx = old_cell_x*patchsize + tex = old_cell_x*patchsize+patchsize + nsy = new_cell_y*patchsize + ney = new_cell_y*patchsize+patchsize + nsx = new_cell_x*patchsize + nex = new_cell_x*patchsize+patchsize + + current_img0[nsy:ney,nsx:nex] = source_png0[tsy:tey,tsx:tex] + current_img1[nsy:ney,nsx:nex] = source_png1[tsy:tey,tsx:tex] + + #write uv + + uv0_y = (new_cell_y*patchsize+0.5)/texture_size + uv0_x = (new_cell_x*patchsize+0.5)/texture_size + + uv1_y = ((new_cell_y+1)*patchsize-0.5)/texture_size + uv1_x = (new_cell_x*patchsize+0.5)/texture_size + + uv2_y = (new_cell_y*patchsize+0.5)/texture_size + uv2_x = ((new_cell_x+1)*patchsize-0.5)/texture_size + + uv3_y = ((new_cell_y+1)*patchsize-0.5)/texture_size + uv3_x = ((new_cell_x+1)*patchsize-0.5)/texture_size + + current_obj.write("vt %.6f %.6f\n" % (uv0_x,1-uv0_y)) + current_obj.write("vt %.6f %.6f\n" % (uv1_x,1-uv1_y)) + current_obj.write("vt %.6f %.6f\n" % (uv2_x,1-uv2_y)) + current_obj.write("vt %.6f %.6f\n" % (uv3_x,1-uv3_y)) + + + elif line[0] == 'f': + f1 = int(line[1].split("/")[0])-current_v_offset + f2 = int(line[2].split("/")[0])-current_v_offset + f3 = int(line[3].split("/")[0])-current_v_offset + current_obj.write("f %d/%d %d/%d %d/%d\n" % (f1,f1,f2,f2,f3,f3)) + + #create new texture image if current is fill + if i==len(lines)-1 or (lines[i+1].split()[0]!='f' and current_quad_count==texture_patch_size*texture_patch_size): + current_obj.close() + + # the following is only required for iphone + # because iphone runs alpha test before the fragment shader + # the viewer code is also changed accordingly + current_img0[:,:,3] = current_img0[:,:,3]//2+128 + current_img1[:,:,3] = current_img1[:,:,3]//2+128 + + cv2.imwrite(target_dir+"/shape"+str(current_img_idx)+".pngfeat0.png", current_img0, [cv2.IMWRITE_PNG_COMPRESSION,9]) + cv2.imwrite(target_dir+"/shape"+str(current_img_idx)+".pngfeat1.png", current_img1, [cv2.IMWRITE_PNG_COMPRESSION,9]) + current_img_idx += 1 + current_img0 = np.zeros([texture_size,texture_size,4],np.uint8) + current_img1 = np.zeros([texture_size,texture_size,4],np.uint8) + current_quad_count = 0 + if i!=len(lines)-1: + current_obj = open(target_dir+"/shape"+str(current_img_idx)+".obj",'w') + current_v_offset += current_v_count + current_v_count = 0 + + + + +#copy the small MLP +source_json_dir = obj_save_dir+"/mlp.json" +current_json_dir = target_dir+"/mlp.json" +fin = open(source_json_dir,'r') +line = fin.readline() +fin.close() +fout = open(current_json_dir,'w') +fout.write(line.strip()[:-1]) +fout.write(",\"obj_num\": "+str(current_img_idx)+"}") +fout.close() + +#%% -------------------------------------------------------------------------------- +# # Save images for testing +#%% + +pred_frames = np.array(frames,np.float32) +gt_frames = np.array(data['test']['images'],np.float32) + +pickle.dump(pred_frames, open("pred_frames.pkl", "wb")) +pickle.dump(gt_frames, open("gt_frames.pkl", "wb")) diff --git a/jax3d/projects/mobilenerf/stage3_with_box.py b/jax3d/projects/mobilenerf/stage3_with_box.py new file mode 100644 index 0000000..7a88b93 --- /dev/null +++ b/jax3d/projects/mobilenerf/stage3_with_box.py @@ -0,0 +1,2768 @@ +# Copyright 2022 The jax3d Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +scene_type = "real360" +object_name = "gardenvase" +scene_dir = "datasets/nerf_real_360/"+object_name + +# synthetic +# chair drums ficus hotdog lego materials mic ship + +# forwardfacing +# fern flower fortress horns leaves orchids room trex + +# real360 +# bicycle flowerbed gardenvase stump treehill +# fulllivingroom kitchencounter kitchenlego officebonsai + +#%% -------------------------------------------------------------------------------- +# ## General imports +#%% +import copy +import gc +import json +import os +import numpy +import cv2 +from tqdm import tqdm +import pickle +import jax +import jax.numpy as np +from jax import random +import flax +import flax.linen as nn +import functools +import math +from typing import Sequence, Callable +import time +import matplotlib.pyplot as plt +from PIL import Image +from multiprocessing.pool import ThreadPool + +print(jax.local_devices()) +if len(jax.local_devices())!=8: + print("ERROR: need 8 v100 GPUs") + 1/0 +weights_dir = "weights" +samples_dir = "samples" +if not os.path.exists(weights_dir): + os.makedirs(weights_dir) +if not os.path.exists(samples_dir): + os.makedirs(samples_dir) +def write_floatpoint_image(name,img): + img = numpy.clip(numpy.array(img)*255,0,255).astype(numpy.uint8) + cv2.imwrite(name,img[:,:,::-1]) +#%% -------------------------------------------------------------------------------- +# ## Load the dataset. +#%% +# """ Load dataset """ + +if scene_type=="synthetic": + white_bkgd = True +elif scene_type=="forwardfacing": + white_bkgd = False +elif scene_type=="real360": + white_bkgd = False + + +#https://github.com/google-research/google-research/blob/master/snerg/nerf/datasets.py + + +if scene_type=="synthetic": + + def load_blender(data_dir, split): + with open( + os.path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp: + meta = json.load(fp) + + cams = [] + paths = [] + for i in range(len(meta["frames"])): + frame = meta["frames"][i] + cams.append(np.array(frame["transform_matrix"], dtype=np.float32)) + + fname = os.path.join(data_dir, frame["file_path"] + ".png") + paths.append(fname) + + def image_read_fn(fname): + with open(fname, "rb") as imgin: + image = np.array(Image.open(imgin), dtype=np.float32) / 255. + return image + with ThreadPool() as pool: + images = pool.map(image_read_fn, paths) + pool.close() + pool.join() + + images = np.stack(images, axis=0) + if white_bkgd: + images = (images[..., :3] * images[..., -1:] + (1. - images[..., -1:])) + else: + images = images[..., :3] * images[..., -1:] + + h, w = images.shape[1:3] + camera_angle_x = float(meta["camera_angle_x"]) + focal = .5 * w / np.tan(.5 * camera_angle_x) + + hwf = np.array([h, w, focal], dtype=np.float32) + poses = np.stack(cams, axis=0) + return {'images' : images, 'c2w' : poses, 'hwf' : hwf} + + data = {'train' : load_blender(scene_dir, 'train'), + 'test' : load_blender(scene_dir, 'test')} + + splits = ['train', 'test'] + for s in splits: + print(s) + for k in data[s]: + print(f' {k}: {data[s][k].shape}') + + images, poses, hwf = data['train']['images'], data['train']['c2w'], data['train']['hwf'] + write_floatpoint_image(samples_dir+"/training_image_sample.png",images[0]) + + for i in range(3): + plt.figure() + plt.scatter(poses[:,i,3], poses[:,(i+1)%3,3]) + plt.axis('equal') + plt.savefig(samples_dir+"/training_camera"+str(i)+".png") + +elif scene_type=="forwardfacing" or scene_type=="real360": + + import numpy as np #temporarily use numpy as np, then switch back to jax.numpy + import jax.numpy as jnp + + def _viewmatrix(z, up, pos): + """Construct lookat view matrix.""" + vec2 = _normalize(z) + vec1_avg = up + vec0 = _normalize(np.cross(vec1_avg, vec2)) + vec1 = _normalize(np.cross(vec2, vec0)) + m = np.stack([vec0, vec1, vec2, pos], 1) + return m + + def _normalize(x): + """Normalization helper function.""" + return x / np.linalg.norm(x) + + def _poses_avg(poses): + """Average poses according to the original NeRF code.""" + hwf = poses[0, :3, -1:] + center = poses[:, :3, 3].mean(0) + vec2 = _normalize(poses[:, :3, 2].sum(0)) + up = poses[:, :3, 1].sum(0) + c2w = np.concatenate([_viewmatrix(vec2, up, center), hwf], 1) + return c2w + + def _recenter_poses(poses): + """Recenter poses according to the original NeRF code.""" + poses_ = poses.copy() + bottom = np.reshape([0, 0, 0, 1.], [1, 4]) + c2w = _poses_avg(poses) + c2w = np.concatenate([c2w[:3, :4], bottom], -2) + bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) + poses = np.concatenate([poses[:, :3, :4], bottom], -2) + poses = np.linalg.inv(c2w) @ poses + poses_[:, :3, :4] = poses[:, :3, :4] + poses = poses_ + return poses + + def _transform_poses_pca(poses): + """Transforms poses so principal components lie on XYZ axes.""" + poses_ = poses.copy() + t = poses[:, :3, 3] + t_mean = t.mean(axis=0) + t = t - t_mean + + eigval, eigvec = np.linalg.eig(t.T @ t) + # Sort eigenvectors in order of largest to smallest eigenvalue. + inds = np.argsort(eigval)[::-1] + eigvec = eigvec[:, inds] + rot = eigvec.T + if np.linalg.det(rot) < 0: + rot = np.diag(np.array([1, 1, -1])) @ rot + + transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) + bottom = np.broadcast_to([0, 0, 0, 1.], poses[..., :1, :4].shape) + pad_poses = np.concatenate([poses[..., :3, :4], bottom], axis=-2) + poses_recentered = transform @ pad_poses + poses_recentered = poses_recentered[..., :3, :4] + transform = np.concatenate([transform, np.eye(4)[3:]], axis=0) + + # Flip coordinate system if z component of y-axis is negative + if poses_recentered.mean(axis=0)[2, 1] < 0: + poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered + transform = np.diag(np.array([1, -1, -1, 1])) @ transform + + # Just make sure it's it in the [-1, 1]^3 cube + scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3])) + poses_recentered[:, :3, 3] *= scale_factor + transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform + + poses_[:, :3, :4] = poses_recentered[:, :3, :4] + poses_recentered = poses_ + return poses_recentered, transform + + def load_LLFF(data_dir, split, factor = 4, llffhold = 8): + # Load images. + imgdir_suffix = "" + if factor > 0: + imgdir_suffix = "_{}".format(factor) + imgdir = os.path.join(data_dir, "images" + imgdir_suffix) + if not os.path.exists(imgdir): + raise ValueError("Image folder {} doesn't exist.".format(imgdir)) + imgfiles = [ + os.path.join(imgdir, f) + for f in sorted(os.listdir(imgdir)) + if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") + ] + def image_read_fn(fname): + with open(fname, "rb") as imgin: + image = np.array(Image.open(imgin), dtype=np.float32) / 255. + return image + with ThreadPool() as pool: + images = pool.map(image_read_fn, imgfiles) + pool.close() + pool.join() + images = np.stack(images, axis=-1) + + # Load poses and bds. + with open(os.path.join(data_dir, "poses_bounds.npy"), + "rb") as fp: + poses_arr = np.load(fp) + poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) + bds = poses_arr[:, -2:].transpose([1, 0]) + if poses.shape[-1] != images.shape[-1]: + raise RuntimeError("Mismatch between imgs {} and poses {}".format( + images.shape[-1], poses.shape[-1])) + + # Update poses according to downsampling. + poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1]) + poses[2, 4, :] = poses[2, 4, :] * 1. / factor + + # Correct rotation matrix ordering and move variable dim to axis 0. + poses = np.concatenate( + [poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) + poses = np.moveaxis(poses, -1, 0).astype(np.float32) + images = np.moveaxis(images, -1, 0) + bds = np.moveaxis(bds, -1, 0).astype(np.float32) + + + if scene_type=="real360": + # Rotate/scale poses to align ground with xy plane and fit to unit cube. + poses, _ = _transform_poses_pca(poses) + else: + # Rescale according to a default bd factor. + scale = 1. / (bds.min() * .75) + poses[:, :3, 3] *= scale + bds *= scale + # Recenter poses + poses = _recenter_poses(poses) + + # Select the split. + i_test = np.arange(images.shape[0])[::llffhold] + i_train = np.array( + [i for i in np.arange(int(images.shape[0])) if i not in i_test]) + if split == "train": + indices = i_train + else: + indices = i_test + images = images[indices] + poses = poses[indices] + + camtoworlds = poses[:, :3, :4] + focal = poses[0, -1, -1] + h, w = images.shape[1:3] + + hwf = np.array([h, w, focal], dtype=np.float32) + + return {'images' : jnp.array(images), 'c2w' : jnp.array(camtoworlds), 'hwf' : jnp.array(hwf)} + + data = {'train' : load_LLFF(scene_dir, 'train'), + 'test' : load_LLFF(scene_dir, 'test')} + + splits = ['train', 'test'] + for s in splits: + print(s) + for k in data[s]: + print(f' {k}: {data[s][k].shape}') + + images, poses, hwf = data['train']['images'], data['train']['c2w'], data['train']['hwf'] + write_floatpoint_image(samples_dir+"/training_image_sample.png",images[0]) + + for i in range(3): + plt.figure() + plt.scatter(poses[:,i,3], poses[:,(i+1)%3,3]) + plt.axis('equal') + plt.savefig(samples_dir+"/training_camera"+str(i)+".png") + + bg_color = jnp.mean(images) + + import jax.numpy as np +#%% -------------------------------------------------------------------------------- +# ## Helper functions +#%% +adam_kwargs = { + 'beta1': 0.9, + 'beta2': 0.999, + 'eps': 1e-15, +} + +n_device = jax.local_device_count() + +rng = random.PRNGKey(1) + + + +# General math functions. + +def matmul(a, b): + """jnp.matmul defaults to bfloat16, but this helper function doesn't.""" + return np.matmul(a, b, precision=jax.lax.Precision.HIGHEST) + +def normalize(x): + """Normalization helper function.""" + return x / np.linalg.norm(x, axis=-1, keepdims=True) + +def sinusoidal_encoding(position, minimum_frequency_power, + maximum_frequency_power,include_identity = False): + # Compute the sinusoidal encoding components + frequency = 2.0**np.arange(minimum_frequency_power, maximum_frequency_power) + angle = position[..., None, :] * frequency[:, None] + encoding = np.sin(np.stack([angle, angle + 0.5 * np.pi], axis=-2)) + # Flatten encoding dimensions + encoding = encoding.reshape(*position.shape[:-1], -1) + # Add identity component + if include_identity: + encoding = np.concatenate([position, encoding], axis=-1) + return encoding + +# Pose/ray math. + +def generate_rays(pixel_coords, pix2cam, cam2world): + """Generate camera rays from pixel coordinates and poses.""" + homog = np.ones_like(pixel_coords[..., :1]) + pixel_dirs = np.concatenate([pixel_coords + .5, homog], axis=-1)[..., None] + cam_dirs = matmul(pix2cam, pixel_dirs) + ray_dirs = matmul(cam2world[..., :3, :3], cam_dirs)[..., 0] + ray_origins = np.broadcast_to(cam2world[..., :3, 3], ray_dirs.shape) + + #f = 1./pix2cam[0,0] + #w = -2. * f * pix2cam[0,2] + #h = 2. * f * pix2cam[1,2] + + return ray_origins, ray_dirs + +def pix2cam_matrix(height, width, focal): + """Inverse intrinsic matrix for a pinhole camera.""" + return np.array([ + [1./focal, 0, -.5 * width / focal], + [0, -1./focal, .5 * height / focal], + [0, 0, -1.], + ]) + +def camera_ray_batch_xxxxx_original(cam2world, hwf): + """Generate rays for a pinhole camera with given extrinsic and intrinsic.""" + height, width = int(hwf[0]), int(hwf[1]) + pix2cam = pix2cam_matrix(*hwf) + pixel_coords = np.stack(np.meshgrid(np.arange(width), np.arange(height)), axis=-1) + return generate_rays(pixel_coords, pix2cam, cam2world) + +def camera_ray_batch(cam2world, hwf): ### antialiasing by supersampling + """Generate rays for a pinhole camera with given extrinsic and intrinsic.""" + height, width = int(hwf[0]), int(hwf[1]) + pix2cam = pix2cam_matrix(*hwf) + x_ind, y_ind = np.meshgrid(np.arange(width), np.arange(height)) + pixel_coords = np.stack([x_ind-0.25, y_ind-0.25, x_ind+0.25, y_ind-0.25, + x_ind-0.25, y_ind+0.25, x_ind+0.25, y_ind+0.25], axis=-1) + pixel_coords = np.reshape(pixel_coords, [height,width,4,2]) + + return generate_rays(pixel_coords, pix2cam, cam2world) + +def random_ray_batch_xxxxx_original(rng, batch_size, data): + """Generate a random batch of ray data.""" + keys = random.split(rng, 3) + cam_ind = random.randint(keys[0], [batch_size], 0, data['c2w'].shape[0]) + y_ind = random.randint(keys[1], [batch_size], 0, data['images'].shape[1]) + x_ind = random.randint(keys[2], [batch_size], 0, data['images'].shape[2]) + pixel_coords = np.stack([x_ind, y_ind], axis=-1) + pix2cam = pix2cam_matrix(*data['hwf']) + cam2world = data['c2w'][cam_ind, :3, :4] + rays = generate_rays(pixel_coords, pix2cam, cam2world) + pixels = data['images'][cam_ind, y_ind, x_ind] + return rays, pixels + +def random_ray_batch(rng, batch_size, data): ### antialiasing by supersampling + """Generate a random batch of ray data.""" + keys = random.split(rng, 3) + cam_ind = random.randint(keys[0], [batch_size], 0, data['c2w'].shape[0]) + y_ind = random.randint(keys[1], [batch_size], 0, data['images'].shape[1]) + y_ind_f = y_ind.astype(np.float32) + x_ind = random.randint(keys[2], [batch_size], 0, data['images'].shape[2]) + x_ind_f = x_ind.astype(np.float32) + pixel_coords = np.stack([x_ind_f-0.25, y_ind_f-0.25, x_ind_f+0.25, y_ind_f-0.25, + x_ind_f-0.25, y_ind_f+0.25, x_ind_f+0.25, y_ind_f+0.25], axis=-1) + pixel_coords = np.reshape(pixel_coords, [batch_size,4,2]) + pix2cam = pix2cam_matrix(*data['hwf']) + cam_ind_x4 = np.tile(cam_ind[..., None], [1,4]) + cam_ind_x4 = np.reshape(cam_ind_x4, [-1]) + cam2world = data['c2w'][cam_ind_x4, :3, :4] + cam2world = np.reshape(cam2world, [batch_size,4,3,4]) + rays = generate_rays(pixel_coords, pix2cam, cam2world) + pixels = data['images'][cam_ind, y_ind, x_ind] + return rays, pixels + + +# Learning rate helpers. + +def log_lerp(t, v0, v1): + """Interpolate log-linearly from `v0` (t=0) to `v1` (t=1).""" + if v0 <= 0 or v1 <= 0: + raise ValueError(f'Interpolants {v0} and {v1} must be positive.') + lv0 = np.log(v0) + lv1 = np.log(v1) + return np.exp(np.clip(t, 0, 1) * (lv1 - lv0) + lv0) + +def lr_fn(step, max_steps, lr0, lr1, lr_delay_steps=20000, lr_delay_mult=0.1): + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)) + else: + delay_rate = 1. + return delay_rate * log_lerp(step / max_steps, lr0, lr1) + +#%% -------------------------------------------------------------------------------- +# ## Plane parameters and setup +#%% +#scene scales + +if scene_type=="synthetic": + scene_grid_scale = 1.2 + if "hotdog" in scene_dir or "mic" in scene_dir or "ship" in scene_dir: + scene_grid_scale = 1.5 + grid_min = np.array([-1, -1, -1]) * scene_grid_scale + grid_max = np.array([ 1, 1, 1]) * scene_grid_scale + point_grid_size = 128 + + def get_taper_coord(p): + return p + def inverse_taper_coord(p): + return p + +elif scene_type=="forwardfacing": + scene_grid_taper = 1.25 + scene_grid_zstart = 25.0 + scene_grid_zend = 1.0 + scene_grid_scale = 0.7 + grid_min = np.array([-scene_grid_scale, -scene_grid_scale, 0]) + grid_max = np.array([ scene_grid_scale, scene_grid_scale, 1]) + point_grid_size = 128 + + def get_taper_coord(p): + pz = np.maximum(-p[..., 2:3],1e-10) + px = p[..., 0:1]/(pz*scene_grid_taper) + py = p[..., 1:2]/(pz*scene_grid_taper) + pz = (np.log(pz) - np.log(scene_grid_zend))/(np.log(scene_grid_zstart) - np.log(scene_grid_zend)) + return np.concatenate([px,py,pz],axis=-1) + def inverse_taper_coord(p): + pz = np.exp( p[..., 2:3] * \ + (np.log(scene_grid_zstart) - np.log(scene_grid_zend)) + \ + np.log(scene_grid_zend) ) + px = p[..., 0:1]*(pz*scene_grid_taper) + py = p[..., 1:2]*(pz*scene_grid_taper) + pz = -pz + return np.concatenate([px,py,pz],axis=-1) + +elif scene_type=="real360": + scene_grid_zmax = 16.0 + if object_name == "gardenvase": + scene_grid_zmax = 9.0 + grid_min = np.array([-1, -1, -1]) + grid_max = np.array([ 1, 1, 1]) + point_grid_size = 128 + + def get_taper_coord(p): + return p + def inverse_taper_coord(p): + return p + + #approximate solution of e^x = ax+b + #(np.exp( x ) + (x-1)) / x = scene_grid_zmax + #np.exp( x ) - scene_grid_zmax*x + (x-1) = 0 + scene_grid_zcc = -1 + for i in range(10000): + j = numpy.log(scene_grid_zmax)+i/1000.0 + if numpy.exp(j) - scene_grid_zmax*j + (j-1) >0: + scene_grid_zcc = j + break + if scene_grid_zcc<0: + print("ERROR: approximate solution of e^x = ax+b failed") + 1/0 + + + +grid_dtype = np.float32 + +#plane parameter grid +point_grid = np.zeros( + (point_grid_size, point_grid_size, point_grid_size, 3), + dtype=grid_dtype) +acc_grid = np.zeros( + (point_grid_size, point_grid_size, point_grid_size), + dtype=grid_dtype) +point_grid_diff_lr_scale = 16.0/point_grid_size + + + +def get_acc_grid_masks(taper_positions, acc_grid): + grid_positions = (taper_positions - grid_min) * \ + (point_grid_size / (grid_max - grid_min) ) + grid_masks = (grid_positions[..., 0]>=1) & (grid_positions[..., 0]=1) & (grid_positions[..., 1]=1) & (grid_positions[..., 2] 0: + net = np.concatenate([net, inputs], axis=-1) + + net = dense_layer(self.out_dim)(net) + + return net + +# Set up the MLPs for color and density. +class MLP(nn.Module): + features: Sequence[int] + + @nn.compact + def __call__(self, x): + for feat in self.features[:-1]: + x = nn.relu(nn.Dense(feat)(x)) + x = nn.Dense(self.features[-1])(x) + return x + + +density_model = RadianceField(1) +feature_model = RadianceField(num_bottleneck_features) +color_model = MLP([16,16,3]) + +# These are the variables we will be optimizing during trianing. +model_vars = [point_grid, acc_grid, + density_model.init( + jax.random.PRNGKey(0), + np.zeros([1, 3])), + feature_model.init( + jax.random.PRNGKey(0), + np.zeros([1, 3])), + color_model.init( + jax.random.PRNGKey(0), + np.zeros([1, 3+num_bottleneck_features])), + ] + +#avoid bugs +point_grid = None +acc_grid = None +#%% -------------------------------------------------------------------------------- +# ## Load weights +#%% +vars = pickle.load(open(weights_dir+"/"+"weights_stage2_1.pkl", "rb")) +model_vars = vars +#%% -------------------------------------------------------------------------------- +# ## Get mesh +#%% + +#%% +#extract mesh vertices +layer_num = point_grid_size + +v_grid = numpy.zeros([layer_num+1,layer_num+1,layer_num+1,3], numpy.float32) +v_grid[:-1,:-1,:-1] = numpy.array(vars[0])*point_grid_diff_lr_scale +#%% +#get UV coordinates + +if scene_type=="synthetic": + texture_size = 1024*2 + batch_num = 8*8*8 +elif scene_type=="forwardfacing": + texture_size = 1024*2 + batch_num = 8*8*8 +elif scene_type=="real360": + texture_size = 1024*2 + batch_num = 8*8*8 + +test_threshold = 0.1 + + +out_feat_num = num_bottleneck_features//4 + +quad_size = texture_size//layer_num +assert quad_size*layer_num == texture_size +#pre-compute weights for each quad +# 0 - 1 x +# | \ | +# 2 - 3 +# y +quad_weights = numpy.zeros([quad_size,quad_size,4],numpy.float32) +for i in range(quad_size): + for j in range(quad_size): + x = (i)/quad_size + y = (j)/quad_size + if x>y: + quad_weights[i,j,0] = 1-x + quad_weights[i,j,1] = x-y + quad_weights[i,j,2] = 0 + quad_weights[i,j,3] = y + else: + quad_weights[i,j,0] = 1-y + quad_weights[i,j,1] = 0 + quad_weights[i,j,2] = y-x + quad_weights[i,j,3] = x +quad_weights = numpy.reshape(quad_weights,[quad_size*quad_size,4]) +quad_weights = numpy.transpose(quad_weights, (1,0)) #[4,quad_size*quad_size] + +grid_max_numpy = numpy.array(grid_max,numpy.float32) +grid_min_numpy = numpy.array(grid_min,numpy.float32) + +i_grid = numpy.zeros([layer_num,layer_num,layer_num],numpy.int32) +j_grid = numpy.zeros([layer_num,layer_num,layer_num],numpy.int32) +k_grid = numpy.zeros([layer_num,layer_num,layer_num],numpy.int32) + +i_grid[:,:,:] = numpy.reshape(numpy.arange(layer_num),[-1,1,1]) +j_grid[:,:,:] = numpy.reshape(numpy.arange(layer_num),[1,-1,1]) +k_grid[:,:,:] = numpy.reshape(numpy.arange(layer_num),[1,1,-1]) + + + +def get_density_color(pts, vars): + #redefine net + + acc_grid_masks = get_acc_grid_masks(pts, vars[1]) + + # Now use the MLP to compute density and features + mlp_alpha = density_model.apply(vars[-3], pts) + mlp_alpha = jax.nn.sigmoid(mlp_alpha[..., 0]-8) + mlp_alpha = mlp_alpha * (acc_grid_masks>=test_threshold) + mlp_alpha = (mlp_alpha>0.5).astype(np.uint8) + + #previous: (features+dirs)->MLP->(RGB) + mlp_features = jax.nn.sigmoid(feature_model.apply(vars[-2], pts)) + #discretize + mlp_features_ = np.round(mlp_features*255).astype(np.uint8) + mlp_features_0 = np.clip(mlp_features_[...,0:1],1,255)*mlp_alpha[..., None] + mlp_features_1 = mlp_features_[...,1:]*mlp_alpha[..., None] + mlp_features_ = np.concatenate([mlp_features_0,mlp_features_1],axis=-1) + + return mlp_features_ + +get_density_color_p = jax.pmap(lambda pts, vars: get_density_color(pts,vars), + in_axes=(0, None)) + + + +def get_feature_png(feat): + h,w,c = feat.shape + #deal with opencv BGR->RGB + if c%4!=0: + print("ERROR: c%4!=0") + 1/0 + out = [] + for i in range(out_feat_num): + ff = numpy.zeros([h,w,4],numpy.uint8) + ff[...,0] = feat[..., i*4+2] + ff[...,1] = feat[..., i*4+1] + ff[...,2] = feat[..., i*4+0] + ff[...,3] = feat[..., i*4+3] + out.append(ff) + return out + + + + + +##### z planes + +x,y,z = j_grid,k_grid,i_grid +p0 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = j_grid+1,k_grid,i_grid +p1 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = j_grid,k_grid+1,i_grid +p2 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = j_grid+1,k_grid+1,i_grid +p3 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +p0123 = numpy.stack([p0,p1,p2,p3],axis=-1) #[M,N,K,3,4] +p0123 = p0123 @ quad_weights #[M,N,K,3,quad_size*quad_size] +p0123 = numpy.reshape(p0123, [layer_num,layer_num,layer_num,3,quad_size,quad_size]) #[M,N,K,3,quad_size,quad_size] +p0123 = numpy.transpose(p0123, (0,1,4,2,5,3)) #[M,N,quad_size,K,quad_size,3] +#positions_z = numpy.reshape(numpy.ascontiguousarray(p0123), [layer_num,layer_num*quad_size,layer_num*quad_size,3]) +positions_z = numpy.reshape(numpy.ascontiguousarray(p0123), [-1,3]) + +p0 = None +p1 = None +p2 = None +p3 = None +p0123 = None + +total_len = len(positions_z) +batch_len = total_len//batch_num +coarse_feature_z = numpy.zeros([total_len,num_bottleneck_features],numpy.uint8) +for i in range(batch_num): + t0 = numpy.reshape(positions_z[i*batch_len:(i+1)*batch_len], [n_device,-1,3]) + t0 = get_density_color_p(t0,vars) + coarse_feature_z[i*batch_len:(i+1)*batch_len] = numpy.reshape(t0,[-1,num_bottleneck_features]) +coarse_feature_z = numpy.reshape(coarse_feature_z,[layer_num,texture_size,texture_size,num_bottleneck_features]) +coarse_feature_z[:,-quad_size:,:] = 0 +coarse_feature_z[:,:,-quad_size:] = 0 + +positions_z = None + +buffer_z = [] +for i in range(layer_num): + if not numpy.any(coarse_feature_z[i,:,:,0]>0): + buffer_z.append(None) + continue + feats = get_feature_png(coarse_feature_z[i]) + buffer_z.append(feats) + +coarse_feature_z = None + + + +##### x planes + +x,y,z = i_grid,j_grid,k_grid +p0 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = i_grid,j_grid+1,k_grid +p1 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = i_grid,j_grid,k_grid+1 +p2 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = i_grid,j_grid+1,k_grid+1 +p3 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +p0123 = numpy.stack([p0,p1,p2,p3],axis=-1) #[M,N,K,3,4] +p0123 = p0123 @ quad_weights #[M,N,K,3,quad_size*quad_size] +p0123 = numpy.reshape(p0123, [layer_num,layer_num,layer_num,3,quad_size,quad_size]) #[M,N,K,3,quad_size,quad_size] +p0123 = numpy.transpose(p0123, (0,1,4,2,5,3)) #[M,N,quad_size,K,quad_size,3] +#positions_x = numpy.reshape(numpy.ascontiguousarray(p0123), [layer_num,layer_num*quad_size,layer_num*quad_size,3]) +positions_x = numpy.reshape(numpy.ascontiguousarray(p0123), [-1,3]) + +p0 = None +p1 = None +p2 = None +p3 = None +p0123 = None + +total_len = len(positions_x) +batch_len = total_len//batch_num +coarse_feature_x = numpy.zeros([total_len,num_bottleneck_features],numpy.uint8) +for i in range(batch_num): + t0 = numpy.reshape(positions_x[i*batch_len:(i+1)*batch_len], [n_device,-1,3]) + t0 = get_density_color_p(t0,vars) + coarse_feature_x[i*batch_len:(i+1)*batch_len] = numpy.reshape(t0,[-1,num_bottleneck_features]) +coarse_feature_x = numpy.reshape(coarse_feature_x,[layer_num,texture_size,texture_size,num_bottleneck_features]) +coarse_feature_x[:,-quad_size:,:] = 0 +coarse_feature_x[:,:,-quad_size:] = 0 + +positions_x = None + +buffer_x = [] +for i in range(layer_num): + if not numpy.any(coarse_feature_x[i,:,:,0]>0): + buffer_x.append(None) + continue + feats = get_feature_png(coarse_feature_x[i]) + buffer_x.append(feats) + +coarse_feature_x = None + + + +##### y planes + +x,y,z = j_grid,i_grid,k_grid +p0 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = j_grid+1,i_grid,k_grid +p1 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = j_grid,i_grid,k_grid+1 +p2 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +x,y,z = j_grid+1,i_grid,k_grid+1 +p3 = v_grid[x,y,z] + (numpy.stack([x,y,z],axis=-1).astype(numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy +p0123 = numpy.stack([p0,p1,p2,p3],axis=-1) #[M,N,K,3,4] +p0123 = p0123 @ quad_weights #[M,N,K,3,quad_size*quad_size] +p0123 = numpy.reshape(p0123, [layer_num,layer_num,layer_num,3,quad_size,quad_size]) #[M,N,K,3,quad_size,quad_size] +p0123 = numpy.transpose(p0123, (0,1,4,2,5,3)) #[M,N,quad_size,K,quad_size,3] +#positions_y = numpy.reshape(numpy.ascontiguousarray(p0123), [layer_num,layer_num*quad_size,layer_num*quad_size,3]) +positions_y = numpy.reshape(numpy.ascontiguousarray(p0123), [-1,3]) + +p0 = None +p1 = None +p2 = None +p3 = None +p0123 = None + +total_len = len(positions_y) +batch_len = total_len//batch_num +coarse_feature_y = numpy.zeros([total_len,num_bottleneck_features],numpy.uint8) +for i in range(batch_num): + t0 = numpy.reshape(positions_y[i*batch_len:(i+1)*batch_len], [n_device,-1,3]) + t0 = get_density_color_p(t0,vars) + coarse_feature_y[i*batch_len:(i+1)*batch_len] = numpy.reshape(t0,[-1,num_bottleneck_features]) +coarse_feature_y = numpy.reshape(coarse_feature_y,[layer_num,texture_size,texture_size,num_bottleneck_features]) +coarse_feature_y[:,-quad_size:,:] = 0 +coarse_feature_y[:,:,-quad_size:] = 0 + +positions_y = None + +buffer_y = [] +for i in range(layer_num): + if not numpy.any(coarse_feature_y[i,:,:,0]>0): + buffer_y.append(None) + continue + feats = get_feature_png(coarse_feature_y[i]) + buffer_y.append(feats) + +coarse_feature_y = None + +#%% +write_floatpoint_image(samples_dir+"/s3_slice_sample.png",buffer_x[layer_num//2][0]/255.0) +#%% +out_img_size = 1024*20 +out_img = [] +for i in range(out_feat_num): + out_img.append(numpy.zeros([out_img_size,out_img_size,4], numpy.uint8)) +out_cell_num = 0 +out_cell_size = quad_size+1 +out_img_h = out_img_size//out_cell_size +out_img_w = out_img_size//out_cell_size + + + +if scene_type=="synthetic": + def inverse_taper_coord_numpy(p): + return p + +elif scene_type=="forwardfacing": + def inverse_taper_coord_numpy(p): + pz = numpy.exp( p[..., 2:3] * \ + (numpy.log(scene_grid_zstart) - numpy.log(scene_grid_zend)) + \ + numpy.log(scene_grid_zend) ) + px = p[..., 0:1]*(pz*scene_grid_taper) + py = p[..., 1:2]*(pz*scene_grid_taper) + pz = -pz + return numpy.concatenate([px,py,pz],axis=-1) + +elif scene_type=="real360": + def inverse_taper_coord_numpy(p): + return p + + + +def write_patch_to_png(out_img,out_cell_num,out_img_w,j,k,feats): + py = out_cell_num//out_img_w + px = out_cell_num%out_img_w + + osy = j*quad_size + oey = j*quad_size+out_cell_size + tsy = py*out_cell_size + tey = py*out_cell_size+out_cell_size + osx = k*quad_size + oex = k*quad_size+out_cell_size + tsx = px*out_cell_size + tex = px*out_cell_size+out_cell_size + + for i in range(out_feat_num): + out_img[i][tsy:tey,tsx:tex] = feats[i][osy:oey,osx:oex] + +def get_png_uv(out_cell_num,out_img_w,out_img_size): + py = out_cell_num//out_img_w + px = out_cell_num%out_img_w + + uv0 = numpy.array([py*out_cell_size+0.5, px*out_cell_size+0.5],numpy.float32)/out_img_size + uv1 = numpy.array([(py+1)*out_cell_size-0.5, px*out_cell_size+0.5],numpy.float32)/out_img_size + uv2 = numpy.array([py*out_cell_size+0.5, (px+1)*out_cell_size-0.5],numpy.float32)/out_img_size + uv3 = numpy.array([(py+1)*out_cell_size-0.5, (px+1)*out_cell_size-0.5],numpy.float32)/out_img_size + + return uv0,uv1,uv2,uv3 + + +#for eval +point_UV_grid = numpy.zeros([point_grid_size,point_grid_size,point_grid_size,3,4,2], numpy.float32) + +#mesh vertices +bag_of_v = [] + + +#synthetic and real360 +#up is z- +#order: z-,x+,y+ +if scene_type=="synthetic" or scene_type=="real360": + for k in range(layer_num-1,-1,-1): + for i in range(layer_num): + for j in range(layer_num): + + # z plane + if not(k==0 or k==layer_num-1 or i==layer_num-1 or j==layer_num-1): + feats = buffer_z[k] + if feats is not None and numpy.max(feats[0][i*quad_size:(i+1)*quad_size+1,j*quad_size:(j+1)*quad_size+1,2])>0: + + write_patch_to_png(out_img,out_cell_num,out_img_w,i,j,feats) + uv0,uv1,uv2,uv3 = get_png_uv(out_cell_num,out_img_w,out_img_size) + out_cell_num += 1 + + p0 = v_grid[i,j,k] + (numpy.array([i,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p1 = v_grid[i+1,j,k] + (numpy.array([i+1,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p2 = v_grid[i,j+1,k] + (numpy.array([i,j+1,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p3 = v_grid[i+1,j+1,k] + (numpy.array([i+1,j+1,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + + p0 = inverse_taper_coord_numpy(p0) + p1 = inverse_taper_coord_numpy(p1) + p2 = inverse_taper_coord_numpy(p2) + p3 = inverse_taper_coord_numpy(p3) + + point_UV_grid[i,j,k,2,0] = uv0 + point_UV_grid[i,j,k,2,1] = uv1 + point_UV_grid[i,j,k,2,2] = uv2 + point_UV_grid[i,j,k,2,3] = uv3 + + bag_of_v.append([p0,p1,p2,p3]) + + # x plane + if not(i==0 or i==layer_num-1 or j==layer_num-1 or k==layer_num-1): + feats = buffer_x[i] + if feats is not None and numpy.max(feats[0][j*quad_size:(j+1)*quad_size+1,k*quad_size:(k+1)*quad_size+1,2])>0: + + write_patch_to_png(out_img,out_cell_num,out_img_w,j,k,feats) + uv0,uv1,uv2,uv3 = get_png_uv(out_cell_num,out_img_w,out_img_size) + out_cell_num += 1 + + p0 = v_grid[i,j,k] + (numpy.array([i,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p1 = v_grid[i,j+1,k] + (numpy.array([i,j+1,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p2 = v_grid[i,j,k+1] + (numpy.array([i,j,k+1],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p3 = v_grid[i,j+1,k+1] + (numpy.array([i,j+1,k+1],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + + p0 = inverse_taper_coord_numpy(p0) + p1 = inverse_taper_coord_numpy(p1) + p2 = inverse_taper_coord_numpy(p2) + p3 = inverse_taper_coord_numpy(p3) + + point_UV_grid[i,j,k,0,0] = uv0 + point_UV_grid[i,j,k,0,1] = uv1 + point_UV_grid[i,j,k,0,2] = uv2 + point_UV_grid[i,j,k,0,3] = uv3 + + bag_of_v.append([p0,p1,p2,p3]) + + # y plane + if not(j==0 or j==layer_num-1 or i==layer_num-1 or k==layer_num-1): + feats = buffer_y[j] + if feats is not None and numpy.max(feats[0][i*quad_size:(i+1)*quad_size+1,k*quad_size:(k+1)*quad_size+1,2])>0: + + write_patch_to_png(out_img,out_cell_num,out_img_w,i,k,feats) + uv0,uv1,uv2,uv3 = get_png_uv(out_cell_num,out_img_w,out_img_size) + out_cell_num += 1 + + p0 = v_grid[i,j,k] + (numpy.array([i,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p1 = v_grid[i+1,j,k] + (numpy.array([i+1,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p2 = v_grid[i,j,k+1] + (numpy.array([i,j,k+1],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p3 = v_grid[i+1,j,k+1] + (numpy.array([i+1,j,k+1],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + + p0 = inverse_taper_coord_numpy(p0) + p1 = inverse_taper_coord_numpy(p1) + p2 = inverse_taper_coord_numpy(p2) + p3 = inverse_taper_coord_numpy(p3) + + point_UV_grid[i,j,k,1,0] = uv0 + point_UV_grid[i,j,k,1,1] = uv1 + point_UV_grid[i,j,k,1,2] = uv2 + point_UV_grid[i,j,k,1,3] = uv3 + + bag_of_v.append([p0,p1,p2,p3]) + + + + + +#forwardfacing +#front is z- +#order: z+,x+,y+ +elif scene_type=="forwardfacing": + for k in range(layer_num): + for i in range(layer_num): + for j in range(layer_num): + + # z plane + if not(k==0 or k==layer_num-1 or i==layer_num-1 or j==layer_num-1): + feats = buffer_z[k] + if feats is not None and numpy.max(feats[0][i*quad_size:(i+1)*quad_size+1,j*quad_size:(j+1)*quad_size+1,2])>0: + + write_patch_to_png(out_img,out_cell_num,out_img_w,i,j,feats) + uv0,uv1,uv2,uv3 = get_png_uv(out_cell_num,out_img_w,out_img_size) + out_cell_num += 1 + + p0 = v_grid[i,j,k] + (numpy.array([i,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p1 = v_grid[i+1,j,k] + (numpy.array([i+1,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p2 = v_grid[i,j+1,k] + (numpy.array([i,j+1,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p3 = v_grid[i+1,j+1,k] + (numpy.array([i+1,j+1,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + + p0 = inverse_taper_coord_numpy(p0) + p1 = inverse_taper_coord_numpy(p1) + p2 = inverse_taper_coord_numpy(p2) + p3 = inverse_taper_coord_numpy(p3) + + point_UV_grid[i,j,k,2,0] = uv0 + point_UV_grid[i,j,k,2,1] = uv1 + point_UV_grid[i,j,k,2,2] = uv2 + point_UV_grid[i,j,k,2,3] = uv3 + + bag_of_v.append([p0,p1,p2,p3]) + + # x plane + if not(i==0 or i==layer_num-1 or j==layer_num-1 or k==layer_num-1): + feats = buffer_x[i] + if feats is not None and numpy.max(feats[0][j*quad_size:(j+1)*quad_size+1,k*quad_size:(k+1)*quad_size+1,2])>0: + + write_patch_to_png(out_img,out_cell_num,out_img_w,j,k,feats) + uv0,uv1,uv2,uv3 = get_png_uv(out_cell_num,out_img_w,out_img_size) + out_cell_num += 1 + + p0 = v_grid[i,j,k] + (numpy.array([i,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p1 = v_grid[i,j+1,k] + (numpy.array([i,j+1,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p2 = v_grid[i,j,k+1] + (numpy.array([i,j,k+1],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p3 = v_grid[i,j+1,k+1] + (numpy.array([i,j+1,k+1],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + + p0 = inverse_taper_coord_numpy(p0) + p1 = inverse_taper_coord_numpy(p1) + p2 = inverse_taper_coord_numpy(p2) + p3 = inverse_taper_coord_numpy(p3) + + point_UV_grid[i,j,k,0,0] = uv0 + point_UV_grid[i,j,k,0,1] = uv1 + point_UV_grid[i,j,k,0,2] = uv2 + point_UV_grid[i,j,k,0,3] = uv3 + + bag_of_v.append([p0,p1,p2,p3]) + + # y plane + if not(j==0 or j==layer_num-1 or i==layer_num-1 or k==layer_num-1): + feats = buffer_y[j] + if feats is not None and numpy.max(feats[0][i*quad_size:(i+1)*quad_size+1,k*quad_size:(k+1)*quad_size+1,2])>0: + + write_patch_to_png(out_img,out_cell_num,out_img_w,i,k,feats) + uv0,uv1,uv2,uv3 = get_png_uv(out_cell_num,out_img_w,out_img_size) + out_cell_num += 1 + + p0 = v_grid[i,j,k] + (numpy.array([i,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p1 = v_grid[i+1,j,k] + (numpy.array([i+1,j,k],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p2 = v_grid[i,j,k+1] + (numpy.array([i,j,k+1],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + p3 = v_grid[i+1,j,k+1] + (numpy.array([i+1,j,k+1],numpy.float32)+0.5)*((grid_max_numpy - grid_min_numpy)/point_grid_size) + grid_min_numpy + + p0 = inverse_taper_coord_numpy(p0) + p1 = inverse_taper_coord_numpy(p1) + p2 = inverse_taper_coord_numpy(p2) + p3 = inverse_taper_coord_numpy(p3) + + point_UV_grid[i,j,k,1,0] = uv0 + point_UV_grid[i,j,k,1,1] = uv1 + point_UV_grid[i,j,k,1,2] = uv2 + point_UV_grid[i,j,k,1,3] = uv3 + + bag_of_v.append([p0,p1,p2,p3]) + + + +print("Number of quad faces:", out_cell_num) + +buffer_x = None +buffer_y = None +buffer_z = None +#%% -------------------------------------------------------------------------------- +# ## The boxes +#%% +box_layer_num = layer_num//2+1 +box_texture_size = (layer_num//2)*quad_size+1 + +layers_ = numpy.arange(box_layer_num,dtype=numpy.float32)/(box_layer_num-1) #[0,1] +layers_ = layers_+1 #[1,2] + +layers = numpy.arange(box_texture_size,dtype=numpy.float32)/(box_texture_size-1) #[0,1] +layers = layers*2-1 #[-1,1] + +positions_x_p = numpy.zeros([box_layer_num,box_texture_size,box_texture_size,3],numpy.float32) +positions_x_p[:,:,:,0] = layers_[:,None,None] +positions_x_p[:,:,:,1] = layers_[:,None,None] * layers[None,:,None] +positions_x_p[:,:,:,2] = layers_[:,None,None] * layers[None,None,:] +positions_x_p = numpy.reshape(positions_x_p, [-1,3]) + + +positions_x_n = numpy.zeros([box_layer_num,box_texture_size,box_texture_size,3],numpy.float32) +positions_x_n[:,:,:,0] = -layers_[:,None,None] +positions_x_n[:,:,:,1] = layers_[:,None,None] * layers[None,:,None] +positions_x_n[:,:,:,2] = layers_[:,None,None] * layers[None,None,:] +positions_x_n = numpy.reshape(positions_x_n, [-1,3]) + + +positions_y_p = numpy.zeros([box_layer_num,box_texture_size,box_texture_size,3],numpy.float32) +positions_y_p[:,:,:,1] = layers_[:,None,None] +positions_y_p[:,:,:,0] = layers_[:,None,None] * layers[None,:,None] +positions_y_p[:,:,:,2] = layers_[:,None,None] * layers[None,None,:] +positions_y_p = numpy.reshape(positions_y_p, [-1,3]) + + +positions_y_n = numpy.zeros([box_layer_num,box_texture_size,box_texture_size,3],numpy.float32) +positions_y_n[:,:,:,1] = -layers_[:,None,None] +positions_y_n[:,:,:,0] = layers_[:,None,None] * layers[None,:,None] +positions_y_n[:,:,:,2] = layers_[:,None,None] * layers[None,None,:] +positions_y_n = numpy.reshape(positions_y_n, [-1,3]) + +total_len = len(positions_y_n)*4 +batch_len = (total_len//(batch_num*n_device)+1)*n_device + +positions_box = numpy.concatenate([positions_x_p,positions_x_n, + positions_y_p,positions_y_n, + numpy.zeros([batch_len*batch_num-total_len,3],numpy.float32) #padding + ],axis=0) +positions_x_p = None +positions_x_n = None +positions_y_p = None +positions_y_n = None +layers_ = None +layers = None + + + +def get_density_color(pts, vars): + #redefine net + + acc_grid_masks = get_acc_grid_masks(pts, vars[1]) + + # Now use the MLP to compute density and features + mlp_alpha = density_model.apply(vars[-3], pts) + mlp_alpha = jax.nn.sigmoid(mlp_alpha[..., 0]-8) + mlp_alpha = (mlp_alpha>0.5).astype(np.uint8) + + #previous: (features+dirs)->MLP->(RGB) + mlp_features = jax.nn.sigmoid(feature_model.apply(vars[-2], pts)) + #discretize + mlp_features_ = np.round(mlp_features*255).astype(np.uint8) + mlp_features_0 = np.clip(mlp_features_[...,0:1],1,255)*mlp_alpha[..., None] + mlp_features_1 = mlp_features_[...,1:]*mlp_alpha[..., None] + mlp_features_ = np.concatenate([mlp_features_0,mlp_features_1],axis=-1) + + return mlp_features_ + +get_density_color_p = jax.pmap(lambda pts, vars: get_density_color(pts,vars), + in_axes=(0, None)) + + +coarse_feature_box = numpy.zeros([batch_len*batch_num,num_bottleneck_features],numpy.uint8) +for i in range(batch_num): + t0 = numpy.reshape(positions_box[i*batch_len:(i+1)*batch_len], [n_device,-1,3]) + t0 = get_density_color_p(t0,vars) + coarse_feature_box[i*batch_len:(i+1)*batch_len] = numpy.reshape(t0,[-1,num_bottleneck_features]) +coarse_feature_box = numpy.ascontiguousarray(coarse_feature_box[:total_len]) +coarse_feature_box = numpy.reshape(coarse_feature_box,[4,box_layer_num,box_texture_size,box_texture_size,num_bottleneck_features]) + +positions_box = None +#%% +write_floatpoint_image(samples_dir+"/s3_slicebox_sample.png",coarse_feature_box[0,16,:,:,:3]/255.0) +#%% -------------------------------------------------------------------------------- +# ## Main rendering functions +#%% + +#compute ray-gridcell intersections + +if scene_type=="real360": + + def gridcell_from_rays(rays): + ray_origins = rays[0] + ray_directions = rays[1] + + dtype = ray_origins.dtype + batch_shape = ray_origins.shape[:-1] + small_step = 1e-5 + epsilon = 1e-5 + + ox = ray_origins[..., 0:1] + oy = ray_origins[..., 1:2] + oz = ray_origins[..., 2:3] + + dx = ray_directions[..., 0:1] + dy = ray_directions[..., 1:2] + dz = ray_directions[..., 2:3] + + dxm = (np.abs(dx)=1) & (grid_positions[..., 0]=1) & (grid_positions[..., 1]=1) & (grid_positions[..., 2]=0) & (b>=0) & (c>=0) & np.logical_not(denominator_mask) + return a,b,c,mask + + + + +cell_size_x = (grid_max[0] - grid_min[0])/point_grid_size +half_cell_size_x = cell_size_x/2 +neg_half_cell_size_x = -half_cell_size_x +cell_size_y = (grid_max[1] - grid_min[1])/point_grid_size +half_cell_size_y = cell_size_y/2 +neg_half_cell_size_y = -half_cell_size_y +cell_size_z = (grid_max[2] - grid_min[2])/point_grid_size +half_cell_size_z = cell_size_z/2 +neg_half_cell_size_z = -half_cell_size_z + +def get_inside_cell_mask(P,ooxyz): + P_ = get_taper_coord(P) - ooxyz + return (P_[..., 0]>=neg_half_cell_size_x) \ + & (P_[..., 0]=neg_half_cell_size_y) \ + & (P_[..., 1]=neg_half_cell_size_z) \ + & (P_[..., 2]tx_n,tx_p,tx_n) + ty = np.where(ty_p>ty_n,ty_p,ty_n) + txid = np.where(tx_p>tx_n,0,1).astype(np.int32) + tyid = np.where(ty_p>ty_n,2,3).astype(np.int32) + + tx_py = oy + dy * tx + ty_px = ox + dx * ty + txym = (np.abs(tx_py) 0.5) #[N,4] + + ind = np.argmax(weights_b, axis=-1, keepdims=True) #[N,4,1] + selected_uv = np.take_along_axis(world_uv, ind[..., None], axis=-2) #[N,4,1,2] + selected_uv = selected_uv[..., 0,:] * acc_b[..., None] #[N,4,2] + + + #---------- ray-box intersection points + skybox_masks, skybox_uv = compute_box_intersection_and_return_uv(rays) + + return acc_b, selected_uv, skybox_masks, skybox_uv + +def render_rays_get_color(rays, vars, mlp_features_b, acc_b): + + mlp_features_b = mlp_features_b.astype(np.float32)/255 #[N,4,C] + mlp_features_b = mlp_features_b * acc_b[..., None] #[N,4,C] + mlp_features_b = np.mean(mlp_features_b, axis=-2) #[N,C] + + acc_b = np.mean(acc_b.astype(np.float32), axis=-1) #[N] + + # ... as well as view-dependent colors. + dirs = normalize(rays[1]) #[N,4,3] + dirs = np.mean(dirs, axis=-2) #[N,3] + features_dirs_enc_b = np.concatenate([mlp_features_b, dirs], axis=-1) #[N,C+3] + rgb_b = jax.nn.sigmoid(color_model.apply(vars[-1], features_dirs_enc_b)) + + # Composite onto the background color. + if white_bkgd: + rgb_b = rgb_b * acc_b[..., None] + (1. - acc_b[..., None]) + else: + bgc = bg_color + rgb_b = rgb_b * acc_b[..., None] + (1. - acc_b[..., None]) * bgc + + return rgb_b, acc_b + +#%% -------------------------------------------------------------------------------- +# ## Set up pmap'd rendering for test time evaluation. +#%% +#for eval +texture_alpha = numpy.zeros([out_img_size,out_img_size,1], numpy.uint8) +texture_features = numpy.zeros([out_img_size,out_img_size,8], numpy.uint8) + +texture_alpha[:,:,0] = (out_img[0][:,:,2]>0) + +texture_features[:,:,0:3] = out_img[0][:,:,2::-1] +texture_features[:,:,3] = out_img[0][:,:,3] +texture_features[:,:,4:7] = out_img[1][:,:,2::-1] +texture_features[:,:,7] = out_img[1][:,:,3] + +texture_alpha_box = (coarse_feature_box[:,:,:,:,0]>0).astype(numpy.uint8) +#coarse_feature_box [4,box_layer_num,box_texture_size,box_texture_size,num_bottleneck_features] +#%% +test_batch_size = 4096*n_device + +render_rays_get_uv_p = jax.pmap(lambda rays, vars, uv, alp: render_rays_get_uv( + rays, vars, uv, alp), + in_axes=(0, None, None, None)) + +render_rays_get_color_p = jax.pmap(lambda rays, vars, mlp_features_b, acc_b: render_rays_get_color( + rays, vars, mlp_features_b, acc_b), + in_axes=(0, None, 0, 0)) + + +def render_test(rays, vars, uv, alp, alp_box, feat, feat_box): + sh = rays[0].shape + rays = [x.reshape((jax.local_device_count(), -1) + sh[1:]) for x in rays] + acc_b, selected_uv, skybox_masks, skybox_uv = render_rays_get_uv_p(rays, vars, uv, alp) + + #deferred features + selected_uv = numpy.array(selected_uv) + acc_b = numpy.array(acc_b) + mlp_features_b = feat[selected_uv[...,0],selected_uv[...,1]] + + #box features + skybox_masks = numpy.array(skybox_masks) + skybox_uv = numpy.array(skybox_uv) + mlp_alpha_b = alp_box[skybox_uv[...,0],skybox_uv[...,1],skybox_uv[...,2],skybox_uv[...,3]] #[N,4,P] + mlp_alpha_b = mlp_alpha_b*skybox_masks + weights_b = compute_volumetric_rendering_weights_with_alpha_numpy(mlp_alpha_b) #[N,4,P] + acc_b_box = (numpy.sum(weights_b, axis=-1) > 0.5) #[N,4] + acc_b_box = acc_b_box & numpy.logical_not(acc_b) + + ind = numpy.argmax(weights_b, axis=-1)[..., None] #[N,4,1] + selected_uv_box = numpy.take_along_axis(skybox_uv, ind[..., None], axis=-2) #[N,4,1,4] + selected_uv_box = selected_uv_box[..., 0,:] * acc_b_box[..., None] #[N,4,4] + box_features_b = feat_box[selected_uv_box[...,0],selected_uv_box[...,1],selected_uv_box[...,2],selected_uv_box[...,3]] + + mlp_features_b = mlp_features_b*acc_b[...,None] + box_features_b*acc_b_box[...,None] + acc_b = acc_b | acc_b_box + + rgb_b, acc_b = render_rays_get_color_p(rays, vars, mlp_features_b, acc_b) + + out = [rgb_b, acc_b, selected_uv, selected_uv_box] + out = [numpy.reshape(numpy.array(out[i]),sh[:-2]+(-1,)) for i in range(4)] + return out + +def render_loop(rays, vars, uv, alp, alp_box, feat, feat_box, chunk): + sh = list(rays[0].shape[:-2]) + rays = [x.reshape([-1, 4, 3]) for x in rays] + l = rays[0].shape[0] + n = jax.local_device_count() + p = ((l - 1) // n + 1) * n - l + rays = [np.pad(x, ((0,p),(0,0),(0,0))) for x in rays] + outs = [render_test([x[i:i+chunk] for x in rays], vars, uv, alp, alp_box, feat, feat_box) + for i in range(0, rays[0].shape[0], chunk)] + outs = [np.reshape( + np.concatenate([z[i] for z in outs])[:l], sh + [-1]) for i in range(4)] + return outs + +# Make sure that everything works, by rendering an image from the test set + +if scene_type=="synthetic": + selected_test_index = 97 + preview_image_height = 800 + +elif scene_type=="forwardfacing": + selected_test_index = 0 + preview_image_height = 756//2 + +elif scene_type=="real360": + selected_test_index = 0 + preview_image_height = 840//2 + +rays = camera_ray_batch( + data['test']['c2w'][selected_test_index], data['test']['hwf']) +gt = data['test']['images'][selected_test_index] +out = render_loop(rays, model_vars, point_UV_grid, texture_alpha, texture_alpha_box, + texture_features, coarse_feature_box, test_batch_size) +rgb = out[0] +acc = out[1] +write_floatpoint_image(samples_dir+"/s3_"+str(0)+"_rgb_discretized.png",rgb) +write_floatpoint_image(samples_dir+"/s3_"+str(0)+"_gt.png",gt) +write_floatpoint_image(samples_dir+"/s3_"+str(0)+"_acc_discretized.png",acc) +#%% -------------------------------------------------------------------------------- +# ## Remove invisible triangles +#%% +gc.collect() + +render_poses = data['train']['c2w'] +texture_mask = numpy.zeros([out_img_size,out_img_size], numpy.uint8) +texture_mask_box = numpy.zeros(texture_alpha_box.shape, numpy.uint8) +print("Removing invisible triangles") +for p in tqdm(render_poses): + out = render_loop(camera_ray_batch(p, hwf), vars, point_UV_grid, + texture_alpha, texture_alpha_box, + texture_features, coarse_feature_box, test_batch_size) + uv = np.reshape(out[2],[-1,2]) + texture_mask[uv[:,0],uv[:,1]] = 1 + uv = np.reshape(out[3],[-1,4]) + texture_mask_box[uv[:,0],uv[:,1],uv[:,2],uv[:,3]] = 1 +#%% +#mask invisible triangles for eval + +#count visible quads +num_visible_quads = 0 + +quad_t1_mask = numpy.zeros([out_cell_size,out_cell_size],numpy.uint8) +quad_t2_mask = numpy.zeros([out_cell_size,out_cell_size],numpy.uint8) +for i in range(out_cell_size): + for j in range(out_cell_size): + if i>=j: + quad_t1_mask[i,j] = 1 + if i<=j: + quad_t2_mask[i,j] = 1 + +#more strict check for boxes +quad_t1_mask_box = numpy.copy(quad_t1_mask) +quad_t2_mask_box = numpy.copy(quad_t2_mask) +quad_t1_mask_box[0,:] = 0 +quad_t1_mask_box[-1,:] = 0 +quad_t1_mask_box[:,0] = 0 +quad_t1_mask_box[:,-1] = 0 +quad_t2_mask_box[0,:] = 0 +quad_t2_mask_box[-1,:] = 0 +quad_t2_mask_box[:,0] = 0 +quad_t2_mask_box[:,-1] = 0 + +def check_triangle_visible(mask,out_cell_num): + py = out_cell_num//out_img_w + px = out_cell_num%out_img_w + + tsy = py*out_cell_size + tey = py*out_cell_size+out_cell_size + tsx = px*out_cell_size + tex = px*out_cell_size+out_cell_size + + quad_m = mask[tsy:tey,tsx:tex] + t1_visible = numpy.any(quad_m*quad_t1_mask) + t2_visible = numpy.any(quad_m*quad_t2_mask) + + return (t1_visible or t2_visible), t1_visible, t2_visible + +def mask_triangle_invisible(mask,out_cell_num,imga): + py = out_cell_num//out_img_w + px = out_cell_num%out_img_w + + tsy = py*out_cell_size + tey = py*out_cell_size+out_cell_size + tsx = px*out_cell_size + tex = px*out_cell_size+out_cell_size + + quad_m = mask[tsy:tey,tsx:tex] + t1_visible = numpy.any(quad_m*quad_t1_mask) + t2_visible = numpy.any(quad_m*quad_t2_mask) + + if not (t1_visible or t2_visible): + imga[tsy:tey,tsx:tex] = 0 + + elif not t1_visible: + imga[tsy:tey,tsx:tex] = imga[tsy:tey,tsx:tex]*quad_t2_mask[:,:,None] + + elif not t2_visible: + imga[tsy:tey,tsx:tex] = imga[tsy:tey,tsx:tex]*quad_t1_mask[:,:,None] + + return (t1_visible or t2_visible), t1_visible, t2_visible + + +for i in range(out_cell_num): + quad_visible, t1_visible, t2_visible = mask_triangle_invisible(texture_mask, i, texture_alpha) + if quad_visible: + num_visible_quads += 1 + + + +# for boxes +num_visible_quads_box = 0 + + +def check_triangle_visible_box(mask,xyid,layerid,py,px): + tsy = py*quad_size + tey = py*quad_size+out_cell_size + tsx = px*quad_size + tex = px*quad_size+out_cell_size + + quad_m = mask[xyid,layerid,tsy:tey,tsx:tex] + t1_visible = numpy.any(quad_m*quad_t1_mask_box) + t2_visible = numpy.any(quad_m*quad_t2_mask_box) + + return (t1_visible or t2_visible), t1_visible, t2_visible + +def mask_triangle_invisible_box(mask,xyid,layerid,py,px,imga): + tsy = py*quad_size + tey = py*quad_size+out_cell_size + tsx = px*quad_size + tex = px*quad_size+out_cell_size + + quad_m = mask[xyid,layerid,tsy:tey,tsx:tex] + t1_visible = numpy.any(quad_m*quad_t1_mask_box) + t2_visible = numpy.any(quad_m*quad_t2_mask_box) + + if not (t1_visible or t2_visible): + imga[xyid,layerid,tsy:tey,tsx:tex] = 0 + + elif not t1_visible: + imga[xyid,layerid,tsy:tey,tsx:tex] = imga[xyid,layerid,tsy:tey,tsx:tex]*quad_t2_mask + + elif not t2_visible: + imga[xyid,layerid,tsy:tey,tsx:tex] = imga[xyid,layerid,tsy:tey,tsx:tex]*quad_t1_mask + + return (t1_visible or t2_visible), t1_visible, t2_visible + + +for t in range(4): + for i in range(box_layer_num): + for j in range(box_texture_size//quad_size): + for k in range(box_texture_size//quad_size): + quad_visible, t1_visible, t2_visible = mask_triangle_invisible_box( + texture_mask_box, t,i,j,k, texture_alpha_box) + if quad_visible: + num_visible_quads_box += 1 + +print("Number of quad faces:", num_visible_quads) +print("Number of box quad faces:", num_visible_quads_box) + +#%% -------------------------------------------------------------------------------- +# ## Eval +#%% +gc.collect() + +render_poses = data['test']['c2w'][:len(data['test']['images'])] +frames = [] +framemasks = [] +print("Testing") +for p in tqdm(render_poses): + out = render_loop(camera_ray_batch(p, hwf), vars, point_UV_grid, + texture_alpha, texture_alpha_box, + texture_features, coarse_feature_box, test_batch_size) + frames.append(out[0]) + framemasks.append(out[1]) +psnrs_test = [-10 * np.log10(np.mean(np.square(rgb - gt))) for (rgb, gt) in zip(frames, data['test']['images'])] +print("Test set average PSNR: %f" % np.array(psnrs_test).mean()) + +#%% +import jax.numpy as jnp +import jax.scipy as jsp + +def compute_ssim(img0, + img1, + max_val, + filter_size=11, + filter_sigma=1.5, + k1=0.01, + k2=0.03, + return_map=False): + """Computes SSIM from two images. + This function was modeled after tf.image.ssim, and should produce comparable + output. + Args: + img0: array. An image of size [..., width, height, num_channels]. + img1: array. An image of size [..., width, height, num_channels]. + max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. + filter_size: int >= 1. Window size. + filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. + k1: float > 0. One of the SSIM dampening parameters. + k2: float > 0. One of the SSIM dampening parameters. + return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned + Returns: + Each image's mean SSIM, or a tensor of individual values if `return_map`. + """ + # Construct a 1D Gaussian blur filter. + hw = filter_size // 2 + shift = (2 * hw - filter_size + 1) / 2 + f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma)**2 + filt = jnp.exp(-0.5 * f_i) + filt /= jnp.sum(filt) + + # Blur in x and y (faster than the 2D convolution). + filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid") + filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid") + + # Vmap the blurs to the tensor size, and then compose them. + num_dims = len(img0.shape) + map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1]) + for d in map_axes: + filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d) + filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d) + filt_fn = lambda z: filt_fn1(filt_fn2(z)) + + mu0 = filt_fn(img0) + mu1 = filt_fn(img1) + mu00 = mu0 * mu0 + mu11 = mu1 * mu1 + mu01 = mu0 * mu1 + sigma00 = filt_fn(img0**2) - mu00 + sigma11 = filt_fn(img1**2) - mu11 + sigma01 = filt_fn(img0 * img1) - mu01 + + # Clip the variances and covariances to valid values. + # Variance must be non-negative: + sigma00 = jnp.maximum(0., sigma00) + sigma11 = jnp.maximum(0., sigma11) + sigma01 = jnp.sign(sigma01) * jnp.minimum( + jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01)) + + c1 = (k1 * max_val)**2 + c2 = (k2 * max_val)**2 + numer = (2 * mu01 + c1) * (2 * sigma01 + c2) + denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) + ssim_map = numer / denom + ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims))) + return ssim_map if return_map else ssim + +# Compiling to the CPU because it's faster and more accurate. +ssim_fn = jax.jit( + functools.partial(compute_ssim, max_val=1.), backend="cpu") + +ssim_values = [] +for i in range(len(data['test']['images'])): + ssim = ssim_fn(frames[i], data['test']['images'][i]) + ssim_values.append(float(ssim)) + +print("Test set average SSIM: %f" % np.array(ssim_values).mean()) +#%% -------------------------------------------------------------------------------- +# ## Write mesh +#%% + +#use texture_mask to decide keep or drop + +new_img_sizes = [ + [1024,1024], + [2048,1024], + [2048,2048], + [4096,2048], + [4096,4096], + [8192,4096], + [8192,8192], + [16384,8192], + [16384,16384], + [32768,16384], + [32768,32768], +] + +fit_flag = False +for i in range(len(new_img_sizes)): + new_img_size_w,new_img_size_h = new_img_sizes[i] + new_img_size_ratio = new_img_size_w/new_img_size_h + new_img_h = new_img_size_h//out_cell_size + new_img_w = new_img_size_w//out_cell_size + if num_visible_quads+num_visible_quads_box<=new_img_h*new_img_w: + fit_flag = True + break + +if fit_flag: + print("Texture image size:", new_img_size_w,new_img_size_h) +else: + print("Texture image too small", new_img_size_w,new_img_size_h) + 1/0 + + +new_img = [] +for i in range(out_feat_num): + new_img.append(numpy.zeros([new_img_size_h,new_img_size_w,4], numpy.uint8)) +new_cell_num = 0 + + +def copy_patch_to_png(out_img,out_cell_num,new_img,new_cell_num): + py = out_cell_num//out_img_w + px = out_cell_num%out_img_w + + ny = new_cell_num//new_img_w + nx = new_cell_num%new_img_w + + tsy = py*out_cell_size + tey = py*out_cell_size+out_cell_size + tsx = px*out_cell_size + tex = px*out_cell_size+out_cell_size + nsy = ny*out_cell_size + ney = ny*out_cell_size+out_cell_size + nsx = nx*out_cell_size + nex = nx*out_cell_size+out_cell_size + + for i in range(out_feat_num): + new_img[i][nsy:ney,nsx:nex] = out_img[i][tsy:tey,tsx:tex] + + return True + + +def write_patch_to_png_box(new_img,new_cell_num,xyid,layerid,py,px,feats): + ny = new_cell_num//new_img_w + nx = new_cell_num%new_img_w + + tsy = py*quad_size + tey = py*quad_size+out_cell_size + tsx = px*quad_size + tex = px*quad_size+out_cell_size + nsy = ny*out_cell_size + ney = ny*out_cell_size+out_cell_size + nsx = nx*out_cell_size + nex = nx*out_cell_size+out_cell_size + + #hard coded to save (my) time + new_img[0][nsy:ney,nsx:nex,0] = feats[xyid,layerid,tsy:tey,tsx:tex,2] + new_img[0][nsy:ney,nsx:nex,1] = feats[xyid,layerid,tsy:tey,tsx:tex,1] + new_img[0][nsy:ney,nsx:nex,2] = feats[xyid,layerid,tsy:tey,tsx:tex,0] + new_img[0][nsy:ney,nsx:nex,3] = feats[xyid,layerid,tsy:tey,tsx:tex,3] + new_img[1][nsy:ney,nsx:nex,0] = feats[xyid,layerid,tsy:tey,tsx:tex,6] + new_img[1][nsy:ney,nsx:nex,1] = feats[xyid,layerid,tsy:tey,tsx:tex,5] + new_img[1][nsy:ney,nsx:nex,2] = feats[xyid,layerid,tsy:tey,tsx:tex,4] + new_img[1][nsy:ney,nsx:nex,3] = feats[xyid,layerid,tsy:tey,tsx:tex,7] + + +def get_box_v_from_uv(xyid,layerid,py,px): + d = float(layerid)/(point_grid_size//2) + d = (numpy.exp( d * scene_grid_zcc ) + (scene_grid_zcc-1)) / scene_grid_zcc + + uv = np.array([py*quad_size,py*quad_size+out_cell_size,px*quad_size,px*quad_size+out_cell_size],np.float32) + uv = (uv/box_texture_size*2-1)*d + + u0,u1,v0,v1 = uv + + if xyid==0: + p0 = [d,u0,v0] + p1 = [d,u1,v0] + p2 = [d,u0,v1] + p3 = [d,u1,v1] + elif xyid==1: + p0 = [-d,u0,v0] + p1 = [-d,u1,v0] + p2 = [-d,u0,v1] + p3 = [-d,u1,v1] + elif xyid==2: + p0 = [u0,d,v0] + p1 = [u1,d,v0] + p2 = [u0,d,v1] + p3 = [u1,d,v1] + elif xyid==3: + p0 = [u0,-d,v0] + p1 = [u1,-d,v0] + p2 = [u0,-d,v1] + p3 = [u1,-d,v1] + + return p0,p1,p2,p3 + + + +#write mesh + +obj_save_dir = "obj" +if not os.path.exists(obj_save_dir): + os.makedirs(obj_save_dir) + +obj_f = open(obj_save_dir+"/shape.obj",'w') + +vcount = 0 + +for i in range(out_cell_num): + quad_visible, t1_visible, t2_visible = check_triangle_visible(texture_mask, i) + if quad_visible: + copy_patch_to_png(out_img,i,new_img,new_cell_num) + p0,p1,p2,p3 = bag_of_v[i] + uv0,uv1,uv2,uv3 = get_png_uv(new_cell_num,new_img_w,new_img_size_w) + new_cell_num += 1 + + if scene_type=="synthetic" or scene_type=="real360": + obj_f.write("v %.6f %.6f %.6f\n" % (p0[0],p0[2],-p0[1])) + obj_f.write("v %.6f %.6f %.6f\n" % (p1[0],p1[2],-p1[1])) + obj_f.write("v %.6f %.6f %.6f\n" % (p2[0],p2[2],-p2[1])) + obj_f.write("v %.6f %.6f %.6f\n" % (p3[0],p3[2],-p3[1])) + elif scene_type=="forwardfacing": + obj_f.write("v %.6f %.6f %.6f\n" % (p0[0],p0[1],p0[2])) + obj_f.write("v %.6f %.6f %.6f\n" % (p1[0],p1[1],p1[2])) + obj_f.write("v %.6f %.6f %.6f\n" % (p2[0],p2[1],p2[2])) + obj_f.write("v %.6f %.6f %.6f\n" % (p3[0],p3[1],p3[2])) + + obj_f.write("vt %.6f %.6f\n" % (uv0[1],1-uv0[0]*new_img_size_ratio)) + obj_f.write("vt %.6f %.6f\n" % (uv1[1],1-uv1[0]*new_img_size_ratio)) + obj_f.write("vt %.6f %.6f\n" % (uv2[1],1-uv2[0]*new_img_size_ratio)) + obj_f.write("vt %.6f %.6f\n" % (uv3[1],1-uv3[0]*new_img_size_ratio)) + if t1_visible: + obj_f.write("f %d/%d %d/%d %d/%d\n" % (vcount+1,vcount+1,vcount+2,vcount+2,vcount+4,vcount+4)) + if t2_visible: + obj_f.write("f %d/%d %d/%d %d/%d\n" % (vcount+1,vcount+1,vcount+4,vcount+4,vcount+3,vcount+3)) + vcount += 4 + + +# boxes +for t in range(4): + for i in range(box_layer_num): + for j in range(box_texture_size//quad_size): + for k in range(box_texture_size//quad_size): + quad_visible, t1_visible, t2_visible = check_triangle_visible_box( + texture_mask_box, t,i,j,k) + if quad_visible: + write_patch_to_png_box(new_img,new_cell_num,t,i,j,k,coarse_feature_box) + p0,p1,p2,p3 = get_box_v_from_uv(t,i,j,k) + uv0,uv1,uv2,uv3 = get_png_uv(new_cell_num,new_img_w,new_img_size_w) + new_cell_num += 1 + + if scene_type=="synthetic" or scene_type=="real360": + obj_f.write("v %.6f %.6f %.6f\n" % (p0[0],p0[2],-p0[1])) + obj_f.write("v %.6f %.6f %.6f\n" % (p1[0],p1[2],-p1[1])) + obj_f.write("v %.6f %.6f %.6f\n" % (p2[0],p2[2],-p2[1])) + obj_f.write("v %.6f %.6f %.6f\n" % (p3[0],p3[2],-p3[1])) + elif scene_type=="forwardfacing": + obj_f.write("v %.6f %.6f %.6f\n" % (p0[0],p0[1],p0[2])) + obj_f.write("v %.6f %.6f %.6f\n" % (p1[0],p1[1],p1[2])) + obj_f.write("v %.6f %.6f %.6f\n" % (p2[0],p2[1],p2[2])) + obj_f.write("v %.6f %.6f %.6f\n" % (p3[0],p3[1],p3[2])) + + obj_f.write("vt %.6f %.6f\n" % (uv0[1],1-uv0[0]*new_img_size_ratio)) + obj_f.write("vt %.6f %.6f\n" % (uv1[1],1-uv1[0]*new_img_size_ratio)) + obj_f.write("vt %.6f %.6f\n" % (uv2[1],1-uv2[0]*new_img_size_ratio)) + obj_f.write("vt %.6f %.6f\n" % (uv3[1],1-uv3[0]*new_img_size_ratio)) + if t1_visible: + obj_f.write("f %d/%d %d/%d %d/%d\n" % (vcount+1,vcount+1,vcount+2,vcount+2,vcount+4,vcount+4)) + if t2_visible: + obj_f.write("f %d/%d %d/%d %d/%d\n" % (vcount+1,vcount+1,vcount+4,vcount+4,vcount+3,vcount+3)) + vcount += 4 + + + + +for j in range(out_feat_num): + cv2.imwrite(obj_save_dir+"/shape.pngfeat"+str(j)+".png", new_img[j], [cv2.IMWRITE_PNG_COMPRESSION, 9]) +obj_f.close() + + +#%% +#export weights for the MLP +mlp_params = {} + +mlp_params['0_weights'] = vars[-1]['params']['Dense_0']['kernel'].tolist() +mlp_params['1_weights'] = vars[-1]['params']['Dense_1']['kernel'].tolist() +mlp_params['2_weights'] = vars[-1]['params']['Dense_2']['kernel'].tolist() +mlp_params['0_bias'] = vars[-1]['params']['Dense_0']['bias'].tolist() +mlp_params['1_bias'] = vars[-1]['params']['Dense_1']['bias'].tolist() +mlp_params['2_bias'] = vars[-1]['params']['Dense_2']['bias'].tolist() + +scene_params_path = obj_save_dir+'/mlp.json' +with open(scene_params_path, 'wb') as f: + f.write(json.dumps(mlp_params).encode('utf-8')) +#%% -------------------------------------------------------------------------------- +# ## Split the large texture image into images of size 4096 +#%% +import numpy as np + +target_dir = obj_save_dir+"_phone" + +texture_size = 4096 +patchsize = 17 +texture_patch_size = texture_size//patchsize + +if not os.path.exists(target_dir): + os.makedirs(target_dir) + + +source_obj_dir = obj_save_dir+"/shape.obj" +source_png0_dir = obj_save_dir+"/shape.pngfeat0.png" +source_png1_dir = obj_save_dir+"/shape.pngfeat1.png" + +source_png0 = cv2.imread(source_png0_dir,cv2.IMREAD_UNCHANGED) +source_png1 = cv2.imread(source_png1_dir,cv2.IMREAD_UNCHANGED) + +img_h,img_w,_ = source_png0.shape + + + +num_splits = 0 #this is a counter + + +fin = open(source_obj_dir,'r') +lines = fin.readlines() +fin.close() + + + +current_img_idx = 0 +current_img0 = np.zeros([texture_size,texture_size,4],np.uint8) +current_img1 = np.zeros([texture_size,texture_size,4],np.uint8) +current_quad_count = 0 +current_obj = open(target_dir+"/shape"+str(current_img_idx)+".obj",'w') +current_v_count = 0 +current_v_offset = 0 + +#v-vt-f cycle + +for i in range(len(lines)): + line = lines[i].split() + if len(line)==0: + continue + + elif line[0] == 'v': + current_obj.write(lines[i]) + current_v_count += 1 + + elif line[0] == 'vt': + if lines[i-1].split()[0] == "v": + + line = lines[i].split() + x0 = float(line[1]) + y0 = 1-float(line[2]) + + line = lines[i+1].split() + x1 = float(line[1]) + y1 = 1-float(line[2]) + + line = lines[i+2].split() + x2 = float(line[1]) + y2 = 1-float(line[2]) + + line = lines[i+3].split() + x3 = float(line[1]) + y3 = 1-float(line[2]) + + xc = (x0+x1+x2+x3)*img_w/4 + yc = (y0+y1+y2+y3)*img_h/4 + + old_cell_x = int(xc/patchsize) + old_cell_y = int(yc/patchsize) + + new_cell_x = current_quad_count%texture_patch_size + new_cell_y = current_quad_count//texture_patch_size + current_quad_count += 1 + + #copy patch + + tsy = old_cell_y*patchsize + tey = old_cell_y*patchsize+patchsize + tsx = old_cell_x*patchsize + tex = old_cell_x*patchsize+patchsize + nsy = new_cell_y*patchsize + ney = new_cell_y*patchsize+patchsize + nsx = new_cell_x*patchsize + nex = new_cell_x*patchsize+patchsize + + current_img0[nsy:ney,nsx:nex] = source_png0[tsy:tey,tsx:tex] + current_img1[nsy:ney,nsx:nex] = source_png1[tsy:tey,tsx:tex] + + #write uv + + uv0_y = (new_cell_y*patchsize+0.5)/texture_size + uv0_x = (new_cell_x*patchsize+0.5)/texture_size + + uv1_y = ((new_cell_y+1)*patchsize-0.5)/texture_size + uv1_x = (new_cell_x*patchsize+0.5)/texture_size + + uv2_y = (new_cell_y*patchsize+0.5)/texture_size + uv2_x = ((new_cell_x+1)*patchsize-0.5)/texture_size + + uv3_y = ((new_cell_y+1)*patchsize-0.5)/texture_size + uv3_x = ((new_cell_x+1)*patchsize-0.5)/texture_size + + current_obj.write("vt %.6f %.6f\n" % (uv0_x,1-uv0_y)) + current_obj.write("vt %.6f %.6f\n" % (uv1_x,1-uv1_y)) + current_obj.write("vt %.6f %.6f\n" % (uv2_x,1-uv2_y)) + current_obj.write("vt %.6f %.6f\n" % (uv3_x,1-uv3_y)) + + + elif line[0] == 'f': + f1 = int(line[1].split("/")[0])-current_v_offset + f2 = int(line[2].split("/")[0])-current_v_offset + f3 = int(line[3].split("/")[0])-current_v_offset + current_obj.write("f %d/%d %d/%d %d/%d\n" % (f1,f1,f2,f2,f3,f3)) + + #create new texture image if current is fill + if i==len(lines)-1 or (lines[i+1].split()[0]!='f' and current_quad_count==texture_patch_size*texture_patch_size): + current_obj.close() + + # the following is only required for iphone + # because iphone runs alpha test before the fragment shader + # the viewer code is also changed accordingly + current_img0[:,:,3] = current_img0[:,:,3]//2+128 + current_img1[:,:,3] = current_img1[:,:,3]//2+128 + + cv2.imwrite(target_dir+"/shape"+str(current_img_idx)+".pngfeat0.png", current_img0, [cv2.IMWRITE_PNG_COMPRESSION,9]) + cv2.imwrite(target_dir+"/shape"+str(current_img_idx)+".pngfeat1.png", current_img1, [cv2.IMWRITE_PNG_COMPRESSION,9]) + current_img_idx += 1 + current_img0 = np.zeros([texture_size,texture_size,4],np.uint8) + current_img1 = np.zeros([texture_size,texture_size,4],np.uint8) + current_quad_count = 0 + if i!=len(lines)-1: + current_obj = open(target_dir+"/shape"+str(current_img_idx)+".obj",'w') + current_v_offset += current_v_count + current_v_count = 0 + + + + +#copy the small MLP +source_json_dir = obj_save_dir+"/mlp.json" +current_json_dir = target_dir+"/mlp.json" +fin = open(source_json_dir,'r') +line = fin.readline() +fin.close() +fout = open(current_json_dir,'w') +fout.write(line.strip()[:-1]) +fout.write(",\"obj_num\": "+str(current_img_idx)+"}") +fout.close() + +#%% -------------------------------------------------------------------------------- +# # Save images for testing +#%% + +pred_frames = np.array(frames,np.float32) +gt_frames = np.array(data['test']['images'],np.float32) + +pickle.dump(pred_frames, open("pred_frames.pkl", "wb")) +pickle.dump(gt_frames, open("gt_frames.pkl", "wb")) diff --git a/jax3d/projects/mobilenerf/view_forwardfacing.html b/jax3d/projects/mobilenerf/view_forwardfacing.html new file mode 100644 index 0000000..1880e01 --- /dev/null +++ b/jax3d/projects/mobilenerf/view_forwardfacing.html @@ -0,0 +1,493 @@ +
+ +
+
+
+
+ + + + + + + + + diff --git a/jax3d/projects/mobilenerf/view_unbounded.html b/jax3d/projects/mobilenerf/view_unbounded.html new file mode 100644 index 0000000..e921f9f --- /dev/null +++ b/jax3d/projects/mobilenerf/view_unbounded.html @@ -0,0 +1,547 @@ +
+ +
+
+
+
+ + + + + + + + +