-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexplore.py
More file actions
21 lines (18 loc) · 914 Bytes
/
explore.py
File metadata and controls
21 lines (18 loc) · 914 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import matplotlib.pyplot as plt
def plot_label_distribution(labels, split, class_names):
# This function plots the class distribution for a given list of labels
# The parameters are the following:
# labels: list of integers indicating the class of each example
# split: a string specifying the split name you are ploting.
# This is used in the title of the graph.
# - e.g. 'train' or 'val'
# class_names: a list of class names corresponding to each class
# - e.g. "MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC"
# Given a flat list of integers `labels`, counts how many of each
counts = [ sum(labels == c) for c in range(len(class_names)) ]
# Plot a histogram of the distribution
plt.title(f'{split} distribution')
plt.bar(class_names, counts)
plt.xlabel('Class')
plt.ylabel('Num examples')
plt.show()