Skip to content

Commit

Permalink
每添加一个图像后就存一次mesh
Browse files Browse the repository at this point in the history
  • Loading branch information
leo-frank committed Oct 6, 2023
1 parent 82f42ef commit 30915cb
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 24 deletions.
3 changes: 2 additions & 1 deletion pipelines/Camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import random
from utils.util import log, debug
from utils import camera

from loguru import logger
# from models.cnn_model.encoder import SpatialEncoder
epsilon = 1e-6
from typing import Optional
Expand Down Expand Up @@ -584,6 +584,7 @@ def prealign_cameras(self, opt, pose, pose_GT):
def eval_poses(self,
pick_cam_id=None,
mode="normal"):
# print( ATE, rot_error, t_error )
poses_all, poses_gt_all = self.get_all_poses(pick_cam_id=pick_cam_id)

if poses_all.shape[0] > 2:
Expand Down
10 changes: 7 additions & 3 deletions pipelines/Initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Optional
from multiprocessing.dummy import Pool as ThreadPool
import pycolmap
from loguru import logger


class Initializer():
Expand Down Expand Up @@ -144,9 +145,10 @@ def run(self,
ref_indx: Optional[int] = 0,
src_indx: Optional[int] = 1):
loader = tqdm.trange(self.max_iter, desc="Initialization", leave=False)
logger.info("ref index: {}, src index: {}".format(ref_indx, src_indx))
camera0 = cameraset.cameras[ref_indx]
camera1 = cameraset.cameras[src_indx]
for it in loader:
for it in loader: # 500

ret = edict()
self.optim.zero_grad()
Expand All @@ -169,10 +171,12 @@ def run(self,
ret=ret, Renderer=Renderer)

loss = self.compute_loss(ret)
loss = self.summarize_loss(self.opt, loss)

if it%10==0: #TODO: add an option to control the frequency of printing
if it%10==0: # print stastics every 10 iterations, including such terms:
# 'PSNR': 31.099340438842773, 'reproj_error': 2.2479755878448486, 'sdf_surf': 0.004002795554697514, 'eikonal_loss': 0.034246332943439484, 'rgb': 0.010943753644824028, 'DC_Loss': 0.001470054849050939
self.print_loss(loss_dict=loss,PSNR=ret.get('PSNR'))
loss = self.summarize_loss(self.opt, loss)

loss.all.backward()

self.optim.step()
Expand Down
36 changes: 21 additions & 15 deletions pipelines/LevelS2fM.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tqdm
import utils.util as util
from easydict import EasyDict as edict

from loguru import logger
from utils.util import log
import utils.camera as camera
from . import base
Expand Down Expand Up @@ -124,9 +124,10 @@ def train(self, opt):
random_view = None

rendered_rgb_list = []
for self.it in loader:
for self.it in loader: # 20000000
var.iter = self.it
if random_view is not None:
logger.debug("rendering a random view")
rendered_output = self.camera_set.cameras[0].render_img_by_slices(
self.sdf_func,
self.color_func,
Expand All @@ -142,6 +143,7 @@ def train(self, opt):
# plt.show()
# import pdb; pdb.set_trace()
if len(self.camera_set) < 2:
logger.debug("initializition phase")
# Initialization
if opt.resume == False:
var.indx_init = [pose_graph[0], pose_graph[1]]
Expand Down Expand Up @@ -173,6 +175,7 @@ def train(self, opt):
# extra_info=self.sdf_func,
# N=512)
if opt.resume == False:
# init 2 view, save checkpoint & save geometry
Initializer.run(self.camera_set, self.point_set,
self.sdf_func, self.color_func, Renderer=self.Renderer)
self.save_checkpoint(opt, ep=None, it=self.it + 1, latest=False)
Expand All @@ -181,7 +184,7 @@ def train(self, opt):

random_view = slerp(self.camera_set.cameras[0].get_pose()[0], self.camera_set.cameras[1].get_pose()[0], 0.5)
del Initializer
else:
else: # 开始incremental sfm
if (opt.resume == True) & (load_finish == False):
print("------reloading cameras-------")
for i in tqdm.tqdm(range(len(self.cam_info_reloaded["cam_id"][2:])), desc="reloading cameras"):
Expand Down Expand Up @@ -234,13 +237,13 @@ def train(self, opt):
continue
# 开始incremental sfm
pose_graph_left = [p for p in pose_graph if p not in self.camera_set.cam_ids]
print(f"---------------- {len(pose_graph_left)} frames left ------------------")
logger.info(f"---------------- {len(pose_graph_left)} frames left ------------------")
if len(pose_graph_left) == 0:
self.vis_geo_rgb(opt, cameraset=self.camera_set, new_camera=self.camera_set.cameras[0],
pointset=self.point_set, cam_only=True)
print(f"finish!")
logger.info(f"finish!")

print("---------- searching next best view -------------")
logger.info("---------- searching next best view -------------")
if opt.nbv_mode == "colmap":
new_id = pose_graph_left[0]
else:
Expand Down Expand Up @@ -279,7 +282,7 @@ def train(self, opt):
nbv = np.argmax(pnp_score)
print("--------- max number is {} --------------".format(pnp_num_mches[nbv]))
new_id = pose_graph_left[nbv]
print(f"-------------the best view next id is {new_id}--------------")
logger.info(f"-------------the best view next id is {new_id}--------------")
img_new = self.train_data.all.image[new_id:new_id + 1]
camera_new = Camera.Camera(opt=opt,
id=new_id,
Expand All @@ -293,8 +296,8 @@ def train(self, opt):
Normal_omn=None,
Extrinsic=None,
idx2d_to_3d=None)
print(f"Total cameras num:{len(self.camera_set.cam_ids)}")
print(f"new cam_id: {new_id}")
logger.info(f"Total cameras num:{len(self.camera_set.cam_ids)}, still not add the new camera")
logger.info(f"new cam_id: {new_id}")
# -------------------Registration: PnP+Triangulation----------------------------------------------
if opt.Ablate_config.tri_trad == True:
Register = Registration_Trad.Registration(opt, self.sdf_func, cameraset=self.camera_set)
Expand Down Expand Up @@ -350,7 +353,7 @@ def train(self, opt):
reproj_tem = 100
iter_cycle = 0
mode = "sfm_refine"
print("-------------- reproj+rendering registration refine --------------------")
logger.info("reproj+rendering registration refine")
while (reproj_tem > 2.5):
if iter_cycle >= 1:
break
Expand All @@ -371,13 +374,13 @@ def train(self, opt):
torch.cuda.empty_cache()
iter_cycle += 1
# ------------------local ba---------------------------------------------------------------------
logger.info("Local BA")
reproj_tem = 100
iter_cycle = 0
mode = "sfm"
while (reproj_tem > 1.):
if iter_cycle >= 5:
break
# -----------------------------------
Local_BA_id = [camera_new.id] + src_cam_id
Bundler = BA.BA(opt=opt,
cameraset=self.camera_set,
Expand All @@ -390,18 +393,19 @@ def train(self, opt):
color_func=self.color_func,
Renderer=self.Renderer)
print(f"local frames num:{len(src_cam_id + [camera_new.id])}")
print("reproj_tem:", reproj_tem)
rot_error, t_error = self.camera_set.eval_poses(pick_cam_id=src_cam_id + [camera_new.id])
del Bundler
torch.cuda.empty_cache()
iter_cycle += 1
# ------------------global ba---------------------------------------------------------------------
logger.info("Global BA")
reproj_tem = 100
iter_cycle = 0
mode = "sfm"
while (reproj_tem > 1.):
if iter_cycle >= 5:
break
# -----------------global ba--------------------------------------------------------------------
Bundler = BA.BA(opt=opt,
cameraset=self.camera_set,
pointset=self.point_set,
Expand All @@ -412,12 +416,14 @@ def train(self, opt):
reproj_tem = Bundler.run_ba(sdf_func=self.sdf_func,
color_func=self.color_func,
Renderer=self.Renderer)
print("reproj_tem:", reproj_tem)
rot_error, t_error = self.camera_set.eval_poses()
del Bundler
torch.cuda.empty_cache()
iter_cycle += 1
# ------------------ rendering refine -------------------------------------------------------------
if opt.sfm_mode == "full":
logger.info("rendering refine") # 'PSNR': 16.81053924560547, 'eikonal_loss': 0.1209036335349083, 'rgb': 0.05632907524704933, 'DC_Loss': 0.0019330104114487767, 'sdf_surf': 0.0005162741290405393, 'tracing_loss': 0.00271444208920002}
Refiner = rendering_refine.Refine(opt=opt,
cameraset=self.camera_set,
pointset=self.point_set,
Expand All @@ -438,15 +444,15 @@ def train(self, opt):
save_latest = False
else:
save_latest = True
self.vis_geo_rgb(opt, cameraset=self.camera_set, new_camera=camera_new, pointset=self.point_set,
vis_only=False, cam_only=True)
if (len(self.camera_set.cameras) % opt.freq.ckpt == 0) & (opt.sfm_mode == "full"):
if opt.sfm_mode == "full":
self.vis_geo_rgb(opt, cameraset=self.camera_set, new_camera=camera_new, pointset=self.point_set,
vis_only=False, cam_only=False)
self.save_checkpoint(opt, ep=None, it=self.it + 1, latest=save_latest)

logger.info("end of registration of image {}".format(new_id))
# if self.it % opt.freq.ckpt == 0: self.save_checkpoint(opt, ep=None, it=self.it + 1)
pose_graph_i += 1
logger.info("pose_graph_i={}".format(pose_graph_i))
torch.cuda.empty_cache()
# after training
# if opt.tb:
Expand Down
17 changes: 12 additions & 5 deletions pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from . import Camera
from . import Point3D
from utils.util import extract_mesh
from loguru import logger
import numpy as np
import torch.nn.functional as torch_F

Expand Down Expand Up @@ -134,11 +135,11 @@ def vis_geo_rgb(self,
cameraset: Camera.CameraSet,
new_camera: Camera.Camera,
pointset: Point3D.Point3DSet,
vis_only: bool = True,
vis_only: bool = False,
cam_only: bool = False):
os.makedirs("{0}/mesh".format(opt.output_path), exist_ok=True)
opt.mesh_dir = "{0}/mesh".format(opt.output_path)
# ---------------------------- vis pts ------------------------------------------
# ---------------------------- vis point clouds: _pointcloud_org.ply & _pointcloud.ply ------------------------------------------
view_ord = len(cameraset)
# vis pointset
pts3d_vis = torch.cat(pointset.get_all_parameters()["xyzs"], dim=0).detach().cpu().numpy()
Expand All @@ -156,14 +157,15 @@ def vis_geo_rgb(self,
pts3d_vis = pts3d_vis[mask_vis.squeeze()]
util.draw_pcd(pts3d_vis, f"{opt.output_path}/mesh/{view_ord}_pointcloud.ply",
(sdf_value - sdf_value.min()) / (sdf_value.max() - sdf_value.min()) * sdf_grad)
# ---------------------------- vis cam ------------------------------------------
logger.debug("save {view_ord}_pointcloud_org.ply & {view_ord}_pointcloud.ply".format(view_ord,view_ord))
# ---------------------------- vis cameras: cam00000022.json & cam00000022_gt.json------------------------------------------
cameras = {}
for cam_i in cameraset.cameras + [new_camera]:
camera_i = {"{}".format(cam_i.id): {"K": util.intr2list(cam_i.intrinsic),
"W2C": util.pose2list(cam_i.get_pose().squeeze()),
"img_size": opt.data.image_size}}
cameras.update(camera_i)
util.dict2json(os.path.join(opt.mesh_dir, 'cam{:08d}.json'.format(view_ord)), cameras)
util.dict2json(os.path.join(opt.mesh_dir, 'cam{:08d}.json'.format(view_ord)), cameras) # cam00000022.json
poses_est, poses_gt = cameraset.get_all_poses()
pose_aligned_gt, _ = cameraset.prealign_cameras(opt, poses_gt, poses_est)
pose_wis = torch.cat([poses_est.cpu(),
Expand All @@ -178,9 +180,12 @@ def vis_geo_rgb(self,
pose_wis = torch.cat([pose_aligned_gt.cpu(),
torch.tensor([[0, 0, 0, 1]]).repeat(poses_est.shape[0], 1, 1)], dim=1)
# self.wis3d.add_camera_trajectory(torch.linalg.inv(pose_wis), name=f"{view_ord}_poses_gt")
util.dict2json(os.path.join(opt.mesh_dir, 'cam{:08d}_gt.json'.format(view_ord)), cameras)
util.dict2json(os.path.join(opt.mesh_dir, 'cam{:08d}_gt.json'.format(view_ord)), cameras) # cam00000022_gt.json
if cam_only == True:
return
logger.debug("save cam{:08d}_gt.json & cam{:08d}.json".format(view_ord, view_ord))
# ------------------------------------------------- save mesh ----------------------------------------------------------------------
vis_only = False # always extract mesh
if vis_only == False:
# visualize the mesh
extract_mesh(
Expand All @@ -191,9 +196,11 @@ def vis_geo_rgb(self,
show_progress=True,
extra_info=None,
N=512)
logger.debug("save {:08d}.ply".format(view_ord))
# util_vis.vis_by_wis3d_mesh(self.wis3d, os.path.join(opt.mesh_dir, '{:08d}.ply'.format(view_ord)),
# f"{view_ord}_mesh")

# ------------------------------------------------- save 6 files in image/index ----------------------------------------------------------------------
# visualize the novel view's rgb
ret = new_camera.get_depth(sdf_func=self.sdf_func, mode="eval")
ret_render = new_camera.render_img_by_slices(sdf_func=self.sdf_func,
Expand Down

0 comments on commit 30915cb

Please sign in to comment.