Skip to content

Commit 036b217

Browse files
committed
SARATRX: load HiViT pretrained weights
1 parent 9dff42b commit 036b217

10 files changed

Lines changed: 169 additions & 39 deletions

File tree

.gitmodules

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
[submodule "SARIAD/models/image/MFSA/SARDet_100K"]
55
path = SARIAD/models/image/MFSA/SARDet_100K
66
url = https://github.com/zcablii/SARDet_100K/
7-
[submodule "SARIAD/models/image/SARATR-X/SARATR-X"]
8-
path = SARIAD/models/image/SARATR-X/SARATRX
9-
url = https://github.com/waterdisappear/SARATR-X
107
[submodule "SARIAD/models/image/SARATR-X/SARATRX"]
11-
path = SARIAD/models/image/SARATR-X/SARATRX
12-
url = git@github.com:waterdisappear/SARATR-X.git
8+
path = SARIAD/models/image/SARATRX/SARATRX
9+
url = git@github.com:lucianchauvin/SARATR-X.git

SARIAD/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .image.SARATRX import *
Lines changed: 0 additions & 1 deletion
This file was deleted.
Submodule SARATRX added at 1055154
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .lightning_model import *
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import logging
2+
import torch
3+
from lightning.pytorch.utilities.types import STEP_OUTPUT
4+
from anomalib.data import Batch
5+
from anomalib.models.components import AnomalibModule
6+
from SARIAD.models.image.SARATRX.SARATRX.pretraining.models.models_hivit_mae import HiViTMaskedAutoencoder
7+
from SARIAD.models.image.SARATRX.SARATRX.pretraining.models.models_hivit import HiViT
8+
from SARIAD.models.image.SARATRX.SARATRX.pretraining.util.pos_embed import interpolate_pos_embed
9+
from SARIAD.utils.blob_utils import fetch_blob
10+
from anomalib.post_processing import PostProcessor
11+
12+
logger = logging.getLogger(__name__)
13+
14+
class SARATRX(AnomalibModule):
15+
def __init__(self, pre_processor=True, post_processor=True, num_classes=2):
16+
super().__init__(pre_processor, post_processor)
17+
self.trainer_arguments = {
18+
"max_epochs": 100,
19+
"accelerator": "gpu",
20+
"devices": 1,
21+
"check_val_every_n_epoch": 1,
22+
"callbacks": [],
23+
"logger": True,
24+
}
25+
26+
self.model = HiViT(num_classes=num_classes)
27+
self.outputs = []
28+
29+
fetch_blob("mae_hivit_base_1600ep.pth", drive_file_id="1VZQz4buhlepZ5akTcEvrA3a_nxsQZ8eQ", is_archive=False);
30+
checkpoint = torch.load("mae_hivit_base_1600ep.pth", map_location='cpu')
31+
# print(checkpoint)
32+
33+
checkpoint_model = checkpoint
34+
state_dict = self.model.state_dict()
35+
print(len(state_dict.keys()))
36+
for k in ['head.weight', 'head.bias']:
37+
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
38+
print(f"Removing key {k} from pretrained checkpoint")
39+
del checkpoint_model[k]
40+
41+
interpolate_pos_embed(self.model, checkpoint_model)
42+
43+
msg = self.model.load_state_dict(checkpoint_model, strict=False)
44+
print(msg)
45+
print(self.model)
46+
47+
def configure_post_processor(self):
48+
return PostProcessor()
49+
50+
def training_step(self, batch: Batch, *args, **kwargs) -> None:
51+
output = self.model(batch.image)
52+
self.outputs.append(output)
53+
54+
def configure_optimizers(self):
55+
pass
56+
57+
def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
58+
pass
59+
60+
def learning_type(self):
61+
pass
62+
63+
def trainer_arguments(self):
64+
pass
65+
66+
if __name__ == "__main__":
67+
SARATRX()

SARIAD/models/image/YOLO/__init__.py

