Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a detail variation method of the image input #88

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 209 additions & 52 deletions app.py

Large diffs are not rendered by default.

37,610 changes: 37,610 additions & 0 deletions assets/example_mesh/typical_creature_dragon.obj

Large diffs are not rendered by default.

40,766 changes: 40,766 additions & 0 deletions assets/example_mesh/typical_creature_elephant.obj

Large diffs are not rendered by default.

32,477 changes: 32,477 additions & 0 deletions assets/example_mesh/typical_creature_furry.obj

Large diffs are not rendered by default.

77,621 changes: 77,621 additions & 0 deletions assets/example_mesh/typical_creature_quadruped.obj

Large diffs are not rendered by default.

67,374 changes: 67,374 additions & 0 deletions assets/example_mesh/typical_creature_robot_crab.obj

Large diffs are not rendered by default.

26,204 changes: 26,204 additions & 0 deletions assets/example_mesh/typical_creature_robot_dinosour.obj

Large diffs are not rendered by default.

60,040 changes: 60,040 additions & 0 deletions assets/example_mesh/typical_creature_rock_monster.obj

Large diffs are not rendered by default.

52,316 changes: 52,316 additions & 0 deletions assets/example_mesh/typical_humanoid_block_robot.obj

Large diffs are not rendered by default.

40,586 changes: 40,586 additions & 0 deletions assets/example_mesh/typical_humanoid_dragonborn.obj

Large diffs are not rendered by default.

84,297 changes: 84,297 additions & 0 deletions assets/example_mesh/typical_humanoid_dwarf.obj

Large diffs are not rendered by default.

41,758 changes: 41,758 additions & 0 deletions assets/example_mesh/typical_humanoid_goblin.obj

Large diffs are not rendered by default.

60,009 changes: 60,009 additions & 0 deletions assets/example_mesh/typical_humanoid_mech.obj

Large diffs are not rendered by default.

104 changes: 104 additions & 0 deletions example_detail_variation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
import sys
# os.environ['ATTN_BACKEND'] = 'xformers' # Can be 'flash-attn' or 'xformers', default is 'flash-attn'
os.environ['SPCONV_ALGO'] = 'native' # Can be 'native' or 'auto', default is 'auto'.
# 'auto' is faster but will do benchmarking at the beginning.
# Recommended to set to 'native' if run only once.

import subprocess
import imageio
import trimesh
import torch
from PIL import Image
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.utils import render_utils, postprocessing_utils

"""
Image-version detail variation (Sec3.4 of the paper)
1. Voxelize a GIVEN mesh into the form of Sparse Structure.
A CUDA VOXELIZER is required. Install according to https://github.com/Forceflow/cuda_voxelizer.
2. Run ONLY the second stage with the image prompt.
"""

VOXELIZER = "cuda_voxelizer"
VOVEL_RESOLUTION = 64
VOXELIZER_CMD = "{} -f {} -s {} -o binvox"

def _check_voxelizer_exists(executable):
try:
subprocess.check_output([executable])
except subprocess.CalledProcessError:
print("Can not find the voxelizer!!!")
sys.exit(-1)

def _voxelize_mesh(inpath: str):
outpath = os.path.join(os.path.dirname(inpath), f"{os.path.basename(inpath)}_{VOVEL_RESOLUTION}.binvox")
if os.path.exists(outpath):
return outpath

_check_voxelizer_exists(VOXELIZER)
cmd = VOXELIZER_CMD.format(VOXELIZER, inpath, VOVEL_RESOLUTION)
voxelizer_ans = subprocess.run(cmd, capture_output=True, shell=True)
if voxelizer_ans.returncode != 0 or not os.path.exists(outpath):
print("voxelizer fails with error:")
print(voxelizer_ans.stderr)
return outpath

# Load a pipeline from a model folder or a Hugging Face model hub.
pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
pipeline.cuda()

# Test image path and saving directory
saveroot = "results/texgen"

