diff --git a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py index 24dc456d15..2539a4aa05 100644 --- a/nerfstudio/data/dataparsers/nerfstudio_dataparser.py +++ b/nerfstudio/data/dataparsers/nerfstudio_dataparser.py @@ -31,6 +31,7 @@ get_train_eval_split_all, get_train_eval_split_filename, get_train_eval_split_fraction, + get_train_eval_split_indices, get_train_eval_split_interval, ) from nerfstudio.utils.io import load_from_json @@ -59,7 +60,7 @@ class NerfstudioDataParserConfig(DataParserConfig): """The method to use to center the poses.""" auto_scale_poses: bool = True """Whether to automatically scale the poses to fit in +/- 1 bounding box.""" - eval_mode: Literal["fraction", "filename", "interval", "all"] = "fraction" + eval_mode: Literal["fraction", "filename", "interval", "all", "indices"] = "fraction" """ The method to use for splitting the dataset into train and eval. Fraction splits based on a percentage for train and the remaining for eval. @@ -77,6 +78,8 @@ class NerfstudioDataParserConfig(DataParserConfig): """Replace the unknown pixels with this color. Relevant if you have a mask but still sample everywhere.""" load_3D_points: bool = False """Whether to load the 3D points from the colmap reconstruction.""" + eval_image_indices: Tuple[int, ...] = (0,) + """Specifies the image indices to use during eval; if None, uses all.""" @dataclass @@ -212,6 +215,8 @@ def _generate_dataparser_outputs(self, split="train"): i_train, i_eval = get_train_eval_split_filename(image_filenames) elif self.config.eval_mode == "interval": i_train, i_eval = get_train_eval_split_interval(image_filenames, self.config.eval_interval) + elif self.config.eval_mode == "indices": + i_train, i_eval = get_train_eval_split_indices(image_filenames, self.config.eval_image_indices) elif self.config.eval_mode == "all": CONSOLE.log( "[yellow] Be careful with '--eval-mode=all'. If using camera optimization, the cameras may diverge in the current implementation, giving unpredictable results." diff --git a/nerfstudio/data/utils/dataparsers_utils.py b/nerfstudio/data/utils/dataparsers_utils.py index 0c79cbde18..7e8e7bfff9 100644 --- a/nerfstudio/data/utils/dataparsers_utils.py +++ b/nerfstudio/data/utils/dataparsers_utils.py @@ -99,3 +99,23 @@ def get_train_eval_split_all(image_filenames: List) -> Tuple[np.ndarray, np.ndar i_train = i_all i_eval = i_all return i_train, i_eval + + +def get_train_eval_split_indices( + image_filenames: List, eval_image_indices: Tuple[int, ...] +) -> Tuple[np.ndarray, np.ndarray]: + """ + Get the train/eval split based on specified indices in the config. + + Args: + image_filenames: list of image filenames + eval_image_indices: Tuple of indices to use for evaluation. + """ + for idx in eval_image_indices: + if idx >= len(image_filenames) or idx < 0: + raise ValueError(f"Eval index {idx} is out of bounds for the number of images {len(image_filenames)}.") + all_indices = set(range(len(image_filenames))) + eval_indices = set(eval_image_indices) + train_indices = all_indices - eval_indices + + return np.array(sorted(train_indices)), np.array(sorted(eval_indices))