diff --git a/.gitignore b/.gitignore index c92267f..d695beb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +multirun/ +outputs/ .DS_Store .idea/ .cursor/ diff --git a/configs/appearance.yaml b/configs/appearance.yaml new file mode 100755 index 0000000..ce411a5 --- /dev/null +++ b/configs/appearance.yaml @@ -0,0 +1,11 @@ +defaults: + - the_global_config + - resources: default.yaml + - dataset: dancetrack_appearance.yaml + - train: id.yaml + - eval: default.yaml + - model_config: mm_appearance.yaml + - path: default.yaml + +experiment_name: exp74a-fromExp73-HalvedLrDoubleEpochs +dataset_name: DanceTrack diff --git a/configs/bbox_only.yaml b/configs/bbox_only.yaml new file mode 100755 index 0000000..1ae2a28 --- /dev/null +++ b/configs/bbox_only.yaml @@ -0,0 +1,11 @@ +defaults: + - the_global_config + - resources: default.yaml + - dataset: dancetrack_bbox.yaml + - train: batch.yaml + - eval: default.yaml + - model_config: mm_bboxes.yaml + - path: default.yaml + +experiment_name: exp74b-fromExp73-HalvedLrDoubleEpochs +dataset_name: DanceTrack diff --git a/configs/dataset/augmentations/appearance.yaml b/configs/dataset/augmentations/appearance.yaml new file mode 100644 index 0000000..046d84e --- /dev/null +++ b/configs/dataset/augmentations/appearance.yaml @@ -0,0 +1,10 @@ +_target_: mot_jepa.datasets.dataset.augmentations.base.CompositionAugmentation +augmentations: + - _target_: mot_jepa.datasets.dataset.augmentations.video.PointOcclusionAugmentations + drop_ratio: 0.3 + - _target_: mot_jepa.datasets.dataset.augmentations.video.LeftOrRightOcclusionAugmentations + drop_ratio: 0.2 + - _target_: mot_jepa.datasets.dataset.augmentations.video.IdentitySwitchAugmentation + switch_ratio: 0.3 + - _target_: mot_jepa.datasets.dataset.augmentations.appearance.AppearanceNoiseAugmentation + alpha: 0.5 diff --git a/configs/dataset/augmentations/default.yaml b/configs/dataset/augmentations/default.yaml index ff798dd..c5793b8 100644 --- a/configs/dataset/augmentations/default.yaml +++ b/configs/dataset/augmentations/default.yaml @@ -20,3 +20,7 @@ augmentations: switch_ratio: 0.3 - _target_: mot_jepa.datasets.dataset.augmentations.appearance.AppearanceNoiseAugmentation alpha: 0.5 + - _target_: mot_jepa.datasets.dataset.augmentations.video.SmartIdentitySwitchAugmentation + switch_ratio: 0.5 + iou_threshold: 0.5 + max_switch_ratio: 0.5 diff --git a/configs/dataset/augmentations/smart.yaml b/configs/dataset/augmentations/smart.yaml deleted file mode 100644 index afa7b28..0000000 --- a/configs/dataset/augmentations/smart.yaml +++ /dev/null @@ -1,12 +0,0 @@ -_target_: mot_jepa.datasets.dataset.augmentations.base.CompositionAugmentation -augmentations: - - _target_: mot_jepa.datasets.dataset.augmentations.bbox.BBoxGaussianNoiseAugmentation - proba: 0.5 - sigma: 0.05 - unobs_noise: true - - _target_: mot_jepa.datasets.dataset.augmentations.video.OcclusionAugmentations - drop_ratio: 0.5 - - _target_: mot_jepa.datasets.dataset.augmentations.video.SmartIdentitySwitchAugmentation - switch_ratio: 0.3 - iou_threshold: 0.4 - diff --git a/configs/dataset/dancetrack.yaml b/configs/dataset/dancetrack.yaml index 71b9f17..54fe756 100644 --- a/configs/dataset/dancetrack.yaml +++ b/configs/dataset/dancetrack.yaml @@ -2,6 +2,7 @@ defaults: - transform: scaled_bbox_keypoints.yaml - augmentations: default.yaml - feature_extractor: pred_bbox_keypoints_appearance.yaml + - sampler: scene_sampler.yaml index: type: mot diff --git a/configs/dataset/dancetrack_appearance.yaml b/configs/dataset/dancetrack_appearance.yaml new file mode 100644 index 0000000..3dba934 --- /dev/null +++ b/configs/dataset/dancetrack_appearance.yaml @@ -0,0 +1,19 @@ +defaults: + - transform: scaled_bbox_keypoints.yaml + - augmentations: default.yaml + - feature_extractor: pred_appearance.yaml + +index: + type: mot + params: + paths: + - /media/home/DanceTrack-orig/ + +n_tracks: 40 +clip_length: 50 +min_clip_tracks: 1 +clip_sampling_step: 1 +val_clip_sampling_step: 1 + +sampler: null +use_batch_sampler: false diff --git a/configs/dataset/dancetrack_bbox.yaml b/configs/dataset/dancetrack_bbox.yaml new file mode 100644 index 0000000..2a014d2 --- /dev/null +++ b/configs/dataset/dancetrack_bbox.yaml @@ -0,0 +1,19 @@ +defaults: + - transform: scaled_bbox_keypoints.yaml + - augmentations: default.yaml + - feature_extractor: pred_bbox.yaml + +index: + type: mot + params: + paths: + - /media/home/DanceTrack-orig/ + +n_tracks: 40 +clip_length: 50 +min_clip_tracks: 1 +clip_sampling_step: 1 +val_clip_sampling_step: 1 + +sampler: null +use_batch_sampler: false diff --git a/configs/dataset/dancetrack_keypoints.yaml b/configs/dataset/dancetrack_keypoints.yaml new file mode 100644 index 0000000..b35db3c --- /dev/null +++ b/configs/dataset/dancetrack_keypoints.yaml @@ -0,0 +1,19 @@ +defaults: + - transform: scaled_bbox_keypoints.yaml + - augmentations: default.yaml + - feature_extractor: pred_keypoints.yaml + +index: + type: mot + params: + paths: + - /media/home/DanceTrack-orig/ + +n_tracks: 40 +clip_length: 50 +min_clip_tracks: 1 +clip_sampling_step: 1 +val_clip_sampling_step: 1 + +sampler: null +use_batch_sampler: false diff --git a/configs/dataset/feature_extractor/pred_appearance.yaml b/configs/dataset/feature_extractor/pred_appearance.yaml new file mode 100644 index 0000000..e12878a --- /dev/null +++ b/configs/dataset/feature_extractor/pred_appearance.yaml @@ -0,0 +1,6 @@ +extractor_type: pred_bbox +extractor_params: + prediction_path: /media/home/cameltrack-states/extracted-features + extra_false_positives: true + feature_names: + - appearance \ No newline at end of file diff --git a/configs/dataset/feature_extractor/pred_keypoints.yaml b/configs/dataset/feature_extractor/pred_keypoints.yaml new file mode 100644 index 0000000..e225f30 --- /dev/null +++ b/configs/dataset/feature_extractor/pred_keypoints.yaml @@ -0,0 +1,6 @@ +extractor_type: pred_bbox +extractor_params: + prediction_path: /media/home/cameltrack-states/extracted-features + extra_false_positives: true + feature_names: + - keypoints \ No newline at end of file diff --git a/configs/dataset/sampler/scene_sampler.yaml b/configs/dataset/sampler/scene_sampler.yaml index f0266ba..d096ac9 100644 --- a/configs/dataset/sampler/scene_sampler.yaml +++ b/configs/dataset/sampler/scene_sampler.yaml @@ -1,3 +1,3 @@ _target_: 'mot_jepa.datasets.dataset.sampler.scene_sampler.SceneBatchSamplerWithRepeat.from_dataset' -n_scenes: 3 -n_frames: 2 \ No newline at end of file +n_scenes: 4 +n_frames: 8 \ No newline at end of file diff --git a/configs/dataset/transform/identity.yaml b/configs/dataset/transform/identity.yaml new file mode 100644 index 0000000..3ca948a --- /dev/null +++ b/configs/dataset/transform/identity.yaml @@ -0,0 +1 @@ +_target_: mot_jepa.datasets.dataset.transform.IdentityTransform \ No newline at end of file diff --git a/configs/dataset/transform/scaled_bbox_keypoints_v2.yaml b/configs/dataset/transform/scaled_bbox_keypoints_v2.yaml new file mode 100644 index 0000000..4dc6730 --- /dev/null +++ b/configs/dataset/transform/scaled_bbox_keypoints_v2.yaml @@ -0,0 +1,19 @@ +_target_: mot_jepa.datasets.dataset.transform.ComposeTransform +transforms: + - _target_: mot_jepa.datasets.dataset.transform.BBoxXYWHtoXYXY + keep_wh: true + - _target_: mot_jepa.datasets.dataset.transform.BBoxMinMaxScaling + - _target_: mot_jepa.datasets.dataset.transform.FeatureFODStandardization + coord_mean: + bbox: [0.5, 0.5, 0.5, 0.5, 0.00, 0.00, 0.5] + keypoints: [[35, [0.5]]] + coord_std: + bbox: [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0] + keypoints: [[17, [0.1, 0.1]], 1.0] + fod_mean: + bbox: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0,0.0] + keypoints: [[35, [0.0]]] + fod_std: + bbox: [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 1.0] + keypoints: [[17, [0.05, 0.05]], 1.0] + fod_time_scaled: true \ No newline at end of file diff --git a/configs/default.yaml b/configs/default.yaml index ce33e12..d81e69e 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -2,10 +2,10 @@ defaults: - the_global_config - resources: default.yaml - dataset: dancetrack.yaml - - train: default.yaml + - train: mm.yaml - eval: default.yaml - model_config: mm_bboxes_keypoints_appearance.yaml - path: default.yaml -experiment_name: exp44-fromExp41-BiggerAppearanceHiddenDim +experiment_name: exp74-fromExp73-HalvedLrDoubleEpochs dataset_name: DanceTrack diff --git a/configs/keypoints.yaml b/configs/keypoints.yaml new file mode 100755 index 0000000..e00c44c --- /dev/null +++ b/configs/keypoints.yaml @@ -0,0 +1,11 @@ +defaults: + - the_global_config + - resources: default.yaml + - dataset: dancetrack_keypoints.yaml + - train: batch.yaml + - eval: default.yaml + - model_config: mm_keypoints.yaml + - path: default.yaml + +experiment_name: exp74k-fromExp73-HalvedLrDoubleEpochs +dataset_name: DanceTrack diff --git a/configs/model_config/mm_appearance.yaml b/configs/model_config/mm_appearance.yaml new file mode 100644 index 0000000..b80a0af --- /dev/null +++ b/configs/model_config/mm_appearance.yaml @@ -0,0 +1,23 @@ +_target_: mot_jepa.architectures.tdcp.core.build_mm_tdcp_model +mm_dim: 512 +common_params: + hidden_dim: 256 + dropout: 0.1 + track_encoder_n_heads: 8 + track_encoder_n_layers: 2 + track_encoder_ffn_dim: 512 + projector_intermediate_dim: 512 + interaction_encoder_enable: true + interaction_encoder_n_heads: 8 + interaction_encoder_n_layers: 2 + interaction_encoder_ffn_dim: 512 +per_feature_params: + appearance: + hidden_dim: 512 + feature_encoder_type: parts_appearance + feature_encoder_params: + emb_size: 128 + hidden_dim: 512 + track_encoder_enable_motion_encoder: false +aggregator_type: sum +aggregator_params: {} diff --git a/configs/model_config/mm_bboxes.yaml b/configs/model_config/mm_bboxes.yaml index a79f37c..b202f4f 100644 --- a/configs/model_config/mm_bboxes.yaml +++ b/configs/model_config/mm_bboxes.yaml @@ -1,4 +1,5 @@ _target_: mot_jepa.architectures.tdcp.core.build_mm_tdcp_model +mm_dim: 256 common_params: hidden_dim: 256 dropout: 0.1 @@ -12,6 +13,8 @@ common_params: interaction_encoder_ffn_dim: 512 per_feature_params: bbox: - input_dim: 5 + feature_encoder_type: motion + feature_encoder_params: + input_dim: 5 aggregator_type: sum -aggregator_params: {} \ No newline at end of file +aggregator_params: {} diff --git a/configs/model_config/mm_bboxes_keypoints_appearance.yaml b/configs/model_config/mm_bboxes_keypoints_appearance.yaml index d083d7f..8d06f8f 100644 --- a/configs/model_config/mm_bboxes_keypoints_appearance.yaml +++ b/configs/model_config/mm_bboxes_keypoints_appearance.yaml @@ -27,12 +27,21 @@ per_feature_params: emb_size: 128 hidden_dim: 512 track_encoder_enable_motion_encoder: false -aggregator_type: query +aggregator_type: transformer aggregator_params: hidden_dim: ${model_config.mm_dim} - num_heads: 8 + n_heads: 8 + n_layers: 2 + dropout: 0.1 + +per_feature_checkpoint: + bbox: /media/home/MOT-JEPA-outputs/experiments/DanceTrack/exp74b-fromExp73-HalvedLrDoubleEpochs/checkpoints/last.pt + keypoints: /media/home/MOT-JEPA-outputs/experiments/DanceTrack/exp74k-fromExp73-HalvedLrDoubleEpochs/checkpoints/last.pt + appearance: /media/home/MOT-JEPA-outputs/experiments/DanceTrack/exp74a-fromExp73-HalvedLrDoubleEpochs/checkpoints/last.pt -drop_mm_probas: - bbox: 0.1 - keypoints: 0.1 - appearance: 0.4 +object_interaction_encoder_enable: true +object_interaction_encoder_params: + hidden_dim: ${model_config.mm_dim} + n_heads: 8 + n_layers: 2 + dropout: 0.1 \ No newline at end of file diff --git a/configs/model_config/mm_keypoints.yaml b/configs/model_config/mm_keypoints.yaml new file mode 100644 index 0000000..108a255 --- /dev/null +++ b/configs/model_config/mm_keypoints.yaml @@ -0,0 +1,20 @@ +_target_: mot_jepa.architectures.tdcp.core.build_mm_tdcp_model +mm_dim: 256 +common_params: + hidden_dim: 256 + dropout: 0.1 + track_encoder_n_heads: 8 + track_encoder_n_layers: 2 + track_encoder_ffn_dim: 512 + projector_intermediate_dim: 512 + interaction_encoder_enable: true + interaction_encoder_n_heads: 8 + interaction_encoder_n_layers: 2 + interaction_encoder_ffn_dim: 512 +per_feature_params: + keypoints: + feature_encoder_type: motion + feature_encoder_params: + input_dim: 35 +aggregator_type: sum +aggregator_params: {} \ No newline at end of file diff --git a/configs/resources/default.yaml b/configs/resources/default.yaml index 19828da..f12e0f5 100644 --- a/configs/resources/default.yaml +++ b/configs/resources/default.yaml @@ -1,4 +1,4 @@ -batch_size: 12 -val_batch_size: 6 +batch_size: 8 +val_batch_size: 4 accelerator: 'cuda:0' num_workers: 12 \ No newline at end of file diff --git a/configs/train/base.yaml b/configs/train/base.yaml new file mode 100644 index 0000000..b3530bf --- /dev/null +++ b/configs/train/base.yaml @@ -0,0 +1,15 @@ +max_epochs: 20 +gradient_clip: 1.0 +mixed_precision: true + +loss_config: + _target_: mot_jepa.trainer.losses.infonce.IDLevelInfoNCE + +optimizer_config: + _target_: torch.optim.AdamW + lr: 5e-5 + weight_decay: 1e-2 + +scheduler_config: + _target_: mot_jepa.trainer.scheduler.create_warmup_cosine_annealing_scheduler + n_warmup_epochs: 2 diff --git a/configs/train/batch.yaml b/configs/train/batch.yaml new file mode 100644 index 0000000..aa188ab --- /dev/null +++ b/configs/train/batch.yaml @@ -0,0 +1,6 @@ +defaults: + - base.yaml + - _self_ + +loss_config: + _target_: mot_jepa.trainer.losses.infonce.BatchLevelInfoNCE \ No newline at end of file diff --git a/configs/train/id.yaml b/configs/train/id.yaml new file mode 100644 index 0000000..27580e8 --- /dev/null +++ b/configs/train/id.yaml @@ -0,0 +1,6 @@ +defaults: + - base.yaml + - _self_ + +loss_config: + _target_: mot_jepa.trainer.losses.infonce.IDLevelInfoNCE \ No newline at end of file diff --git a/configs/train/default.yaml b/configs/train/mm.yaml similarity index 53% rename from configs/train/default.yaml rename to configs/train/mm.yaml index 4e6c9dc..f281bae 100644 --- a/configs/train/default.yaml +++ b/configs/train/mm.yaml @@ -1,11 +1,11 @@ -max_epochs: 10 -gradient_clip: null -mixed_precision: true +defaults: + - base.yaml + - _self_ loss_config: _target_: mot_jepa.trainer.losses.infonce.MultiFeatureLoss mm_loss: - _target_: mot_jepa.trainer.losses.infonce.ClipLevelInfoNCE + _target_: mot_jepa.trainer.losses.infonce.IDLevelInfoNCE per_feature_losses: bbox: _target_: mot_jepa.trainer.losses.infonce.BatchLevelInfoNCE @@ -13,11 +13,7 @@ loss_config: _target_: mot_jepa.trainer.losses.infonce.BatchLevelInfoNCE appearance: _target_: mot_jepa.trainer.losses.infonce.IDLevelInfoNCE - -optimizer_config: - _target_: torch.optim.Adam - lr: 1e-4 - -scheduler_config: - _target_: mot_jepa.trainer.scheduler.create_warmup_cosine_annealing_scheduler - n_warmup_epochs: 1 + per_feature_weights: + bbox: 0.3 + keypoints: 0.3 + appearance: 0.3 diff --git a/docker/Dockerfile b/docker/Dockerfile index 5e64c4a..3b2124e 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -22,4 +22,14 @@ RUN pip install uv RUN uv init RUN uv sync +RUN apt-get install -y git python3.10-venv +RUN git clone https://github.com/Megvii-BaseDetection/YOLOX /YOLOX +WORKDIR /YOLOX +RUN python3 -m venv venv +RUN /bin/bash -c "source /YOLOX/venv/bin/activate && pip install -U pip setuptools" +RUN /bin/bash -c "source /YOLOX/venv/bin/activate && pip install torch torchvision torchaudio" +RUN /bin/bash -c "source /YOLOX/venv/bin/activate && pip install --no-build-isolation -v -e ." + +WORKDIR /work + CMD ["bash"] diff --git a/mot_jepa/architectures/tdcp/aggregators.py b/mot_jepa/architectures/tdcp/aggregators.py index 2419e01..c0d0d92 100644 --- a/mot_jepa/architectures/tdcp/aggregators.py +++ b/mot_jepa/architectures/tdcp/aggregators.py @@ -104,12 +104,48 @@ def forward(self, features: Sequence[Tensor]) -> Tensor: return einops.rearrange(pooled, '(b n) e -> b n e', b=B, n=N) +class TDCPTransformer(nn.Module): + def __init__(self, n_features: int, hidden_dim: int, n_heads: int = 4, n_layers: int = 1, dropout: float = 0.0): + super().__init__() + self._type_emb = nn.Parameter(torch.randn(n_features, hidden_dim)) # [M, E] + encoder_layer = nn.TransformerEncoderLayer( + d_model=hidden_dim, nhead=n_heads, + dim_feedforward=2 * hidden_dim, + batch_first=True, + dropout=dropout + ) + self._enc = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=n_layers + ) + self._q_proj = nn.Linear(hidden_dim, hidden_dim) + + self._out_norm = nn.LayerNorm(hidden_dim) + self._out_proj = nn.Linear(hidden_dim, hidden_dim) + + def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor: + # x: list of [B, N, E] -> [B, N, M, E] + x = torch.stack(features, dim=2) + B, N, M, E = x.shape + x = x + self._type_emb.view(1, 1, M, E) # add type token + x = einops.rearrange(x, 'b n m e -> (b n) m e') + h = self._enc(x) # [B*N, M, E] + # build a data-conditioned query: mean-pooled context + q = self._q_proj(h.mean(dim=1, keepdim=True)) # [B*N, 1, E] + attn = torch.softmax((q @ h.transpose(1, 2)) / (E**0.5), dim=-1) # [B*N, 1, M] + pooled = torch.mean((attn @ h), dim=-2) # [B*N, E] + projected = self._out_proj(self._out_norm(pooled)) + return einops.rearrange(projected, '(b n) e -> b n e', b=B, n=N) + + + TDCP_AGGREGATOR_CATALOG = { 'sum': TDCPSumAggregator, 'linear_sum': TDCPLinearSumAggregator, 'static_softmax': TDCPStaticSoftmaxSum, 'attn': TDCPAttnWeightedSum, - 'query': TDCPQueryAttentionPool + 'query': TDCPQueryAttentionPool, + 'transformer': TDCPTransformer } diff --git a/mot_jepa/architectures/tdcp/core.py b/mot_jepa/architectures/tdcp/core.py index 5ac49d1..597d1f1 100644 --- a/mot_jepa/architectures/tdcp/core.py +++ b/mot_jepa/architectures/tdcp/core.py @@ -1,7 +1,6 @@ """Core TDCP model combining track and detection encoders.""" import copy -import random from typing import List, Tuple, Optional, Dict, Any, Set import torch @@ -13,6 +12,9 @@ from mot_jepa.architectures.tdcp.object_interaction_encoder import ObjectInteractionEncoder from mot_jepa.architectures.tdcp.projector import TrackToDetectionProjector from mot_jepa.architectures.tdcp.track_encoder import TrackEncoder +import logging + +logger = logging.getLogger('Architecture') class TrackDetectionContrastivePrediction(nn.Module): @@ -78,7 +80,10 @@ def forward( if self._object_interaction_encoder is not None: agg_track_mask = track_mask.all(dim=-1) projected_features, det_features = self._object_interaction_encoder( - projected_features, agg_track_mask, det_features, det_mask + projected_features, + agg_track_mask, + det_features, + det_mask ) return projected_features, det_features @@ -90,66 +95,27 @@ def __init__( tdcps: Dict[str, TrackDetectionContrastivePrediction], mm_dim: int, aggregator: TDCPAggregator, - drop_mm_probas: Optional[Dict[str, float]] = None + object_interaction_encoder: Optional[ObjectInteractionEncoder] = None ): super().__init__() self._tdcps = nn.ModuleDict(tdcps) - self._mm_linear_layers = nn.ModuleList([nn.Linear(tdcp.output_dim, mm_dim) for tdcp in tdcps.values()]) + self._mm_linear_layers = nn.ModuleList([ + nn.Linear(tdcp.output_dim, mm_dim) + for tdcp in tdcps.values() + ]) self._aggregator = aggregator - self._drop_mm_probas = drop_mm_probas - - if self._drop_mm_probas is not None: - assert set(self._drop_mm_probas.keys()) == set(self._tdcps.keys()), \ - f'drop_mm_probas keys must match tdcps keys. Got {set(self._drop_mm_probas.keys())} and {set(self._tdcps.keys())}' + self._object_interaction_encoder = object_interaction_encoder @property def feature_names(self) -> Set[str]: return set(self._tdcps.keys()) - def _drop_mm_tokens(self, mm_track_features: List[torch.Tensor], mm_det_features: List[torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - """ - Apply token dropping to multimodal features during training. - - Args: - mm_track_features: List of track feature tensors - mm_det_features: List of detection feature tensors - - Returns: - Tuple of (modified_track_features, modified_det_features) - """ - # Skip dropping during evaluation - if not self.training or self._drop_mm_probas is None: - return mm_track_features, mm_det_features - - # Generate drop flags for each modality - feature_names = list(self._tdcps.keys()) - drop_flags: List[bool] = [] - for feature_name in feature_names: - flag = random.random() < self._drop_mm_probas[feature_name] - drop_flags.append(flag) - - # edge-case: If all drop flags are True, return original features - if all(drop_flags): - return mm_track_features, mm_det_features - - # Apply dropping to each modality - filtered_track_features = [] - filtered_det_features = [] - for i, (feature_name, track_feat, det_feat) in enumerate(zip(feature_names, mm_track_features, mm_det_features)): - if not drop_flags[i]: - filtered_track_features.append(track_feat) - filtered_det_features.append(det_feat) - - return filtered_track_features, filtered_det_features - - - def forward( self, track_features: Dict[str, torch.Tensor], track_mask: torch.Tensor, det_features: Dict[str, torch.Tensor], - det_mask: torch.Tensor + det_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: assert set(track_features.keys()) == set(det_features.keys()) == set(self._tdcps.keys()) @@ -158,15 +124,23 @@ def forward( track_x=track_features[key], track_mask=track_mask, det_x=det_features[key], - det_mask=det_mask + det_mask=det_mask, ) mm_track_features = [lin_layer(mm_feat) for lin_layer, mm_feat in zip(self._mm_linear_layers, list(track_features.values()))] mm_det_features = [lin_layer(mm_feat) for lin_layer, mm_feat in zip(self._mm_linear_layers, list(det_features.values()))] - mm_track_features, mm_det_features = self._drop_mm_tokens(mm_track_features, mm_det_features) agg_track_features = self._aggregator(mm_track_features) agg_det_features = self._aggregator(mm_det_features) + if self._object_interaction_encoder is not None: + agg_track_mask = track_mask.all(dim=-1) + agg_track_features, agg_det_features = self._object_interaction_encoder( + agg_track_features, + agg_track_mask, + agg_det_features, + det_mask + ) + return agg_track_features, agg_det_features, track_features, det_features @@ -250,23 +224,43 @@ def build_mm_tdcp_model( mm_dim: int, aggregator_type: str, aggregator_params: Dict[str, Any], - drop_mm_probas: Optional[Dict[str, float]] = None + per_feature_checkpoint: Optional[Dict[str, str]] = None, + object_interaction_encoder_enable: bool = False, + object_interaction_encoder_params: Optional[Dict[str, Any]] = None ) -> MultiModalTDCP: + per_feature_checkpoint = per_feature_checkpoint or {} + tdcps: Dict[str, TrackDetectionContrastivePrediction] = {} for feature_name in per_feature_params: params = tdcp_utils.merge_configs(common_params, per_feature_params[feature_name]) tdcps[feature_name] = build_tdcp_model(**params) + if feature_name in per_feature_checkpoint: + logger.info(f'Loading checkpoint for {feature_name} from {per_feature_checkpoint[feature_name]}') + state_dict = torch.load(per_feature_checkpoint[feature_name])['model'] + state_dict = { + k.replace(f'_tdcps.{feature_name}.', ''): v + for k, v in state_dict.items() + } + state_dict = {k: v for k, v in state_dict.items() if not k.startswith('_mm_linear_layers.')} + tdcps[feature_name].load_state_dict(state_dict) aggregator = tdcp_aggregator_factory( aggregator_type=aggregator_type, aggregator_params=aggregator_params, n_features=len(per_feature_params) ) + + if object_interaction_encoder_enable: + assert object_interaction_encoder_params is not None + object_interaction_encoder = ObjectInteractionEncoder(**object_interaction_encoder_params) + else: + object_interaction_encoder = None + return MultiModalTDCP( tdcps=tdcps, aggregator=aggregator, mm_dim=mm_dim, - drop_mm_probas=drop_mm_probas + object_interaction_encoder=object_interaction_encoder ) diff --git a/mot_jepa/architectures/tdcp/object_interaction_encoder.py b/mot_jepa/architectures/tdcp/object_interaction_encoder.py index 74140c6..f34b0fc 100644 --- a/mot_jepa/architectures/tdcp/object_interaction_encoder.py +++ b/mot_jepa/architectures/tdcp/object_interaction_encoder.py @@ -60,7 +60,6 @@ def forward( Tuple ``(track_features, detection_features)`` with updated representations after self-attention. """ - N = track_x.shape[1] x = torch.cat([track_x, det_x], dim=1) diff --git a/mot_jepa/architectures/tdcp/track_encoder.py b/mot_jepa/architectures/tdcp/track_encoder.py index 64a9085..8b4dc13 100644 --- a/mot_jepa/architectures/tdcp/track_encoder.py +++ b/mot_jepa/architectures/tdcp/track_encoder.py @@ -94,8 +94,7 @@ def forward(self, x: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: Returns: Tensor of shape ``(B, N, D)`` representing encoded tracks. """ - - B, N, T, D = x.shape # batch, objects, temporal, dim + B, N, _, _ = x.shape # batch, objects, temporal, dim # Tokenize x = einops.rearrange(x, 'b n t e -> t (b n) e') diff --git a/mot_jepa/datasets/dataset/augmentations/video.py b/mot_jepa/datasets/dataset/augmentations/video.py index c4d5e2e..559abb2 100644 --- a/mot_jepa/datasets/dataset/augmentations/video.py +++ b/mot_jepa/datasets/dataset/augmentations/video.py @@ -141,6 +141,9 @@ def __init__(self, switch_ratio: float, iou_threshold: float = 0.1, max_switch_r self._max_switch_ratio = max_switch_ratio def apply(self, data: VideoClipData) -> VideoClipData: + if 'bbox' not in data.observed.features: + return data + n_tracks, _ = data.observed.mask.shape candidate_matrix, candidate_pair_matrix = self._compute_switch_candidates(data) @@ -229,17 +232,10 @@ def _get_switchable_timesteps(mask, idx_a, idx_b, candidate_vector): return valid_timesteps.tolist() @staticmethod - def _switch(data, idx_a, idx_b, start_time): - _, clip_length = data.observed.mask.shape - switch_max_len = clip_length - start_time - 1 - assert switch_max_len >= 0 - if switch_max_len == 0: - return - swap_index = start_time + random.randint(1, switch_max_len) - + def _switch(data, idx_a, idx_b, swap_index): for feature_key in data.observed.features: data.observed.features[feature_key][[idx_a, idx_b], swap_index] = data.observed.features[feature_key][[idx_b, idx_a], swap_index] - data.observed.mask[[idx_a, idx_b], swap_index] = data.observed.mask[[idx_b, idx_a], swap_index] + data.observed.mask[[idx_a, idx_b], swap_index] = data.observed.mask[[idx_b, idx_a], swap_index] def test_identity_switch_augmentation(): diff --git a/mot_jepa/datasets/dataset/feature_extractor/feature_extractor.py b/mot_jepa/datasets/dataset/feature_extractor/feature_extractor.py index b2a17fc..95fdc73 100644 --- a/mot_jepa/datasets/dataset/feature_extractor/feature_extractor.py +++ b/mot_jepa/datasets/dataset/feature_extractor/feature_extractor.py @@ -113,13 +113,13 @@ def _extract_common_part( f'Removing at random...') object_ids = object_ids[:self._n_tracks] - ts = torch.arange(start_time, start_time + temporal_length, dtype=torch.long) \ - .unsqueeze(0).repeat(self._n_tracks, 1) + ts = torch.zeros(self._n_tracks, temporal_length, dtype=torch.long) ids = torch.full_like(ts, fill_value=-1) mask = torch.ones(self._n_tracks, temporal_length, dtype=torch.bool) for clip_index, frame_index in enumerate(range(start_index, end_index)): for object_index, object_id in enumerate(object_ids): + ts[object_index, clip_index] = frame_index ids[object_index, clip_index] = self._object_id_mapping[object_id] mask[object_index, clip_index] = False diff --git a/mot_jepa/datasets/dataset/index/mot.py b/mot_jepa/datasets/dataset/index/mot.py index 7c93c7e..659ad75 100644 --- a/mot_jepa/datasets/dataset/index/mot.py +++ b/mot_jepa/datasets/dataset/index/mot.py @@ -62,7 +62,7 @@ def __init__( sequence_list: Optional[List[str]] = None, label_type: LabelType = LabelType.GROUND_TRUTH, skip_corrupted: bool = False, - test: bool = False + test: bool = True ) -> None: """ Args: @@ -335,7 +335,7 @@ def _parse_labels(self, scene_infos: SceneInfoIndex, test: bool = False) \ n_labels = 0 if test: # Return empty labels - return data, n_labels + return data, present_object_ids, n_labels for scene_name, scene_info in tqdm(scene_infos.items(), unit='scene', total=len(scene_infos), desc=f'Indexing'): scene_info = self._scene_info_index[scene_name] diff --git a/mot_jepa/datasets/dataset/transform/bbox.py b/mot_jepa/datasets/dataset/transform/bbox.py index e70b17e..f2e55dc 100644 --- a/mot_jepa/datasets/dataset/transform/bbox.py +++ b/mot_jepa/datasets/dataset/transform/bbox.py @@ -8,12 +8,32 @@ class BBoxXYWHtoXYXY(Transform): - def __init__(self): + def __init__(self, keep_wh: bool = False): super().__init__(name='bbox_xywh_to_xyxy') + self._keep_wh = keep_wh def apply(self, data: VideoClipData) -> VideoClipData: - data.observed.features['bbox'][..., 2:4] = data.observed.features['bbox'][..., :2] + data.observed.features['bbox'][..., 2:4] - data.unobserved.features['bbox'][..., 2:4] = data.unobserved.features['bbox'][..., :2] + data.unobserved.features['bbox'][..., 2:4] + if 'bbox' not in data.observed.features: + return data + + observed_bottom_xy = data.observed.features['bbox'][..., :2] + data.observed.features['bbox'][..., 2:4] + unobserved_bottom_xy = data.unobserved.features['bbox'][..., :2] + data.unobserved.features['bbox'][..., 2:4] + + if not self._keep_wh: + data.observed.features['bbox'][..., 2:4] = observed_bottom_xy + data.unobserved.features['bbox'][..., 2:4] = unobserved_bottom_xy + else: + data.observed.features['bbox'] = torch.cat([ + data.observed.features['bbox'][..., :2], + observed_bottom_xy, + data.observed.features['bbox'][..., 2:] + ], dim=-1) + data.unobserved.features['bbox'] = torch.cat([ + data.unobserved.features['bbox'][..., :2], + unobserved_bottom_xy, + data.unobserved.features['bbox'][..., 2:] + ], dim=-1) + return data @@ -33,6 +53,9 @@ def __init__( def apply(self, data: VideoClipData) -> VideoClipData: for feature_name in self._feature_names: + if feature_name not in data.observed.features: + continue + data.observed.features[feature_name] = \ (data.observed.features[feature_name] - self._coord_mean[feature_name]) / self._coord_std[feature_name] # Centralize data.unobserved.features[feature_name] = \ @@ -65,6 +88,9 @@ def __init__( def apply(self, data: VideoClipData) -> VideoClipData: for feature_name in self._feature_names: + if feature_name not in data.observed.features: + continue + features = data.observed.features[feature_name] mask = data.observed.mask @@ -77,7 +103,8 @@ def apply(self, data: VideoClipData) -> VideoClipData: if features_n.shape[0] == 0: continue - features_n[1:, :] = (features_n[1:, :] - features_n[:-1, :]) / (ts_n[1:, :] - ts_n[:-1, :]) + ts_diff = torch.clamp(ts_n[1:, :] - ts_n[:-1, :], min=1) + features_n[1:, :] = (features_n[1:, :] - features_n[:-1, :]) / (ts_diff) features_n[0, :] = 0 fod[n][~mask[n]] = features_n else: @@ -105,39 +132,68 @@ def __init__(self): super().__init__(name='bbox_min_max_scaling') def apply(self, data: VideoClipData) -> VideoClipData: - # Concatenate all bboxes and masks - all_bboxes = torch.cat([ - data.observed.features['bbox'], # (N, T, 5) - data.unobserved.features['bbox'].unsqueeze(1) # (N, 1, 5) - ], dim=1).view(-1, data.observed.features['bbox'].shape[-1]) # (N * (T + 1), 5) - - all_masks = torch.cat([ - data.observed.mask, # (N, T) - data.unobserved.mask.unsqueeze(1) # (N, 1) - ], dim=1).view(-1) # (N * (T + 1)) + if 'bbox' not in data.observed.features and 'keypoints' not in data.observed.features: + return data + + bboxes_list: List[torch.Tensor] = [] + masks_list: List[torch.Tensor] = [] + + # Collect coordinates and masks + if 'bbox' in data.observed.features: + N, T, _ = data.observed.features['bbox'].shape + bboxes_list.append(data.observed.features['bbox'][:, :, :4].reshape(N, 2 * T, 2)) # (N, 2 * T, 2) + masks_list.append(data.observed.mask.repeat(1, 2)) + bboxes_list.append(data.unobserved.features['bbox'][:, :4].unsqueeze(1).reshape(N, 2 * 1, 2)) # (N, 2, 2) + masks_list.append(data.unobserved.mask.unsqueeze(1).repeat(1, 2)) + if 'keypoints' in data.observed.features: + N, T, D = data.observed.features['keypoints'].shape + assert (D - 1) % 2 == 0 + n_coords = (D - 1) // 2 + bboxes_list.append(data.observed.features['keypoints'][:, :, :-1].reshape(N, T * n_coords, 2)) # (N, T * n_coords, 2) + masks_list.append(data.observed.mask.repeat(1, n_coords)) + bboxes_list.append(data.unobserved.features['keypoints'][:, :-1].unsqueeze(1).reshape(N, n_coords, 2)) # (N, n_coords, 2) + masks_list.append(data.unobserved.mask.unsqueeze(1).repeat(1, n_coords)) - # Invert masks: 0 = valid, 1 = invalid - valid_mask = ~all_masks + # Concatenate all bboxes and masks + all_bboxes = torch.cat(bboxes_list, dim=1) # (N, X, 2) + all_masks = torch.cat(masks_list, dim=1) # (N, X) - valid_bboxes = all_bboxes[valid_mask][:, :4].reshape(-1, 2) + # Compute min and max values + valid_bboxes = all_bboxes[~all_masks].view(-1, 2) min_val = valid_bboxes.min(dim=0).values max_val = valid_bboxes.max(dim=0).values scale = max_val - min_val scale[scale == 0] = 1 # prevent division by zero - # Apply scaling - data.observed.features['bbox'][:, :, :2] = (data.observed.features['bbox'][:, :, :2] - min_val) / scale - data.unobserved.features['bbox'][:, :2] = (data.unobserved.features['bbox'][:, :2] - min_val) / scale - data.observed.features['bbox'][:, :, 2:4] = (data.observed.features['bbox'][:, :, 2:4] - min_val) / scale - data.unobserved.features['bbox'][:, 2:4] = (data.unobserved.features['bbox'][:, 2:4] - min_val) / scale + if 'bbox' in data.observed.features: + # Update bboxes + data.observed.features['bbox'][:, :, :2] = (data.observed.features['bbox'][:, :, :2] - min_val) / scale + data.unobserved.features['bbox'][:, :2] = (data.unobserved.features['bbox'][:, :2] - min_val) / scale + data.observed.features['bbox'][:, :, 2:4] = (data.observed.features['bbox'][:, :, 2:4] - min_val) / scale + data.unobserved.features['bbox'][:, 2:4] = (data.unobserved.features['bbox'][:, 2:4] - min_val) / scale - # Zero-out masked entries - data.observed.features['bbox'][data.observed.mask] = 0 - data.unobserved.features['bbox'][data.unobserved.mask] = 0 + # Zero-out masked entries + data.observed.features['bbox'][data.observed.mask] = 0 + data.unobserved.features['bbox'][data.unobserved.mask] = 0 if 'keypoints' in data.observed.features: - data.observed.features['keypoints'][:, :, :2] = (data.observed.features['keypoints'][:, :, :2] - min_val) / scale - data.unobserved.features['keypoints'][:, :2] = (data.unobserved.features['keypoints'][:, :2] - min_val) / scale + N, T, D = data.observed.features['keypoints'].shape + assert (D - 1) % 2 == 0 + n_coords = (D - 1) // 2 + + # Rescale keypoints + observed_keypoints = data.observed.features['keypoints'][:, :, :-1].reshape(N, T * n_coords, 2) + unobserved_keypoints = data.unobserved.features['keypoints'][:, :-1].unsqueeze(1).reshape(N, n_coords, 2) + observed_keypoints = (observed_keypoints - min_val) / scale + unobserved_keypoints = (unobserved_keypoints - min_val) / scale + + # Update keypoints + data.observed.features['keypoints'][:, :, :-1] = observed_keypoints.reshape(N, T, n_coords * 2) + data.unobserved.features['keypoints'][:, :-1] = unobserved_keypoints.reshape(N, n_coords * 2) + + # Zero-out masked entries + data.observed.features['keypoints'][data.observed.mask] = 0 + data.unobserved.features['keypoints'][data.unobserved.mask] = 0 return data diff --git a/mot_jepa/trainer/losses/infonce.py b/mot_jepa/trainer/losses/infonce.py index fa8cf35..4dc00e8 100644 --- a/mot_jepa/trainer/losses/infonce.py +++ b/mot_jepa/trainer/losses/infonce.py @@ -10,11 +10,11 @@ def torch_combine(xs: List[torch.Tensor], dtype: torch.dtype) -> torch.Tensor: """ Concatenate tensors or return empty tensor if list is empty. - + Args: xs: List of tensors to concatenate dtype: Data type for empty tensor when xs is empty - + Returns: Concatenated tensor or empty tensor with specified dtype """ @@ -26,7 +26,7 @@ def torch_combine(xs: List[torch.Tensor], dtype: torch.dtype) -> torch.Tensor: class ClipLevelInfoNCE(VideoClipLoss): """ InfoNCE loss computed separately for each clip in the batch. - + Applies InfoNCE to tracks and detections within each clip as positive pairs. Uses cosine similarity and averages over non-zero valid pairs per clip. """ @@ -80,8 +80,6 @@ def forward( track_predictions_list = [] det_predictions_list = [] with torch.no_grad(): - track_labels = torch.arange(N).to(track_x).unsqueeze(0).repeat(B, 1).long() - det_labels = torch.arange(N).to(det_x).unsqueeze(0).repeat(B, 1).long() for b_i in range(B): combined_mask = ~agg_track_mask[b_i] & ~detection_mask[b_i] if not bool(combined_mask.any().item()): @@ -97,9 +95,9 @@ def forward( n_sub_tracks = sub_track_x.shape[0] n_sub_det = sub_det_x.shape[0] if n_sub_tracks > 0 and n_sub_det > 0: - distances = torch.cdist(sub_track_x, sub_det_x, p=2) - sub_track_predictions = torch.argmin(distances, dim=1) - sub_det_predictions = torch.argmin(distances, dim=0) + distances = sub_track_x @ sub_det_x.T + sub_track_predictions = sub_track_labels[torch.argmax(distances, dim=1)] + sub_det_predictions = sub_det_labels[torch.argmax(distances, dim=0)] filtered_track_labels_list.append(sub_track_labels) filtered_det_labels_list.append(sub_det_labels) @@ -128,7 +126,7 @@ def forward( class BatchLevelInfoNCE(VideoClipLoss): """ InfoNCE loss computed across the entire batch. - + Allows cross-clip matching by treating all tracks and detections as potential positive pairs. Uses cosine similarity and averages over non-zero valid pairs. """ @@ -197,9 +195,9 @@ def forward( n_sub_tracks = sub_track_x.shape[0] n_sub_det = sub_det_x.shape[0] if n_sub_tracks > 0 and n_sub_det > 0: - distances = torch.cdist(sub_track_x, sub_det_x, p=2) - sub_track_predictions = torch.argmin(distances, dim=1) - sub_det_predictions = torch.argmin(distances, dim=0) + distances = sub_track_x @ sub_det_x.T + sub_track_predictions = sub_track_labels[torch.argmax(distances, dim=1)] + sub_det_predictions = sub_det_labels[torch.argmax(distances, dim=0)] filtered_track_labels_list.append(sub_track_labels) filtered_det_labels_list.append(sub_det_labels) @@ -227,7 +225,7 @@ def forward( class IDLevelInfoNCE(VideoClipLoss): """ InfoNCE loss using object IDs to determine positive pairs. - + Uses explicit identity information rather than spatial/temporal correspondence. Requires track_ids and det_ids parameters. Uses cosine similarity. """ @@ -248,7 +246,7 @@ def forward( ) -> Dict[str, torch.Tensor]: """ Compute InfoNCE loss using object IDs for positive pairs. - + Args: track_x: Track embeddings (B, N, E) det_x: Detection embeddings (B, N, E) @@ -258,10 +256,10 @@ def forward( det_feature_dict: Optional modality-specific detection features track_ids: Track identifiers (B, N) - required det_ids: Detection identifiers (B, N) - required - + Returns: Dictionary with loss, predictions, and evaluation metrics - + Raises: ValueError: If track_ids or det_ids not provided """ @@ -309,9 +307,9 @@ def forward( n_sub_tracks = sub_track_x.shape[0] n_sub_det = sub_det_x.shape[0] if n_sub_tracks > 0 and n_sub_det > 0: - distances = torch.cdist(sub_track_x, sub_det_x, p=2) - sub_track_predictions = torch.argmin(distances, dim=1) - sub_det_predictions = torch.argmin(distances, dim=0) + distances = sub_track_x @ sub_det_x.T + sub_track_predictions = sub_track_labels[torch.argmax(distances, dim=1)] + sub_det_predictions = sub_det_labels[torch.argmax(distances, dim=0)] filtered_track_labels_list.append(sub_track_labels) filtered_det_labels_list.append(sub_det_labels) @@ -412,14 +410,14 @@ def run_test() -> None: track_x = torch.tensor([ [ [0, 1], - [0, 1], - [0, 1], + [1, 0], + [1, 1], [0, 0] ], [ [0, 1], - [0, 1], - [0, 1], + [1, 0], + [1, 1], [0, 0] ] ], dtype=torch.float32) @@ -427,14 +425,14 @@ def run_test() -> None: det_x = torch.tensor([ [ [0, 1], - [0, 1], - [0, 1], + [1, 0], + [1, 1], [0, 0] ], [ [0, 1], - [0, 1], - [0, 1], + [1, 0], + [1, 1], [0, 0] ] ], dtype=torch.float32) diff --git a/mot_jepa/trainer/trainer.py b/mot_jepa/trainer/trainer.py index 1661597..5552746 100644 --- a/mot_jepa/trainer/trainer.py +++ b/mot_jepa/trainer/trainer.py @@ -231,14 +231,14 @@ def _forward_and_loss(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Te with autocast(enabled=self._mixed_precision): model_output = self._model(track_x, track_mask, det_x, det_mask) - if isinstance(model_output, tuple) and len(model_output) == 4: + if len(model_output) == 4: track_features, det_features, track_feat_dict, det_feat_dict = model_output else: track_features, det_features = model_output track_feat_dict = None det_feat_dict = None - return self._loss_func( + loss_dict = self._loss_func( track_features, det_features, track_mask, @@ -249,6 +249,8 @@ def _forward_and_loss(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Te det_ids ) + return loss_dict + def _train_epoch(self, train_loader: 'DataLoader') -> Dict[str, float]: """ Train for one epoch. diff --git a/tools/analysis/bbox_stats.py b/tools/analysis/bbox_stats.py index 4e45277..cdbee57 100644 --- a/tools/analysis/bbox_stats.py +++ b/tools/analysis/bbox_stats.py @@ -43,14 +43,14 @@ def main(cfg: GlobalConfig) -> None: for i in tqdm(range(n_samples), unit='sample', desc='Calculating bbox statistics', total=n_samples): data = train_dataset.get_raw(i) - observed_bboxes = data.observed_bboxes[~data.observed_temporal_mask] - unobserved_bboxes = data.unobserved_bboxes[~data.unobserved_temporal_mask] + observed_bboxes = data.observed.features['bbox'][~data.observed.mask] + unobserved_bboxes = data.unobserved.features['bbox'][~data.unobserved.mask] bboxes = torch.cat([observed_bboxes, unobserved_bboxes], dim=0) - fod = torch.zeros_like(data.observed_bboxes) - fod[:, 1:, :] = data.observed_bboxes[:, 1:, :] - data.observed_bboxes[:, :-1, :] - fod[:, 1:, :] = fod[:, 1:, :] * (1 - data.observed_temporal_mask[:, :-1].unsqueeze(-1).repeat(1, 1, data.observed_bboxes.shape[-1]).float()) - fod = fod[~data.observed_temporal_mask] + fod = torch.zeros_like(data.observed.features['bbox']) + fod[:, 1:, :] = data.observed.features['bbox'][:, 1:, :] - data.observed.features['bbox'][:, :-1, :] + fod[:, 1:, :] = fod[:, 1:, :] * (1 - data.observed.mask[:, :-1].unsqueeze(-1).repeat(1, 1, data.observed.features['bbox'].shape[-1]).float()) + fod = fod[~data.observed.mask] bboxes_sum = bboxes.sum(dim=0) bboxes_sum2 = torch.square(bboxes).sum(dim=0) diff --git a/tools/analysis/cameltrack/feature_extraction.py b/tools/analysis/cameltrack/feature_extraction.py index 0b4dc5f..69763f3 100644 --- a/tools/analysis/cameltrack/feature_extraction.py +++ b/tools/analysis/cameltrack/feature_extraction.py @@ -172,7 +172,8 @@ def add_track_ids(pred_frame_data: List[dict], gt_frame_data: List[FrameObjectDa @pipeline.task('cameltrack-features-extraction') def main(cfg: GlobalConfig) -> None: # Hardcoded stuff - SPLIT = 'val' + SPLIT = 'test' + is_test = (SPLIT == 'test') CAMELTRACK_STATES_PATH = f'/media/home/cameltrack-states/dancetrack-{SPLIT}.pklz' TEMPORARY_DIRPATH = '/media/home/cameltrack-states/tmp' EXTRACTED_OUTPUT_PATH = '/media/home/cameltrack-states/extracted-features' @@ -194,14 +195,18 @@ def main(cfg: GlobalConfig) -> None: for scene_name in tqdm(scenes, desc='Parsing extra features', unit='scene'): scene_info = dataset_index.get_scene_info(scene_name) for frame_index in range(scene_info.seqlength): - object_ids = dataset_index.get_objects_present_in_scene_at_frame(scene_name, frame_index) # TODO: Check if +1 is needed - gt_frame_data = [dataset_index.get_object_data_label_by_frame_index(object_id, frame_index) for object_id in object_ids] # TODO: Check if +1 is needed - pred_frame_data = parser.get(scene_name, frame_index) pred_frame_data = postprocess_data(scene_info, pred_frame_data) - pred_frame_data, n_matches, n_unmatches = add_track_ids(pred_frame_data, gt_frame_data) - n_total_matches += n_matches - n_total_unmatches += n_unmatches + + if not is_test: + object_ids = dataset_index.get_objects_present_in_scene_at_frame(scene_name, frame_index) + gt_frame_data = [dataset_index.get_object_data_label_by_frame_index(object_id, frame_index) for object_id in object_ids] + pred_frame_data, n_matches, n_unmatches = add_track_ids(pred_frame_data, gt_frame_data) + n_total_matches += n_matches + n_total_unmatches += n_unmatches + else: + n_total_matches += 0 + n_total_unmatches += 0 features_writer.write(scene_name, frame_index, pred_frame_data) diff --git a/tools/inference.py b/tools/inference.py old mode 100644 new mode 100755 index 95427c2..142585a --- a/tools/inference.py +++ b/tools/inference.py @@ -30,6 +30,10 @@ from mot_jepa.datasets.dataset.transform import Transform from mot_jepa.utils import pipeline from mot_jepa.utils.extra_features import ExtraFeaturesReader +import logging + + +logger = logging.getLogger('Inference') @@ -71,11 +75,91 @@ def __init__( self._next_id = 0 self._use_conf = use_conf - def _association( + def _convert_data( self, tracklets: List[Tracklet], objects_data: List[dict], frame_index: int + ) -> VideoClipData: + n_tracks = len(tracklets) + n_detections = len(objects_data) + + # Determine maximum size + N = max(n_tracks, n_detections) + + # Observed initialization + observed_ts = torch.zeros(N, self._clip_length, dtype=torch.long) + observed_temporal_mask = torch.ones(N, self._clip_length, dtype=torch.bool) + observed_features = PredictionBBoxFeatureExtractor.initialize_features( + feature_names=self._feature_names, + n_tracks=N, + temporal_length=self._clip_length, + ) + + time_offset = frame_index - self._clip_length + for t_i, tracklet in enumerate(tracklets): + for frame_info in tracklet.history: + hist_frame_index = frame_info.frame_index + data = frame_info.data + relative_index = hist_frame_index - time_offset + if relative_index < 0: + continue + + PredictionBBoxFeatureExtractor._set_features( + feature_names=self._feature_names, + features=observed_features, + object_index=t_i, + clip_index=relative_index, + data=data + ) + observed_ts[t_i, relative_index] = hist_frame_index + observed_temporal_mask[t_i, relative_index] = False + + # Unobserved initialization + unobserved_features = PredictionBBoxFeatureExtractor.initialize_features( + feature_names=self._feature_names, + n_tracks=N, + temporal_length=1, + ) + unobserved_ts = torch.zeros(N, dtype=torch.long) + unobserved_temporal_mask = torch.ones(N, dtype=torch.bool) + + unobserved_ts[:n_detections] = frame_index + unobserved_temporal_mask[:n_detections] = False + + for d_i, data in enumerate(objects_data): + PredictionBBoxFeatureExtractor._set_features( + feature_names=self._feature_names, + features=unobserved_features, + object_index=d_i, + clip_index=0, + data=data + ) + + # Remove temporal dimension + unobserved_features = {k: v[:, 0] for k, v in unobserved_features.items()} + + return VideoClipData( + observed=VideoClipPart( + ids=None, + ts=observed_ts, + mask=observed_temporal_mask, + features=observed_features + ), + unobserved=VideoClipPart( + ids=None, + ts=unobserved_ts, + mask=unobserved_temporal_mask, + features=unobserved_features + ) + ) + + def _association( + self, + tracklets: List[Tracklet], + objects_data: List[dict], + frame_index: int, + sim_threshold: float ) -> Tuple[List[Tuple[int, int]], List[int], List[int]]: n_tracks = len(tracklets) n_detections = len(objects_data) @@ -87,20 +171,35 @@ def _association( data = self._convert_data(tracklets, objects_data, frame_index) data = self._transform(data) data.apply(lambda x: x.unsqueeze(0).to(self._device)) - track_mm_features, det_mm_features, track_all_features, det_all_features = self._model( + track_mm_features, det_mm_features, _, _ = self._model( data.observed.features, data.observed.mask, data.unobserved.features, data.unobserved.mask ) - track_mm_features = track_mm_features[0].cpu() - det_mm_features = det_mm_features[0].cpu() - track_mm_features = F.normalize(track_mm_features, dim=-1) - det_mm_features = F.normalize(det_mm_features, dim=-1) - - cost_matrix = (track_mm_features[:n_tracks] @ det_mm_features[:n_detections].T).numpy() + track_mm_features = track_mm_features[0][:n_tracks] + det_mm_features = det_mm_features[0][:n_detections] + track_mm_features = F.normalize(track_mm_features, dim=-1).cpu() + det_mm_features = F.normalize(det_mm_features, dim=-1).cpu() + + ALPHA = 0.65 + if ALPHA > 0.0: + ema_track_mm_features = torch.zeros_like(track_mm_features) + for t_i in range(track_mm_features.shape[0]): + tracklet = tracklets[t_i] + track_ema = tracklet.get('track_ema') + if track_ema is None: + ema_track_mm_features[t_i] = track_mm_features[t_i] + else: + ema_track_mm_features[t_i] = ALPHA * track_ema + (1 - ALPHA) * track_mm_features[t_i] + ema_track_mm_features[t_i] = F.normalize(ema_track_mm_features[t_i], dim=-1) + tracklet.set('track_ema', ema_track_mm_features[t_i]) + else: + ema_track_mm_features = track_mm_features + + cost_matrix = (ema_track_mm_features @ det_mm_features.T).numpy() cost_matrix = 1 - (cost_matrix + 1) / 2 # [-1, 1] -> [0, 1] - cost_matrix[cost_matrix > self._sim_threshold] = np.inf + cost_matrix[cost_matrix > sim_threshold] = np.inf return hungarian(cost_matrix) @@ -118,7 +217,7 @@ def track(self, # Remove deleted tracklets = [t for t in tracklets if t.state != TrackletState.DELETED] - matches, unmatched_tracklets, unmatched_detections = self._association(tracklets, objects_data, frame_index) + matches, unmatched_tracklets, unmatched_detections = self._association(tracklets, objects_data, frame_index, sim_threshold=self._sim_threshold) # Handle matches for t_i, d_i in matches: @@ -164,86 +263,6 @@ def track(self, return tracklets - def _convert_data( - self, - tracklets: List[Tracklet], - objects_data: List[dict], - frame_index: int - ) -> VideoClipData: - n_tracks = len(tracklets) - n_detections = len(objects_data) - - # Determine maximum size - N = max(n_tracks, n_detections) - - # Observed initialization - observed_ts = torch.zeros(N, self._clip_length, dtype=torch.long) - observed_temporal_mask = torch.ones(N, self._clip_length, dtype=torch.bool) - observed_features = PredictionBBoxFeatureExtractor.initialize_features( - feature_names=self._feature_names, - n_tracks=N, - temporal_length=self._clip_length, - ) - - time_offset = frame_index - self._clip_length - for t_i, tracklet in enumerate(tracklets): - for frame_info in tracklet.history: - hist_frame_index = frame_info.frame_index - data = frame_info.data - relative_index = hist_frame_index - time_offset - if relative_index < 0: - continue - - PredictionBBoxFeatureExtractor._set_features( - feature_names=self._feature_names, - features=observed_features, - object_index=t_i, - clip_index=relative_index, - data=data - ) - observed_ts[t_i, relative_index] = hist_frame_index - observed_temporal_mask[t_i, relative_index] = False - - # Unobserved initialization - unobserved_features = PredictionBBoxFeatureExtractor.initialize_features( - feature_names=self._feature_names, - n_tracks=N, - temporal_length=1, - ) - unobserved_ts = torch.zeros(N, dtype=torch.long) - unobserved_temporal_mask = torch.ones(N, dtype=torch.bool) - - unobserved_ts[:n_detections] = frame_index - unobserved_temporal_mask[:n_detections] = False - - for d_i, data in enumerate(objects_data): - PredictionBBoxFeatureExtractor._set_features( - feature_names=self._feature_names, - features=unobserved_features, - object_index=d_i, - clip_index=0, - data=data - ) - - # Remove temporal dimension - unobserved_features = {k: v[:, 0] for k, v in unobserved_features.items()} - - return VideoClipData( - observed=VideoClipPart( - ids=None, - ts=observed_ts, - mask=observed_temporal_mask, - features=observed_features - ), - unobserved=VideoClipPart( - ids=None, - ts=unobserved_ts, - mask=unobserved_temporal_mask, - features=unobserved_features - ) - ) - - # class MyByteTracker(MyTracker): # def __init__( # self, @@ -377,6 +396,7 @@ def main(cfg: GlobalConfig) -> None: ) model = cfg.build_model() + logger.info(f'Loading model from {cfg.eval.checkpoint}.') state_dict = torch.load(cfg.eval.checkpoint) model.load_state_dict(state_dict['model']) @@ -417,7 +437,9 @@ def main(cfg: GlobalConfig) -> None: device=cfg.resources.accelerator, remember_threshold=30, use_conf=True, - sim_threshold=0.5 + sim_threshold=0.90, + initialization_threshold=1, + new_tracklet_detection_threshold=0.9 ) scene_names = dataset_index.scenes diff --git a/train_with_pretrained_encoders.sh b/train_with_pretrained_encoders.sh new file mode 100755 index 0000000..5d94907 --- /dev/null +++ b/train_with_pretrained_encoders.sh @@ -0,0 +1,3 @@ +for cfg in appearance bbox_only keypoints default; do + uv run tools/train.py --config-name="$cfg" train.truncate=true +done