example_mesh_image_pairs = [
("assets/example_mesh/typical_creature_dragon.obj", "assets/example_image/typical_creature_elephant.png"),
("assets/example_mesh/typical_creature_dragon.obj", "assets/example_image/typical_creature_furry.png"),
("assets/example_mesh/typical_creature_dragon.obj", "assets/example_image/typical_creature_robot_dinosour.png"),
("assets/example_mesh/typical_creature_dragon.obj", "assets/example_image/typical_creature_robot_crab.png"),
("assets/example_mesh/typical_humanoid_block_robot.obj", "assets/example_image/typical_building_mushroom.png"),
("assets/example_mesh/typical_humanoid_block_robot.obj", "assets/example_image/typical_humanoid_mech.png")
]

for mesh_image_pair in example_mesh_image_pairs:
mesh_path, image_path = mesh_image_pair
instance_name = f"{os.path.splitext(os.path.basename(mesh_path))[0]}+{os.path.splitext(os.path.basename(image_path))[0]}"
savedir = os.path.join(saveroot, instance_name)
os.makedirs(savedir, exist_ok=True)

# Load the image
image = Image.open(image_path)

binary_voxel = trimesh.load(_voxelize_mesh(mesh_path)).matrix

# Run the pipeline
outputs = pipeline.run_texgen(
binary_voxel,
image,
seed=1,
# more steps, larger cfg
slat_sampler_params={
"steps": 35,
"cfg_strength": 6.0,
},
)

torch.cuda.empty_cache()
# Render the outputs
video = render_utils.render_video(outputs['gaussian'][0])['color']
imageio.mimsave(os.path.join(savedir, f"{instance_name}_gs.mp4"), video, fps=30)
video = render_utils.render_video(outputs['radiance_field'][0])['color']
imageio.mimsave(os.path.join(savedir, f"{instance_name}_rf.mp4"), video, fps=30)
video = render_utils.render_video(outputs['mesh'][0])['normal']
imageio.mimsave(os.path.join(savedir, f"{instance_name}_mesh.mp4"), video, fps=30)

# GLB files can be extracted from the outputs
glb = postprocessing_utils.to_trimesh(
outputs['gaussian'][0],
outputs['mesh'][0],
# Optional parameters
simplify=0.95, # Ratio of triangles to remove in the simplification process
texture_size=1024, # Size of the texture used for the GLB
debug=False,
verbose=True
)
75 changes: 69 additions & 6 deletions trellis/pipelines/trellis_image_to_3d.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from typing import *
from contextlib import contextmanager
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import numpy as np
from tqdm import tqdm
from easydict import EasyDict as edict
from torchvision import transforms
Expand Down Expand Up @@ -81,6 +82,19 @@ def _init_image_cond_model(self, name: str):
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.image_cond_model_transform = transform

def preprocess_voxel(self, binary_voxel: np.ndarray, voxel_res: int = 64) -> torch.Tensor:
"""
Preprocess(read / voxelize) the given 3D object.
"""
assert all([s == voxel_res for s in binary_voxel.shape]), "Input voxels have incompatible resolution {}".format(binary_voxel.shape)
# Active voxels (N_p x 3)
x, y, z = np.nonzero(binary_voxel)
values_sum = x * voxel_res * voxel_res + y * voxel_res + z
active_voxels = np.stack([x, y, z], axis=1)[np.argsort(values_sum)]
active_voxels = np.concatenate((np.zeros((len(active_voxels), 1), dtype=np.int32), active_voxels), axis=1)
# Pad with the batch dimension
return torch.from_numpy(active_voxels).int().cuda()

