Skip to content

Commit 9ccba46

Browse files
committed
codes cleaned and documentations added
1 parent c72b7ec commit 9ccba46

2 files changed

Lines changed: 60 additions & 42 deletions

File tree

GEMstack/offboard/mast3r_3d_reconstruction/mast3r_runner.py

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def get_args_parser():
5050

5151

5252
class SparseGA():
53+
'''
54+
PointCloud class.
55+
'''
5356
def __init__(self, img_paths, pairs_in, res_fine, anchors, canonical_paths=None):
5457
def fetch_img(im):
5558
def torgb(x): return (x[0].permute(1, 2, 0).numpy() * .5 + .5).clip(min=0., max=1.)
@@ -87,6 +90,9 @@ def get_sparse_pts3d(self):
8790
return self.pts3d
8891

8992
def get_dense_pts3d(self, clean_depth=True, subsample=8):
93+
'''
94+
Get dense 3D points.
95+
'''
9096
assert self.canonical_paths, 'cache_path is required for dense 3d points'
9197
device = self.cam2w.device
9298
confs = []
@@ -223,6 +229,9 @@ def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc
223229
def _convert_scene_output_to_glb(outfile, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
224230
cam_color=None, as_pointcloud=False,
225231
transparent_cams=False, silent=False):
232+
'''
233+
Convert scene output to GLB file.
234+
'''
226235
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
227236
pts3d = to_numpy(pts3d)
228237
imgs = to_numpy(imgs)
@@ -275,6 +284,7 @@ def convert_scene_output_to_ply_impl(outfile, imgs, pts3d, mask, scale=1.0, appl
275284
pts3d (list of np.ndarray): 3D points per view, shape (H * W, 3).
276285
mask (list of np.ndarray): Boolean masks indicating valid points per view (H, W).
277286
scale (float): Scale factor to apply to the 3D points.
287+
apply_y_flip (bool): If True, apply a y-axis flip to the 3D points.
278288
silent (bool): If False, print export message.
279289
280290
Returns:
@@ -314,6 +324,9 @@ def convert_scene_output_to_ply_impl(outfile, imgs, pts3d, mask, scale=1.0, appl
314324
return outfile
315325

316326
def convert_scene_output_to_ply(outfile, data, scale=1.0, apply_y_flip=False, min_conf_thr=1.5, clean=True, TSDF_thresh=0):
327+
'''
328+
Convert scene output to PLY file. This is to filter out points that are not visible in the images either using TSDF or normal confidence thresholding.
329+
'''
317330
imgs = to_numpy(data.imgs)
318331
if TSDF_thresh > 0:
319332
tsdf = TSDFPostProcess(data, TSDF_thresh=TSDF_thresh)
@@ -353,28 +366,6 @@ def get_3D_model_from_scene(silent, scene_state, min_conf_thr=2, as_pointcloud=F
353366
return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
354367
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
355368

356-
def sort_images_from_longest_endpoint(D_square, data_length):
357-
D_square = D_square.copy()
358-
# Find the two farthest points
359-
i, j = np.unravel_index(np.argmax(D_square), D_square.shape)
360-
start_idx = i # or j — either works
361-
362-
# Greedy traversal using the precomputed distance matrix
363-
N = data_length
364-
visited = np.zeros(N, dtype=bool)
365-
visited[start_idx] = True
366-
path = [start_idx]
367-
368-
current_idx = start_idx
369-
for _ in range(N - 1):
370-
dists = D_square[current_idx]
371-
dists[visited] = np.inf # Ignore visited
372-
next_idx = np.argmin(dists)
373-
path.append(next_idx)
374-
visited[next_idx] = True
375-
current_idx = next_idx
376-
return path
377-
378369
def get_reconstructed_scene(outdir, gradio_delete_cache, model, retrieval_model, device, silent, image_size,
379370
current_scene_state, filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr,
380371
matching_conf_thr, as_pointcloud, mask_sky, clean_depth,
@@ -429,17 +420,7 @@ def get_reconstructed_scene(outdir, gradio_delete_cache, model, retrieval_model,
429420
model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
430421
opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
431422
matching_conf_thr=matching_conf_thr, **kw)
432-
# if current_scene_state is not None and \
433-
# not current_scene_state.should_delete and \
434-
# current_scene_state.outfile_name is not None:
435-
# outfile_name = current_scene_state.outfile_name
436-
# else:
437-
# outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir)
438-
439-
# scene_state = SparseGAState(scene, gradio_delete_cache, cache_dir, outfile_name)
440-
# outfile = get_3D_model_from_scene(silent, scene_state, min_conf_thr, as_pointcloud, mask_sky,
441-
# clean_depth, transparent_cams, cam_size, TSDF_thresh)
423+
442424
scene.get_dense_pts3d()
443-
# outfile = convert_scene_output_to_ply(outfile_name, scene, scale=1.0, apply_y_flip=False, min_conf_thr=min_conf_thr, clean=clean_depth)
444425
return scene
445426

GEMstack/offboard/mast3r_3d_reconstruction/scale_pointcloud_based_on_geotag.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,16 @@ def to_cpu(x): return todevice(x, 'cpu')
6161
def to_cuda(x): return todevice(x, 'cuda')
6262

6363
def dms_to_decimal(d, m, s, ref):
64+
'''
65+
Convert degrees, minutes, seconds to decimal degrees.
66+
'''
6467
dd = d + m / 60 + s / 3600
6568
return -dd if ref in ['S', 'W'] else dd
6669

6770
def get_gps_from_exif(image_path):
71+
'''
72+
Get GPS information from an image file.
73+
'''
6874
img = Image.open(image_path)
6975
exif = img._getexif()
7076
if not exif:
@@ -80,6 +86,9 @@ def get_gps_from_exif(image_path):
8086
return gps_info
8187

8288
def parse_gps_info(image_path):
89+
'''
90+
Parse GPS information from an image file.
91+
'''
8392
gps_info = get_gps_from_exif(image_path)
8493
# Latitude
8594
lat_ref = gps_info.get('GPSLatitudeRef', 'N')
@@ -111,6 +120,9 @@ def parse_gps_info(image_path):
111120
}
112121

113122
def gps_to_xyz(gps_lookup, crs_from, crs_to):
123+
'''
124+
Convert GPS coordinates to xyz coordinates.
125+
'''
114126
xyz_lookup = {}
115127
transformer = Transformer.from_crs(crs_from, crs_to, always_xy=True) # includes altitude
116128
for image_name, (lat, lon, alt) in gps_lookup.items():
@@ -125,6 +137,8 @@ def estimate_3d_scale_from_gps(camera_centers, gps_xyz, camera_image_names, min_
125137
Inputs:
126138
camera_centers: (N, 3) array in MASt3r units (arbitrary scale)
127139
gps_xyz: (N, 3) array in meters [x, y, z] from lat/lon/alt
140+
camera_image_names: list of camera image names
141+
min_dist_threshold: minimum distance threshold for valid GPS pairs
128142
Returns:
129143
scale: estimated meters-per-unit scale factor
130144
"""
@@ -153,6 +167,19 @@ def estimate_3d_scale_from_gps(camera_centers, gps_xyz, camera_image_names, min_
153167

154168

155169
def estimate_scale_ransac(camera_centers, gps_xyz, camera_image_names, threshold=0.05, iterations=1000, min_dist=1.0):
170+
'''
171+
Estimate scale factor between MASt3r's camera centers and GPS 3D coordinates using RANSAC.
172+
173+
Inputs:
174+
camera_centers: (N, 3) array in MASt3r units (arbitrary scale)
175+
gps_xyz: (N, 3) array in meters [x, y, z] from lat/lon/alt
176+
camera_image_names: list of camera image names
177+
threshold: threshold for inliers
178+
iterations: number of RANSAC iterations
179+
min_dist: minimum distance threshold for valid GPS pairs
180+
Returns:
181+
scale: estimated meters-per-unit scale factor
182+
'''
156183
scales = []
157184
pairs = []
158185

@@ -191,11 +218,17 @@ def estimate_scale_ransac(camera_centers, gps_xyz, camera_image_names, threshold
191218
return best_scale, len(best_inliers), len(scales)
192219

193220
def extract_image_names(image_paths):
221+
'''
222+
Extract image names from a list of image paths.
223+
'''
194224
return [path.split('/')[-1] for path in image_paths]
195225

196226

197227

198228
def collect_gps_data(data_folder):
229+
'''
230+
Collect GPS data with extra metadata from a folder of images.
231+
'''
199232
records = []
200233
for fname in sorted(os.listdir(data_folder)):
201234
if fname.lower().endswith(('.jpg', '.jpeg', '.png')):
@@ -218,6 +251,9 @@ def collect_gps_data(data_folder):
218251

219252

220253
def run_mast3r(args):
254+
'''
255+
Run MASt3R on a folder of images.
256+
'''
221257
if args.weights is not None:
222258
weights_path = args.weights
223259
else:
@@ -246,22 +282,14 @@ def add_parse_args(parser, is_scene_path=False):
246282
parser.add_argument('--crs_from', type=str, required=False, default='EPSG:4979', help='EPSG code of the input CRS')
247283
parser.add_argument('--crs_to', type=str, required=False, default='EPSG:32616', help='EPSG code of the output CRS')
248284
if not is_scene_path:
249-
# parser.add_argument('--retrieval_model', type=str, required=False, default=None, help='Retrieval model weights path that is used to make image pairs')
250-
# parser.add_argument('--device', type=str, required=False, default='cuda:0', help='Device to run the model on')
251-
# parser.add_argument('--silent', type=bool, required=True, help='Whether to run the model silently')
252-
# parser.add_argument('--image_size', type=int, required=True, help='Image size')
253285
parser.add_argument('--optim_level', type=str, required=False, default='refine+depth', choices=['coarse', 'refine', 'refine+depth'], help='Optimization level')
254286
parser.add_argument('--lr1', type=float, required=False, default=0.07, help='Learning rate for the first refinement iteration stage')
255287
parser.add_argument('--niter1', type=int, required=False, default=300, help='Number of iterations for the first refinement iteration stage')
256288
parser.add_argument('--lr2', type=float, required=False, default=0.01, help='Learning rate for the second refinement iteration stage')
257289
parser.add_argument('--niter2', type=int, required=False, default=300, help='Number of iterations for the second refinement iteration stage')
258290
parser.add_argument('--min_conf_thr', type=float, required=False, default=1.5, help='Minimum confidence threshold')
259291
parser.add_argument('--matching_conf_thr', type=float, required=False, default=0., help='Matching confidence threshold')
260-
# parser.add_argument('--as_pointcloud', type=bool, required=True, help='Whether to output a pointcloud')
261-
# parser.add_argument('--mask_sky', type=bool, required=True, help='Whether to mask the sky')
262292
parser.add_argument('--clean_depth', type=bool, required=False, default=True, help='Whether to clean the depth')
263-
# parser.add_argument('--transparent_cams', type=bool, required=True, help='Whether to make the cameras transparent')
264-
# parser.add_argument('--cam_size', type=float, required=True, help='Camera size')
265293

266294
available_scenegraph_type = [
267295
("complete: all possible image pairs", "complete"),
@@ -295,6 +323,9 @@ def add_parse_args(parser, is_scene_path=False):
295323
return parser
296324

297325
def scale_pointcloud_based_on_geotag():
326+
'''
327+
Scale a pointcloud based on GPS data. If no scene file is provided, MASt3R will be run to generate a scene file.
328+
'''
298329
parser = argparse.ArgumentParser()
299330

300331
# Add known args
@@ -316,9 +347,11 @@ def scale_pointcloud_based_on_geotag():
316347
with open(args.scene_path, 'rb') as f:
317348
data = pickle.load(f)
318349

319-
350+
# Get camera centers
320351
cam2w = data.get_im_poses()
321352
camera_centers = cam2w[:, :3, 3] # Extract translation component from [R|t]
353+
354+
# Collect GPS data
322355
df = collect_gps_data(args.folder_path)
323356
image_gps_data = df.to_numpy()
324357
gps_lookup = {
@@ -327,6 +360,8 @@ def scale_pointcloud_based_on_geotag():
327360
}
328361
image_names = extract_image_names(data.img_paths)
329362
xyz_lookup = gps_to_xyz(gps_lookup, args.crs_from, args.crs_to)
363+
364+
# Estimate scale
330365
scale = 1.0
331366
if args.scale_method == 'ransac':
332367
scale, sfm_dists, gps_dists = estimate_scale_ransac(camera_centers.cpu().numpy(), xyz_lookup, image_names)
@@ -335,6 +370,8 @@ def scale_pointcloud_based_on_geotag():
335370
else:
336371
raise ValueError(f"Invalid scale method: {args.scale_method}")
337372
print(f"Estimated scale: {scale}")
373+
374+
# Convert scene output to PLY
338375
convert_scene_output_to_ply(args.output_path, data, scale=scale, apply_y_flip=False, min_conf_thr=args.min_conf_thr, clean=args.clean_depth, TSDF_thresh=args.TSDF_thresh)
339376

340377
if __name__ == "__main__":

0 commit comments

Comments
 (0)