-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathutils.py
More file actions
23 lines (19 loc) · 748 Bytes
/
utils.py
File metadata and controls
23 lines (19 loc) · 748 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from omegaconf import OmegaConf
from conerf.base.model_base import ModelBase
from conerf.trainers.gaussian_trainer import GaussianSplatTrainer
from conerf.trainers.scaffold_gs_trainer import ScaffoldGSTrainer
def create_trainer(
config: OmegaConf,
prefetch_dataset=True,
trainset=None,
valset=None,
model: ModelBase = None
):
"""Factory function for training neural network trainers."""
if config.neural_field_type == "gs":
trainer = GaussianSplatTrainer(config, prefetch_dataset, trainset, valset, model)
elif config.neural_field_type == "scaffold_gs":
trainer = ScaffoldGSTrainer(config, prefetch_dataset, trainset, valset, model)
else:
raise NotImplementedError
return trainer