Whitespace-only changes.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import logging
2+
import torch
3+
from lightning.pytorch.utilities.types import STEP_OUTPUT
4+
from anomalib import LearningType
5+
from anomalib.data import Batch
6+
from anomalib.metrics import Evaluator
7+
from anomalib.models.components import AnomalibModule
8+
from anomalib.post_processing import PostProcessor
9+
from anomalib.visualization import Visualizer
10+
11+
from .torch_model import YOLOAnomalyModel
12+
13+
logger = logging.getLogger(__name__)
14+
15+
class YOLOAnomaly(AnomalibModule):
16+
def __init__(self, backbone = "yolov8n.pt", pre_processor = True, post_processor = True):
17+
super().__init__(pre_processor=pre_processor, post_processor=post_processor)
18+
19+
self.model = YOLOAnomalyModel(model_path=backbone)
20+
21+
def training_step(self, batch: Batch, *args, **kwargs) -> None:
22+
pass
23+
def fit(self) -> None:
24+
pass
25+
def validation_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
26+
pass
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from ultralytics import YOLO
2+
import torch
3+
from torch import nn
4+
from torch.nn import functional as F
5+
from anomalib.data import InferenceBatch
6+
from anomalib.models.components import MultiVariateGaussian
7+
8+
class YOLOAnomalyModel(nn.Module):
9+
def __init__(self, model_path = "yolov8n.pt"):
10+
super().__init__()
11+
self.yolo = YOLO(model_path)
12+
13+
def forward(self, input_tensor):

SARIAD/utils/blob_utils.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,66 +3,87 @@
33

44
from SARIAD.config import DATASETS_PATH
55

6-
def fetch_blob(path, link="", drive_file_id="", kaggle="", ext="zip"):
6+
def fetch_blob(path, link="", drive_file_id="", kaggle="", is_archive=True, ext="zip"):
77
"""
88
Fetches the dataset blob from a direct link, Google Drive, or Kaggle,
99
and extracts it directly to the specified path.
1010
1111
Parameters:
12-
- path: str, The full path to the directory where the extracted blob should reside.
13-
- link: str, optional, direct HTTP(s) link to an archive.
14-
- drive_file_id: str, optional, ID for Google Drive file (archive).
12+
- path: str, The full path to the directory where the extracted blob should reside (if is_archive=True)
13+
or the full path to the file itself (if is_archive=False).
14+
- link: str, optional, direct HTTP(s) link to an archive or single file.
15+
- drive_file_id: str, optional, ID for Google Drive file (archive or single file).
1516
- kaggle: str, optional, KaggleHub dataset slug.
16-
- ext: str, archive type (zip, tar.gz, rar, tar), used for link and drive_file_id.
17-
This parameter is ignored if 'kaggle' is provided.
17+
- is_archive: bool, set to True if the fetched item is an archive that needs extraction.
18+
Set to False for single files like .pth.
19+
- ext: str, archive type (zip, tar.gz, rar, tar) or file extension (e.g., "pth"),
20+
used for link and drive_file_id. This parameter is ignored if 'kaggle' is provided.
1821
"""
19-
if os.path.exists(path) and os.path.isdir(path) and len(os.listdir(path)) > 0:
20-
print(f"Dataset found locally at: {path}")
21-
return
22+
if is_archive:
23+
if os.path.exists(path) and os.path.isdir(path) and len(os.listdir(path)) > 0:
24+
print(f"Dataset found locally at: {path}")
25+
return
26+
else:
27+
if os.path.exists(path) and os.path.isfile(path):
28+
print(f"File found locally at: {path}")
29+
return
2230

2331
print(f"Dataset not found locally at {path}. Downloading...")
24-
os.makedirs(path, exist_ok=True)
32+
if is_archive:
33+
os.makedirs(path, exist_ok=True)
34+
else:
35+
os.makedirs(os.path.dirname(path) or '.', exist_ok=True)
2536

2637
if link:
27-
temp_archive_name = f"{os.path.basename(path)}_archive.{ext}"
28-
temp_archive_path = os.path.join(os.path.dirname(path) or '.', temp_archive_name)
38+
if is_archive:
39+
temp_target_path = os.path.join(os.path.dirname(path) or '.', f"{os.path.basename(path)}_archive.{ext}")
40+
else:
41+
temp_target_path = path
2942

3043
response = requests.get(link, stream=True)
3144
if response.status_code != 200:
3245
raise RuntimeError(f"Failed to download file from {link}: HTTP {response.status_code}")
3346

3447
total_size_in_bytes = int(response.headers.get('content-length', 0))
3548
block_size = 8192
36-
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, desc=f"Downloading {os.path.basename(temp_archive_path)}")
49+
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, desc=f"Downloading {os.path.basename(temp_target_path)}")
3750

38-
with open(temp_archive_path, 'wb') as f:
51+
with open(temp_target_path, 'wb') as f:
3952
for chunk in response.iter_content(chunk_size=block_size):
4053
progress_bar.update(len(chunk))
4154
f.write(chunk)
4255
progress_bar.close()
4356

