|
| 1 | +#!/usr/bin/env python |
| 2 | +# coding: utf-8 |
| 3 | +import os |
| 4 | +import re |
| 5 | +from collections import Counter |
| 6 | +import pandas as pd |
| 7 | +import numpy as np |
| 8 | +import matplotlib.pyplot as plt |
| 9 | +import tensorflow as tf |
| 10 | +from tensorflow.keras.models import Sequential, load_model |
| 11 | +from tensorflow.keras.layers import BatchNormalization, Conv2D, MaxPooling2D, Activation, Flatten, Dropout, Dense |
| 12 | +from tensorflow.keras.preprocessing.image import img_to_array, load_img, ImageDataGenerator |
| 13 | +from tensorflow.keras.optimizers import Adam |
| 14 | +from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger |
| 15 | +from sklearn.preprocessing import LabelBinarizer |
| 16 | +from sklearn.model_selection import train_test_split |
| 17 | +from sklearn.utils.multiclass import unique_labels |
| 18 | +from sklearn.metrics import accuracy_score, classification_report, roc_curve, auc, balanced_accuracy_score, \ |
| 19 | + confusion_matrix, roc_auc_score |
| 20 | +from imblearn.over_sampling import RandomOverSampler |
| 21 | +from imblearn.under_sampling import RandomUnderSampler |
| 22 | +from imblearn.keras import balanced_batch_generator |
| 23 | +from scipy import interp |
| 24 | +from imutils import paths |
| 25 | + |
| 26 | +# Move images (Unless images already moved) |
| 27 | +campfire_df = pd.read_csv('dataset/campfire_subset.csv') |
| 28 | +images = list(paths.list_images('dataset/')) |
| 29 | + |
| 30 | +# for img_path in images: |
| 31 | +# obj_id = int(img_path.strip('dataset/OBJID_').strip('.tif')) |
| 32 | +# if obj_id in campfire_df.OBJECTID.values: |
| 33 | +# damage = campfire_df.loc[campfire_df.OBJECTID == obj_id].iloc[0].DAMAGE |
| 34 | +# os.rename('dataset/OBJID_{}.tif'.format(obj_id), 'dataset/{0}/OBJID_{1}.tif'.format(damage, obj_id)) |
| 35 | +# else: |
| 36 | +# os.rename('dataset/OBJID_{}.tif'.format(obj_id), 'dataset/Unburned (0%)/OBJID_{}.tif'.format(obj_id)) |
| 37 | + |
| 38 | +# Settings |
| 39 | + |
| 40 | +EPOCHS = 50 |
| 41 | +INIT_LR = 1e-3 |
| 42 | +BS = 16 |
| 43 | +IMAGE_DIMS = (128, 128, 3) |
| 44 | + |
| 45 | + |
| 46 | +# Create Dataset |
| 47 | + |
| 48 | +def create_dataset(path, width, height, resample=None, random_state=0): |
| 49 | + """ |
| 50 | + Converts a dataset of images in the directory structure {CLASS_LABEL}/{FILENAME}.{IMAGE_EXTENSION} |
| 51 | + to list of 3D NumPy arrays and their corresponding labels. Images resized to (width, height). Dataset |
| 52 | + is resampled to balance class distribution based on resample='over'|'under'. |
| 53 | + # Arguments |
| 54 | + path: path to dataset |
| 55 | + width: width of resized image |
| 56 | + height: height of resized image |
| 57 | + (optional) resample: resample dataset using ROS('over')/RUS('under') |
| 58 | + # Returns |
| 59 | + data: A list of 3D NumPy arrays converted from images |
| 60 | + labels: A list of labels of the images |
| 61 | + """ |
| 62 | + image_paths = list(paths.list_images(path)) |
| 63 | + labels = [image_path.split(os.path.sep)[-2] for image_path in image_paths] |
| 64 | + |
| 65 | + if resample: |
| 66 | + if resample == 'over': |
| 67 | + sampler = RandomOverSampler(random_state=random_state) |
| 68 | + elif resample == 'under': |
| 69 | + sampler = RandomUnderSampler(random_state=random_state) |
| 70 | + image_paths = [[image_path] for image_path in image_paths] |
| 71 | + image_paths_resampled, labels = sampler.fit_resample(image_paths, labels) |
| 72 | + image_paths = image_paths_resampled.ravel() |
| 73 | + |
| 74 | + data = [img_to_array(load_img(img_path, target_size=(width, height))) for img_path in image_paths] |
| 75 | + |
| 76 | + return np.array(data, dtype="float") / 255.0, np.array(labels) |
| 77 | + |
| 78 | + |
| 79 | +data, labels = create_dataset('dataset/', IMAGE_DIMS[0], IMAGE_DIMS[1]) |
| 80 | + |
| 81 | +np.save('dataset/data.npy', data) |
| 82 | +np.save('dataset/labels.npy', labels) |
| 83 | + |
| 84 | +# Load Dataset |
| 85 | + |
| 86 | +data, labels = np.load('dataset/data.npy'), np.load('dataset/labels.npy') |
| 87 | + |
| 88 | +classes = np.unique(labels) |
| 89 | +n_classes = len(classes) |
| 90 | + |
| 91 | +labels |
| 92 | + |
| 93 | +fig, axs = plt.subplots(2, 3) |
| 94 | +axs = axs.flatten() |
| 95 | +fig.set_size_inches((8, 8)) |
| 96 | +for i, c in enumerate(classes): |
| 97 | + axs[i].imshow(data[labels == c][np.random.choice(data[labels == c].shape[0], 1)[0]]) |
| 98 | + axs[i].set_title(c) |
| 99 | +fig.delaxes(axs[5]) |
| 100 | +fig.delaxes(axs[4]) |
| 101 | +for ax in axs: |
| 102 | + ax.label_outer() |
| 103 | + |
| 104 | +# explicitly call the chart |
| 105 | +plt.show() |
| 106 | + |
| 107 | +# Create Training/Testing Data |
| 108 | + |
| 109 | +X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=0, stratify=labels) |
| 110 | + |
| 111 | +Counter(y_train) |
| 112 | + |
| 113 | +Counter(y_test) |
| 114 | + |
| 115 | + |
| 116 | +def class_distribution(arr): |
| 117 | + total = sum(Counter(arr).values()) |
| 118 | + return {c: c_count / total for c, c_count in Counter(arr).items()} |
| 119 | + |
| 120 | + |
| 121 | +class_distribution(y_train) |
| 122 | + |
| 123 | +class_distribution(y_test) |
| 124 | + |
| 125 | +lb = LabelBinarizer() |
| 126 | +y_train_bin = lb.fit_transform(y_train) |
| 127 | +y_test_bin = lb.transform(y_test) |
| 128 | + |
| 129 | + |
| 130 | +# Model Building and Configuration |
| 131 | + |
| 132 | + |
| 133 | +class MiniVGGNet: |
| 134 | + |
| 135 | + def __init__(self, name, input_shape, n_classes, init_lr, epochs, batch_size): |
| 136 | + self.model = MiniVGGNet.build(input_shape=input_shape, n_classes=n_classes) |
| 137 | + self.epochs = epochs |
| 138 | + self.batch_size = batch_size |
| 139 | + opt = Adam(lr=init_lr, decay=init_lr / epochs) |
| 140 | + self.model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy']) |
| 141 | + model_filepath = 'model_checkpoints/{}/model.h5'.format(name) |
| 142 | + mcp_save = ModelCheckpoint(model_filepath, save_best_only=True, monitor='val_loss', mode='min') |
| 143 | + csv_logger = CSVLogger('model_checkpoints/{}/log.csv'.format(name)) |
| 144 | + self.callbacks = [mcp_save, csv_logger] |
| 145 | + |
| 146 | + def fit(self, X_train, y_train, X_test, y_test): |
| 147 | + return self.model.fit( |
| 148 | + X_train, |
| 149 | + y_train, |
| 150 | + batch_size=self.batch_size, |
| 151 | + validation_data=(X_test, y_test), |
| 152 | + epochs=self.epochs, |
| 153 | + callbacks=self.callbacks) |
| 154 | + |
| 155 | + def fit_generator(self, X_train, y_train, X_test, y_test, generator, steps_per_epoch): |
| 156 | + return self.model.fit_generator( |
| 157 | + generator, |
| 158 | + validation_data=(X_test, y_test), |
| 159 | + epochs=self.epochs, |
| 160 | + steps_per_epoch=steps_per_epoch, |
| 161 | + callbacks=self.callbacks) |
| 162 | + |
| 163 | + @staticmethod |
| 164 | + def build(input_shape, n_classes): |
| 165 | + model = Sequential() |
| 166 | + |
| 167 | + model.add(Conv2D(64, (3, 3), padding="same", input_shape=input_shape, data_format='channels_last')) |
| 168 | + model.add(Activation("relu")) |
| 169 | + model.add(BatchNormalization()) |
| 170 | + model.add(MaxPooling2D(pool_size=(2, 2))) |
| 171 | + model.add(Dropout(0.25)) |
| 172 | + |
| 173 | + model.add(Conv2D(128, (3, 3), padding="same")) |
| 174 | + model.add(Activation("relu")) |
| 175 | + model.add(BatchNormalization()) |
| 176 | + model.add(Conv2D(128, (3, 3), padding="same")) |
| 177 | + model.add(Activation("relu")) |
| 178 | + model.add(BatchNormalization()) |
| 179 | + model.add(MaxPooling2D(pool_size=(2, 2))) |
| 180 | + model.add(Dropout(0.25)) |
| 181 | + |
| 182 | + model.add(Conv2D(256, (3, 3), padding="same")) |
| 183 | + model.add(Activation("relu")) |
| 184 | + model.add(BatchNormalization()) |
| 185 | + model.add(Conv2D(256, (3, 3), padding="same")) |
| 186 | + model.add(Activation("relu")) |
| 187 | + model.add(BatchNormalization()) |
| 188 | + model.add(MaxPooling2D(pool_size=(2, 2))) |
| 189 | + model.add(Dropout(0.25)) |
| 190 | + |
| 191 | + model.add(Flatten()) |
| 192 | + model.add(Dense(1024)) |
| 193 | + model.add(Activation("relu")) |
| 194 | + model.add(BatchNormalization()) |
| 195 | + model.add(Dropout(0.5)) |
| 196 | + |
| 197 | + model.add(Dense(n_classes)) |
| 198 | + model.add(Activation("softmax")) |
| 199 | + |
| 200 | + return model |
| 201 | + |
| 202 | + |
| 203 | +# Model Fitting |
| 204 | + |
| 205 | +baseline = MiniVGGNet('baseline', IMAGE_DIMS, n_classes, INIT_LR, EPOCHS, BS) |
| 206 | + |
| 207 | +# takes some time to run, depending on hardware (resume here) |
| 208 | +baseline.fit(X_train, y_train_bin, X_test, y_test_bin) |
| 209 | + |
| 210 | +ros = MiniVGGNet('baseline_ros', IMAGE_DIMS, n_classes, INIT_LR, EPOCHS, BS) |
| 211 | + |
| 212 | +# ValueError: Found array with dim 4. Estimator expected <= 2. |
| 213 | +ros_generator, steps_per_epoch_ros = balanced_batch_generator( |
| 214 | + X_train, |
| 215 | + y_train_bin, |
| 216 | + sampler=RandomOverSampler(), |
| 217 | + batch_size=BS, |
| 218 | + random_state=0) |
| 219 | + |
| 220 | +ros.fit_generator(X_train, y_train_bin, X_test, y_test_bin, ros_generator, steps_per_epoch_ros) |
| 221 | + |
| 222 | +img_datagen = MiniVGGNet('baseline_datagen', IMAGE_DIMS, n_classes, INIT_LR, EPOCHS, BS) |
| 223 | + |
| 224 | +img_data_generator = ImageDataGenerator( |
| 225 | + rotation_range=25, |
| 226 | + width_shift_range=0.1, |
| 227 | + height_shift_range=0.1, |
| 228 | + shear_range=0.2, |
| 229 | + zoom_range=0.2, |
| 230 | + horizontal_flip=True, |
| 231 | + fill_mode="nearest") |
| 232 | + |
| 233 | +img_datagen.fit_generator(X_train, y_train_bin, X_test, y_test_bin, |
| 234 | + img_data_generator.flow(X_train, y_train_bin, batch_size=BS), X_train.shape[0] // BS) |
| 235 | + |
| 236 | +# ## Model Evaluation |
| 237 | + |
| 238 | + |
| 239 | +model = load_model('model_checkpoints/baseline_datagen/model.h5') |
| 240 | + |
| 241 | +y_test_pred = model.predict(X_test) |
| 242 | + |
| 243 | +y_test_pred_labels = np.argmax(y_test_pred, axis=-1) |
| 244 | +y_test_labels = np.argmax(y_test_bin, axis=-1) |
| 245 | + |
| 246 | +y_test_pred_bin = [] |
| 247 | +for l in y_test_pred_labels: |
| 248 | + pred_bin = [0] * n_classes |
| 249 | + pred_bin[l] = 1 |
| 250 | + y_test_pred_bin.append(pred_bin) |
| 251 | +y_test_pred_bin = np.array(y_test_pred_bin) |
| 252 | + |
| 253 | +accuracy_score(y_test_labels, y_test_pred_labels) |
| 254 | + |
| 255 | +balanced_accuracy_score(y_test_labels, y_test_pred_labels) |
| 256 | + |
| 257 | +print(classification_report(y_test_labels, y_test_pred_labels, target_names=classes)) |
| 258 | + |
| 259 | + |
| 260 | +def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, title=None, cmap=plt.cm.Blues): |
| 261 | + """ |
| 262 | + This function prints and plots the confusion matrix. |
| 263 | + Normalization can be applied by setting `normalize=True`. |
| 264 | + """ |
| 265 | + if not title: |
| 266 | + if normalize: |
| 267 | + title = 'Normalized confusion matrix' |
| 268 | + else: |
| 269 | + title = 'Confusion matrix, without normalization' |
| 270 | + |
| 271 | + # Compute confusion matrix |
| 272 | + cm = confusion_matrix(y_true, y_pred) |
| 273 | + # Only use the labels that appear in the data |
| 274 | + classes = classes[unique_labels(y_true, y_pred)] |
| 275 | + if normalize: |
| 276 | + cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] |
| 277 | + print("Normalized confusion matrix") |
| 278 | + else: |
| 279 | + print('Confusion matrix, without normalization') |
| 280 | + |
| 281 | + print(cm) |
| 282 | + |
| 283 | + fig, ax = plt.subplots() |
| 284 | + im = ax.imshow(cm, interpolation='nearest', cmap=cmap) |
| 285 | + ax.figure.colorbar(im, ax=ax) |
| 286 | + # We want to show all ticks... |
| 287 | + ax.set(xticks=np.arange(cm.shape[1]), |
| 288 | + yticks=np.arange(cm.shape[0]), |
| 289 | + # ... and label them with the respective list entries |
| 290 | + xticklabels=classes, yticklabels=classes, |
| 291 | + title=title, |
| 292 | + ylabel='True label', |
| 293 | + xlabel='Predicted label') |
| 294 | + |
| 295 | + # Rotate the tick labels and set their alignment. |
| 296 | + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", |
| 297 | + rotation_mode="anchor") |
| 298 | + |
| 299 | + # Loop over data dimensions and create text annotations. |
| 300 | + fmt = '.2f' if normalize else 'd' |
| 301 | + thresh = cm.max() / 2. |
| 302 | + for i in range(cm.shape[0]): |
| 303 | + for j in range(cm.shape[1]): |
| 304 | + ax.text(j, i, format(cm[i, j], fmt), |
| 305 | + ha="center", va="center", |
| 306 | + color="white" if cm[i, j] > thresh else "black") |
| 307 | + fig.tight_layout() |
| 308 | + return ax |
| 309 | + |
| 310 | + |
| 311 | +class_names = [re.sub(r' ?\([^)]+\)', '', c) for c in classes] |
| 312 | + |
| 313 | +ax = plot_confusion_matrix(y_test_labels, y_test_pred_labels, np.array(class_names), normalize=True, |
| 314 | + title='Datagen Model Confusion Matrix') |
| 315 | +plt.savefig('datagen_conf_matrix.png') |
| 316 | +plt.show() |
| 317 | + |
| 318 | +# Compute ROC curve and ROC area for each class |
| 319 | +fpr, tpr, roc_auc = {}, {}, {} |
| 320 | +for i, c in enumerate(class_names): |
| 321 | + fpr[c], tpr[c], _ = roc_curve(y_test_bin[:, i], y_test_pred_bin[:, i]) |
| 322 | + roc_auc[c] = auc(fpr[c], tpr[c]) |
| 323 | + |
| 324 | +# Compute micro-average ROC curve and ROC area |
| 325 | +fpr['micro'], tpr['micro'], _ = roc_curve(y_test_bin.ravel(), y_test_pred_bin.ravel()) |
| 326 | +roc_auc['micro'] = auc(fpr['micro'], tpr['micro']) |
| 327 | + |
| 328 | +# Compute macro-average ROC curve and ROC area |
| 329 | +lw = 2 |
| 330 | + |
| 331 | +# First aggregate all false positive rates |
| 332 | +all_fpr = np.unique(np.concatenate([fpr[c] for c in class_names])) |
| 333 | + |
| 334 | +# Then interpolate all ROC curves at this points |
| 335 | +mean_tpr = np.zeros_like(all_fpr) |
| 336 | +for c in class_names: |
| 337 | + mean_tpr += interp(all_fpr, fpr[c], tpr[c]) |
| 338 | + |
| 339 | +# Finally average it and compute AUC |
| 340 | +mean_tpr /= n_classes |
| 341 | + |
| 342 | +fpr["macro"] = all_fpr |
| 343 | +tpr["macro"] = mean_tpr |
| 344 | +roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) |
| 345 | + |
| 346 | +# Plot all ROC curves |
| 347 | +plt.figure() |
| 348 | +plt.plot(fpr["micro"], tpr["micro"], |
| 349 | + label='micro-average (area = {0:0.2f})' |
| 350 | + ''.format(roc_auc["micro"]), |
| 351 | + color='deeppink', linestyle=':', linewidth=4) |
| 352 | + |
| 353 | +plt.plot(fpr["macro"], tpr["macro"], |
| 354 | + label='macro-average (area = {0:0.2f})' |
| 355 | + ''.format(roc_auc["macro"]), |
| 356 | + color='navy', linestyle=':', linewidth=4) |
| 357 | + |
| 358 | +colors = ['aqua', 'darkorange', 'cornflowerblue', 'purple', 'green'] |
| 359 | +for c, color in zip(class_names, colors): |
| 360 | + plt.plot(fpr[c], tpr[c], color=color, lw=lw, |
| 361 | + label='{0} (area = {1:0.2f})' |
| 362 | + ''.format(c, roc_auc[c])) |
| 363 | + |
| 364 | +plt.plot([0, 1], [0, 1], 'k--', lw=lw) |
| 365 | +plt.xlim([0.0, 1.0]) |
| 366 | +plt.ylim([0.0, 1.05]) |
| 367 | +plt.xlabel('False Positive Rate') |
| 368 | +plt.ylabel('True Positive Rate') |
| 369 | +plt.title('Datagen Multiclass ROC/AUC') |
| 370 | +plt.legend(loc="lower right") |
| 371 | +plt.savefig('datagen_roc_auc.png') |
| 372 | +plt.show() |
| 373 | + |
| 374 | + |
| 375 | +# Extras |
| 376 | + |
| 377 | + |
| 378 | +def get_class_weights(y): |
| 379 | + counter = Counter(y) |
| 380 | + majority = max(counter.values()) |
| 381 | + return {cls: float(majority / count) for cls, count in counter.items()} |
| 382 | + |
| 383 | + |
| 384 | +class_weights_train = get_class_weights(labels) |
0 commit comments