Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion nerfstudio/data/dataparsers/nerfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
get_train_eval_split_all,
get_train_eval_split_filename,
get_train_eval_split_fraction,
get_train_eval_split_indices,
get_train_eval_split_interval,
)
from nerfstudio.utils.io import load_from_json
Expand Down Expand Up @@ -59,7 +60,7 @@ class NerfstudioDataParserConfig(DataParserConfig):
"""The method to use to center the poses."""
auto_scale_poses: bool = True
"""Whether to automatically scale the poses to fit in +/- 1 bounding box."""
eval_mode: Literal["fraction", "filename", "interval", "all"] = "fraction"
eval_mode: Literal["fraction", "filename", "interval", "all", "indices"] = "fraction"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the indices behavior in the comment below

"""
The method to use for splitting the dataset into train and eval.
Fraction splits based on a percentage for train and the remaining for eval.
Expand All @@ -77,6 +78,8 @@ class NerfstudioDataParserConfig(DataParserConfig):
"""Replace the unknown pixels with this color. Relevant if you have a mask but still sample everywhere."""
load_3D_points: bool = False
"""Whether to load the 3D points from the colmap reconstruction."""
eval_image_indices: Tuple[int, ...] = (0,)
"""Specifies the image indices to use during eval; if None, uses all."""


@dataclass
Expand Down Expand Up @@ -212,6 +215,8 @@ def _generate_dataparser_outputs(self, split="train"):
i_train, i_eval = get_train_eval_split_filename(image_filenames)
elif self.config.eval_mode == "interval":
i_train, i_eval = get_train_eval_split_interval(image_filenames, self.config.eval_interval)
elif self.config.eval_mode == "indices":
i_train, i_eval = get_train_eval_split_indices(image_filenames, self.config.eval_image_indices)
elif self.config.eval_mode == "all":
CONSOLE.log(
"[yellow] Be careful with '--eval-mode=all'. If using camera optimization, the cameras may diverge in the current implementation, giving unpredictable results."
Expand Down
20 changes: 20 additions & 0 deletions nerfstudio/data/utils/dataparsers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,23 @@ def get_train_eval_split_all(image_filenames: List) -> Tuple[np.ndarray, np.ndar
i_train = i_all
i_eval = i_all
return i_train, i_eval


def get_train_eval_split_indices(
image_filenames: List, eval_image_indices: Tuple[int, ...]
) -> Tuple[np.ndarray, np.ndarray]:
"""
Get the train/eval split based on specified indices in the config.

Args:
image_filenames: list of image filenames
eval_image_indices: Tuple of indices to use for evaluation.
"""
for idx in eval_image_indices:
if idx >= len(image_filenames) or idx < 0:
raise ValueError(f"Eval index {idx} is out of bounds for the number of images {len(image_filenames)}.")
all_indices = set(range(len(image_filenames)))
eval_indices = set(eval_image_indices)
train_indices = all_indices - eval_indices

return np.array(sorted(train_indices)), np.array(sorted(eval_indices))