44-
print(f"Extracting the {ext} archive...")
45-
_extract_archive(temp_archive_path, path, ext)
46-
os.remove(temp_archive_path)
47-
print(f"Downloaded and extracted to {path}.")
57+
if is_archive:
58+
print(f"Extracting the {ext} archive...")
59+
_extract_archive(temp_target_path, path, ext)
60+
os.remove(temp_target_path)
61+
print(f"Downloaded and extracted to {path}.")
62+
else:
63+
print(f"Downloaded file to {path}.")
4864

4965
elif drive_file_id:
50-
temp_archive_name = f"{os.path.basename(path)}_archive.{ext}"
51-
temp_archive_path = os.path.join(os.path.dirname(path) or '.', temp_archive_name)
66+
if is_archive:
67+
temp_target_path = os.path.join(os.path.dirname(path) or '.', f"{os.path.basename(path)}_archive.{ext}")
68+
else:
69+
temp_target_path = path # For single files, download directly to the final path
5270

5371
print(f"Downloading from Google Drive ID: {drive_file_id}")
54-
gdown.download(f"https://drive.google.com/uc?id={drive_file_id}", temp_archive_path, quiet=False)
72+
gdown.download(f"https://drive.google.com/uc?id={drive_file_id}", temp_target_path, quiet=False)
5573

56-
print(f"Extracting the {ext} archive...")
57-
_extract_archive(temp_archive_path, path, ext)
58-
os.remove(temp_archive_path)
59-
print(f"Downloaded and extracted to {path}.")
74+
if is_archive:
75+
print(f"Extracting the {ext} archive...")
76+
_extract_archive(temp_target_path, path, ext)
77+
os.remove(temp_target_path)
78+
print(f"Downloaded and extracted to {path}.")
79+
else:
80+
print(f"Downloaded file to {path}.")
6081

6182
elif kaggle:
6283
downloaded_kaggle_path = kagglehub.dataset_download(kaggle)
6384
print(f"KaggleHub {kaggle} dataset downloaded to: {downloaded_kaggle_path}")
6485

65-
os.makedirs(path, exist_ok=True)
86+
os.makedirs(path, exist_ok=True) # Always treat kaggle as an archive/dataset for now
6687

6788
for item in os.listdir(downloaded_kaggle_path):
6889
s = os.path.join(downloaded_kaggle_path, item)
@@ -132,25 +153,29 @@ def _extract_archive(archive_path, extract_to, ext):
132153
shutil.move(os.path.join(temp_extract_dir, item), extract_to)
133154
shutil.rmtree(temp_extract_dir)
134155

135-
def fetch_dataset(dataset_name, datasets_dir=DATASETS_PATH, link="", drive_file_id="", kaggle="", ext="zip"):
156+
def fetch_dataset(dataset_name, datasets_dir=DATASETS_PATH, link="", drive_file_id="", kaggle="", is_archive=True, ext="zip"):
136157
"""
137158
Fetches a dataset blob from a direct link, Google Drive, or Kaggle,
138159
maintaining backward compatibility with the original fetch_blob signature.
139160
140161
Parameters:
141-
- dataset_name: str, The name of the dataset. This will be the directory name inside datasets_dir.
162+
- dataset_name: str, The name of the dataset. This will be the directory name inside datasets_dir
163+
for archives, or the file name if is_archive is False.
142164
- datasets_dir: str, The root directory where datasets are stored.
143-
- link: str, optional, direct HTTP(s) link to an archive.
144-
- drive_file_id: str, optional, ID for Google Drive file (archive).
165+
- link: str, optional, direct HTTP(s) link to an archive or file.
166+
- drive_file_id: str, optional, ID for Google Drive file (archive or file).
145167
- kaggle: str, optional, KaggleHub dataset slug.
146-
- ext: str, archive type (zip, tar.gz, rar, tar), used for link and drive_file_id.
147-
This parameter is ignored if 'kaggle' is provided.
168+
- ext: str, archive type (zip, tar.gz, rar, tar) or file extension (e.g., "pth"),
169+
used for link and drive_file_id. This parameter is ignored if 'kaggle' is provided.
170+
- is_archive: bool, set to True if the fetched item is an archive that needs extraction.
171+
Set to False for single files like .pth.
148172
"""
149173
full_dataset_path = os.path.join(datasets_dir, dataset_name)
150174
fetch_blob(
151175
path=full_dataset_path,
152176
link=link,
153177
drive_file_id=drive_file_id,
154178
kaggle=kaggle,
155-
ext=ext
179+
ext=ext,
180+
is_archive=is_archive
156181
)

0 commit comments

Comments
 (0)