Skip to content

Commit 59ade32

Browse files
authored
Add files via upload
1 parent 0f5e0e0 commit 59ade32

File tree

2 files changed

+477
-0
lines changed

2 files changed

+477
-0
lines changed
Lines changed: 384 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,384 @@
1+
#!/usr/bin/env python
2+
# coding: utf-8
3+
import os
4+
import re
5+
from collections import Counter
6+
import pandas as pd
7+
import numpy as np
8+
import matplotlib.pyplot as plt
9+
import tensorflow as tf
10+
from tensorflow.keras.models import Sequential, load_model
11+
from tensorflow.keras.layers import BatchNormalization, Conv2D, MaxPooling2D, Activation, Flatten, Dropout, Dense
12+
from tensorflow.keras.preprocessing.image import img_to_array, load_img, ImageDataGenerator
13+
from tensorflow.keras.optimizers import Adam
14+
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
15+
from sklearn.preprocessing import LabelBinarizer
16+
from sklearn.model_selection import train_test_split
17+
from sklearn.utils.multiclass import unique_labels
18+
from sklearn.metrics import accuracy_score, classification_report, roc_curve, auc, balanced_accuracy_score, \
19+
confusion_matrix, roc_auc_score
20+
from imblearn.over_sampling import RandomOverSampler
21+
from imblearn.under_sampling import RandomUnderSampler
22+
from imblearn.keras import balanced_batch_generator
23+
from scipy import interp
24+
from imutils import paths
25+
26+
# Move images (Unless images already moved)
27+
campfire_df = pd.read_csv('dataset/campfire_subset.csv')
28+
images = list(paths.list_images('dataset/'))
29+
30+
# for img_path in images:
31+
# obj_id = int(img_path.strip('dataset/OBJID_').strip('.tif'))
32+
# if obj_id in campfire_df.OBJECTID.values:
33+
# damage = campfire_df.loc[campfire_df.OBJECTID == obj_id].iloc[0].DAMAGE
34+
# os.rename('dataset/OBJID_{}.tif'.format(obj_id), 'dataset/{0}/OBJID_{1}.tif'.format(damage, obj_id))
35+
# else:
36+
# os.rename('dataset/OBJID_{}.tif'.format(obj_id), 'dataset/Unburned (0%)/OBJID_{}.tif'.format(obj_id))
37+
38+
# Settings
39+
40+
EPOCHS = 50
41+
INIT_LR = 1e-3
42+
BS = 16
43+
IMAGE_DIMS = (128, 128, 3)
44+
45+
46+
# Create Dataset
47+
48+
def create_dataset(path, width, height, resample=None, random_state=0):
49+
"""
50+
Converts a dataset of images in the directory structure {CLASS_LABEL}/{FILENAME}.{IMAGE_EXTENSION}
51+
to list of 3D NumPy arrays and their corresponding labels. Images resized to (width, height). Dataset
52+
is resampled to balance class distribution based on resample='over'|'under'.
53+
# Arguments
54+
path: path to dataset
55+
width: width of resized image
56+
height: height of resized image
57+
(optional) resample: resample dataset using ROS('over')/RUS('under')
58+
# Returns
59+
data: A list of 3D NumPy arrays converted from images
60+
labels: A list of labels of the images
61+
"""
62+
image_paths = list(paths.list_images(path))
63+
labels = [image_path.split(os.path.sep)[-2] for image_path in image_paths]
64+
65+
if resample:
66+
if resample == 'over':
67+
sampler = RandomOverSampler(random_state=random_state)
68+
elif resample == 'under':
69+
sampler = RandomUnderSampler(random_state=random_state)
70+
image_paths = [[image_path] for image_path in image_paths]
71+
image_paths_resampled, labels = sampler.fit_resample(image_paths, labels)
72+
image_paths = image_paths_resampled.ravel()
73+
74+
data = [img_to_array(load_img(img_path, target_size=(width, height))) for img_path in image_paths]
75+
76+
return np.array(data, dtype="float") / 255.0, np.array(labels)
77+
78+
79+
data, labels = create_dataset('dataset/', IMAGE_DIMS[0], IMAGE_DIMS[1])
80+
81+
np.save('dataset/data.npy', data)
82+
np.save('dataset/labels.npy', labels)
83+
84+
# Load Dataset
85+
86+
data, labels = np.load('dataset/data.npy'), np.load('dataset/labels.npy')
87+
88+
classes = np.unique(labels)
89+
n_classes = len(classes)
90+
91+
labels
92+
93+
fig, axs = plt.subplots(2, 3)
94+
axs = axs.flatten()
95+
fig.set_size_inches((8, 8))
96+
for i, c in enumerate(classes):
97+
axs[i].imshow(data[labels == c][np.random.choice(data[labels == c].shape[0], 1)[0]])
98+
axs[i].set_title(c)
99+
fig.delaxes(axs[5])
100+
fig.delaxes(axs[4])
101+
for ax in axs:
102+
ax.label_outer()
103+
104+
# explicitly call the chart
105+
plt.show()
106+
107+
# Create Training/Testing Data
108+
109+
X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=0, stratify=labels)
110+
111+
Counter(y_train)
112+
113+
Counter(y_test)
114+
115+
116+
def class_distribution(arr):
117+
total = sum(Counter(arr).values())
118+
return {c: c_count / total for c, c_count in Counter(arr).items()}
119+
120+
121+
class_distribution(y_train)
122+
123+
class_distribution(y_test)
124+
125+
lb = LabelBinarizer()
126+
y_train_bin = lb.fit_transform(y_train)
127+
y_test_bin = lb.transform(y_test)
128+
129+
130+
# Model Building and Configuration
131+
132+
133+
class MiniVGGNet:
134+
135+
def __init__(self, name, input_shape, n_classes, init_lr, epochs, batch_size):
136+
self.model = MiniVGGNet.build(input_shape=input_shape, n_classes=n_classes)
137+
self.epochs = epochs
138+
self.batch_size = batch_size
139+
opt = Adam(lr=init_lr, decay=init_lr / epochs)
140+
self.model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
141+
model_filepath = 'model_checkpoints/{}/model.h5'.format(name)
142+
mcp_save = ModelCheckpoint(model_filepath, save_best_only=True, monitor='val_loss', mode='min')
143+
csv_logger = CSVLogger('model_checkpoints/{}/log.csv'.format(name))
144+
self.callbacks = [mcp_save, csv_logger]
145+
146+
def fit(self, X_train, y_train, X_test, y_test):
147+
return self.model.fit(
148+
X_train,
149+
y_train,
150+
batch_size=self.batch_size,
151+
validation_data=(X_test, y_test),
152+
epochs=self.epochs,
153+
callbacks=self.callbacks)
154+
155+
def fit_generator(self, X_train, y_train, X_test, y_test, generator, steps_per_epoch):
156+
return self.model.fit_generator(
157+
generator,
158+
validation_data=(X_test, y_test),
159+
epochs=self.epochs,
160+
steps_per_epoch=steps_per_epoch,
161+
callbacks=self.callbacks)
162+
163+
@staticmethod
164+
def build(input_shape, n_classes):
165+
model = Sequential()
166+
167+
model.add(Conv2D(64, (3, 3), padding="same", input_shape=input_shape, data_format='channels_last'))
168+
model.add(Activation("relu"))
169+
model.add(BatchNormalization())
170+
model.add(MaxPooling2D(pool_size=(2, 2)))
171+
model.add(Dropout(0.25))
172+
173+
model.add(Conv2D(128, (3, 3), padding="same"))
174+
model.add(Activation("relu"))
175+
model.add(BatchNormalization())
176+
model.add(Conv2D(128, (3, 3), padding="same"))
177+
model.add(Activation("relu"))
178+
model.add(BatchNormalization())
179+
model.add(MaxPooling2D(pool_size=(2, 2)))
180+
model.add(Dropout(0.25))
181+
182+
model.add(Conv2D(256, (3, 3), padding="same"))
183+
model.add(Activation("relu"))
184+
model.add(BatchNormalization())
185+
model.add(Conv2D(256, (3, 3), padding="same"))
186+
model.add(Activation("relu"))
187+
model.add(BatchNormalization())
188+
model.add(MaxPooling2D(pool_size=(2, 2)))
189+
model.add(Dropout(0.25))
190+
191+
model.add(Flatten())
192+
model.add(Dense(1024))
193+
model.add(Activation("relu"))
194+
model.add(BatchNormalization())
195+
model.add(Dropout(0.5))
196+
197+
model.add(Dense(n_classes))
198+
model.add(Activation("softmax"))
199+
200+
return model
201+
202+
203+
# Model Fitting
204+
205+
baseline = MiniVGGNet('baseline', IMAGE_DIMS, n_classes, INIT_LR, EPOCHS, BS)
206+
207+
# takes some time to run, depending on hardware (resume here)
208+
baseline.fit(X_train, y_train_bin, X_test, y_test_bin)
209+
210+
ros = MiniVGGNet('baseline_ros', IMAGE_DIMS, n_classes, INIT_LR, EPOCHS, BS)
211+
212+
# ValueError: Found array with dim 4. Estimator expected <= 2.
213+
ros_generator, steps_per_epoch_ros = balanced_batch_generator(
214+
X_train,
215+
y_train_bin,
216+
sampler=RandomOverSampler(),
217+
batch_size=BS,
218+
random_state=0)
219+
220+
ros.fit_generator(X_train, y_train_bin, X_test, y_test_bin, ros_generator, steps_per_epoch_ros)
221+
222+
img_datagen = MiniVGGNet('baseline_datagen', IMAGE_DIMS, n_classes, INIT_LR, EPOCHS, BS)
223+
224+
img_data_generator = ImageDataGenerator(
225+
rotation_range=25,
226+
width_shift_range=0.1,
227+
height_shift_range=0.1,
228+
shear_range=0.2,
229+
zoom_range=0.2,
230+
horizontal_flip=True,
231+
fill_mode="nearest")
232+
233+
img_datagen.fit_generator(X_train, y_train_bin, X_test, y_test_bin,
234+
img_data_generator.flow(X_train, y_train_bin, batch_size=BS), X_train.shape[0] // BS)
235+
236+
# ## Model Evaluation
237+
238+
239+
model = load_model('model_checkpoints/baseline_datagen/model.h5')
240+
241+
y_test_pred = model.predict(X_test)
242+
243+
y_test_pred_labels = np.argmax(y_test_pred, axis=-1)
244+
y_test_labels = np.argmax(y_test_bin, axis=-1)
245+
246+
y_test_pred_bin = []
247+
for l in y_test_pred_labels:
248+
pred_bin = [0] * n_classes
249+
pred_bin[l] = 1
250+
y_test_pred_bin.append(pred_bin)
251+
y_test_pred_bin = np.array(y_test_pred_bin)
252+
253+
accuracy_score(y_test_labels, y_test_pred_labels)
254+
255+
balanced_accuracy_score(y_test_labels, y_test_pred_labels)
256+
257+
print(classification_report(y_test_labels, y_test_pred_labels, target_names=classes))
258+
259+
260+
def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, title=None, cmap=plt.cm.Blues):
261+
"""
262+
This function prints and plots the confusion matrix.
263+
Normalization can be applied by setting `normalize=True`.
264+
"""
265+
if not title:
266+
if normalize:
267+
title = 'Normalized confusion matrix'
268+
else:
269+
title = 'Confusion matrix, without normalization'
270+
271+
# Compute confusion matrix
272+
cm = confusion_matrix(y_true, y_pred)
273+
# Only use the labels that appear in the data
274+
classes = classes[unique_labels(y_true, y_pred)]
275+
if normalize:
276+
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
277+
print("Normalized confusion matrix")
278+
else:
279+
print('Confusion matrix, without normalization')
280+
281+
print(cm)
282+
283+
fig, ax = plt.subplots()
284+
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
285+
ax.figure.colorbar(im, ax=ax)
286+
# We want to show all ticks...
287+
ax.set(xticks=np.arange(cm.shape[1]),
288+
yticks=np.arange(cm.shape[0]),
289+
# ... and label them with the respective list entries
290+
xticklabels=classes, yticklabels=classes,
291+
title=title,
292+
ylabel='True label',
293+
xlabel='Predicted label')
294+
295+
# Rotate the tick labels and set their alignment.
296+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
297+
rotation_mode="anchor")
298+
299+
# Loop over data dimensions and create text annotations.
300+
fmt = '.2f' if normalize else 'd'
301+
thresh = cm.max() / 2.
302+
for i in range(cm.shape[0]):
303+
for j in range(cm.shape[1]):
304+
ax.text(j, i, format(cm[i, j], fmt),
305+
ha="center", va="center",
306+
color="white" if cm[i, j] > thresh else "black")
307+
fig.tight_layout()
308+
return ax
309+
310+
311+
class_names = [re.sub(r' ?\([^)]+\)', '', c) for c in classes]
312+
313+
ax = plot_confusion_matrix(y_test_labels, y_test_pred_labels, np.array(class_names), normalize=True,
314+
title='Datagen Model Confusion Matrix')
315+
plt.savefig('datagen_conf_matrix.png')
316+
plt.show()
317+
318+
# Compute ROC curve and ROC area for each class
319+
fpr, tpr, roc_auc = {}, {}, {}
320+
for i, c in enumerate(class_names):
321+
fpr[c], tpr[c], _ = roc_curve(y_test_bin[:, i], y_test_pred_bin[:, i])
322+
roc_auc[c] = auc(fpr[c], tpr[c])
323+
324+
# Compute micro-average ROC curve and ROC area
325+
fpr['micro'], tpr['micro'], _ = roc_curve(y_test_bin.ravel(), y_test_pred_bin.ravel())
326+
roc_auc['micro'] = auc(fpr['micro'], tpr['micro'])
327+
328+
# Compute macro-average ROC curve and ROC area
329+
lw = 2
330+
331+
# First aggregate all false positive rates
332+
all_fpr = np.unique(np.concatenate([fpr[c] for c in class_names]))
333+
334+
# Then interpolate all ROC curves at this points
335+
mean_tpr = np.zeros_like(all_fpr)
336+
for c in class_names:
337+
mean_tpr += interp(all_fpr, fpr[c], tpr[c])
338+
339+
# Finally average it and compute AUC
340+
mean_tpr /= n_classes
341+
342+
fpr["macro"] = all_fpr
343+
tpr["macro"] = mean_tpr
344+
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
345+
346+
# Plot all ROC curves
347+
plt.figure()
348+
plt.plot(fpr["micro"], tpr["micro"],
349+
label='micro-average (area = {0:0.2f})'
350+
''.format(roc_auc["micro"]),
351+
color='deeppink', linestyle=':', linewidth=4)
352+
353+
plt.plot(fpr["macro"], tpr["macro"],
354+
label='macro-average (area = {0:0.2f})'
355+
''.format(roc_auc["macro"]),
356+
color='navy', linestyle=':', linewidth=4)
357+
358+
colors = ['aqua', 'darkorange', 'cornflowerblue', 'purple', 'green']
359+
for c, color in zip(class_names, colors):
360+
plt.plot(fpr[c], tpr[c], color=color, lw=lw,
361+
label='{0} (area = {1:0.2f})'
362+
''.format(c, roc_auc[c]))
363+
364+
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
365+
plt.xlim([0.0, 1.0])
366+
plt.ylim([0.0, 1.05])
367+
plt.xlabel('False Positive Rate')
368+
plt.ylabel('True Positive Rate')
369+
plt.title('Datagen Multiclass ROC/AUC')
370+
plt.legend(loc="lower right")
371+
plt.savefig('datagen_roc_auc.png')
372+
plt.show()
373+
374+
375+
# Extras
376+
377+
378+
def get_class_weights(y):
379+
counter = Counter(y)
380+
majority = max(counter.values())
381+
return {cls: float(majority / count) for cls, count in counter.items()}
382+
383+
384+
class_weights_train = get_class_weights(labels)

0 commit comments

Comments
 (0)