def preprocess_image(self, input: Image.Image) -> Image.Image:
"""
Expand Down Expand Up @@ -253,11 +267,51 @@ def sample_slat(
slat = slat * std + mean

return slat

@torch.no_grad()
def run_detail_variation(
self,
binary_voxel: np.ndarray,
images: Union[Image.Image, List[Image.Image]],
num_samples: int = 1,
seed: int = 42,
sparse_structure_sampler_params: dict = {},
slat_sampler_params: dict = {},
formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
preprocess_image: bool = True,
) -> dict:
"""
Run the texture generation(2nd stage) pipeline.

Args:
binary_voxel (np.ndarray): The input binary voxel.
image (Image.Image or a list of Image.Image): The image prompt(s).
num_samples (int): The number of samples to generate.
sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
slat_sampler_params (dict): Additional parameters for the structured latent sampler.
preprocess_image (bool): Whether to preprocess the image.
"""
conds = []
if isinstance(images, Image.Image):
images = [images]
# Get condition for each image prompt and take average
for image in images:
if preprocess_image:
image = self.preprocess_image(image)
conds.append(self.get_cond([image]))
cond = {
key: torch.stack([item[key] for item in conds], dim=0).mean(dim=0) for key in conds[0].keys()
}

torch.manual_seed(seed)
coords = self.preprocess_voxel(binary_voxel)
slat = self.sample_slat(cond, coords, slat_sampler_params)
return self.decode_slat(slat, formats)

@torch.no_grad()
def run(
self,
image: Image.Image,
images: Union[Image.Image, List[Image.Image]],
num_samples: int = 1,
seed: int = 42,
sparse_structure_sampler_params: dict = {},
Expand All @@ -269,15 +323,24 @@ def run(
Run the pipeline.

Args:
image (Image.Image): The image prompt.
image (Image.Image or a list of Image.Image): The image prompt(s).
num_samples (int): The number of samples to generate.
sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
slat_sampler_params (dict): Additional parameters for the structured latent sampler.
preprocess_image (bool): Whether to preprocess the image.
"""
if preprocess_image:
image = self.preprocess_image(image)
cond = self.get_cond([image])
conds = []
if isinstance(images, Image.Image):
images = [images]
# Get condition for each image prompt and take average
for image in images:
if preprocess_image:
image = self.preprocess_image(image)
conds.append(self.get_cond([image]))
cond = {
key: torch.stack([item[key] for item in conds], dim=0).mean(dim=0) for key in conds[0].keys()
}

torch.manual_seed(seed)
coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
slat = self.sample_slat(cond, coords, slat_sampler_params)
Expand Down
4 changes: 2 additions & 2 deletions trellis/representations/gaussian/gaussian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def from_features(self, features):
def from_opacity(self, opacities):
self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias

def construct_list_of_attributes(self):
def construct_list_of_attributes(self, mode='all'):
l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
# All channels except the 3 DC
for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
Expand All @@ -121,7 +121,7 @@ def construct_list_of_attributes(self):
for i in range(self._rotation.shape[1]):
l.append('rot_{}'.format(i))
return l

def save_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]):
xyz = self.get_xyz.detach().cpu().numpy()
normals = np.zeros_like(xyz)
Expand Down
17 changes: 17 additions & 0 deletions trellis/utils/general_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import zipfile
import numpy as np
import cv2
import torch
Expand Down Expand Up @@ -185,3 +187,18 @@ def indent(s, n=4):
lines[i] = ' ' * n + lines[i]
return '\n'.join(lines)

def zip_directory(directory_path, zip_file_path):
"""
compress directory into a zipfile
"""
if not os.path.isdir(directory_path):
raise FileNotFoundError(f"Directory {directory_path} not found.")

with zipfile.ZipFile(zip_file_path, 'w', zipfile.ZIP_DEFLATED) as zip_file:
for root, dirs, files in os.walk(directory_path):
for dir_name in dirs:
dir_path = os.path.join(root, dir_name)
zip_file.write(dir_path, os.path.relpath(dir_path, directory_path))
for file_name in files:
file_path = os.path.join(root, file_name)
zip_file.write(file_path, os.path.relpath(file_path, directory_path))
Loading