Skip to content
This repository was archived by the owner on Aug 1, 2024. It is now read-only.
This repository was archived by the owner on Aug 1, 2024. It is now read-only.

Bias in Multiblock Mask Collator #60

@joeagriffith

Description

@joeagriffith

I've been sampling the multi-block mask collator and plotting the masks to understand how they look, and believe I've found a bias that may have significant impact on the training of any models using this class.

The following patterns are consistently shown for batch sizes >= 128, and convey that many patches central in the image are never masked by enc_masks. Note this behaviour only occurs for allow_overlap=False.

Here are four examples I've sampled using the code below, with no cherry picking.
Each image is a 128 sized batch of enc masks, generated using the default arguments for the multi-block mask collator.
Each pixel represents a patch, and is white iff that patch is included in any of the masks in its batch.
Repro code below.

image
image
image
image

import torch
from src.mask import MaskCollator
import matplotlib.pyplot as plt

collator = MaskCollator()

batch = [torch.randn(3, 224, 224) for _ in range(1024)]
batch = collator(batch)
batch, enc_masks, pred_masks = batch

def display_mask(mask):
    # mask is a tensor of indices from 0 to 195
    # can be individual mask, or multiple.
    # display a 14x14 grid, where each cell is on if the corresponding index is in the mask
    grid = torch.zeros(14,14)
    for i in range(196):
        grid[i // 14, i % 14] = 1 if i in mask else 0
    plt.imshow(grid, cmap='gray')
    plt.show()

# change second index from ':' to integer to visualise individual masks
display_mask(enc_masks[0][:])

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions