forked from zichaozeng/MuSSel-V_LAD
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfigs.py
More file actions
179 lines (151 loc) · 5.52 KB
/
configs.py
File metadata and controls
179 lines (151 loc) · 5.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""Global configuration management for MuSSel-V experiments.
This module provides:
- Path configuration and PYTHONPATH setup
- Command-line argument parsing via tyro
- Device configuration (CUDA/CPU)
- Dataclass-based configuration structures for experiments
The configurations defined here can be used across different scripts
for consistent experiment settings and reproducibility.
"""
# %%
# Get the repository path as 'lib_path'
import os
import sys
from pathlib import Path
dir_name = None
try:
dir_name = os.path.dirname(os.path.realpath(__file__))
except NameError:
dir_name = os.path.abspath('')
lib_path = os.path.realpath(f"{Path(dir_name)}")
# Add library path to PYTHONPATH
print("[INFO]: Configs is modifying path")
if lib_path not in sys.path:
print(f"Adding library path: {lib_path} to PYTHONPATH")
sys.path.append(lib_path)
else:
print(f"Library path {lib_path} already in PYTHONPATH")
# Username on system
user_name = os.environ.get("USER", "xxxxxx")
# %%
# Import everything
from dataclasses import dataclass, field
from typing import Literal, List, Union
import torch
import tyro
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
# %%
def try_tyro(x, allow_safe_quit=False):
"""
Wrap 'x' dataclass around tyro.cli (if it works)
Parameters:
- x: A class wrapped in `dataclasses.dataclass` for tyro
- allow_safe_quit: If exit code is 0 (maybe `-h`), then exit
"""
try:
return tyro.cli(x)
except (SystemExit, Exception) as exc:
print("[WARN]: Tyro might not have parsed all arguments! "\
f"Exception: {exc}")
print("[INFO]: Ignore above warning if multiple tyro used")
if str(exc) == "0":
print("[DEBUG]: Exit code 0 received")
if allow_safe_quit:
print("[INFO]: Safe exit is enabled, exiting...")
quit(0)
return x() # Passthrough
# %% --------------------- Simple variables ---------------------
@dataclass(frozen=True)
class ProgArgs:
"""
Core program arguments
"""
# A directory for storing cache
cache_dir: Path = ""
# The directory where 'datasets-vg' are downloaded
data_vg_dir: Path = ""
# Default dataset name for VPR
vg_dataset_name: Literal["st_lucia", "pitts30k", "17places", "nordland", "tokyo247", "baidu_datasets", "Oxford", "Oxford_25m", "gardens", "hawkins","hawkins_long_corridor", "global", "VPAir", "Tartan_GNSS_rotated", "Tartan_GNSS_notrotated", "Tartan_GNSS_test_notrotated", "Tartan_GNSS_test_rotated", "laurel_caverns","eiffel"] = "hawkins_long_corridor"
# Use wandb (False = No WandB)
use_wandb: bool = False
# WandB project name
wandb_proj: str = "Baselines"
# WandB entity (should be 'vpr-vl')
wandb_entity: str = "vpr-vl"
# WandB result group (within the project)
wandb_group: str = "Oxford"
# WandB run name (within the group)
wandb_run_name: str = "Oxford/CLIP_TopK"
# Save qualitative results for WandB
wandb_save_qual: bool = False
prog_args = ProgArgs()
"""
> Note: Default placeholder, not effected by tyro.
"""
_real_path = lambda x: os.path.realpath(os.path.expanduser(x))
# Cache folder for results
caching_directory = _real_path(prog_args.cache_dir)
"""
A folder that has a lot of space (to store cache). The folder
structure is created by the program that uses this variable.
> Note: Default placeholder, not effected by tyro.
"""
# Datasets directory
datasets_dir = _real_path(prog_args.data_vg_dir)
"""
The folder where all the VPR datasets are stored in the format of
the `datasets-vg` repository. The format is usually of the form:\n
datasets\n
└── st_lucia\n
├── images\n
│ └── test\n
│ ├── database [1549 entries ...]\n
│ └── queries [1464 entries ...]\n
└── map_st_lucia.png\n
Name `st_lucia` is just an example.
> Note: Default placeholder, not effected by tyro.
"""
# Dataset name
dataset_name = prog_args.vg_dataset_name
"""
Dataset name for VPR
> Note: Default placeholder, not effected by tyro.
"""
# %% --------------------- Argument classes ---------------------
# For parsing datasets
@dataclass(frozen=True)
class BaseDatasetArgs:
"""
Dataset arguments for BaseDataset in `datasets_ws.py`
"""
# Resize shape: [H, W]
resize: List[int] = field(default_factory=lambda:[480, 640])
# Pre/post-processing methods and prediction refinement
test_method: Literal["hard_resize", "single_query", \
"central_crop", "five_crops", "nearest_crop", \
"maj_voting"] = "hard_resize"
"""
Pre/post processing method must be one of the following (str)
- hard_resize: Apply straightforward resize to above shape.
- single_query: Resize to `min(resize)`. Used when queries
have varying sizes and can't be batched.
- central_crop: Take the biggest central crop of size resize.
Preserves ratio.
- five_crops: See [1]
- nearest_crop: See [1]
- maj_voting: See [1]
[1]: Get 5 square crops with size `min(resize)` and
batch them.
"""
# Threshold value for positive distance (classification)
val_positive_dist_threshold: int = 25
base_dataset_args = BaseDatasetArgs()
"""
> Note: Default placeholder, not effected by tyro.
"""
# %%
# Experimental section
# %%
# %%
# Experimental section
# %%