-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
62 lines (50 loc) · 1.9 KB
/
eval.py
File metadata and controls
62 lines (50 loc) · 1.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import torch.nn as nn
import torch.optim as optim
from model import ResNet50WithFeatures
from dataload import OpenSARShipDataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
def inference_on_test_data(model_path, test_data_path, device='cuda'):
"""
Perform inference on test data collected from real SAR-AIS integration
"""
# Load model
checkpoint = torch.load(model_path)
model = ResNet50WithFeatures(num_features=14, num_classes=4)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()
# Load test dataset
test_dataset = OpenSARShipDataset(
root_dir=test_data_path,
split='test',
extract_features=True
)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
# Class names
class_names = ['Bulk Carrier', 'Container Ship', 'Fishing', 'Tanker']
# Inference
results = []
with torch.no_grad():
for batch in test_loader:
image = batch['image'].to(device)
features = batch['features'].to(device)
img_id = batch['img_id'][0]
# Predict
output = model(image, features)
probs = F.softmax(output, dim=1)
pred_class = output.argmax(1).item()
confidence = probs[0, pred_class].item()
results.append({
'img_id': img_id,
'predicted_class': class_names[pred_class],
'confidence': confidence,
'probabilities': probs[0].cpu().numpy()
})
print(f"Image: {img_id} | "
f"Predicted: {class_names[pred_class]} | "
f"Confidence: {confidence:.4f}")
return results
# Usage
# results = inference_on_test_data('resnet50_ship_classifier.pth', '/path/to/test/data')