From 11682b1a30696a41aa2671c5dcb29bbe5a647828 Mon Sep 17 00:00:00 2001 From: Robotmurlock Date: Sun, 21 Sep 2025 10:27:53 +0200 Subject: [PATCH 1/2] feat: Implement per scene sampler within range --- .../sampler/scene_sampler_within_range.yaml | 3 + .../pred_bbox_feature_extractor.py | 63 ++++++-- mot_jepa/datasets/dataset/index/mot.py | 4 +- .../datasets/dataset/sampler/scene_sampler.py | 149 ++++++++++++++++++ tools/inference.py | 6 +- 5 files changed, 203 insertions(+), 22 deletions(-) create mode 100644 configs/dataset/sampler/scene_sampler_within_range.yaml diff --git a/configs/dataset/sampler/scene_sampler_within_range.yaml b/configs/dataset/sampler/scene_sampler_within_range.yaml new file mode 100644 index 0000000..6436a6a --- /dev/null +++ b/configs/dataset/sampler/scene_sampler_within_range.yaml @@ -0,0 +1,3 @@ +_target_: 'mot_jepa.datasets.dataset.sampler.scene_sampler.OneSceneWithRangeSampler.from_dataset' +n_scenes: 5 +n_frames: 2 \ No newline at end of file diff --git a/mot_jepa/datasets/dataset/feature_extractor/pred_bbox_feature_extractor.py b/mot_jepa/datasets/dataset/feature_extractor/pred_bbox_feature_extractor.py index 5a14ddd..c5479d4 100644 --- a/mot_jepa/datasets/dataset/feature_extractor/pred_bbox_feature_extractor.py +++ b/mot_jepa/datasets/dataset/feature_extractor/pred_bbox_feature_extractor.py @@ -1,6 +1,7 @@ import random from typing import Dict, List, Set +import enum import torch from mot_jepa.datasets.dataset.common.data import VideoClipPart @@ -9,11 +10,17 @@ from mot_jepa.utils.extra_features import ExtraFeaturesReader +class SupportedFeatures(enum.Enum): + BBOX = 'bbox' + KEYPOINTS = 'keypoints' + APPEARANCE = 'appearance' + + class PredictionBBoxFeatureExtractor(FeatureExtractor): BBOX_DIM = 5 # XYWHC KEYPOINTS_DIM = 35 # 17x2 XYC (per part) + 1 C (global) = 52 APPEARANCE_DIM = 768 # 6x128 - SUPPORTED_FEATURES = {'bbox', 'keypoints', 'appearance'} + SUPPORTED_FEATURES = {SupportedFeatures.BBOX, SupportedFeatures.KEYPOINTS, SupportedFeatures.APPEARANCE} WORKER_ID_STEP = 1_000_000 @@ -24,21 +31,27 @@ def __init__( n_tracks: int, prediction_path: str, feature_names: List[str], - extra_false_positives: bool = True + extra_false_positives: bool = True, + random_appearance_jitter_ratio: float = 0.0, + random_appearance_jitter_range: int = 0 ): super().__init__( index=index, object_id_mapping=object_id_mapping, n_tracks=n_tracks ) - feature_names = [feature_name.lower() for feature_name in feature_names] + feature_names = [SupportedFeatures(feature_name.lower()) for feature_name in feature_names] for feature_name in feature_names: assert feature_name in self.SUPPORTED_FEATURES, \ f'Unsupported feature "{feature_name}". Supported features: {self.SUPPORTED_FEATURES}' self._feature_names = set(feature_names) self._extra_features_reader = ExtraFeaturesReader(prediction_path) + + # Augmentations self._extra_false_positives = extra_false_positives + self._random_appearance_jitter_ratio = random_appearance_jitter_ratio if index.split == 'train' else 0.0 + self._random_appearance_jitter_range = random_appearance_jitter_range if index.split == 'train' else 0 self._worker_id_counter = 0 @@ -57,27 +70,27 @@ def bbox_to_tensor(bbox: List[float], score: float) -> torch.Tensor: return torch.tensor([*bbox, score], dtype=torch.float32) @staticmethod - def initialize_features(feature_names: Set[str], n_tracks: int, temporal_length: int) -> Dict[str, torch.Tensor]: + def initialize_features(feature_names: Set[SupportedFeatures], n_tracks: int, temporal_length: int) -> Dict[str, torch.Tensor]: features: Dict[str, torch.Tensor] = {} - if 'bbox' in feature_names: - features['bbox'] = torch.zeros(n_tracks, temporal_length, PredictionBBoxFeatureExtractor.BBOX_DIM, dtype=torch.float32) - if 'keypoints' in feature_names: - features['keypoints'] = torch.zeros(n_tracks, temporal_length, PredictionBBoxFeatureExtractor.KEYPOINTS_DIM, dtype=torch.float32) - if 'appearance' in feature_names: - features['appearance'] = torch.zeros(n_tracks, temporal_length, PredictionBBoxFeatureExtractor.APPEARANCE_DIM, dtype=torch.float32) + if SupportedFeatures.BBOX in feature_names: + features[SupportedFeatures.BBOX.value] = torch.zeros(n_tracks, temporal_length, PredictionBBoxFeatureExtractor.BBOX_DIM, dtype=torch.float32) + if SupportedFeatures.KEYPOINTS in feature_names: + features[SupportedFeatures.KEYPOINTS.value] = torch.zeros(n_tracks, temporal_length, PredictionBBoxFeatureExtractor.KEYPOINTS_DIM, dtype=torch.float32) + if SupportedFeatures.APPEARANCE in feature_names: + features[SupportedFeatures.APPEARANCE.value] = torch.zeros(n_tracks, temporal_length, PredictionBBoxFeatureExtractor.APPEARANCE_DIM, dtype=torch.float32) return features @staticmethod - def _set_features(feature_names: Set[str], features: Dict[str, torch.Tensor], object_index: int, clip_index: int, data: dict) -> None: - if 'bbox' in feature_names: + def _set_features(feature_names: Set[SupportedFeatures], features: Dict[str, torch.Tensor], object_index: int, clip_index: int, data: dict) -> None: + if SupportedFeatures.BBOX in feature_names: bbox = [*data['bbox_xywh'], data['bbox_conf']] - features['bbox'][object_index, clip_index, :] = torch.tensor(bbox, dtype=torch.float32) - if 'keypoints' in feature_names: + features[SupportedFeatures.BBOX.value][object_index, clip_index, :] = torch.tensor(bbox, dtype=torch.float32) + if SupportedFeatures.KEYPOINTS in feature_names: keypoints = sum([d[:2] for d in data['keypoints_xyc']], []) + [data['keypoints_conf']] - features['keypoints'][object_index, clip_index, :] = torch.tensor(keypoints, dtype=torch.float32) - if 'appearance' in feature_names: + features[SupportedFeatures.KEYPOINTS.value][object_index, clip_index, :] = torch.tensor(keypoints, dtype=torch.float32) + if SupportedFeatures.APPEARANCE in feature_names: embs = [[e * float(visibility) for e in emb] for emb, visibility in zip(data['appearance_embeddings'], data['appearance_visibility'])] - features['appearance'][object_index, clip_index, :] = torch.tensor(sum(embs, []), dtype=torch.float32) + features[SupportedFeatures.APPEARANCE.value][object_index, clip_index, :] = torch.tensor(sum(embs, []), dtype=torch.float32) def _extract_extra_data( self, @@ -124,5 +137,21 @@ def _extract_extra_data( next_id = self.WORKER_ID_STEP * worker_id + self._worker_id_counter video_clip_part.ids[object_index, clip_index] = next_id + if self._random_appearance_jitter_ratio > 0 and SupportedFeatures.APPEARANCE in self._feature_names: + scene_info = self._index.get_scene_info(scene_name) + jitter = random.randint(-self._random_appearance_jitter_range, self._random_appearance_jitter_range) + random_frame_index = max(0, min(scene_info.seqlength - 1, frame_index + jitter)) + aug_extra_data = self._extra_features_reader.read(scene_name, random_frame_index) + aug_object_id_to_extra_data_lookup: Dict[str, dict] = {raw['object_id']: raw for raw in aug_extra_data if raw['object_id'] is not None} + + for object_index, object_id in enumerate(object_ids): + data = aug_object_id_to_extra_data_lookup.get(object_id) + if data is None: + continue + + r = random.random() + if r < self._random_appearance_jitter_ratio: + self._set_features([SupportedFeatures.APPEARANCE], features, object_index, clip_index, data) + video_clip_part.features = features return video_clip_part diff --git a/mot_jepa/datasets/dataset/index/mot.py b/mot_jepa/datasets/dataset/index/mot.py index 659ad75..210bfc1 100644 --- a/mot_jepa/datasets/dataset/index/mot.py +++ b/mot_jepa/datasets/dataset/index/mot.py @@ -61,8 +61,7 @@ def __init__( split: str, sequence_list: Optional[List[str]] = None, label_type: LabelType = LabelType.GROUND_TRUTH, - skip_corrupted: bool = False, - test: bool = True + skip_corrupted: bool = False ) -> None: """ Args: @@ -77,6 +76,7 @@ def __init__( split=split, sequence_list=sequence_list ) + test = (split == 'test') if isinstance(paths, str): paths = [paths] diff --git a/mot_jepa/datasets/dataset/sampler/scene_sampler.py b/mot_jepa/datasets/dataset/sampler/scene_sampler.py index ee8fbbf..b756236 100644 --- a/mot_jepa/datasets/dataset/sampler/scene_sampler.py +++ b/mot_jepa/datasets/dataset/sampler/scene_sampler.py @@ -219,6 +219,120 @@ def __iter__(self) -> Iterator[List[int]]: yield batch +class OneSceneWithRangeSampler(Sampler[List[int]]): + """BatchSampler: N scenes × exactly M consecutive frames/scene. + + Each batch contains N different scenes with M consecutive frames per scene. + Frames within each scene must be consecutive (no gaps). + Scenes can be sampled multiple times across different batches. + + Args: + scene_ids: scene id per dataset index (len == len(dataset)). + n_scenes: number of distinct scenes per batch (N). + n_frames: number of consecutive frames per scene in a batch (M). + shuffle_scenes: shuffle scene order each epoch. + seed: base RNG seed. + """ + def __init__( + self, + scene_ids: List[str], + n_scenes: int, + n_frames: int, + shuffle_scenes: bool = True, + seed: int = 0, + ) -> None: + super().__init__() + self._scene_ids = scene_ids + self._n_scenes = n_scenes + self._n_frames = n_frames + self._shuffle_scenes = shuffle_scenes + self._seed = seed + self._iteration = 0 + self._by_scene = _group_indices_by_scene(scene_ids) + + # Filter out scenes that don't have enough consecutive frames + self._valid_scenes = {} + for scene_id, indices in self._by_scene.items(): + if len(indices) >= self._n_frames: + # Sort indices to ensure we can find consecutive sequences + sorted_indices = sorted(indices) + self._valid_scenes[scene_id] = sorted_indices + + # Calculate number of batches based on total possible samples + # Each scene can contribute multiple samples (consecutive frame ranges) + total_samples = 0 + for scene_id, indices in self._valid_scenes.items(): + # Number of possible consecutive ranges of length n_frames + num_ranges = len(indices) - self._n_frames + 1 + total_samples += num_ranges + + self._num_batches = math.ceil(total_samples / self._n_scenes) + + @classmethod + def from_dataset( + cls, + dataset: MOTClipDataset, + n_scenes: int, + n_frames: int, + shuffle_scenes: bool = True, + seed: int = 0, + ): + return cls( + scene_ids=dataset.scene_names_per_frame, + n_scenes=n_scenes, + n_frames=n_frames, + shuffle_scenes=shuffle_scenes, + seed=seed, + ) + + def __len__(self) -> int: + return self._num_batches + + def __iter__(self) -> Iterator[List[int]]: + rng = random.Random(self._seed + self._iteration) + self._iteration += 1 + + # Generate all possible consecutive frame ranges for each scene + all_samples = [] + for scene_id, indices in self._valid_scenes.items(): + for start_idx in range(len(indices) - self._n_frames + 1): + consecutive_frames = indices[start_idx:start_idx + self._n_frames] + all_samples.append((scene_id, consecutive_frames)) + + # Shuffle all samples + rng.shuffle(all_samples) + + # Group samples into batches, ensuring each scene appears only once per batch + batch_samples = [] + used_scenes_in_batch = set() + + for scene_id, frames in all_samples: + # If we already have this scene in current batch, skip + if scene_id in used_scenes_in_batch: + continue + + batch_samples.append((scene_id, frames)) + used_scenes_in_batch.add(scene_id) + + # When we have enough scenes for a batch, yield it + if len(batch_samples) >= self._n_scenes: + batch: List[int] = [] + for _, frames in batch_samples: + batch.extend(frames) + yield batch + + # Reset for next batch + batch_samples = [] + used_scenes_in_batch = set() + + # Yield remaining samples as final batch if any + if batch_samples: + batch: List[int] = [] + for _, frames in batch_samples: + batch.extend(frames) + yield batch + + def run_test_batch_sampler_without_repeat() -> None: scene_ids = ['A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'C', 'C', 'D', 'D'] batch_sampler = SceneBatchSamplerNoRepeat( @@ -243,6 +357,41 @@ def run_test_batch_sampler_with_repeat() -> None: print(sampled_scene_ids) +def run_test_one_scene_with_range_sampler() -> None: + """Test the OneSceneWithRangeSampler with consecutive frame sampling.""" + # Create test data with consecutive indices per scene + scene_ids = ['A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'C', 'C', 'C', 'D', 'D'] + batch_sampler = OneSceneWithRangeSampler( + scene_ids=scene_ids, + n_scenes=2, + n_frames=3 + ) + + print("OneSceneWithRangeSampler test:") + print(f"Scene distribution: {dict(_group_indices_by_scene(scene_ids))}") + print(f"Valid scenes: {list(batch_sampler._valid_scenes.keys())}") + print(f"Total batches: {len(batch_sampler)}") + + for i, batch in enumerate(batch_sampler): + batch_scene_ids = [scene_ids[b_i] for b_i in batch] + print(f"Batch {i}: indices={batch}, scenes={batch_scene_ids}") + + # Verify consecutive frames within each scene + scene_groups = {} + for idx, scene_id in zip(batch, batch_scene_ids): + if scene_id not in scene_groups: + scene_groups[scene_id] = [] + scene_groups[scene_id].append(idx) + + for scene_id, indices in scene_groups.items(): + # Check if indices are consecutive + sorted_indices = sorted(indices) + is_consecutive = all(sorted_indices[i] == sorted_indices[i-1] + 1 + for i in range(1, len(sorted_indices))) + print(f" Scene {scene_id}: indices={sorted_indices}, consecutive={is_consecutive}") + + if __name__ == '__main__': run_test_batch_sampler_without_repeat() run_test_batch_sampler_with_repeat() + run_test_one_scene_with_range_sampler() diff --git a/tools/inference.py b/tools/inference.py index 142585a..a51ec0c 100755 --- a/tools/inference.py +++ b/tools/inference.py @@ -25,7 +25,7 @@ from mot_jepa.config_parser import GlobalConfig from mot_jepa.datasets.dataset import dataset_index_factory from mot_jepa.datasets.dataset.common.data import VideoClipData, VideoClipPart -from mot_jepa.datasets.dataset.feature_extractor.pred_bbox_feature_extractor import PredictionBBoxFeatureExtractor +from mot_jepa.datasets.dataset.feature_extractor.pred_bbox_feature_extractor import PredictionBBoxFeatureExtractor, SupportedFeatures from mot_jepa.datasets.dataset.motrack import MotrackDatasetWrapper from mot_jepa.datasets.dataset.transform import Transform from mot_jepa.utils import pipeline @@ -62,7 +62,7 @@ def __init__( self._model.to(device) self._model.eval() - self._feature_names = self._model.feature_names + self._feature_names = set([SupportedFeatures(feature_name) for feature_name in self._model.feature_names]) self._extra_features_reader = extra_features_reader self._sim_threshold = sim_threshold @@ -182,7 +182,7 @@ def _association( track_mm_features = F.normalize(track_mm_features, dim=-1).cpu() det_mm_features = F.normalize(det_mm_features, dim=-1).cpu() - ALPHA = 0.65 + ALPHA = 0.0 # 0.65 if ALPHA > 0.0: ema_track_mm_features = torch.zeros_like(track_mm_features) for t_i in range(track_mm_features.shape[0]): From 03207e3e60008a22593e4e00e80ae7d83b0915bc Mon Sep 17 00:00:00 2001 From: Robotmurlock Date: Tue, 4 Nov 2025 21:51:55 +0100 Subject: [PATCH 2/2] feat: Implement TDSP - DanceTrack SOTA --- configs/appearance.yaml | 6 +- configs/{bbox_only.yaml => bbox.yaml} | 6 +- configs/dataset/augmentations/default.yaml | 2 +- configs/dataset/dancetrack.yaml | 6 +- configs/dataset/dancetrack_appearance.yaml | 1 + .../feature_extractor/pred_appearance.yaml | 5 +- configs/dataset/sampler/scene_sampler.yaml | 4 +- configs/default.yaml | 21 +- configs/keypoints.yaml | 6 +- configs/model_config/mm_appearance.yaml | 2 +- configs/model_config/mm_bboxes.yaml | 2 +- configs/model_config/mm_bboxes_keypoints.yaml | 2 +- .../mm_bboxes_keypoints_appearance.yaml | 8 +- configs/model_config/mm_keypoints.yaml | 2 +- configs/model_config/mm_tdsp_appearance.yaml | 33 +++ configs/model_config/mm_tdsp_bboxes.yaml | 28 ++ .../mm_tdsp_bboxes_keypoints_appearance.yaml | 61 ++++ configs/model_config/mm_tdsp_keypoints.yaml | 28 ++ configs/resources/default.yaml | 6 +- configs/train/bce.yaml | 8 + configs/train/clip.yaml | 6 + configs/train/mm.yaml | 8 +- mot_jepa/architectures/tdcp/core.py | 262 ++++++++++++++++-- .../architectures/tdcp/feature_encoders.py | 20 +- .../tdcp/similarity_prediction.py | 119 ++++++++ mot_jepa/config_parser/core.py | 2 +- .../dataset/augmentations/appearance.py | 4 +- .../feature_extractor/feature_extractor.py | 2 +- .../pred_bbox_feature_extractor.py | 69 ++++- mot_jepa/datasets/dataset/mot.py | 2 + mot_jepa/trainer/losses/__init__.py | 2 + mot_jepa/trainer/losses/bce.py | 249 +++++++++++++++++ mot_jepa/trainer/losses/infonce.py | 2 +- mot_jepa/trainer/trainer.py | 96 +++++-- .../analysis/cameltrack/feature_extraction.py | 94 +++++-- tools/inference.py | 209 ++++---------- tools/train.py | 2 + train_with_pretrained_encoders.sh | 2 +- 38 files changed, 1107 insertions(+), 280 deletions(-) rename configs/{bbox_only.yaml => bbox.yaml} (60%) mode change 100755 => 100644 create mode 100644 configs/model_config/mm_tdsp_appearance.yaml create mode 100644 configs/model_config/mm_tdsp_bboxes.yaml create mode 100644 configs/model_config/mm_tdsp_bboxes_keypoints_appearance.yaml create mode 100644 configs/model_config/mm_tdsp_keypoints.yaml create mode 100644 configs/train/bce.yaml create mode 100644 configs/train/clip.yaml create mode 100644 mot_jepa/architectures/tdcp/similarity_prediction.py create mode 100644 mot_jepa/trainer/losses/bce.py diff --git a/configs/appearance.yaml b/configs/appearance.yaml index ce411a5..dd77916 100755 --- a/configs/appearance.yaml +++ b/configs/appearance.yaml @@ -2,10 +2,10 @@ defaults: - the_global_config - resources: default.yaml - dataset: dancetrack_appearance.yaml - - train: id.yaml + - train: bce.yaml - eval: default.yaml - - model_config: mm_appearance.yaml + - model_config: mm_tdsp_appearance.yaml - path: default.yaml -experiment_name: exp74a-fromExp73-HalvedLrDoubleEpochs +experiment_name: exp107a-fromExp105-ScaleUp dataset_name: DanceTrack diff --git a/configs/bbox_only.yaml b/configs/bbox.yaml old mode 100755 new mode 100644 similarity index 60% rename from configs/bbox_only.yaml rename to configs/bbox.yaml index 1ae2a28..fb90c07 --- a/configs/bbox_only.yaml +++ b/configs/bbox.yaml @@ -2,10 +2,10 @@ defaults: - the_global_config - resources: default.yaml - dataset: dancetrack_bbox.yaml - - train: batch.yaml + - train: bce.yaml - eval: default.yaml - - model_config: mm_bboxes.yaml + - model_config: mm_tdsp_bboxes.yaml - path: default.yaml -experiment_name: exp74b-fromExp73-HalvedLrDoubleEpochs +experiment_name: exp108b-fromExp107-BboxEmbDim512 dataset_name: DanceTrack diff --git a/configs/dataset/augmentations/default.yaml b/configs/dataset/augmentations/default.yaml index c5793b8..d791a40 100644 --- a/configs/dataset/augmentations/default.yaml +++ b/configs/dataset/augmentations/default.yaml @@ -19,7 +19,7 @@ augmentations: - _target_: mot_jepa.datasets.dataset.augmentations.video.IdentitySwitchAugmentation switch_ratio: 0.3 - _target_: mot_jepa.datasets.dataset.augmentations.appearance.AppearanceNoiseAugmentation - alpha: 0.5 + alpha: 0.40 - _target_: mot_jepa.datasets.dataset.augmentations.video.SmartIdentitySwitchAugmentation switch_ratio: 0.5 iou_threshold: 0.5 diff --git a/configs/dataset/dancetrack.yaml b/configs/dataset/dancetrack.yaml index 54fe756..f1fa4ad 100644 --- a/configs/dataset/dancetrack.yaml +++ b/configs/dataset/dancetrack.yaml @@ -2,7 +2,7 @@ defaults: - transform: scaled_bbox_keypoints.yaml - augmentations: default.yaml - feature_extractor: pred_bbox_keypoints_appearance.yaml - - sampler: scene_sampler.yaml + # - sampler: scene_sampler_within_range.yaml index: type: mot @@ -11,10 +11,10 @@ index: - /media/home/DanceTrack-orig/ n_tracks: 40 -clip_length: 30 +clip_length: 50 min_clip_tracks: 1 clip_sampling_step: 1 val_clip_sampling_step: 1 sampler: null -use_batch_sampler: false +use_batch_sampler: false \ No newline at end of file diff --git a/configs/dataset/dancetrack_appearance.yaml b/configs/dataset/dancetrack_appearance.yaml index 3dba934..f034939 100644 --- a/configs/dataset/dancetrack_appearance.yaml +++ b/configs/dataset/dancetrack_appearance.yaml @@ -2,6 +2,7 @@ defaults: - transform: scaled_bbox_keypoints.yaml - augmentations: default.yaml - feature_extractor: pred_appearance.yaml + # - sampler: scene_sampler_within_range.yaml index: type: mot diff --git a/configs/dataset/feature_extractor/pred_appearance.yaml b/configs/dataset/feature_extractor/pred_appearance.yaml index e12878a..691b4f8 100644 --- a/configs/dataset/feature_extractor/pred_appearance.yaml +++ b/configs/dataset/feature_extractor/pred_appearance.yaml @@ -3,4 +3,7 @@ extractor_params: prediction_path: /media/home/cameltrack-states/extracted-features extra_false_positives: true feature_names: - - appearance \ No newline at end of file + - appearance + + random_appearance_jitter_ratio: 0.0 + random_appearance_jitter_range: 0 \ No newline at end of file diff --git a/configs/dataset/sampler/scene_sampler.yaml b/configs/dataset/sampler/scene_sampler.yaml index d096ac9..636c759 100644 --- a/configs/dataset/sampler/scene_sampler.yaml +++ b/configs/dataset/sampler/scene_sampler.yaml @@ -1,3 +1,3 @@ _target_: 'mot_jepa.datasets.dataset.sampler.scene_sampler.SceneBatchSamplerWithRepeat.from_dataset' -n_scenes: 4 -n_frames: 8 \ No newline at end of file +n_scenes: 12 +n_frames: 1 \ No newline at end of file diff --git a/configs/default.yaml b/configs/default.yaml index d81e69e..135eb9f 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -2,10 +2,25 @@ defaults: - the_global_config - resources: default.yaml - dataset: dancetrack.yaml - - train: mm.yaml + - train: bce.yaml - eval: default.yaml - - model_config: mm_bboxes_keypoints_appearance.yaml + - model_config: mm_tdsp_bboxes_keypoints_appearance.yaml - path: default.yaml -experiment_name: exp74-fromExp73-HalvedLrDoubleEpochs +experiment_name: exp111-fromExp110-LinearSumDecoder dataset_name: DanceTrack + +resources: + batch_size: 8 + +train: + max_epochs: 10 + + optimizer_config: + _target_: torch.optim.AdamW + lr: 1e-5 + weight_decay: 1e-3 + + scheduler_config: + _target_: mot_jepa.trainer.scheduler.create_warmup_cosine_annealing_scheduler + n_warmup_epochs: 1 diff --git a/configs/keypoints.yaml b/configs/keypoints.yaml index e00c44c..5fe3176 100755 --- a/configs/keypoints.yaml +++ b/configs/keypoints.yaml @@ -2,10 +2,10 @@ defaults: - the_global_config - resources: default.yaml - dataset: dancetrack_keypoints.yaml - - train: batch.yaml + - train: bce.yaml - eval: default.yaml - - model_config: mm_keypoints.yaml + - model_config: mm_tdsp_keypoints.yaml - path: default.yaml -experiment_name: exp74k-fromExp73-HalvedLrDoubleEpochs +experiment_name: exp108k-fromExp107-BboxEmbDim512 dataset_name: DanceTrack diff --git a/configs/model_config/mm_appearance.yaml b/configs/model_config/mm_appearance.yaml index b80a0af..252e062 100644 --- a/configs/model_config/mm_appearance.yaml +++ b/configs/model_config/mm_appearance.yaml @@ -7,7 +7,7 @@ common_params: track_encoder_n_layers: 2 track_encoder_ffn_dim: 512 projector_intermediate_dim: 512 - interaction_encoder_enable: true + interaction_encoder_enable: false interaction_encoder_n_heads: 8 interaction_encoder_n_layers: 2 interaction_encoder_ffn_dim: 512 diff --git a/configs/model_config/mm_bboxes.yaml b/configs/model_config/mm_bboxes.yaml index b202f4f..952d8a0 100644 --- a/configs/model_config/mm_bboxes.yaml +++ b/configs/model_config/mm_bboxes.yaml @@ -7,7 +7,7 @@ common_params: track_encoder_n_layers: 2 track_encoder_ffn_dim: 512 projector_intermediate_dim: 512 - interaction_encoder_enable: true + interaction_encoder_enable: false interaction_encoder_n_heads: 8 interaction_encoder_n_layers: 2 interaction_encoder_ffn_dim: 512 diff --git a/configs/model_config/mm_bboxes_keypoints.yaml b/configs/model_config/mm_bboxes_keypoints.yaml index dfa5705..d109078 100644 --- a/configs/model_config/mm_bboxes_keypoints.yaml +++ b/configs/model_config/mm_bboxes_keypoints.yaml @@ -6,7 +6,7 @@ common_params: track_encoder_n_layers: 2 track_encoder_ffn_dim: 512 projector_intermediate_dim: 512 - interaction_encoder_enable: true + interaction_encoder_enable: false interaction_encoder_n_heads: 8 interaction_encoder_n_layers: 2 interaction_encoder_ffn_dim: 512 diff --git a/configs/model_config/mm_bboxes_keypoints_appearance.yaml b/configs/model_config/mm_bboxes_keypoints_appearance.yaml index 8d06f8f..22bc787 100644 --- a/configs/model_config/mm_bboxes_keypoints_appearance.yaml +++ b/configs/model_config/mm_bboxes_keypoints_appearance.yaml @@ -7,7 +7,7 @@ common_params: track_encoder_n_layers: 2 track_encoder_ffn_dim: 512 projector_intermediate_dim: 512 - interaction_encoder_enable: true + interaction_encoder_enable: false interaction_encoder_n_heads: 8 interaction_encoder_n_layers: 2 interaction_encoder_ffn_dim: 512 @@ -35,9 +35,9 @@ aggregator_params: dropout: 0.1 per_feature_checkpoint: - bbox: /media/home/MOT-JEPA-outputs/experiments/DanceTrack/exp74b-fromExp73-HalvedLrDoubleEpochs/checkpoints/last.pt - keypoints: /media/home/MOT-JEPA-outputs/experiments/DanceTrack/exp74k-fromExp73-HalvedLrDoubleEpochs/checkpoints/last.pt - appearance: /media/home/MOT-JEPA-outputs/experiments/DanceTrack/exp74a-fromExp73-HalvedLrDoubleEpochs/checkpoints/last.pt + bbox: /media/home/MOT-JEPA-outputs/experiments/DanceTrack/exp93b-fromExp90-ClipLevelBCE/checkpoints/last.pt + keypoints: /media/home/MOT-JEPA-outputs/experiments/DanceTrack/exp93k-fromExp90-ClipLevelBCE/checkpoints/last.pt + appearance: /media/home/MOT-JEPA-outputs/experiments/DanceTrack/exp98a-fromExp97-EmbDim512/checkpoints/last.pt object_interaction_encoder_enable: true object_interaction_encoder_params: diff --git a/configs/model_config/mm_keypoints.yaml b/configs/model_config/mm_keypoints.yaml index 108a255..447772a 100644 --- a/configs/model_config/mm_keypoints.yaml +++ b/configs/model_config/mm_keypoints.yaml @@ -7,7 +7,7 @@ common_params: track_encoder_n_layers: 2 track_encoder_ffn_dim: 512 projector_intermediate_dim: 512 - interaction_encoder_enable: true + interaction_encoder_enable: false interaction_encoder_n_heads: 8 interaction_encoder_n_layers: 2 interaction_encoder_ffn_dim: 512 diff --git a/configs/model_config/mm_tdsp_appearance.yaml b/configs/model_config/mm_tdsp_appearance.yaml new file mode 100644 index 0000000..2ab6bed --- /dev/null +++ b/configs/model_config/mm_tdsp_appearance.yaml @@ -0,0 +1,33 @@ +_target_: mot_jepa.architectures.tdcp.core.build_mm_tdsp_model +mm_dim: 1024 +similarity_prediction_head_hidden_dim: 512 + +sph_common_params: + hidden_dim: 512 +sph_per_feature_params: + appearance: + hidden_dim: 512 + +common_params: + hidden_dim: 256 + dropout: 0.1 + track_encoder_n_heads: 8 + track_encoder_n_layers: 4 + track_encoder_ffn_dim: 512 + projector_intermediate_dim: 512 + interaction_encoder_enable: true + interaction_encoder_n_heads: 8 + interaction_encoder_n_layers: 4 + interaction_encoder_ffn_dim: 512 +per_feature_params: + appearance: + hidden_dim: 512 + feature_encoder_type: parts_appearance + feature_encoder_params: + emb_size: 128 + hidden_dim: 512 + track_encoder_enable_motion_encoder: false + track_encoder_ffn_dim: 1024 + interaction_encoder_ffn_dim: 1024 +aggregator_type: sum +aggregator_params: {} diff --git a/configs/model_config/mm_tdsp_bboxes.yaml b/configs/model_config/mm_tdsp_bboxes.yaml new file mode 100644 index 0000000..6f747e5 --- /dev/null +++ b/configs/model_config/mm_tdsp_bboxes.yaml @@ -0,0 +1,28 @@ +_target_: mot_jepa.architectures.tdcp.core.build_mm_tdsp_model +mm_dim: 1024 +similarity_prediction_head_hidden_dim: 512 + +sph_common_params: + hidden_dim: 512 +sph_per_feature_params: + bbox: + hidden_dim: 512 + +common_params: + hidden_dim: 512 + dropout: 0.1 + track_encoder_n_heads: 8 + track_encoder_n_layers: 4 + track_encoder_ffn_dim: 1024 + projector_intermediate_dim: 512 + interaction_encoder_enable: true + interaction_encoder_n_heads: 8 + interaction_encoder_n_layers: 4 + interaction_encoder_ffn_dim: 1024 +per_feature_params: + bbox: + feature_encoder_type: motion + feature_encoder_params: + input_dim: 5 +aggregator_type: sum +aggregator_params: {} diff --git a/configs/model_config/mm_tdsp_bboxes_keypoints_appearance.yaml b/configs/model_config/mm_tdsp_bboxes_keypoints_appearance.yaml new file mode 100644 index 0000000..f18e535 --- /dev/null +++ b/configs/model_config/mm_tdsp_bboxes_keypoints_appearance.yaml @@ -0,0 +1,61 @@ +_target_: mot_jepa.architectures.tdcp.core.build_mm_tdsp_model +mm_dim: 1024 +similarity_prediction_head_hidden_dim: 1024 + +sph_common_params: + hidden_dim: 512 +sph_per_feature_params: + bbox: + hidden_dim: 512 + keypoints: + hidden_dim: 512 + appearance: + hidden_dim: 512 + +common_params: + hidden_dim: 512 + dropout: 0.1 + track_encoder_n_heads: 8 + track_encoder_n_layers: 4 + track_encoder_ffn_dim: 1024 + projector_intermediate_dim: 512 + interaction_encoder_enable: true + interaction_encoder_n_heads: 8 + interaction_encoder_n_layers: 4 + interaction_encoder_ffn_dim: 1024 +per_feature_params: + bbox: + feature_encoder_type: motion + feature_encoder_params: + input_dim: 5 + keypoints: + feature_encoder_type: motion + feature_encoder_params: + input_dim: 35 + appearance: + hidden_dim: 512 + feature_encoder_type: parts_appearance + feature_encoder_params: + emb_size: 128 + hidden_dim: 512 + track_encoder_enable_motion_encoder: false + track_encoder_ffn_dim: 1024 + interaction_encoder_ffn_dim: 1024 +aggregator_type: sum # transformer +aggregator_params: {} + # hidden_dim: ${model_config.mm_dim} + # n_heads: 8 + # n_layers: 4 + # dropout: 0.1 + +per_feature_checkpoint: + bbox: /media/home/MOT-JEPA-outputs/experiments/DanceTrack/exp108b-fromExp107-BboxEmbDim512/checkpoints/last.pt + keypoints: /media/home/MOT-JEPA-outputs/experiments/DanceTrack/exp108k-fromExp107-BboxEmbDim512/checkpoints/last.pt + appearance: /media/home/MOT-JEPA-outputs/experiments/DanceTrack/exp107a-fromExp105-ScaleUp/checkpoints/last.pt + +object_interaction_encoder_enable: true +object_interaction_encoder_params: + hidden_dim: ${model_config.mm_dim} + n_heads: 8 + n_layers: 2 + dropout: 0.1 diff --git a/configs/model_config/mm_tdsp_keypoints.yaml b/configs/model_config/mm_tdsp_keypoints.yaml new file mode 100644 index 0000000..b4e48fc --- /dev/null +++ b/configs/model_config/mm_tdsp_keypoints.yaml @@ -0,0 +1,28 @@ +_target_: mot_jepa.architectures.tdcp.core.build_mm_tdsp_model +mm_dim: 1024 +similarity_prediction_head_hidden_dim: 512 + +sph_common_params: + hidden_dim: 512 +sph_per_feature_params: + keypoints: + hidden_dim: 512 + +common_params: + hidden_dim: 512 + dropout: 0.1 + track_encoder_n_heads: 8 + track_encoder_n_layers: 4 + track_encoder_ffn_dim: 1024 + projector_intermediate_dim: 512 + interaction_encoder_enable: true + interaction_encoder_n_heads: 8 + interaction_encoder_n_layers: 4 + interaction_encoder_ffn_dim: 1024 +per_feature_params: + keypoints: + feature_encoder_type: motion + feature_encoder_params: + input_dim: 35 +aggregator_type: sum +aggregator_params: {} diff --git a/configs/resources/default.yaml b/configs/resources/default.yaml index f12e0f5..c77cec2 100644 --- a/configs/resources/default.yaml +++ b/configs/resources/default.yaml @@ -1,4 +1,4 @@ -batch_size: 8 -val_batch_size: 4 +batch_size: 32 +val_batch_size: 16 accelerator: 'cuda:0' -num_workers: 12 \ No newline at end of file +num_workers: 20 \ No newline at end of file diff --git a/configs/train/bce.yaml b/configs/train/bce.yaml new file mode 100644 index 0000000..9ad96bb --- /dev/null +++ b/configs/train/bce.yaml @@ -0,0 +1,8 @@ +defaults: + - base.yaml + - _self_ + +loss_config: + _target_: mot_jepa.trainer.losses.bce.ClipLevelBCE + pos_weight: 10.0 + assoc_threshold: 1e-2 diff --git a/configs/train/clip.yaml b/configs/train/clip.yaml new file mode 100644 index 0000000..d536660 --- /dev/null +++ b/configs/train/clip.yaml @@ -0,0 +1,6 @@ +defaults: + - base.yaml + - _self_ + +loss_config: + _target_: mot_jepa.trainer.losses.infonce.ClipLevelInfoNCE \ No newline at end of file diff --git a/configs/train/mm.yaml b/configs/train/mm.yaml index f281bae..1bcb8bd 100644 --- a/configs/train/mm.yaml +++ b/configs/train/mm.yaml @@ -5,14 +5,14 @@ defaults: loss_config: _target_: mot_jepa.trainer.losses.infonce.MultiFeatureLoss mm_loss: - _target_: mot_jepa.trainer.losses.infonce.IDLevelInfoNCE + _target_: mot_jepa.trainer.losses.infonce.ClipLevelInfoNCE per_feature_losses: bbox: - _target_: mot_jepa.trainer.losses.infonce.BatchLevelInfoNCE + _target_: mot_jepa.trainer.losses.infonce.ClipLevelInfoNCE keypoints: - _target_: mot_jepa.trainer.losses.infonce.BatchLevelInfoNCE + _target_: mot_jepa.trainer.losses.infonce.ClipLevelInfoNCE appearance: - _target_: mot_jepa.trainer.losses.infonce.IDLevelInfoNCE + _target_: mot_jepa.trainer.losses.infonce.ClipLevelInfoNCE per_feature_weights: bbox: 0.3 keypoints: 0.3 diff --git a/mot_jepa/architectures/tdcp/core.py b/mot_jepa/architectures/tdcp/core.py index 597d1f1..a2269b8 100644 --- a/mot_jepa/architectures/tdcp/core.py +++ b/mot_jepa/architectures/tdcp/core.py @@ -12,6 +12,7 @@ from mot_jepa.architectures.tdcp.object_interaction_encoder import ObjectInteractionEncoder from mot_jepa.architectures.tdcp.projector import TrackToDetectionProjector from mot_jepa.architectures.tdcp.track_encoder import TrackEncoder +from mot_jepa.architectures.tdcp.similarity_prediction import TDSPMLPHead import logging logger = logging.getLogger('Architecture') @@ -64,7 +65,6 @@ def forward( Returns: Tuple ``(projected_track_features, detection_features)``. """ - det_features = self._static_encoder(det_x) if self._motion_encoder is not None: @@ -98,18 +98,27 @@ def __init__( object_interaction_encoder: Optional[ObjectInteractionEncoder] = None ): super().__init__() + self._mm_dim = mm_dim + self._tdcps = nn.ModuleDict(tdcps) - self._mm_linear_layers = nn.ModuleList([ - nn.Linear(tdcp.output_dim, mm_dim) - for tdcp in tdcps.values() - ]) + self._mm_linear_layers = nn.ModuleDict({ + feature_name: nn.Linear(tdcp.output_dim, mm_dim) + for feature_name, tdcp in tdcps.items() + }) self._aggregator = aggregator self._object_interaction_encoder = object_interaction_encoder + @property + def output_dim(self) -> int: + return self._mm_dim + @property def feature_names(self) -> Set[str]: return set(self._tdcps.keys()) + def get_tdcp(self, feature_name: str) -> TrackDetectionContrastivePrediction: + return self._tdcps[feature_name] + def forward( self, track_features: Dict[str, torch.Tensor], @@ -127,10 +136,10 @@ def forward( det_mask=det_mask, ) - mm_track_features = [lin_layer(mm_feat) for lin_layer, mm_feat in zip(self._mm_linear_layers, list(track_features.values()))] - mm_det_features = [lin_layer(mm_feat) for lin_layer, mm_feat in zip(self._mm_linear_layers, list(det_features.values()))] - agg_track_features = self._aggregator(mm_track_features) - agg_det_features = self._aggregator(mm_det_features) + mm_track_features = {feature_name: self._mm_linear_layers[feature_name](track_features[feature_name]) for feature_name in track_features} + mm_det_features = {feature_name: self._mm_linear_layers[feature_name](det_features[feature_name]) for feature_name in det_features} + agg_track_features = self._aggregator(list(mm_track_features.values())) + agg_det_features = self._aggregator(list(mm_det_features.values())) if self._object_interaction_encoder is not None: agg_track_mask = track_mask.all(dim=-1) @@ -144,6 +153,74 @@ def forward( return agg_track_features, agg_det_features, track_features, det_features +class TrackDetectionSimilarityPrediction(nn.Module): + """Similarity model comparing track and detection embeddings.""" + + def __init__( + self, + tdcp: TrackDetectionContrastivePrediction, + similarity_prediction_head: TDSPMLPHead + ) -> None: + """Args: + tdcp: Temporal encoder for track sequences. + similarity_prediction_head: MLP head for similarity prediction. + """ + super().__init__() + self._tdcp = tdcp + self._similarity_prediction_head = similarity_prediction_head + + def forward( + self, + track_x: torch.Tensor, + track_mask: torch.Tensor, + det_x: torch.Tensor, + det_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + track_features, det_features = self._tdcp(track_x, track_mask, det_x, det_mask) + logits = self._similarity_prediction_head(track_features, det_features) + return logits + + +class MultiModalTDSP(nn.Module): + def __init__( + self, + mm_tdcp: MultiModalTDCP, + sphs: Dict[str, TDSPMLPHead], + mm_sph: TDSPMLPHead + ) -> None: + super().__init__() + self._mm_tdcp = mm_tdcp + self._sphs = nn.ModuleDict(sphs) + self._mm_sph = mm_sph + + @property + def feature_names(self) -> Set[str]: + return set(self._sphs.keys()) + + def forward( + self, + track_features: Dict[str, torch.Tensor], + track_mask: torch.Tensor, + det_features: Dict[str, torch.Tensor], + det_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + assert set(track_features.keys()) == set(det_features.keys()) == set(self._sphs.keys()) + + agg_track_features, agg_det_features, track_features, det_features = self._mm_tdcp( + track_features=track_features, + track_mask=track_mask, + det_features=det_features, + det_mask=det_mask, + ) + + agg_logits = self._mm_sph(agg_track_features, agg_det_features) + sphs_logits = { + key: sph(track_features[key], det_features[key]) + for key, sph in self._sphs.items() + } + + return agg_logits, sphs_logits + def build_tdcp_model( feature_encoder_type: str = 'motion', feature_encoder_params: Dict[str, Any] = None, @@ -226,9 +303,16 @@ def build_mm_tdcp_model( aggregator_params: Dict[str, Any], per_feature_checkpoint: Optional[Dict[str, str]] = None, object_interaction_encoder_enable: bool = False, - object_interaction_encoder_params: Optional[Dict[str, Any]] = None + object_interaction_encoder_params: Optional[Dict[str, Any]] = None, + tdcps_prefix: str = '_tdcps', + mm_linear_layers_prefix: str = '_mm_linear_layers' ) -> MultiModalTDCP: per_feature_checkpoint = per_feature_checkpoint or {} + state_dicts: Dict[str, Dict[str, Any]] = { + feature_name: torch.load(per_feature_checkpoint[feature_name])['model'] + for feature_name in per_feature_params + if feature_name in per_feature_checkpoint + } tdcps: Dict[str, TrackDetectionContrastivePrediction] = {} for feature_name in per_feature_params: @@ -236,12 +320,12 @@ def build_mm_tdcp_model( tdcps[feature_name] = build_tdcp_model(**params) if feature_name in per_feature_checkpoint: logger.info(f'Loading checkpoint for {feature_name} from {per_feature_checkpoint[feature_name]}') - state_dict = torch.load(per_feature_checkpoint[feature_name])['model'] + state_dict = state_dicts[feature_name] state_dict = { - k.replace(f'_tdcps.{feature_name}.', ''): v + k.replace(f'{tdcps_prefix}.{feature_name}.', ''): v for k, v in state_dict.items() + if k.startswith(f'{tdcps_prefix}.{feature_name}.') } - state_dict = {k: v for k, v in state_dict.items() if not k.startswith('_mm_linear_layers.')} tdcps[feature_name].load_state_dict(state_dict) aggregator = tdcp_aggregator_factory( @@ -256,17 +340,133 @@ def build_mm_tdcp_model( else: object_interaction_encoder = None - return MultiModalTDCP( + mm_tdcp = MultiModalTDCP( tdcps=tdcps, aggregator=aggregator, mm_dim=mm_dim, object_interaction_encoder=object_interaction_encoder ) + for feature_name in per_feature_params: + if feature_name in per_feature_checkpoint: + state_dict = state_dicts[feature_name] + state_dict = { + k.replace(f'{mm_linear_layers_prefix}.{feature_name}.', ''): v + for k, v in state_dict.items() + if k.startswith(f'{mm_linear_layers_prefix}.{feature_name}.') + } + mm_tdcp._mm_linear_layers[feature_name].load_state_dict(state_dict) + + return mm_tdcp + + + + +def build_tdsp_model( + feature_encoder_type: str = 'motion', + feature_encoder_params: Dict[str, Any] = None, + hidden_dim: int = 256, + dropout: float = 0.1, + track_encoder_n_heads: int = 8, + track_encoder_n_layers: int = 6, + track_encoder_ffn_dim: int = 512, + track_encoder_enable_motion_encoder: bool = True, + projector_intermediate_dim: int = 512, + interaction_encoder_enable: bool = False, + interaction_encoder_n_heads: int = 8, + interaction_encoder_n_layers: int = 6, + interaction_encoder_ffn_dim: int = 512, + similarity_prediction_head_hidden_dim: int = 256, + tdcps_prefix: str = '_tdcp._tdcps' +) -> TrackDetectionSimilarityPrediction: + tdcp = build_tdcp_model( + feature_encoder_type=feature_encoder_type, + feature_encoder_params=feature_encoder_params, + hidden_dim=hidden_dim, + dropout=dropout, + track_encoder_n_heads=track_encoder_n_heads, + track_encoder_n_layers=track_encoder_n_layers, + track_encoder_ffn_dim=track_encoder_ffn_dim, + track_encoder_enable_motion_encoder=track_encoder_enable_motion_encoder, + projector_intermediate_dim=projector_intermediate_dim, + interaction_encoder_enable=interaction_encoder_enable, + interaction_encoder_n_heads=interaction_encoder_n_heads, + interaction_encoder_n_layers=interaction_encoder_n_layers, + interaction_encoder_ffn_dim=interaction_encoder_ffn_dim, + tdcps_prefix=tdcps_prefix + ) + similarity_prediction_head = TDSPMLPHead( + input_dim=hidden_dim, + hidden_dim=similarity_prediction_head_hidden_dim, + ) + return TrackDetectionSimilarityPrediction( + tdcp=tdcp, + similarity_prediction_head=similarity_prediction_head, + ) + + +def build_mm_tdsp_model( + per_feature_params: Dict[str, Any], + common_params: Dict[str, Any], + sph_per_feature_params: Dict[str, Any], + sph_common_params: Dict[str, Any], + mm_dim: int, + aggregator_type: str, + aggregator_params: Dict[str, Any], + similarity_prediction_head_hidden_dim: int = 256, + object_interaction_encoder_enable: bool = False, + object_interaction_encoder_params: Optional[Dict[str, Any]] = None, + per_feature_checkpoint: Optional[Dict[str, str]] = None, + tdcps_prefix: str = '_mm_tdcp._tdcps', + tdcp_mm_linear_layers_prefix: str = '_mm_tdcp._mm_linear_layers' +) -> MultiModalTDSP: + mm_tdcp = build_mm_tdcp_model( + per_feature_params=per_feature_params, + common_params=common_params, + mm_dim=mm_dim, + aggregator_type=aggregator_type, + aggregator_params=aggregator_params, + object_interaction_encoder_enable=object_interaction_encoder_enable, + object_interaction_encoder_params=object_interaction_encoder_params, + per_feature_checkpoint=per_feature_checkpoint, + tdcps_prefix=tdcps_prefix, + mm_linear_layers_prefix=tdcp_mm_linear_layers_prefix + ) + sphs: Dict[str, TDSPMLPHead] = {} + for feature_name in sph_per_feature_params: + params = tdcp_utils.merge_configs(sph_common_params, sph_per_feature_params[feature_name]) + sphs[feature_name] = TDSPMLPHead( + input_dim=mm_tdcp.get_tdcp(feature_name).output_dim, + **params + ) + if per_feature_checkpoint is not None: + for key in per_feature_checkpoint: + logger.info(f'Loading SPH checkpoint for {key} from {per_feature_checkpoint[key]}') + state_dict = torch.load(per_feature_checkpoint[key])['model'] + state_dict = { + k.replace(f'_sphs.{key}.', ''): v + for k, v in state_dict.items() + if k.startswith(f'_sphs.{key}') + } + sphs[key].load_state_dict(state_dict) + + mm_sph = TDSPMLPHead( + input_dim=mm_tdcp.output_dim, + hidden_dim=similarity_prediction_head_hidden_dim, + ) + return MultiModalTDSP( + mm_tdcp=mm_tdcp, + sphs=sphs, + mm_sph=mm_sph, + ) + -def run_test() -> None: +def run_test_tdcp() -> None: tdcp = build_tdcp_model( - input_dim=4, + feature_encoder_type='motion', + feature_encoder_params={ + 'input_dim': 4 + }, hidden_dim=4, track_encoder_n_heads=2, track_encoder_n_layers=1, @@ -278,10 +478,34 @@ def run_test() -> None: track_mask = torch.zeros(3, 4, 5, dtype=torch.bool) det_features = torch.randn(3, 4, 4) det_mask = torch.zeros(3, 4, dtype=torch.bool) - x_output = tdcp(track_features, track_mask, det_features, det_mask) + track_x, det_x = tdcp(track_features, track_mask, det_features, det_mask) expected_shape = (3, 4, 4) - assert x_output.shape == expected_shape, f'Test failed! Expected shape {expected_shape} but got {x_output.shape}.' + assert track_x.shape == expected_shape, f'Test failed! Expected shape {expected_shape} but got {track_x.shape}.' + assert det_x.shape == expected_shape, f'Test failed! Expected shape {expected_shape} but got {det_x.shape}.' + + +def run_test_tdsp() -> None: + tdsp = build_tdsp_model( + feature_encoder_type='motion', + feature_encoder_params={ + 'input_dim': 4 + }, + hidden_dim=4, + track_encoder_n_heads=2, + track_encoder_n_layers=1, + track_encoder_ffn_dim=8, + projector_intermediate_dim=8 + ) + + track_features = torch.randn(3, 4, 5, 8) + track_mask = torch.zeros(3, 4, 5, dtype=torch.bool) + det_features = torch.randn(3, 4, 4) + det_mask = torch.zeros(3, 4, dtype=torch.bool) + logits = tdsp(track_features, track_mask, det_features, det_mask) + expected_shape = (3, 4, 4, 1) + assert logits.shape == expected_shape, f'Test failed! Expected shape {expected_shape} but got {logits.shape}.' if __name__ == '__main__': - run_test() + run_test_tdcp() + run_test_tdsp() diff --git a/mot_jepa/architectures/tdcp/feature_encoders.py b/mot_jepa/architectures/tdcp/feature_encoders.py index 882d0dd..ca163f8 100644 --- a/mot_jepa/architectures/tdcp/feature_encoders.py +++ b/mot_jepa/architectures/tdcp/feature_encoders.py @@ -42,7 +42,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class PartsAppearanceEncoder(nn.Module): - NUM_PARTS = 5 + NUM_PARTS = 6 def __init__(self, emb_size: int, hidden_dim: int, dropout: float = 0.1): """ @@ -53,14 +53,22 @@ def __init__(self, emb_size: int, hidden_dim: int, dropout: float = 0.1): self._emb_size = emb_size self._dropout = dropout - self._linear_layers = nn.ModuleList([nn.Linear(emb_size, hidden_dim, bias=True)] * (self.NUM_PARTS + 1)) + self._linear_layers = nn.ModuleList([nn.Linear(emb_size, hidden_dim, bias=True)] * self.NUM_PARTS) self._drop = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: - assert x.shape[-1] == (self._emb_size * (self.NUM_PARTS + 1)) - projected = [layer(x[..., i * self._emb_size:(i+1) * self._emb_size]) for i, layer in enumerate(self._linear_layers)] - aggregated = torch.stack(projected, dim=1).sum(dim=1) - return self._drop(aggregated) + P, E = x.shape[-2:] + assert P == self.NUM_PARTS + assert E == (self._emb_size + 1) + embeddings = x[..., :self._emb_size] + visibilities = x[..., self._emb_size] + + embeddings = embeddings * visibilities.unsqueeze(-1) + projected = self._linear_layers[0](embeddings[..., 0, :]) + for i, layer in enumerate(self._linear_layers[1:]): + projected += layer(embeddings[..., i + 1, :]) * visibilities[..., i + 1].unsqueeze(-1) + projected = self._drop(projected) + return projected FEATURE_ENCODER_CATALOG = { diff --git a/mot_jepa/architectures/tdcp/similarity_prediction.py b/mot_jepa/architectures/tdcp/similarity_prediction.py new file mode 100644 index 0000000..c7073ce --- /dev/null +++ b/mot_jepa/architectures/tdcp/similarity_prediction.py @@ -0,0 +1,119 @@ +from torch import nn +import torch +from torch.nn import functional as F + + +def create_pair_embedding( + track_features: torch.Tensor, + det_features: torch.Tensor, +) -> torch.Tensor: + """ + Create pairwise embeddings between all tracks and all detections. + + Args: + track_features: Track embeddings of shape (B, N, E) + det_features: Detection embeddings of shape (B, M, E) + + Returns: + Pair embeddings of shape (B, N, M, 3E) containing [z1, z2, |z1-z2|] + for each track-detection pair + """ + B, N, E = track_features.shape + B, M, E = det_features.shape + + # Expand dimensions for pairwise comparison + # track_features: (B, N, E) -> (B, N, 1, E) + track_expanded = track_features.unsqueeze(2) + # det_features: (B, M, E) -> (B, 1, M, E) + det_expanded = det_features.unsqueeze(1) + + # Broadcast to create all pairwise combinations + # track_broadcasted: (B, N, M, E) + # det_broadcasted: (B, N, M, E) + track_broadcasted = track_expanded.expand(B, N, M, E) + det_broadcasted = det_expanded.expand(B, N, M, E) + + # Calculate absolute difference for each pair + diff_features = torch.abs(track_broadcasted - det_broadcasted) + + # Concatenate along the last dimension: [z1, z2, |z1-z2|] + pair_embeddings = torch.cat([track_broadcasted, det_broadcasted, diff_features], dim=-1) + + return pair_embeddings + + +class TDSPMLPHead(nn.Module): + """MLP head for similarity prediction.""" + + def __init__( + self, + input_dim: int, + hidden_dim: int, + ): + super().__init__() + self._mlp = self._projector = nn.Sequential( + nn.Linear(3 * input_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, 1), + ) + + def forward( + self, + track_features: torch.Tensor, + det_features: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass for similarity prediction. + + Args: + track_features: Track embeddings of shape (B, N, E) + det_features: Detection embeddings of shape (B, M, E) + + Returns: + Similarity scores of shape (B, N, M, 1) + """ + track_features = F.normalize(track_features, dim=-1) + det_features = F.normalize(det_features, dim=-1) + pair_embeddings = create_pair_embedding(track_features, det_features) + B, N, M, E3 = pair_embeddings.shape + pair_embeddings_flat = pair_embeddings.view(B * N * M, E3) + + similarity_scores_flat = self._mlp(pair_embeddings_flat) + similarity_scores = similarity_scores_flat.view(B, N, M) + return similarity_scores + + +def test_pair_embedding(): + """Test function for pair embedding creation.""" + B, N, M, E = 2, 4, 6, 8 + + # Create dummy track and detection features + track_features = torch.randn(B, N, E) + det_features = torch.randn(B, M, E) + + # Create pair embeddings + pair_embeddings = create_pair_embedding(track_features, det_features) + + print(f"Track features shape: {track_features.shape}") + print(f"Detection features shape: {det_features.shape}") + print(f"Pair embeddings shape: {pair_embeddings.shape}") + print(f"Expected shape: ({B}, {N}, {M}, {3*E})") + + # Test MLP head + mlp_head = TDSPMLPHead(input_dim=E, hidden_dim=64) + similarity_scores = mlp_head(track_features, det_features) + + print(f"Similarity scores shape: {similarity_scores.shape}") + print(f"Expected shape: ({B}, {N}, {M}, 1)") + + # Verify pairwise structure + print(f"\nPairwise verification:") + print(f"Number of track-detection pairs: {N} × {M} = {N*M}") + print(f"Each pair embedding dimension: {3*E}") + print(f"Total pair embeddings: {B} × {N} × {M} × {3*E} = {B*N*M*3*E}") + + +if __name__ == '__main__': + test_pair_embedding() + \ No newline at end of file diff --git a/mot_jepa/config_parser/core.py b/mot_jepa/config_parser/core.py index 44b1771..3c7edda 100644 --- a/mot_jepa/config_parser/core.py +++ b/mot_jepa/config_parser/core.py @@ -163,7 +163,7 @@ class EvalConfig: split: str = 'val' checkpoint: Optional[str] = None visualize: bool = False - postprocess: bool = False + postprocess: bool = True @dataclass diff --git a/mot_jepa/datasets/dataset/augmentations/appearance.py b/mot_jepa/datasets/dataset/augmentations/appearance.py index 6fd7340..f202e7b 100644 --- a/mot_jepa/datasets/dataset/augmentations/appearance.py +++ b/mot_jepa/datasets/dataset/augmentations/appearance.py @@ -21,9 +21,9 @@ def apply(self, data: VideoClipData) -> VideoClipData: return data for attr in ['observed', 'unobserved']: - observed_emb = getattr(data, attr).features['appearance'] + observed_emb = getattr(data, attr).features['appearance'][..., :-1] # Do not jitter the visibility observed_emb_std = observed_emb.std(dim=-1, keepdim=True) observed_emb_noise = torch.randn_like(observed_emb) * observed_emb_std - getattr(data, attr).features['appearance'] = observed_emb + self._alpha * observed_emb_noise + getattr(data, attr).features['appearance'][..., :-1] = observed_emb + self._alpha * observed_emb_noise return data \ No newline at end of file diff --git a/mot_jepa/datasets/dataset/feature_extractor/feature_extractor.py b/mot_jepa/datasets/dataset/feature_extractor/feature_extractor.py index 95fdc73..f3af951 100644 --- a/mot_jepa/datasets/dataset/feature_extractor/feature_extractor.py +++ b/mot_jepa/datasets/dataset/feature_extractor/feature_extractor.py @@ -1,6 +1,6 @@ import logging from abc import abstractmethod, ABC -from typing import Dict, Optional +from typing import Dict import torch diff --git a/mot_jepa/datasets/dataset/feature_extractor/pred_bbox_feature_extractor.py b/mot_jepa/datasets/dataset/feature_extractor/pred_bbox_feature_extractor.py index c5479d4..3751b34 100644 --- a/mot_jepa/datasets/dataset/feature_extractor/pred_bbox_feature_extractor.py +++ b/mot_jepa/datasets/dataset/feature_extractor/pred_bbox_feature_extractor.py @@ -19,8 +19,13 @@ class SupportedFeatures(enum.Enum): class PredictionBBoxFeatureExtractor(FeatureExtractor): BBOX_DIM = 5 # XYWHC KEYPOINTS_DIM = 35 # 17x2 XYC (per part) + 1 C (global) = 52 - APPEARANCE_DIM = 768 # 6x128 - SUPPORTED_FEATURES = {SupportedFeatures.BBOX, SupportedFeatures.KEYPOINTS, SupportedFeatures.APPEARANCE} + APPEARANCE_NUM_PARTS = 6 + APPEARANCE_PER_PART_DIM = 129 # 6x129 = 774 + SUPPORTED_FEATURES = { + SupportedFeatures.BBOX, + SupportedFeatures.KEYPOINTS, + SupportedFeatures.APPEARANCE, + } WORKER_ID_STEP = 1_000_000 @@ -40,6 +45,7 @@ def __init__( object_id_mapping=object_id_mapping, n_tracks=n_tracks ) + is_train = (index.split == 'train') feature_names = [SupportedFeatures(feature_name.lower()) for feature_name in feature_names] for feature_name in feature_names: assert feature_name in self.SUPPORTED_FEATURES, \ @@ -50,8 +56,8 @@ def __init__( # Augmentations self._extra_false_positives = extra_false_positives - self._random_appearance_jitter_ratio = random_appearance_jitter_ratio if index.split == 'train' else 0.0 - self._random_appearance_jitter_range = random_appearance_jitter_range if index.split == 'train' else 0 + self._random_appearance_jitter_ratio = random_appearance_jitter_ratio if is_train else 0.0 + self._random_appearance_jitter_range = random_appearance_jitter_range if is_train else 0 self._worker_id_counter = 0 @@ -70,18 +76,28 @@ def bbox_to_tensor(bbox: List[float], score: float) -> torch.Tensor: return torch.tensor([*bbox, score], dtype=torch.float32) @staticmethod - def initialize_features(feature_names: Set[SupportedFeatures], n_tracks: int, temporal_length: int) -> Dict[str, torch.Tensor]: + def initialize_features(feature_names: Set[SupportedFeatures], n_tracks: int, temporal_length: int, is_train: bool = False) -> Dict[str, torch.Tensor]: features: Dict[str, torch.Tensor] = {} if SupportedFeatures.BBOX in feature_names: - features[SupportedFeatures.BBOX.value] = torch.zeros(n_tracks, temporal_length, PredictionBBoxFeatureExtractor.BBOX_DIM, dtype=torch.float32) + features[SupportedFeatures.BBOX.value] = \ + torch.zeros(n_tracks, temporal_length, PredictionBBoxFeatureExtractor.BBOX_DIM, dtype=torch.float32) if SupportedFeatures.KEYPOINTS in feature_names: - features[SupportedFeatures.KEYPOINTS.value] = torch.zeros(n_tracks, temporal_length, PredictionBBoxFeatureExtractor.KEYPOINTS_DIM, dtype=torch.float32) + features[SupportedFeatures.KEYPOINTS.value] = \ + torch.zeros(n_tracks, temporal_length, PredictionBBoxFeatureExtractor.KEYPOINTS_DIM, dtype=torch.float32) if SupportedFeatures.APPEARANCE in feature_names: - features[SupportedFeatures.APPEARANCE.value] = torch.zeros(n_tracks, temporal_length, PredictionBBoxFeatureExtractor.APPEARANCE_DIM, dtype=torch.float32) + features[SupportedFeatures.APPEARANCE.value] = \ + torch.zeros(n_tracks, temporal_length, PredictionBBoxFeatureExtractor.APPEARANCE_NUM_PARTS, PredictionBBoxFeatureExtractor.APPEARANCE_PER_PART_DIM, dtype=torch.float32) + return features @staticmethod - def _set_features(feature_names: Set[SupportedFeatures], features: Dict[str, torch.Tensor], object_index: int, clip_index: int, data: dict) -> None: + def set_features( + feature_names: Set[SupportedFeatures], + features: Dict[str, torch.Tensor], + object_index: int, + clip_index: int, + data: dict + ) -> None: if SupportedFeatures.BBOX in feature_names: bbox = [*data['bbox_xywh'], data['bbox_conf']] features[SupportedFeatures.BBOX.value][object_index, clip_index, :] = torch.tensor(bbox, dtype=torch.float32) @@ -89,8 +105,8 @@ def _set_features(feature_names: Set[SupportedFeatures], features: Dict[str, tor keypoints = sum([d[:2] for d in data['keypoints_xyc']], []) + [data['keypoints_conf']] features[SupportedFeatures.KEYPOINTS.value][object_index, clip_index, :] = torch.tensor(keypoints, dtype=torch.float32) if SupportedFeatures.APPEARANCE in feature_names: - embs = [[e * float(visibility) for e in emb] for emb, visibility in zip(data['appearance_embeddings'], data['appearance_visibility'])] - features[SupportedFeatures.APPEARANCE.value][object_index, clip_index, :] = torch.tensor(sum(embs, []), dtype=torch.float32) + embs = [[*emb, visibility] for emb, visibility in zip(data['appearance_embeddings'], data['appearance_visibility'])] + features[SupportedFeatures.APPEARANCE.value][object_index, clip_index, :] = torch.tensor(embs, dtype=torch.float32) def _extract_extra_data( self, @@ -106,7 +122,12 @@ def _extract_extra_data( n_object_ids = len(object_ids) video_clip_part.mask.fill_(True) # Override mask - features: Dict[str, torch.Tensor] = self.initialize_features(self._feature_names, self._n_tracks, temporal_length) + features: Dict[str, torch.Tensor] = self.initialize_features( + feature_names=self._feature_names, + n_tracks=self._n_tracks, + temporal_length=temporal_length + ) + for clip_index, frame_index in enumerate(range(start_index, end_index)): extra_data = self._extra_features_reader.read(scene_name, frame_index) object_id_to_extra_data_lookup: Dict[str, dict] = {raw['object_id']: raw for raw in extra_data if raw['object_id'] is not None} @@ -116,7 +137,13 @@ def _extract_extra_data( if data is None: continue - self._set_features(self._feature_names, features, object_index, clip_index, data) + self.set_features( + feature_names=self._feature_names, + features=features, + object_index=object_index, + clip_index=clip_index, + data=data + ) video_clip_part.mask[object_index, clip_index] = False # Add extra false positives (specific augmentation type) @@ -127,7 +154,13 @@ def _extract_extra_data( extra_false_positives = extra_false_positives[:n_extra_positives] for data_index, object_index in enumerate(range(n_object_ids, n_object_ids + n_extra_positives)): data = extra_false_positives[data_index] - self._set_features(self._feature_names, features, object_index, clip_index, data) + self.set_features( + feature_names=self._feature_names, + features=features, + object_index=object_index, + clip_index=clip_index, + data=data + ) video_clip_part.mask[object_index, clip_index] = False # Setting ID is complex as it need has to be unique and not match any dataset ID @@ -151,7 +184,13 @@ def _extract_extra_data( r = random.random() if r < self._random_appearance_jitter_ratio: - self._set_features([SupportedFeatures.APPEARANCE], features, object_index, clip_index, data) + self.set_features( + feature_names=[SupportedFeatures.APPEARANCE], + features=features, + object_index=object_index, + clip_index=clip_index, + data=data + ) video_clip_part.features = features return video_clip_part diff --git a/mot_jepa/datasets/dataset/mot.py b/mot_jepa/datasets/dataset/mot.py index 884428c..ba54c4c 100644 --- a/mot_jepa/datasets/dataset/mot.py +++ b/mot_jepa/datasets/dataset/mot.py @@ -15,6 +15,7 @@ import cv2 import numpy as np +from mot_jepa.datasets.dataset.feature_extractor.pred_bbox_feature_extractor import SupportedFeatures from torch.utils.data import Dataset from mot_jepa.datasets.dataset.augmentations import Augmentation @@ -55,6 +56,7 @@ def __init__( clip_sampling_step: Clip sampling step (subsamples dataset) """ self._index = index + self._is_train = (index.split == 'train') # Track parameters max_tracks = index.get_max_tracks() diff --git a/mot_jepa/trainer/losses/__init__.py b/mot_jepa/trainer/losses/__init__.py index 8527d29..9473112 100644 --- a/mot_jepa/trainer/losses/__init__.py +++ b/mot_jepa/trainer/losses/__init__.py @@ -4,10 +4,12 @@ IDLevelInfoNCE, MultiFeatureLoss, ) +from mot_jepa.trainer.losses.bce import ClipLevelBCE __all__ = [ 'ClipLevelInfoNCE', 'BatchLevelInfoNCE', 'IDLevelInfoNCE', 'MultiFeatureLoss', + 'ClipLevelBCE', ] diff --git a/mot_jepa/trainer/losses/bce.py b/mot_jepa/trainer/losses/bce.py new file mode 100644 index 0000000..c00c555 --- /dev/null +++ b/mot_jepa/trainer/losses/bce.py @@ -0,0 +1,249 @@ +from typing import Dict, Optional, List + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def torch_combine(xs: List[torch.Tensor], dtype: torch.dtype) -> torch.Tensor: + """ + Concatenate tensors or return empty tensor if list is empty. + + Args: + xs: List of tensors to concatenate + dtype: Data type for empty tensor when xs is empty + + Returns: + Concatenated tensor or empty tensor with specified dtype + """ + if len(xs) > 0: + return torch.cat(xs) + return torch.empty(0, dtype=dtype) + + +class ClipLevelBCE(nn.Module): + """ + Binary Cross-Entropy loss computed separately for each clip in the batch. + + Performs binary classification between track and detection embeddings using + their differences. Uses BCE loss to classify whether embeddings belong to + the same identity (positive) or different identities (negative). + """ + + def __init__( + self, + pos_weight: Optional[float] = 10.0, + assoc_threshold: float = 0.5, + fp_label_threshold: int = 1_000_000 + ) -> None: + super().__init__() + pos_weight = torch.tensor([pos_weight]) if pos_weight is not None else None + self._loss_func = nn.BCEWithLogitsLoss(reduction='mean', pos_weight=pos_weight) + self._assoc_threshold = assoc_threshold + self._fp_label_threshold = fp_label_threshold + + # State + self._k = 0 + + def forward( + self, + logits: torch.Tensor, + track_mask: torch.Tensor, + detection_mask: torch.Tensor, + track_ids: Optional[torch.Tensor] = None, + det_ids: Optional[torch.Tensor] = None, + logits_dict: Optional[Dict[str, torch.Tensor]] = None, + ) -> Dict[str, torch.Tensor]: + """ + Compute BCE loss using embedding differences for binary classification. + + Args: + logits: Logits (B, N, M, 1) + track_mask: Track mask (B, N, T), True=missing + detection_mask: Detection mask (B, N), True=missing + track_ids: Track identifiers (B, N) - required + det_ids: Detection identifiers (B, N) - required + logits_dict: Dictionary containing modality-specific logits (B, N, M, 1) + + Returns: + Dictionary containing loss and additional debug information + """ + B = logits.shape[0] + _ = logits_dict # unused + + agg_track_mask = track_mask.all(dim=-1) + agg_track_ids = track_ids.max(dim=-1).values + mask = ~agg_track_mask.unsqueeze(-1) | ~detection_mask.unsqueeze(-2) + id_match_mask = (agg_track_ids.unsqueeze(-1) == det_ids.unsqueeze(-2)).float() + loss = self._loss_func(logits[mask], id_match_mask[mask]) + + # Accuracy metrics (clip level like BatchLevelInfoNCE) + filtered_track_labels_list = [] + filtered_det_labels_list = [] + track_predictions_list = [] + det_predictions_list = [] + + probas = torch.sigmoid(logits) + with torch.no_grad(): + for b_i in range(B): + sub_mask = mask[b_i] + if not bool(sub_mask.any().item()): + continue + + sub_track_labels = agg_track_ids[b_i] + sub_det_labels = det_ids[b_i] + sub_probas = probas[b_i] + sub_probas[~sub_mask] = -1 + + if torch.numel(sub_probas) > 0: + track_max_probas, track_max_indices = torch.max(sub_probas, dim=1) + track_max_indices[track_max_probas <= self._assoc_threshold] = -1 + track_max_indices = track_max_indices[~agg_track_mask[b_i]] + track_predictions = sub_track_labels[track_max_indices] + sub_track_labels = sub_track_labels[~agg_track_mask[b_i]] + sub_track_labels[sub_track_labels >= self._fp_label_threshold] = -1 + + det_max_probas, det_max_indices = torch.max(sub_probas, dim=0) + det_max_indices[det_max_probas <= self._assoc_threshold] = -1 + det_max_indices = det_max_indices[~detection_mask[b_i]] + det_predictions = sub_det_labels[det_max_indices] + sub_det_labels = sub_det_labels[~detection_mask[b_i]] + sub_det_labels[sub_det_labels >= self._fp_label_threshold] = -1 + + filtered_track_labels_list.append(sub_track_labels) + filtered_det_labels_list.append(sub_det_labels) + track_predictions_list.append(track_predictions) + det_predictions_list.append(det_predictions) + + self._k += 1 + if self._k % 100 == 0: + print( + f'{probas[mask].mean()=}', + f'{probas[mask].max()=}', + f'{id_match_mask[mask].sum() / id_match_mask[mask].numel()=}', + f'{id_match_mask[mask].numel()=}' + ) + print( + f'{filtered_track_labels_list[0]=}', + f'{filtered_det_labels_list[0]=}', + f'{track_predictions_list[0]=}', + f'{det_predictions_list[0]=}', + end='\n\n' + ) + + filtered_track_labels = torch_combine(filtered_track_labels_list, dtype=torch.long) + filtered_det_labels_list = torch_combine(filtered_det_labels_list, dtype=torch.long) + track_predictions = torch_combine(track_predictions_list, dtype=torch.long) + det_predictions = torch_combine(det_predictions_list, dtype=torch.long) + + return { + 'loss': loss, + 'track_loss': loss, + 'det_loss': loss, + 'track_labels': filtered_track_labels, + 'det_labels': filtered_det_labels_list, + 'track_predictions': track_predictions, + 'det_predictions': det_predictions, + 'track_mask': None, + 'det_mask': None + } + + +class MultiFeatureBCELoss(nn.Module): + """Compose losses over multimodal and modality-specific embeddings.""" + + def __init__( + self, + mm_loss: nn.Module, + per_feature_losses: Optional[Dict[str, nn.Module]] = None, + per_feature_weights: Optional[Dict[str, float]] = None, + ) -> None: + """Initialize the composite loss. + + Args: + mm_loss: Loss applied to the fused multimodal features. + per_feature_losses: Optional mapping from modality name to the + loss used for that modality. + per_feature_weights: Optional mapping providing weights for each + modality-specific loss. Defaults to 1.0 when not provided. + """ + super().__init__() + self._mm_loss = mm_loss + self._per_feature_losses = per_feature_losses or {} + self._per_feature_weights = per_feature_weights or {} + + def forward( + self, + logits: torch.Tensor, + track_mask: torch.Tensor, + detection_mask: torch.Tensor, + track_ids: Optional[torch.Tensor] = None, + det_ids: Optional[torch.Tensor] = None, + logits_dict: Optional[Dict[str, torch.Tensor]] = None, + ) -> Dict[str, torch.Tensor]: + result = self._mm_loss( + logits, + track_mask, + detection_mask, + track_ids=track_ids, + det_ids=det_ids, + logits_dict=logits_dict, + ) + + total_loss = result['loss'] + + if self._per_feature_losses: + if logits_dict is None: + raise ValueError('MultiFeatureBCELoss requires logits dictionary.') + for key, loss_fn in self._per_feature_losses.items(): + if key not in logits_dict: + raise KeyError(f'Missing logits for modality "{key}".') + sub_result = loss_fn( + logits_dict[key], + track_mask, + detection_mask, + track_ids=track_ids, + det_ids=det_ids, + logits_dict={key: logits_dict[key]}, + ) + weight = self._per_feature_weights.get(key, 1.0) + result[f'{key}_loss'] = sub_result['loss'] + total_loss = total_loss + weight * sub_result['loss'] + + result['loss'] = total_loss + return result + + +def run_test() -> None: + """Test function for ClipLevelBCE.""" + B, N, E = 2, 4, 8 + + # Create dummy difference embeddings + diff_x = torch.randn(B, N, E) + + # Create masks: some valid differences, some identity matches + diff_mask = torch.tensor([[1, 1, 0, 1], [1, 0, 1, 1]], dtype=torch.bool) + id_match_mask = torch.tensor([[1, 0, 1, 1], [1, 1, 0, 1]], dtype=torch.bool) + + # Create optional feature dictionary + diff_feature_dict = {'additional_feature': torch.randn(B, N, E)} + + # Test the loss + loss_fn = ClipLevelBCE() + outputs = loss_fn( + diff_x=diff_x, + diff_mask=diff_mask, + diff_feature_dict=diff_feature_dict, + id_match_mask=id_match_mask + ) + + print("BCE Loss outputs:") + for key, value in outputs.items(): + if isinstance(value, torch.Tensor): + print(f"{key}: {value.shape} - {value.item() if value.numel() == 1 else 'tensor'}") + else: + print(f"{key}: {value}") + + +if __name__ == '__main__': + run_test() diff --git a/mot_jepa/trainer/losses/infonce.py b/mot_jepa/trainer/losses/infonce.py index 4dc00e8..3124955 100644 --- a/mot_jepa/trainer/losses/infonce.py +++ b/mot_jepa/trainer/losses/infonce.py @@ -109,7 +109,7 @@ def forward( track_predictions = torch_combine(track_predictions_list, dtype=torch.long) det_predictions = torch_combine(det_predictions_list, dtype=torch.long) - loss = sum(losses) / len(losses) + loss = torch.stack(losses).mean() return { 'loss': loss, 'track_loss': loss, diff --git a/mot_jepa/trainer/trainer.py b/mot_jepa/trainer/trainer.py index 5552746..b821d76 100644 --- a/mot_jepa/trainer/trainer.py +++ b/mot_jepa/trainer/trainer.py @@ -21,8 +21,13 @@ from mot_jepa.common.conventions import LAST_CKPT from mot_jepa.trainer import torch_distrib_utils from mot_jepa.trainer import torch_helper -from mot_jepa.trainer.losses.base import VideoClipLoss from mot_jepa.trainer.metrics import LossDictMeter, AccuracyMeter +from mot_jepa.architectures.tdcp.core import ( + MultiModalTDCP, + TrackDetectionContrastivePrediction, + MultiModalTDSP, + TrackDetectionSimilarityPrediction, +) logger = logging.getLogger('Trainer') @@ -33,7 +38,7 @@ class ContrastiveTrainer: def __init__( self, model: nn.Module, - loss_func: VideoClipLoss, + loss_func: nn.Module, optimizer: Optimizer, scheduler: LRScheduler, n_epochs: int, @@ -42,7 +47,8 @@ def __init__( metric_monitor: str = 'val-epoch/loss', metric_monitor_minimize: bool = True, gradient_clip: Optional[float] = None, - mixed_precision: bool = False + mixed_precision: bool = False, + device: Optional[str] = None ): """ Args: @@ -63,7 +69,7 @@ def __init__( self._rank = int(os.environ.get('RANK', -1)) self._local_rank = int(os.environ.get('LOCAL_RANK', -1)) self._world_size = int(os.environ.get('WORLD_SIZE', 1)) - self._device = f'cuda:{max(0, self._local_rank)}' if torch.cuda.is_available() else 'cpu' + self._device = self._setup_device(self._local_rank, device) # Trainer state self._model = model @@ -93,6 +99,27 @@ def __init__( # Finish self._log_trainer_configuration() + @staticmethod + def _setup_device(local_rank: int, device: Optional[str]) -> str: + """ + Setup device. + + Args: + local_rank: Local rank + device: Device to use + + Returns: + Device + """ + if local_rank == -1: + if device is None: + return f'cuda:0' if torch.cuda.is_available() else 'cpu' + else: + return device + else: + assert torch.cuda.is_available(), 'CUDA is not available!' + return f'cuda:{local_rank}' + @property def model(self) -> Union[nn.Module, DDP]: """ @@ -132,6 +159,7 @@ def _on_start(self) -> None: Pre-process code before a train/eval process. """ self._model.to(self._device) + self._loss_func.to(self._device) torch_helper.optimizer_to(self._optimizer, device=self._device) if self._use_ddp: @@ -231,23 +259,40 @@ def _forward_and_loss(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Te with autocast(enabled=self._mixed_precision): model_output = self._model(track_x, track_mask, det_x, det_mask) - if len(model_output) == 4: - track_features, det_features, track_feat_dict, det_feat_dict = model_output - else: - track_features, det_features = model_output - track_feat_dict = None - det_feat_dict = None - - loss_dict = self._loss_func( - track_features, - det_features, - track_mask, - det_mask, - track_feat_dict, - det_feat_dict, - track_ids, - det_ids - ) + # TODO: Refactor + if isinstance(self._model, (MultiModalTDCP, TrackDetectionContrastivePrediction)): + if isinstance(self._model, MultiModalTDCP): + track_features, det_features, track_feat_dict, det_feat_dict = model_output + else: + track_features, det_features = model_output + track_feat_dict = None + det_feat_dict = None + + loss_dict = self._loss_func( + track_features, + det_features, + track_mask, + det_mask, + track_feat_dict, + det_feat_dict, + track_ids, + det_ids + ) + elif isinstance(self._model, (MultiModalTDSP, TrackDetectionSimilarityPrediction)): + if isinstance(self._model, MultiModalTDSP): + logits, logits_dict = model_output + else: + logits = model_output + logits_dict = None + + loss_dict = self._loss_func( + logits, + track_mask, + det_mask, + track_ids, + det_ids, + logits_dict + ) return loss_dict @@ -395,14 +440,7 @@ def _ddp_setup(self) -> None: local_rank: Optional[int] = int(os.environ.get('LOCAL_RANK', self._local_rank)) if local_rank == -1: - logger.info('Using single-GPU training mode.') - - if torch.cuda.is_available(): - logger.info('Using "cuda:0" (default) for single-GPU training.') - self._local_rank = 0 - else: - logger.info('Using CPU for single-GPU training.') - self._local_rank = 'cpu' + logger.info(f'Using "{self._device}" for single-node training.') else: if not dist.is_initialized(): # Allows `_setup(model)` to be called multiple times with a different (or same) model diff --git a/tools/analysis/cameltrack/feature_extraction.py b/tools/analysis/cameltrack/feature_extraction.py index 69763f3..7ebc073 100644 --- a/tools/analysis/cameltrack/feature_extraction.py +++ b/tools/analysis/cameltrack/feature_extraction.py @@ -4,7 +4,7 @@ import shutil import zipfile from pathlib import Path -from typing import Dict, Any, Tuple, List +from typing import Dict, Any, Tuple, List, Optional import hydra import numpy as np @@ -13,6 +13,7 @@ from motrack.library.cv import BBox from motrack.tracker.matching.utils import hungarian from tqdm import tqdm +from pathlib import Path from mot_jepa.common.project import CONFIGS_PATH from mot_jepa.config_parser import GlobalConfig @@ -26,9 +27,15 @@ class CamelTrackParser: - def __init__(self, states_path: str, temporary_dirpath: str): + def __init__( + self, + states_path: str, + temporary_dirpath: str, + samples_path: Optional[str] = None + ): self._states_path = states_path self._temporary_path = temporary_dirpath + self._samples_path = samples_path # State self._scene_mapping: Dict[str, int] = {} @@ -43,13 +50,19 @@ def scene_mapping(self) -> Dict[str, int]: def scene_files(self) -> Dict[str, Dict[str, str]]: return self._scene_files + @property + def has_samples(self) -> bool: + return self._samples_path is not None + def open(self) -> None: + temporary_states_path = os.path.join(self._temporary_path, 'states') + Path(temporary_states_path).mkdir(parents=True, exist_ok=True) with zipfile.ZipFile(self._states_path, 'r') as zip_ref: - zip_ref.extractall(self._temporary_path) - pickle_filenames = [filename for filename in os.listdir(self._temporary_path) if filename.endswith('.pkl')] - pickle_filepaths = [os.path.join(self._temporary_path, filename) for filename in pickle_filenames] + zip_ref.extractall(temporary_states_path) + pickle_filenames = [filename for filename in os.listdir(temporary_states_path) if filename.endswith('.pkl')] + pickle_filepaths = [os.path.join(temporary_states_path, filename) for filename in pickle_filenames] - # Extract scene mapping + logger.info(f'Extracting scene mapping...') self._scene_mapping.clear() for filename, filepath in zip(pickle_filenames, pickle_filepaths): if not filepath.endswith('_image.pkl'): @@ -64,7 +77,7 @@ def open(self) -> None: self._scene_mapping[scene_name] = video_id self._scene_mapping = dict(sorted(self._scene_mapping.items())) - # Extract scene files + logger.info(f'Extracting scene files...') self._scene_files.clear() reverse_scene_mapping = {v: k for k, v in self._scene_mapping.items()} for filename, filepath in zip(pickle_filenames, pickle_filepaths): @@ -80,6 +93,20 @@ def open(self) -> None: else: raise ValueError(f'Unexpected filename "{filename}"!') + logger.info(f'Extracting track ids...') + temporary_samples_path = os.path.join(self._temporary_path, 'samples') + Path(temporary_samples_path).mkdir(parents=True, exist_ok=True) + with zipfile.ZipFile(self._samples_path, 'r') as zip_ref: + zip_ref.extractall(temporary_samples_path) + pickle_filenames = [filename for filename in os.listdir(temporary_samples_path) if filename.endswith('.pkl')] + pickle_filepaths = [os.path.join(temporary_samples_path, filename) for filename in pickle_filenames] + + for filename, filepath in zip(pickle_filenames, pickle_filepaths): + video_id = int(filename.replace('.pkl', '').replace('sample_', '')) + scene_name = reverse_scene_mapping[video_id] + self._scene_files[scene_name]['samples'] = filepath + + def close(self): if os.path.exists(self._temporary_path): shutil.rmtree(self._temporary_path) @@ -91,30 +118,55 @@ def __enter__(self) -> 'CamelTrackParser': def __exit__(self, exc_type, exc_val, exc_tb): self.close() - def get_scene_dfs(self, scene: str) -> Tuple[pd.DataFrame, pd.DataFrame]: + def get_scene_dfs(self, scene: str) -> Dict[str, pd.DataFrame]: if scene not in self._cache: - with open(self._scene_files[scene]['features'], 'rb') as f: + scene_files = self._scene_files[scene] + scene_data: Dict[str, pd.DataFrame] = {} + + with open(scene_files['features'], 'rb') as f: df_features = pickle.load(f) - with open(self._scene_files[scene]['image'], 'rb') as f: + scene_data['features'] = df_features + with open(scene_files['image'], 'rb') as f: df_image = pickle.load(f) - self._cache[scene] = (df_features, df_image) + scene_data['image'] = df_image + if 'samples' in scene_files: + with open(scene_files['samples'], 'rb') as f: + df_samples = pickle.load(f) + scene_data['samples'] = df_samples + + self._cache[scene] = scene_data return self._cache[scene] def get(self, scene: str, frame_index: int) -> List[dict]: - df_features, df_image = self.get_scene_dfs(scene) + dfs = self.get_scene_dfs(scene) + df_image = dfs['image'] + + if 'samples' in dfs: + df_data = dfs['samples'] + has_samples = True + else: + df_data = dfs['features'] + has_samples = False + image_id = int(df_image[df_image.frame == frame_index].id.iloc[0]) - df_frame = df_features[df_features.image_id == image_id] + df_data = df_data[df_data.image_id == image_id] result = [] - for _, row in df_frame.iterrows(): - result.append({ + for _, row in df_data.iterrows(): + detection_data = { 'bbox_xywh': row.bbox_ltwh.tolist(), 'bbox_conf': row.bbox_conf, 'keypoints_xyc': row.keypoints_xyc.tolist(), 'keypoints_conf': row.keypoints_conf, 'appearance_embeddings': row.embeddings.tolist(), 'appearance_visibility': row.visibility_scores.tolist() - }) + } + + if has_samples: + detection_data['object_id'] = f'{scene}_{int(row.person_id)}' + detection_data['occlusions'] = [(df_data.loc[idx].person_id, iou) for idx, iou in row.occlusions] + + result.append(detection_data) return result @@ -172,11 +224,12 @@ def add_track_ids(pred_frame_data: List[dict], gt_frame_data: List[FrameObjectDa @pipeline.task('cameltrack-features-extraction') def main(cfg: GlobalConfig) -> None: # Hardcoded stuff - SPLIT = 'test' + SPLIT = 'train' is_test = (SPLIT == 'test') CAMELTRACK_STATES_PATH = f'/media/home/cameltrack-states/dancetrack-{SPLIT}.pklz' - TEMPORARY_DIRPATH = '/media/home/cameltrack-states/tmp' - EXTRACTED_OUTPUT_PATH = '/media/home/cameltrack-states/extracted-features' + CAMELTRACK_SAMPLES_PATH = f'/media/home/data/DanceTrack/states/camel_training/camel_{SPLIT}.pklz' + TEMPORARY_DIRPATH = '/media/home/cameltrack-states/extraction-tmp' + EXTRACTED_OUTPUT_PATH = '/media/home/cameltrack-states/extracted-features-v2' dataset_index = dataset_index_factory( name=cfg.dataset.index.type, @@ -188,6 +241,7 @@ def main(cfg: GlobalConfig) -> None: n_total_matches, n_total_unmatches = 0, 0 with CamelTrackParser( states_path=CAMELTRACK_STATES_PATH, + samples_path=CAMELTRACK_SAMPLES_PATH, temporary_dirpath=TEMPORARY_DIRPATH ) as parser: features_writer = ExtraFeaturesWriter(EXTRACTED_OUTPUT_PATH) @@ -198,7 +252,7 @@ def main(cfg: GlobalConfig) -> None: pred_frame_data = parser.get(scene_name, frame_index) pred_frame_data = postprocess_data(scene_info, pred_frame_data) - if not is_test: + if not is_test or not parser.has_samples: object_ids = dataset_index.get_objects_present_in_scene_at_frame(scene_name, frame_index) gt_frame_data = [dataset_index.get_object_data_label_by_frame_index(object_id, frame_index) for object_id in object_ids] pred_frame_data, n_matches, n_unmatches = add_track_ids(pred_frame_data, gt_frame_data) diff --git a/tools/inference.py b/tools/inference.py index a51ec0c..89bb23f 100755 --- a/tools/inference.py +++ b/tools/inference.py @@ -19,7 +19,7 @@ from torch.nn import functional as F from tqdm import tqdm -from mot_jepa.architectures.tdcp.core import MultiModalTDCP +from mot_jepa.architectures.tdcp.core import MultiModalTDCP, MultiModalTDSP from mot_jepa.common import conventions from mot_jepa.common.project import CONFIGS_PATH from mot_jepa.config_parser import GlobalConfig @@ -45,6 +45,7 @@ def __init__( model: nn.Module, extra_features_reader: ExtraFeaturesReader, device: str, + detection_threshold: float = 0.4, sim_threshold: float = 0.5, initialization_threshold: int = 1, remember_threshold: int = 30, @@ -58,13 +59,14 @@ def __init__( self._transform = transform - self._model: MultiModalTDCP = model + self._model: MultiModalTDCP | MultiModalTDSP = model self._model.to(device) self._model.eval() self._feature_names = set([SupportedFeatures(feature_name) for feature_name in self._model.feature_names]) self._extra_features_reader = extra_features_reader + self._detection_threshold = detection_threshold self._sim_threshold = sim_threshold self._initialization_threshold = initialization_threshold @@ -105,7 +107,7 @@ def _convert_data( if relative_index < 0: continue - PredictionBBoxFeatureExtractor._set_features( + PredictionBBoxFeatureExtractor.set_features( feature_names=self._feature_names, features=observed_features, object_index=t_i, @@ -128,7 +130,7 @@ def _convert_data( unobserved_temporal_mask[:n_detections] = False for d_i, data in enumerate(objects_data): - PredictionBBoxFeatureExtractor._set_features( + PredictionBBoxFeatureExtractor.set_features( feature_names=self._feature_names, features=unobserved_features, object_index=d_i, @@ -171,34 +173,49 @@ def _association( data = self._convert_data(tracklets, objects_data, frame_index) data = self._transform(data) data.apply(lambda x: x.unsqueeze(0).to(self._device)) - track_mm_features, det_mm_features, _, _ = self._model( - data.observed.features, - data.observed.mask, - data.unobserved.features, - data.unobserved.mask - ) - track_mm_features = track_mm_features[0][:n_tracks] - det_mm_features = det_mm_features[0][:n_detections] - track_mm_features = F.normalize(track_mm_features, dim=-1).cpu() - det_mm_features = F.normalize(det_mm_features, dim=-1).cpu() - - ALPHA = 0.0 # 0.65 - if ALPHA > 0.0: - ema_track_mm_features = torch.zeros_like(track_mm_features) - for t_i in range(track_mm_features.shape[0]): - tracklet = tracklets[t_i] - track_ema = tracklet.get('track_ema') - if track_ema is None: - ema_track_mm_features[t_i] = track_mm_features[t_i] - else: - ema_track_mm_features[t_i] = ALPHA * track_ema + (1 - ALPHA) * track_mm_features[t_i] - ema_track_mm_features[t_i] = F.normalize(ema_track_mm_features[t_i], dim=-1) - tracklet.set('track_ema', ema_track_mm_features[t_i]) + + if isinstance(self._model, MultiModalTDCP): + track_mm_features, det_mm_features, _, _ = self._model( + data.observed.features, + data.observed.mask, + data.unobserved.features, + data.unobserved.mask + ) + track_mm_features = track_mm_features[0][:n_tracks] + det_mm_features = det_mm_features[0][:n_detections] + track_mm_features = F.normalize(track_mm_features, dim=-1).cpu() + det_mm_features = F.normalize(det_mm_features, dim=-1).cpu() + + ALPHA = 0.0 + if ALPHA > 0.0: + ema_track_mm_features = torch.zeros_like(track_mm_features) + for t_i in range(track_mm_features.shape[0]): + tracklet = tracklets[t_i] + track_ema = tracklet.get('track_ema') + if track_ema is None: + ema_track_mm_features[t_i] = track_mm_features[t_i] + else: + ema_track_mm_features[t_i] = ALPHA * track_ema + (1 - ALPHA) * track_mm_features[t_i] + ema_track_mm_features[t_i] = F.normalize(ema_track_mm_features[t_i], dim=-1) + tracklet.set('track_ema', ema_track_mm_features[t_i]) + else: + ema_track_mm_features = track_mm_features + + cost_matrix = (ema_track_mm_features @ det_mm_features.T).numpy() + cost_matrix = 1 - (cost_matrix + 1) / 2 # [-1, 1] -> [0, 1] + elif isinstance(self._model, MultiModalTDSP): + logits, _ = self._model( + data.observed.features, + data.observed.mask, + data.unobserved.features, + data.unobserved.mask + ) + probas = torch.sigmoid(logits).cpu().numpy() + cost_matrix = 1 - probas[0, :n_tracks, :n_detections] + else: - ema_track_mm_features = track_mm_features + raise ValueError(f'Unsupported model type: {type(self._model)}') - cost_matrix = (ema_track_mm_features @ det_mm_features.T).numpy() - cost_matrix = 1 - (cost_matrix + 1) / 2 # [-1, 1] -> [0, 1] cost_matrix[cost_matrix > sim_threshold] = np.inf return hungarian(cost_matrix) @@ -212,7 +229,7 @@ def track(self, _, _ = frame, detections # Ignored (for now) scene_name = self.get_scene() objects_data = self._extra_features_reader.read(scene_name, frame_index) - objects_data = [data for data in objects_data if data['bbox_conf'] > 0.4] + objects_data = [data for data in objects_data if data['bbox_conf'] > self._detection_threshold] detections = [PredBBox.create(BBox.from_xywh(*data['bbox_xywh']), label='pedestrian', conf=data['bbox_conf']) for data in objects_data] # Remove deleted @@ -263,124 +280,6 @@ def track(self, return tracklets -# class MyByteTracker(MyTracker): -# def __init__( -# self, -# transform: Transform, -# model: nn.Module, -# device: str, -# sim_threshold: float = 0.5, -# initialization_threshold: int = 1, -# remember_threshold: int = 30, -# clip_length: Optional[int] = None, -# new_tracklet_detection_threshold: float = 0.7, -# use_conf: bool = True, -# detection_threshold: float = 0.6 -# ): -# super().__init__( -# transform=transform, -# model=model, -# device=device, -# sim_threshold=sim_threshold, -# initialization_threshold=initialization_threshold, -# remember_threshold=remember_threshold, -# clip_length=clip_length, -# new_tracklet_detection_threshold=new_tracklet_detection_threshold, -# use_conf=use_conf -# ) -# self._detection_threshold = detection_threshold -# -# def track(self, -# tracklets: List[Tracklet], -# detections: List[PredBBox], -# frame_index: int, -# frame: Optional[np.ndarray] = None -# ) -> List[Tracklet]: -# tracklets = [t for t in tracklets if t.state != TrackletState.DELETED] -# -# # (1) Split detections into low and high -# high_detections = [d for d in detections if d.conf >= self._detection_threshold] -# high_det_indices = [i for i, d in enumerate(detections) if d.conf >= self._detection_threshold] -# low_detections = [d for d in detections if d.conf < self._detection_threshold] -# low_det_indices = [i for i, d in enumerate(detections) if d.conf < self._detection_threshold] -# -# # (2) Match high detections with tracklets with states ACTIVE and LOST using HighMatchAlgorithm -# tracklets_active_and_lost_indices, tracklets_active_and_lost = \ -# unpack_n([(i, t) for i, t in enumerate(tracklets) if t.is_tracked], n=2) -# high_matches, remaining_tracklet_indices, high_unmatched_detections_indices = \ -# self._association(tracklets_active_and_lost, high_detections, frame_index) -# high_matches = [(tracklets_active_and_lost_indices[t_i], high_det_indices[d_i]) for t_i, d_i in high_matches] -# high_unmatched_detections_indices = [high_det_indices[d_i] for d_i in high_unmatched_detections_indices] -# remaining_tracklets = [tracklets_active_and_lost[t_i] for t_i in remaining_tracklet_indices] -# remaining_tracklet_indices = [tracklets_active_and_lost_indices[t_i] for t_i in remaining_tracklet_indices] -# -# # (3) Match remaining ACTIVE tracklets with low detections using LowMatchAlgorithm -# remaining_active_tracklet_indices, remaining_active_tracklets = \ -# unpack_n([(i, t) for i, t in zip(remaining_tracklet_indices, remaining_tracklets) -# if t.state == TrackletState.ACTIVE], n=2) -# remaining_lost_tracklet_indices = \ -# [i for i, t in enumerate(tracklets) if t.state == TrackletState.LOST and i in remaining_tracklet_indices] -# -# low_matches, low_unmatched_tracklet_indices, _ = \ -# self._association(remaining_active_tracklets, low_detections, frame_index) -# low_matches = [(remaining_active_tracklet_indices[t_i], low_det_indices[d_i]) for t_i, d_i in low_matches] -# unmatched_tracklet_indices = [remaining_active_tracklet_indices[t_i] for t_i in low_unmatched_tracklet_indices] + \ -# remaining_lost_tracklet_indices -# -# # (5) Match NEW tracklets with high detections using NewMatchAlgorithm -# remaining_high_detections = [detections[d_i] for d_i in high_unmatched_detections_indices] -# remaining_high_detection_indices = high_unmatched_detections_indices -# tracklets_new_indices, tracklets_new = \ -# unpack_n([(i, t) for i, t in enumerate(tracklets) if t.state == TrackletState.NEW], n=2) -# new_matches, new_unmatched_tracklets_indices, new_unmatched_detections_indices = \ -# self._association(tracklets_new, remaining_high_detections, frame_index) -# new_matches = [(tracklets_new_indices[t_i], high_unmatched_detections_indices[d_i]) for t_i, d_i in new_matches] -# new_unmatched_tracklets_indices = [tracklets_new_indices[t_i] for t_i in new_unmatched_tracklets_indices] -# new_unmatched_detections_indices = [remaining_high_detection_indices[d_i] for d_i in new_unmatched_detections_indices] -# -# # (6) Initialize new tracklets from unmatched high detections -# new_tracklets: List[Tracklet] = [] -# for d_i in new_unmatched_detections_indices: -# detection = detections[d_i] -# -# if self._new_tracklet_detection_threshold is not None and detection.conf < self._new_tracklet_detection_threshold: -# continue -# -# new_tracklet = Tracklet( -# bbox=copy.deepcopy(detection), -# frame_index=frame_index, -# _id=self._next_id, -# state=TrackletState.NEW if frame_index > self._initialization_threshold else TrackletState.ACTIVE, -# max_history=self._clip_length - 1 -# ) -# self._next_id += 1 -# new_tracklets.append(new_tracklet) -# -# # (7) Update matched tracklets -# all_matches = high_matches + low_matches + new_matches -# for t_i, d_i in all_matches: -# tracklet = tracklets[t_i] -# detection = detections[d_i] -# -# new_state = TrackletState.ACTIVE -# if tracklet.state == TrackletState.NEW and tracklet.total_matches + 1 < self._initialization_threshold: -# new_state = TrackletState.NEW -# tracklet.update(detection, frame_index, state=new_state) -# -# # (8) Delete new unmatched and long-lost tracklets -# # Handle unmatched tracklets -# for t_i in (unmatched_tracklet_indices + new_unmatched_tracklets_indices): -# tracklet = tracklets[t_i] -# -# if tracklet.lost_time > self._remember_threshold or tracklet.state == TrackletState.NEW: -# tracklet.state = TrackletState.DELETED -# else: -# tracklet.state = TrackletState.LOST -# -# tracklets.extend(new_tracklets) -# return tracklets - - @torch.no_grad() @hydra.main(config_path=CONFIGS_PATH, config_name='default', version_base='1.2') @@ -435,13 +334,21 @@ def main(cfg: GlobalConfig) -> None: model=model, extra_features_reader=extra_features_reader, device=cfg.resources.accelerator, - remember_threshold=30, + remember_threshold=50, use_conf=True, - sim_threshold=0.90, + detection_threshold=0.4, + sim_threshold=0.985, initialization_threshold=1, new_tracklet_detection_threshold=0.9 ) + # DanceTrack exp111: + # remember_threshold=50, + # detection_threshold=0.4, + # sim_threshold=0.985, + # initialization_threshold=1, + # new_tracklet_detection_threshold=0.9 + scene_names = dataset_index.scenes for scene_name in tqdm(scene_names): scene_info = dataset_index.get_scene_info(scene_name) diff --git a/tools/train.py b/tools/train.py index 5f65fff..f10cfe4 100644 --- a/tools/train.py +++ b/tools/train.py @@ -116,6 +116,7 @@ def main(cfg: GlobalConfig) -> None: logger.warning(f'Using "{cfg.train.checkpoint_cfg.resume_from}" as starting checkpoint.') state_dict = torch.load(cfg.train.checkpoint_cfg.resume_from) model.load_state_dict(state_dict['model']) + logger.info(f'Model:\n{model}') loss_func = cfg.train.build_loss_func() optimizer = cfg.train.build_optimizer(model.parameters()) @@ -137,6 +138,7 @@ def main(cfg: GlobalConfig) -> None: n_epochs=cfg.train.max_epochs, gradient_clip=cfg.train.gradient_clip, mixed_precision=cfg.train.mixed_precision, + device=cfg.resources.accelerator, tensorboard_log_dirpath=tensorboard_log_dirpath, checkpoints_dirpath=checkpoints_dirpath, diff --git a/train_with_pretrained_encoders.sh b/train_with_pretrained_encoders.sh index 5d94907..da86cff 100755 --- a/train_with_pretrained_encoders.sh +++ b/train_with_pretrained_encoders.sh @@ -1,3 +1,3 @@ -for cfg in appearance bbox_only keypoints default; do +for cfg in appearance bbox keypoints default; do uv run tools/train.py --config-name="$cfg" train.truncate=true done