Skip to content

Train with smaller classes #50

@KazuhideMimura

Description

@KazuhideMimura

Hi, thank you for sharing nice programs.

The following changes to utils.py will allow training with fewer classes. Rotating with 4 classes may be especially useful for training with small data.
There seems to be something wrong in visualization (function display_example), but I was able to train four classes.

def angle_difference(x, y, nb_classes = 360):
    """
    Calculate minimum difference between two angles.
    """
    assert 360 % nb_classes == 0, 'nb_classes should be a divisor of 360'
    unit_angle = 360 // nb_classes
    return 180 - abs(abs(x - y) * unit_angle - 180)


def angle_error(y_true, y_pred):
    """
    Calculate the mean diference between the true angles
    and the predicted angles. Each angle is represented
    as a binary vector.
    """
    diff = angle_difference(K.argmax(y_true), K.argmax(y_pred), nb_classes = y_pred.shape[1])
    return K.mean(K.cast(K.abs(diff), K.floatx()))

(omitted)

@Class RotNetDataGenerator

def __init__(self, input, input_shape=None, color_mode='rgb', batch_size=64,
                one_hot=True, preprocess_func=None, rotate=True, crop_center=False,
                crop_largest_rect=False, shuffle=False, seed=None, nb_classes = 360): # nb_classes is added
        
        assert 360 % nb_classes == 0, 'nb_classes should be a divisor of 360' # inserted

        self.images = None
        self.filenames = None
        self.input_shape = input_shape
        self.color_mode = color_mode
        self.batch_size = batch_size
        self.one_hot = one_hot
        self.preprocess_func = preprocess_func
        self.rotate = rotate
        self.crop_center = crop_center
        self.crop_largest_rect = crop_largest_rect
        self.shuffle = shuffle
        self.nb_classes = nb_classes # added
        self.unit_angle = 360 // nb_classes # added
        ...

(omitted)

def _get_batches_of_transformed_samples(self, index_array):
    # create array to hold the images
    batch_x = np.zeros((len(index_array),) + self.input_shape, dtype='float32')
    # create array to hold the labels
    batch_y = np.zeros(len(index_array), dtype='float32')

    # iterate through the current batch
    for i, j in enumerate(index_array):
        if self.filenames is None:
            image = self.images[j]
        else:
            is_color = int(self.color_mode == 'rgb')
            image = cv2.imread(self.filenames[j], is_color)
            if is_color:
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.rotate:
            # get a random angle
            rotation_angle = self.unit_angle * np.random.randint(self.nb_classes)
        else:
            rotation_angle = 0

        # generate the rotated image
        rotated_image = generate_rotated_image(
            image,
            rotation_angle,
            size=self.input_shape[:2],
            crop_center=self.crop_center,
            crop_largest_rect=self.crop_largest_rect
        )

        # add dimension to account for the channels if the image is greyscale
        if rotated_image.ndim == 2:
            rotated_image = np.expand_dims(rotated_image, axis=2)

        # store the image and label in their corresponding batches
        batch_x[i] = rotated_image
        batch_y[i] = rotation_angle // self.unit_angle

    if self.one_hot:
        # convert the numerical labels to binary labels
        batch_y = to_categorical(batch_y, self.nb_classes) # modified
    else:
        batch_y /= self.nb_classes
    ...

(omitted)

def display_examples(model, input, num_images=5, size=None, crop_center=False,
                    crop_largest_rect=False, preprocess_func=None, save_path=None,
                    nb_classes = 360): # nb_class was added
    """
    Given a model that predicts the rotation angle of an image,
    and a NumPy array of images or a list of image paths, display
    the specified number of example images in three columns:
    Original, Rotated and Corrected.
    """
    assert 360 % nb_classes == 0, 'nb_classes should be a divisor of 360' # added
    unit_angle = 360 // nb_classes # added
    
    if isinstance(input, (np.ndarray)):
        images = input
        N, h, w = images.shape[:3]
        if not size:
            size = (h, w)
        indexes = np.random.choice(N, num_images)
        images = images[indexes, ...]
    else:
        images = []
        filenames = input
        N = len(filenames)
        indexes = np.random.choice(N, num_images)
        for i in indexes:
            image = cv2.imread(filenames[i])
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            images.append(image)
        images = np.asarray(images)

    x = []
    y = []
    for image in images:
        rotation_angle = np.random.randint(nb_classes) * unit_angle
        rotated_image = generate_rotated_image(
            image,
            rotation_angle,
            size=size,
            crop_center=crop_center,
            crop_largest_rect=crop_largest_rect
        )
        x.append(rotated_image)
        y.append(rotation_angle // unit_angle)

    x = np.asarray(x, dtype='float32')
    y = np.asarray(y, dtype='float32')

    if x.ndim == 3:
        x = np.expand_dims(x, axis=3)

    y = to_categorical(y, nb_classes)

    x_rot = np.copy(x)

    if preprocess_func:
        x = preprocess_func(x)

    y = np.argmax(y, axis=1)
    y_pred = np.argmax(model.predict(x), axis=1)

    plt.figure(figsize=(10.0, 2 * num_images))

    title_fontdict = {
        'fontsize': 14,
        'fontweight': 'bold'
    }

    fig_number = 0
    for rotated_image, true_angle, predicted_angle in zip(x_rot, y, y_pred):
        true_angle *= unit_angle # added
        predicted_angle *= unit_angle # added
        ...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions