-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtransferlearning_tensowflow.py
More file actions
205 lines (147 loc) · 9.62 KB
/
transferlearning_tensowflow.py
File metadata and controls
205 lines (147 loc) · 9.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# -*- coding: utf-8 -*-
"""TransferLearning_Tensowflow.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1lLGY6gNiEmqI0GTD4GoxyT4YKFcY9S9I
**Objective**:
Run prediction (inference) using pre-trained models and dataset from Tensorflow.
**Importing the necessary libraries**
"""
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
"""**Loading the DATASET**
Here I'm using the CIFAR-10 dataset, taking only 0.1% of the dataset because it's taking a long time to process.
"""
# Load the CIFAR-10 dataset (0.1% of the data)
(train_dataset, test_dataset), dataset_info = tfds.load(
'cifar10',
split=['train[:0.1%]', 'test[:0.1%]'], # Use 0.1% of the data for both training and testing
as_supervised=True,
with_info=True
)
"""**Loading the Pre trained model from tensowflow**
Model 1: MobilenetV2
"""
# Load the MobileNetV2 model from TensorFlow Hub
mobile_net_model_url = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"
mobile_net_model = hub.load(mobile_net_model_url)
# Initialize variables to track accuracy for MobileNetV2
total_samples_mobile_net = 0
correct_predictions_mobile_net = 0
true_labels_mobile_net = []
predicted_labels_mobile_net = []
# Define class names for CIFAR-10
class_names = dataset_info.features['label'].names
"""**Calculating the confusion matrix for the model**"""
# Calculate the confusion matrix for MobileNetV2
confusion_mobile_net = confusion_matrix(true_labels_mobile_net, predicted_labels_mobile_net)
#print(confusion_mobile_net)
# Display the confusion matrix as a heatmap for MobileNetV2
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_mobile_net, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted Labels (MobileNetV2)")
plt.ylabel("True Labels")
plt.title("Confusion Matrix (MobileNetV2)")
plt.show()
"""Unfortunately, the matrix reveals a challenging scenario where the model didn't make any correct predictions within this particular dataset subset. The off-diagonal elements highlight misclassifications, indicating that the model struggled to distinguish between different classes. The dataset might exhibit class imbalance, which can further hinder the model's ability to learn effectively. With a limited amount of data, it becomes a difficult task for the model to generalize and make accurate predictions. Was able to use only 0.1% of the model due to high processing time. May be in the future if I could add more dataset to the pre trained model, the prediction might improve"""
# Iterate through the test dataset and calculate accuracy for MobileNetV2
for image, label in test_dataset:
# Preprocess the image for MobileNetV2
image_mobile_net = tf.image.resize(image, (224, 224)) # Resize to match the model's input size
image_mobile_net = tf.image.convert_image_dtype(image_mobile_net, tf.float32) # Convert to float32
image_mobile_net = tf.image.per_image_standardization(image_mobile_net) # Standardize
image_mobile_net = tf.expand_dims(image_mobile_net, axis=0) # Add batch dimension
logits_values_mobile_net = mobile_net_model(image_mobile_net)
predicted_label_mobile_net = tf.argmax(logits_values_mobile_net, axis=-1).numpy()
true_label = label.numpy()
true_labels_mobile_net.append(true_label)
predicted_labels_mobile_net.append(predicted_label_mobile_net)
total_samples_mobile_net += 1
if predicted_label_mobile_net == true_label:
correct_predictions_mobile_net += 1
accuracy_mobile_net = correct_predictions_mobile_net / total_samples_mobile_net
print(f'MobileNetV2 Accuracy: {accuracy_mobile_net * 100:.2f}%')
"""**Few Examples of the predicted Output**"""
# Flatten the predicted_labels_mobile_net using numpy ravel
predicted_labels_mobile_net_flat = np.array(predicted_labels_mobile_net).ravel()
# Convert flattened predicted labels to class names for MobileNetV2
#predicted_class_names_mobile_net = [class_names[i] for i in predicted_labels_mobile_net_flat]
# Convert predicted labels to class names for MobileNetV2
predicted_class_names_mobile_net = [class_names[np.argmax(labels)] for labels in predicted_labels_mobile_net]
# Convert true labels to class names for MobileNetV2
true_class_names_mobile_net = [class_names[i] for i in true_labels_mobile_net]
# Print true and predicted labels for three examples from the test dataset for MobileNetV2
for i in range(3):
print(f"Example {i+1}:")
print(f"True Label: {true_class_names_mobile_net[i]}")
print(f"Predicted Label: {predicted_class_names_mobile_net[i]}\n")
"""These predicted examples illustrate the variability in the model's performance, indicating room for improvement, possibly through fine-tuning or training on a more diverse dataset to enhance its ability to correctly classify a wider range of objects."""
# Clear memory to release resources used by MobileNetV2
tf.keras.backend.clear_session()
del mobile_net_model # Remove the MobileNetV2 model
"""**Loading the Pre trained model from tensowflow**
Model 2: ResNet-50
"""
# Load the ResNet-50 model from TensorFlow Hub
resnet_model_url = "https://tfhub.dev/tensorflow/resnet_50/classification/1"
resnet_model = hub.load(resnet_model_url)
# Initialize variables to track accuracy for ResNet-50
total_samples_resnet = 0
correct_predictions_resnet = 0
true_labels_resnet = []
predicted_labels_resnet = []
# Iterate through the test dataset and calculate accuracy for ResNet-50
for image, label in test_dataset:
# Preprocess the image for ResNet-50
image_resnet = tf.image.resize(image, (224, 224)) # Resize to match the model's input size
image_resnet = tf.image.convert_image_dtype(image_resnet, tf.float32) # Convert to float32
image_resnet = tf.image.per_image_standardization(image_resnet) # Standardize
image_resnet = tf.expand_dims(image_resnet, axis=0) # Add batch dimension
logits_values_resnet = resnet_model(image_resnet)
predicted_label_resnet = tf.argmax(logits_values_resnet, axis=-1).numpy()
true_label = label.numpy()
true_labels_resnet.append(true_label)
predicted_labels_resnet.append(predicted_label_resnet)
total_samples_resnet += 1
if predicted_label_resnet == true_label:
correct_predictions_resnet += 1
accuracy_resnet = correct_predictions_resnet / total_samples_resnet
print(f'ResNet-50 Accuracy: {accuracy_resnet * 100:.2f}%')
"""**Calculating the confusion matrix for the model**"""
# Calculate the confusion matrix for ResNet-50
confusion_resnet = confusion_matrix(true_labels_resnet, predicted_labels_resnet)
#print(confusion_resnet)
# Display the confusion matrix as a heatmap for ResNet-50
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_resnet, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted Labels (ResNet-50)")
plt.ylabel("True Labels")
plt.title("Confusion Matrix (ResNet-50)")
plt.show()
"""The confusion matrix for the ResNet-50 model shows that it sometimes gets things right, but it also faces difficulties. When it correctly predicts a label, that's shown on the diagonal line. However, there are instances where it predicts the wrong label, especially for classes like 9, 10, and 11. This might be because some classes have more examples than others in the dataset, making it harder for the model to learn them equally well. The overall accuracy of the model depends on how often it gets things right versus wrong, and this matrix helps us see where it struggles. To improve the model's performance, we might need to balance the classes better or fine-tune the model.
**Few Examples of the predicted Output**
"""
# Flatten the predicted_labels_resnet using numpy ravel
predicted_labels_resnet_flat = np.array(predicted_labels_resnet).ravel()
# Convert predicted labels to class names for ResNet
predicted_class_names_resnet = [class_names[np.argmax(labels)] for labels in predicted_labels_resnet]
# Convert true labels to class names for ResNet-50
true_class_names_resnet = [class_names[i] for i in true_labels_resnet]
# Print true and predicted labels for three examples from the test dataset for ResNet-50
for i in range(3):
print(f"Example {i+1}:")
print(f"True Label: {true_class_names_resnet[i]}")
print(f"Predicted Label: {predicted_class_names_resnet[i]}\n")
"""I have observed a notable discrepancy in the model's predictions. In the case of the ResNet-50 model used here, there are instances where the model's predictions do not align with the true labels, resulting in a lower accuracy.
Considering a couple of examples to illustrate this:
1. In the first example, the model predicted the label 'airplane,' while the true label for the image is 'horse.'
2. In the third example, the model predicted 'airplane,' even though the true label corresponds to 'frog.'
These instances of incorrect predictions significantly impact the overall accuracy.
**Overall Observarion:**
Both models exhibited a mix of accurate and erroneous predictions. They faced challenges in distinguishing between specific classes, and class imbalance likely contributed to the difficulties. The confusion matrices for both models demonstrated similar patterns of correct and incorrect predictions, further emphasizing the need for addressing class imbalance and possibly fine-tuning the models to improve their overall accuracy. his suggests we might need to balance things better and fine-tune the models to make them better at figuring out what's what. While these models are good in many ways, they could do even better on this dataset with a bit of fine-tuning and balance.
"""