diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f62f8e3..9b2b494 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,9 +2,9 @@ name: CI on: push: - branches: [ "main", "master" ] + branches: [ "**" ] pull_request: - branches: [ "main", "master" ] + branches: [ "**" ] jobs: test: diff --git a/src/spatial_transcript_former/models/interaction.py b/src/spatial_transcript_former/models/interaction.py index 695cc77..77ace8a 100644 --- a/src/spatial_transcript_former/models/interaction.py +++ b/src/spatial_transcript_former/models/interaction.py @@ -301,23 +301,28 @@ def forward( # Standard nn.TransformerEncoder suppresses weights for performance. x_layer = sequence for layer in self.fusion_engine.layers: - # Multi-head attention bit + # 1. Attention Block + qkv = layer.norm1(x_layer) if layer.norm_first else x_layer + # We need to call the internal self_attn with need_weights=True + # and average_attn_weights=False to get per-head maps. attn_output, attn_weights = layer.self_attn( - x_layer, - x_layer, - x_layer, + qkv, + qkv, + qkv, attn_mask=interaction_mask, key_padding_mask=pad_mask, need_weights=True, + average_attn_weights=False, ) + print( + f"DEBUG: Internal attn_weights shape: {attn_weights.shape}" + ) # DEBUG attentions.append(attn_weights) - # Rest of the layer (as per nn.TransformerEncoderLayer) + # Continue forward pass (matching nn.TransformerEncoderLayer logic) if layer.norm_first: - x_layer = x_layer + layer._sa_block( - layer.norm1(x_layer), interaction_mask, pad_mask - ) + x_layer = x_layer + layer._sa_block(qkv, interaction_mask, pad_mask) x_layer = x_layer + layer._ff_block(layer.norm2(x_layer)) else: x_layer = layer.norm1( diff --git a/tests/test_dataset_logic.py b/tests/test_dataset_logic.py new file mode 100644 index 0000000..92ae4d3 --- /dev/null +++ b/tests/test_dataset_logic.py @@ -0,0 +1,106 @@ +import torch +import numpy as np +import pytest +from spatial_transcript_former.data.dataset import ( + apply_dihedral_augmentation, + apply_dihedral_to_tensor, + normalize_coordinates, +) + + +def test_apply_dihedral_augmentation_all_ops(): + """Verify all 8 dihedral operations against expected transformations.""" + # Unit square coordinates + coords = torch.tensor([[1.0, 1.0]]) + + # Expected results for (1, 1) under each op + expected = { + 0: [1.0, 1.0], # Identity + 1: [1.0, -1.0], # 90 CCW: (x,y) -> (y,-x) + 2: [-1.0, -1.0], # 180: (x,y) -> (-x,-y) + 3: [-1.0, 1.0], # 270 CCW: (x,y) -> (-y,x) + 4: [-1.0, 1.0], # Flip H: (-x,y) + 5: [1.0, -1.0], # Flip V: (x,-y) + 6: [1.0, 1.0], # Transpose: (y,x) + 7: [-1.0, -1.0], # Anti-transpose: (-y,-x) + } + + for op, exp in expected.items(): + out, _ = apply_dihedral_augmentation(coords, op=op) + assert torch.allclose(out, torch.tensor([exp])), f"Failed op {op}" + + +def test_dihedral_composition_properties(): + """Verify mathematical properties of the D4 group.""" + coords = torch.randn(10, 2) + + # Flip H (4) twice is identity + out, _ = apply_dihedral_augmentation(coords, op=4) + out2, _ = apply_dihedral_augmentation(out, op=4) + assert torch.allclose(out2, coords) + + # Rotate 90 (1) four times is identity + curr = coords + for _ in range(4): + curr, _ = apply_dihedral_augmentation(curr, op=1) + assert torch.allclose(curr, coords) + + # Transpose (6) is its own inverse + out, _ = apply_dihedral_augmentation(coords, op=6) + out2, _ = apply_dihedral_augmentation(out, op=6) + assert torch.allclose(out2, coords) + + +def test_normalize_coordinates_boundaries(): + """Verify step_size thresholds (0.5 and 2.0).""" + # Test step_size < 2.0 (Identity) + # x_vals: [0, 1.9] -> step 1.9 + c1 = np.array([[0.0, 0.0], [1.9, 0.0]]) + assert np.allclose(normalize_coordinates(c1), c1) + + # Test step_size == 2.0 (Normalize) + c2 = np.array([[0.0, 0.0], [2.0, 0.0]]) + assert np.allclose(normalize_coordinates(c2), [[0, 0], [1, 0]]) + + # Test valid_steps filtering (steps <= 0.5 are ignored) + # x_vals: [0, 0.5, 3.0] -> steps [0.5, 2.5]. + # valid_steps should only see 2.5 + c3 = np.array([[0.0, 0.0], [0.5, 0.0], [3.0, 0.0]]) + # step_size = 2.5. 0.5/2.5 = 0.2 -> rounds to 0. 3.0/2.5 = 1.2 -> rounds to 1 + assert np.allclose(normalize_coordinates(c3), [[0, 0], [0, 0], [1, 0]]) + + # x_vals: [0, 0.51, 3.0] -> steps [0.51, 2.49]. + # step_size = 0.51. 0.51/0.51 = 1. 3.0/0.51 = 5.88 -> 6 + # But wait, step_size 0.51 < 2.0, so it remains identity + c4 = np.array([[0.0, 0.0], [0.51, 0.0], [3.0, 0.0]]) + assert np.allclose(normalize_coordinates(c4), c4) + + +def test_apply_dihedral_to_tensor_consistency(): + """Verify all tensor ops match coordinate ops for a single point.""" + # Use a 3x3 tensor with a single hot spot at (2,0) -> row 0, col 2 + # Coordinates in centered frame for 3x3: + # (-1, -1) (0, -1) (1, -1) + # (-1, 0) (0, 0) (1, 0) + # (-1, 1) (0, 1) (1, 1) + # Point (1, 0) is index [1, 2] + + img = torch.zeros((1, 3, 3)) + img[0, 1, 2] = 1.0 + coords = torch.tensor([[1.0, 0.0]]) + + for op in range(8): + # Transform coord + aug_coords, _ = apply_dihedral_augmentation(coords, op=op) + ax, ay = aug_coords[0] + + # Transform image + aug_img = apply_dihedral_to_tensor(img, op) + + # Map back to indices: row = ay + 1, col = ax + 1 + row, col = int(ay + 1), int(ax + 1) + assert aug_img[0, row, col] == 1.0, f"Inconsistent mapping for op {op}" + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_dataset_mocks.py b/tests/test_dataset_mocks.py new file mode 100644 index 0000000..93a7c5a --- /dev/null +++ b/tests/test_dataset_mocks.py @@ -0,0 +1,121 @@ +import torch +import numpy as np +import pytest +from unittest.mock import MagicMock, patch +from spatial_transcript_former.data.dataset import HEST_Dataset, HEST_FeatureDataset + + +@pytest.fixture +def mock_h5_file(): + with patch("h5py.File") as mock_file: + mock_instance = mock_file.return_value + # Mock image data + mock_instance.__getitem__.side_effect = lambda key: { + "img": np.zeros((10, 224, 224, 3), dtype=np.uint8) + }[key] + yield mock_instance + + +def test_hest_dataset_augmentation_consistency(mock_h5_file): + """Verify that HEST_Dataset applies the same augmentation to pixels and coords.""" + # We need neighborhood_indices to trigger apply_dihedral_augmentation + coords = np.array([[10.0, 20.0], [30.0, 40.0]]) + genes = np.zeros((2, 100)) + indices = np.array([0, 1]) + neighborhood_indices = np.array([[1]]) # center 0 has neighbor 1 + + ds = HEST_Dataset( + h5_path="mock.h5", + spatial_coords=coords, + gene_matrix=genes, + indices=indices, + neighborhood_indices=neighborhood_indices, + coords_all=coords, + augment=True, + ) + + # We want to check if apply_dihedral_to_tensor and apply_dihedral_augmentation + # are called with the same 'op'. + with ( + patch( + "spatial_transcript_former.data.dataset.apply_dihedral_to_tensor" + ) as mock_tensor_aug, + patch( + "spatial_transcript_former.data.dataset.apply_dihedral_augmentation" + ) as mock_coord_aug, + ): + + mock_tensor_aug.side_effect = lambda img, op: img + mock_coord_aug.side_effect = lambda coords, op: (coords, op) + + _ = ds[0] + + # Check that both mocks were called + assert mock_tensor_aug.called, "apply_dihedral_to_tensor was not called" + assert mock_coord_aug.called, "apply_dihedral_augmentation was not called" + + # Check that the 'op' argument matches + tensor_op = mock_tensor_aug.call_args[0][1] + coord_op = mock_coord_aug.call_args[1]["op"] + assert tensor_op == coord_op + + +def test_hest_feature_dataset_neighborhood_dropout(): + """Verify that HEST_FeatureDataset correctly zeros out neighbors during augmentation.""" + n_neighbors = 2 + # Ensure features, coords, and barcodes all match in length (3) + feats = torch.ones((3, 128)) + coords = torch.zeros((3, 2)) + barcodes = [b"p0", b"p1", b"p2"] + + mock_gene_matrix = np.zeros((3, 10)) + mock_mask = [True, True, True] # Must match length of barcodes + mock_names = ["gene1"] + + with ( + patch("torch.load") as mock_load, + patch( + "spatial_transcript_former.data.dataset.load_gene_expression_matrix" + ) as mock_gene_load, + ): + + mock_load.return_value = { + "features": feats, + "coords": coords, + "barcodes": barcodes, + } + mock_gene_load.return_value = (mock_gene_matrix, mock_mask, mock_names) + + ds = HEST_FeatureDataset( + feature_path="mock.pt", + h5ad_path="mock.h5ad", + n_neighbors=n_neighbors, + augment=True, + ) + + # Run multiple times to trigger the stochastic dropout + dropout_occurred = False + for _ in range(100): + f, _, _ = ds[0] + # Center (index 0) should NEVER be zero + assert not torch.all(f[0] == 0) + + # Check if any neighbor is zero + if torch.any(torch.all(f[1:] == 0, dim=1)): + dropout_occurred = True + + assert dropout_occurred, "Neighborhood dropout augmentation was never triggered" + + +def test_hest_dataset_log1p_logic(mock_h5_file): + """Verify that log1p is applied to genes when enabled.""" + coords = np.array([[10.0, 20.0]]) + genes = np.array([[10.0]]) + + ds_no_log = HEST_Dataset("mock.h5", coords, genes, log1p=False) + _, g_no_log, _ = ds_no_log[0] + assert g_no_log[0] == 10.0 + + ds_log = HEST_Dataset("mock.h5", coords, genes, log1p=True) + _, g_log, _ = ds_log[0] + assert torch.allclose(g_log[0], torch.log1p(torch.tensor(10.0))) diff --git a/tests/test_interactions.py b/tests/test_interactions.py index 6847538..41d23bd 100644 --- a/tests/test_interactions.py +++ b/tests/test_interactions.py @@ -137,30 +137,34 @@ def test_attention_extraction(): # attentions is list of weights [layers] for i, attn in enumerate(attentions): print(f"Testing Layer {i}...") - # attn is (B, T, T) - assert attn.shape == (1, p + s, p + s) + # attn is (B, H, T, T) + h = model.fusion_engine.layers[0].self_attn.num_heads + assert attn.shape == (1, h, p + s, p + s) - # We expect blocked regions to have 0 attention - h2p_region = attn[0, p:, :p] - h2h_region = attn[0, p:, p:] + # We expect blocked regions to have 0 attention across all heads + # h2p_region is (H, s, p) + h2p_region = attn[0, :, p:, :p] + h2h_region = attn[0, :, p:, p:] - # For h2h, we must ignore diagonal - h2h_off_diag = h2h_region.clone() - h2h_off_diag.fill_diagonal_(0) + # For h2h, we must ignore diagonal within each head + # We can just check that the entire (H, s, s) block is 0 except the diag + h2h_zeroed = h2h_region.clone() + for head_idx in range(h): + h2h_zeroed[head_idx].fill_diagonal_(0) print(f"Layer {i} h2p attention max: {h2p_region.max().item():.2e}") - print(f"Layer {i} h2h off-diag attention max: {h2h_off_diag.max().item():.2e}") + print(f"Layer {i} h2h off-diag attention max: {h2h_zeroed.max().item():.2e}") assert ( h2p_region.max() < 1e-10 ), f"Layer {i} h2p attention should be zero when blocked" assert ( - h2h_off_diag.max() < 1e-10 + h2h_zeroed.max() < 1e-10 ), f"Layer {i} h2h attention should be zero when blocked" - # Check that allowed regions have non-zero attention - p2p_region = attn[0, :p, :p] - p2h_region = attn[0, :p, p:] + # Check that allowed regions have non-zero attention in at least one head + p2p_region = attn[0, :, :p, :p] + p2h_region = attn[0, :, :p, p:] print(f"Layer {i} p2p attention max: {p2p_region.max().item():.2e}") print(f"Layer {i} p2h attention max: {p2h_region.max().item():.2e}") diff --git a/tests/test_losses.py b/tests/test_losses.py index e9ffbeb..e5cfd34 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -134,6 +134,20 @@ def test_gradient_flow(self, tensors_2d): loss.backward() assert preds.grad is not None + def test_pcc_fallback_n1(self): + """Verify the N=1 fallback (batch-wise correlation) is robust.""" + # preds/target: (B, 1, G). With B=2, N=1 + preds = torch.tensor([[[1.0, 2.0]], [[2.0, 3.0]]]) + target = torch.tensor([[[1.0, 2.0]], [[2.0, 3.0]]]) + + # Perfect correlation => loss 0 + loss = PCCLoss()(preds, target) + assert loss.item() == pytest.approx(0.0, abs=1e-5) + + # Anti-correlation => loss 2 + loss_anti = PCCLoss()(preds, -target) + assert loss_anti.item() == pytest.approx(2.0, abs=1e-5) + # --------------------------------------------------------------------------- # CompositeLoss @@ -160,6 +174,17 @@ def test_alpha_zero_is_mse(self, tensors_2d): comp_val = CompositeLoss(alpha=0.0)(preds, target) assert torch.allclose(mse_val, comp_val, atol=1e-6) + def test_default_alpha(self, tensors_2d): + """Ensure the default alpha is 1.0.""" + preds, target = tensors_2d + loss_default = CompositeLoss()(preds, target) + loss_explicit = CompositeLoss(alpha=1.0)(preds, target) + assert torch.allclose(loss_default, loss_explicit) + + # Ensure it's NOT 0.0 + loss_mse = MaskedMSELoss()(preds, target) + assert not torch.allclose(loss_default, loss_mse) + def test_mask_support(self, tensors_3d): """CompositeLoss should handle masks in 3D mode.""" preds, target, mask = tensors_3d @@ -188,6 +213,21 @@ def test_different_alphas(self, tensors_2d): assert loss_low.item() != pytest.approx(loss_high.item(), abs=0.01) +class TestMaskedHuber: + def test_3d_mask_impact(self, tensors_3d): + """Verify that padding mask works correctly for Huber 3D.""" + from spatial_transcript_former.training.losses import MaskedHuberLoss + + preds, target, mask = tensors_3d + + loss_fn = MaskedHuberLoss() + loss_masked = loss_fn(preds, target, mask=mask) + loss_unmasked = loss_fn(preds, target) + + assert not torch.allclose(loss_masked, loss_unmasked) + assert loss_masked.isfinite() + + # --------------------------------------------------------------------------- # ZINBLoss # --------------------------------------------------------------------------- @@ -268,6 +308,42 @@ def test_mask_support(self): padded_grad = pi.grad[0, 5:, :] assert padded_grad.abs().sum() == 0.0 + def test_zinb_zero_vs_nonzero(self): + """Verify that ZINB treats 0 and non-zero targets differently (branch coverage).""" + from spatial_transcript_former.training.losses import ZINBLoss + + loss_fn = ZINBLoss() + + # B=1, G=2. G0=0 (zero-inflation branch), G1=10 (NB branch) + target = torch.tensor([[0.0, 10.0]]) + # Fix params for predictable results: pi=0.5, mu=1.0, theta=1.0 + pi = torch.tensor([[0.5, 0.5]]) + mu = torch.tensor([[1.0, 1.0]]) + theta = torch.tensor([[1.0, 1.0]]) + + # If we change target[0, 1] from 10 to 0, the loss should change significantly + loss1 = loss_fn((pi, mu, theta), target) + target2 = torch.tensor([[0.0, 0.0]]) + loss2 = loss_fn((pi, mu, theta), target2) + assert not torch.allclose(loss1, loss2) + + def test_zinb_extreme_stability(self): + """Verify stability with very large or small parameters (clamping logic).""" + from spatial_transcript_former.training.losses import ZINBLoss + + loss_fn = ZINBLoss() + target = torch.tensor([[0.0, 10.0]]) + + # mu and theta at extremes + pi = torch.tensor([[0.1, 0.1]]) + mu = torch.tensor([[1e-12, 1e12]]) + theta = torch.tensor([[1e12, 1e-12]]) + + loss = loss_fn((pi, mu, theta), target) + print(f"DEBUG: ZINB extreme loss value: {loss}") + assert torch.isfinite(loss).item(), f"Loss is not finite: {loss}" + assert not torch.isnan(loss).item(), f"Loss is NaN: {loss}" + # --------------------------------------------------------------------------- # AuxiliaryPathwayLoss @@ -415,6 +491,26 @@ def test_lambda_scaling(self, pathway_tensors): # term2 should be approx 2 * term1 assert term2.item() == pytest.approx(2 * term1.item(), rel=1e-4) + def test_auxiliary_lambda_sensitivity(self, pathway_tensors): + """Verify that changing lambda actually scales the pathway component.""" + gene_preds, targets, pw_preds, pw_matrix, mask = pathway_tensors + base = MaskedMSELoss() + + aux_low = AuxiliaryPathwayLoss(pw_matrix, base, lambda_pathway=0.1) + aux_high = AuxiliaryPathwayLoss(pw_matrix, base, lambda_pathway=10.0) + + loss_low = aux_low(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + loss_high = aux_high(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + + # high lambda should force a different loss value + assert not torch.allclose(loss_low, loss_high) + + # Verify that lambda=0 exactly matches gene loss + aux_zero = AuxiliaryPathwayLoss(pw_matrix, base, lambda_pathway=0.0) + loss_zero = aux_zero(gene_preds, targets, mask=mask, pathway_preds=pw_preds) + gene_only = base(gene_preds, targets, mask=mask) + assert torch.allclose(loss_zero, gene_only) + def test_hallmark_integration(self): """Test with a real (though small) MSigDB Hallmark matrix.""" from spatial_transcript_former.data.pathways import get_pathway_init diff --git a/tests/test_losses_robust.py b/tests/test_losses_robust.py new file mode 100644 index 0000000..019ecd2 --- /dev/null +++ b/tests/test_losses_robust.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +import pytest +from spatial_transcript_former.training.losses import ZINBLoss, PCCLoss + + +def test_zinb_gradient_flow(): + """Verify that gradients flow to all three parameters of ZINB.""" + zinb = ZINBLoss() + + B, G = 4, 10 + # Inputs require gradients + pi = torch.full((B, G), 0.5, requires_grad=True) + mu = torch.full((B, G), 10.0, requires_grad=True) + theta = torch.full((B, G), 1.0, requires_grad=True) + + target = torch.randint(0, 100, (B, G)).float() + + loss = zinb((pi, mu, theta), target) + loss.backward() + + assert pi.grad is not None + assert mu.grad is not None + assert theta.grad is not None + + assert not torch.allclose(pi.grad, torch.zeros_like(pi.grad)) + assert not torch.allclose(mu.grad, torch.zeros_like(mu.grad)) + assert not torch.allclose(theta.grad, torch.zeros_like(theta.grad)) + + +def test_zinb_stability(): + """Verify ZINB handles zeros and very large counts without NaNs.""" + zinb = ZINBLoss() + + # Extreme inputs + pi = torch.tensor([[1e-8, 0.5, 1.0 - 1e-8]]) + mu = torch.tensor([[1e-8, 100.0, 1e6]]) + theta = torch.tensor([[1e-8, 1.0, 1e6]]) + + # Extreme targets + target = torch.tensor([[0.0, 10.0, 1e6]]) + + loss = zinb((pi, mu, theta), target) + assert not torch.isnan(loss) + assert not torch.isinf(loss) + + +def test_pcc_edge_cases(): + """Verify PCC fallback and masking edge cases.""" + pcc = PCCLoss() + + # N=1 Case (fallback to batch-wise) + preds = torch.randn(4, 1, 10, requires_grad=True) + target = torch.randn(4, 1, 10) + + loss = pcc(preds, target) + assert loss.requires_grad + assert not torch.isnan(loss) + + # mask all but 1 in spatial dim -> N=1 fallback + preds_3d = torch.randn(2, 5, 3, requires_grad=True) + target_3d = torch.randn(2, 5, 3) + mask = torch.ones(2, 5, dtype=torch.bool) + mask[:, 0] = False # only first spot valid + + loss_masked = pcc(preds_3d, target_3d, mask=mask) + assert not torch.isnan(loss_masked) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_pathways_robust.py b/tests/test_pathways_robust.py new file mode 100644 index 0000000..6917b0c --- /dev/null +++ b/tests/test_pathways_robust.py @@ -0,0 +1,42 @@ +import torch +import pytest +from spatial_transcript_former.data.pathways import build_membership_matrix + + +def test_build_membership_matrix_integrity(): + """Verify that the membership matrix correctly maps genes to pathways.""" + pathway_dict = { + "PATHWAY_A": ["GENE_1", "GENE_2"], + "PATHWAY_B": ["GENE_2", "GENE_3"], + } + gene_list = ["GENE_1", "GENE_2", "GENE_3", "GENE_4"] + + matrix, names = build_membership_matrix(pathway_dict, gene_list) + + assert names == ["PATHWAY_A", "PATHWAY_B"] + assert matrix.shape == (2, 4) + + # Pathway A: GENE_1, GENE_2 + assert matrix[0, 0] == 1.0 + assert matrix[0, 1] == 1.0 + assert matrix[0, 2] == 0.0 + assert matrix[0, 3] == 0.0 + + # Pathway B: GENE_2, GENE_3 + assert matrix[1, 0] == 0.0 + assert matrix[1, 1] == 1.0 + assert matrix[1, 2] == 1.0 + assert matrix[1, 3] == 0.0 + + +def test_build_membership_matrix_empty(): + """Check behavior with no matches.""" + pathway_dict = {"EMPTY": ["XYZ"]} + gene_list = ["ABC", "DEF"] + matrix, names = build_membership_matrix(pathway_dict, gene_list) + assert matrix.sum() == 0 + assert names == ["EMPTY"] + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_spatial_interaction.py b/tests/test_spatial_interaction.py index b887043..b084ec4 100644 --- a/tests/test_spatial_interaction.py +++ b/tests/test_spatial_interaction.py @@ -209,3 +209,89 @@ def dummy_criterion(p, t, mask=None): assert torch.allclose( kwargs["rel_coords"], fake_coords ), "Validate engine passed wrong coordinate tensor!" + + +def test_spatial_encoder_normalization(): + """Verify LearnedSpatialEncoder handles extreme coords and centers them.""" + encoder = LearnedSpatialEncoder(64) + # Extreme coordinates: very far and very close + coords = torch.tensor([[[1000.0, 1000.0], [1000.1, 1000.1]]]) + normed = encoder._normalize_coords(coords) + + # Should be centered (mean 0) + assert torch.allclose(normed.mean(dim=1), torch.zeros(1, 2), atol=1e-5) + # Should be bounded by [-1, 1] + assert normed.abs().max() <= 1.0 + + # Verify forward doesn't crash + out = encoder(coords) + assert out.shape == (1, 2, 64) + + +def test_interaction_mask_bits(): + """Explicitly verify which bits are blocked in the interaction mask.""" + model = SpatialTranscriptFormer( + num_genes=50, interactions=["p2h", "h2p", "h2h"] + ) # No p2p + p, s = 2, 3 + mask = model._build_interaction_mask(p, s, torch.device("cpu")) + + # mask[i, j] is True if blocked + # p2p is index [0:p, 0:p]. Should be blocked (True) except diagonal + assert mask[0, 1] == True, "p2p interaction [0, 1] should be blocked" + + # p2h is index [0:p, p:]. Should be enabled (False) + assert mask[0, 2] == False, "p2h interaction [0, 2] should be enabled" + + # h2p is index [p:, 0:p]. Should be enabled (False) + assert mask[2, 0] == False, "h2p interaction [2, 0] should be enabled" + + # h2h is index [p:, p:]. Should be enabled (False) + assert mask[2, 3] == False, "h2h interaction [2, 3] should be enabled" + + +def test_temperature_scaling(): + """Verify log_temperature actually scales the pathway scores.""" + model = SpatialTranscriptFormer(num_genes=10, token_dim=64) + features = torch.randn(1, 4, 2048) + coords = torch.randn(1, 4, 2) + + # Initial scores with default temp + scores1 = model(features, rel_coords=coords, return_pathways=True)[1] + + # Manually increase log_temperature significantly + with torch.no_grad(): + model.log_temperature.fill_(10.0) # Massive temp + + scores2 = model(features, rel_coords=coords, return_pathways=True)[1] + + # Scores should be different and typically more extreme + assert not torch.allclose(scores1, scores2) + assert scores2.abs().max() > scores1.abs().max() + + +def test_return_attention_values(): + """Validate attention weight extraction logic.""" + model = SpatialTranscriptFormer( + num_genes=10, token_dim=64, n_heads=2, n_layers=2 + ).eval() + B, S, D = 1, 4, 2048 + features = torch.randn(B, S, D) + coords = torch.randn(B, S, 2) + P = model.num_pathways + + # [gene_expr, pw_scores, attentions] + with torch.no_grad(): + _, _, attentions = model( + features, rel_coords=coords, return_attention=True, return_pathways=True + ) + + assert len(attentions) == 2 # n_layers + for layer_attn in attentions: + # Expected shape: (B, n_heads, Total_T, Total_T) where Total_T = P + S + expected_shape = (B, 2, P + S, P + S) + assert layer_attn.shape == expected_shape + + # In eval mode, attention should sum to 1.0 across the last dimension (softmax) + sums = layer_attn.sum(dim=-1) + assert torch.allclose(sums, torch.ones_like(sums), atol=1e-6) diff --git a/tests/test_splitting_robust.py b/tests/test_splitting_robust.py new file mode 100644 index 0000000..09a4873 --- /dev/null +++ b/tests/test_splitting_robust.py @@ -0,0 +1,73 @@ +import pandas as pd +import pytest +import os +import tempfile +from spatial_transcript_former.data.splitting import split_hest_patients, main +import sys +from unittest.mock import patch + + +@pytest.fixture +def mock_metadata(): + """Create a temporary metadata CSV with known patient structure.""" + data = { + "id": ["S1", "S2", "S3", "S4", "S5", "S6"], + "patient": ["P1", "P1", "P2", "P2", "P3", None], # S6 has no patient + } + df = pd.DataFrame(data) + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + df.to_csv(f.name, index=False) + return f.name + + +def test_split_hest_patients_isolation(mock_metadata): + """Verify that patients are strictly isolated.""" + # With 3 patients (+ 1 unique fallback), one patient in val is 25% + train, val, test = split_hest_patients(mock_metadata, val_ratio=0.25, seed=42) + + # Check that no sample is in both + assert set(train).isdisjoint(set(val)) + + # Map back to patients + df = pd.read_csv(mock_metadata) + df["patient_filled"] = df["patient"].fillna(df["id"]) + + train_patients = set(df[df["id"].isin(train)]["patient_filled"]) + val_patients = set(df[df["id"].isin(val)]["patient_filled"]) + + # Critical check: No patient overlap + assert train_patients.isdisjoint(val_patients) + + # Cleanup + os.remove(mock_metadata) + + +def test_split_hest_patients_missing_id_fallback(): + """Verify that samples with missing patient IDs are treated as unique.""" + data = {"id": ["S1", "S2", "S3"], "patient": [None, None, None]} + df = pd.DataFrame(data) + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: + df.to_csv(f.name, index=False) + path = f.name + + # With 3 unique "patients", split. Since test_size=0.34, 1 should be in val. + train, val, test = split_hest_patients(path, val_ratio=0.34, seed=42) + # Ensure total is 3 and val/train are not empty (since 0.34 * 3 = 1.02) + assert len(train) + len(val) == 3 + assert len(val) >= 1 + assert len(train) >= 1 + + os.remove(path) + + +def test_splitting_main_cli(mock_metadata): + """Verify that the CLI main function runs without error and respects args.""" + test_args = ["prog", mock_metadata, "--val_ratio", "0.5", "--seed", "123"] + with patch.object(sys, "argv", test_args): + # Should not raise exception + main() + os.remove(mock_metadata) + + +if __name__ == "__main__": + pytest.main([__file__])