diff --git a/notebooks/step-3-reusable-code.ipynb b/notebooks/step-3-reusable-code.ipynb
new file mode 100644
index 00000000..f38e79b9
--- /dev/null
+++ b/notebooks/step-3-reusable-code.ipynb
@@ -0,0 +1,786 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2019-06-16T21:17:31.460557Z",
+ "start_time": "2019-06-16T21:17:29.395297Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import itertools\n",
+ "import joblib\n",
+ "import json\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "from sklearn.metrics import confusion_matrix, f1_score\n",
+ "from sklearn.linear_model import LogisticRegression\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "import yaml"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/Users/jenif/course-ds-base\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Go to project root folder\n",
+ "%cd .."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'base': {'random_state': 42},\n",
+ " 'data': {'dataset_csv': 'data/raw/iris.csv',\n",
+ " 'features_path': 'data/processed/featured_iris.csv',\n",
+ " 'test_size': 0.2,\n",
+ " 'testset_path': 'data/processed/test_iris.csv',\n",
+ " 'trainset_path': 'data/processed/train_iris.csv'},\n",
+ " 'reports': {'confusion_matrix_image': 'reports/confusion_matrix.png',\n",
+ " 'metrics_file': 'reports/metrics.json'},\n",
+ " 'train': {'clf_params': {'C': 0.001,\n",
+ " 'max_iter': 100,\n",
+ " 'multi_class': 'multinomial',\n",
+ " 'solver': 'lbfgs'},\n",
+ " 'model_path': 'models/model.joblib'}}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Read config\n",
+ "import pprint\n",
+ "\n",
+ "with open('params.yaml') as conf_file:\n",
+ " config = yaml.safe_load(conf_file)\n",
+ "\n",
+ "pprint.pprint(config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Load dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2019-06-16T21:17:31.485189Z",
+ "start_time": "2019-06-16T21:17:31.473720Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " sepal length (cm) | \n",
+ " sepal width (cm) | \n",
+ " petal length (cm) | \n",
+ " petal width (cm) | \n",
+ " target | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 5.1 | \n",
+ " 3.5 | \n",
+ " 1.4 | \n",
+ " 0.2 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 4.9 | \n",
+ " 3.0 | \n",
+ " 1.4 | \n",
+ " 0.2 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 4.7 | \n",
+ " 3.2 | \n",
+ " 1.3 | \n",
+ " 0.2 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 4.6 | \n",
+ " 3.1 | \n",
+ " 1.5 | \n",
+ " 0.2 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 5.0 | \n",
+ " 3.6 | \n",
+ " 1.4 | \n",
+ " 0.2 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) \\\n",
+ "0 5.1 3.5 1.4 0.2 \n",
+ "1 4.9 3.0 1.4 0.2 \n",
+ "2 4.7 3.2 1.3 0.2 \n",
+ "3 4.6 3.1 1.5 0.2 \n",
+ "4 5.0 3.6 1.4 0.2 \n",
+ "\n",
+ " target \n",
+ "0 0 \n",
+ "1 0 \n",
+ "2 0 \n",
+ "3 0 \n",
+ "4 0 "
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Get data \n",
+ "\n",
+ "import pandas as pd\n",
+ "from sklearn.datasets import load_iris\n",
+ "\n",
+ "data = load_iris(as_frame=True)\n",
+ "dataset = data.frame\n",
+ "dataset.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0: setosa\n",
+ "1: versicolor\n",
+ "2: virginica\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[None, None, None]"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# print labels for target values \n",
+ "\n",
+ "[print(f'{target}: {label}') for target, label in zip(data.target.unique(), data.target_names)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2019-06-16T21:17:32.328046Z",
+ "start_time": "2019-06-16T21:17:32.323611Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['sepal_length', 'sepal_width', 'petal_length', 'petal_width']"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# feature names\n",
+ "\n",
+ "dataset.columns = [colname.strip(' (cm)').replace(' ', '_') for colname in dataset.columns.tolist()]\n",
+ "\n",
+ "feature_names = dataset.columns.tolist()[:4]\n",
+ "feature_names"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Save raw data\n",
+ "dataset.to_csv(config['data']['dataset_csv'], index=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Features engineering"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2019-06-16T21:21:02.150708Z",
+ "start_time": "2019-06-16T21:21:02.144518Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "dataset['sepal_length_to_sepal_width'] = dataset['sepal_length'] / dataset['sepal_width']\n",
+ "dataset['petal_length_to_petal_width'] = dataset['petal_length'] / dataset['petal_width']\n",
+ "\n",
+ "dataset = dataset[[\n",
+ " 'sepal_length', 'sepal_width', 'petal_length', 'petal_width',\n",
+ "# 'sepal_length_in_square', 'sepal_width_in_square', 'petal_length_in_square', 'petal_width_in_square',\n",
+ " 'sepal_length_to_sepal_width', 'petal_length_to_petal_width',\n",
+ " 'target'\n",
+ "]]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2019-06-16T21:21:02.987144Z",
+ "start_time": "2019-06-16T21:21:02.976092Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " sepal_length | \n",
+ " sepal_width | \n",
+ " petal_length | \n",
+ " petal_width | \n",
+ " sepal_length_to_sepal_width | \n",
+ " petal_length_to_petal_width | \n",
+ " target | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 5.1 | \n",
+ " 3.5 | \n",
+ " 1.4 | \n",
+ " 0.2 | \n",
+ " 1.457143 | \n",
+ " 7.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 4.9 | \n",
+ " 3.0 | \n",
+ " 1.4 | \n",
+ " 0.2 | \n",
+ " 1.633333 | \n",
+ " 7.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 4.7 | \n",
+ " 3.2 | \n",
+ " 1.3 | \n",
+ " 0.2 | \n",
+ " 1.468750 | \n",
+ " 6.5 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 4.6 | \n",
+ " 3.1 | \n",
+ " 1.5 | \n",
+ " 0.2 | \n",
+ " 1.483871 | \n",
+ " 7.5 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 5.0 | \n",
+ " 3.6 | \n",
+ " 1.4 | \n",
+ " 0.2 | \n",
+ " 1.388889 | \n",
+ " 7.0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " sepal_length sepal_width petal_length petal_width \\\n",
+ "0 5.1 3.5 1.4 0.2 \n",
+ "1 4.9 3.0 1.4 0.2 \n",
+ "2 4.7 3.2 1.3 0.2 \n",
+ "3 4.6 3.1 1.5 0.2 \n",
+ "4 5.0 3.6 1.4 0.2 \n",
+ "\n",
+ " sepal_length_to_sepal_width petal_length_to_petal_width target \n",
+ "0 1.457143 7.0 0 \n",
+ "1 1.633333 7.0 0 \n",
+ "2 1.468750 6.5 0 \n",
+ "3 1.483871 7.5 0 \n",
+ "4 1.388889 7.0 0 "
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dataset.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Save features\n",
+ "dataset.to_csv(config['data']['features_path'], index=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Split dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2019-06-16T21:21:07.438133Z",
+ "start_time": "2019-06-16T21:21:07.431649Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((120, 7), (30, 7))"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_dataset, test_dataset = train_test_split(\n",
+ " dataset, test_size=config['data']['test_size'],\n",
+ " random_state=config['base']['random_state']\n",
+ ")\n",
+ "train_dataset.shape, test_dataset.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Save train and test sets\n",
+ "train_dataset.to_csv(config['data']['trainset_path'])\n",
+ "test_dataset.to_csv(config['data']['testset_path'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Train"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2019-06-16T21:21:10.932148Z",
+ "start_time": "2019-06-16T21:21:10.927844Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Get X and Y\n",
+ "\n",
+ "y_train = train_dataset.loc[:, 'target'].values.astype('int32')\n",
+ "X_train = train_dataset.drop('target', axis=1).values.astype('float32')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2019-06-16T21:21:55.427365Z",
+ "start_time": "2019-06-16T21:21:55.416431Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "LogisticRegression(C=0.001, multi_class='multinomial', random_state=42)"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Create an instance of Logistic Regression Classifier CV and fit the data\n",
+ "\n",
+ "logreg = LogisticRegression(\n",
+ " **config['train']['clf_params'],\n",
+ " random_state=config['base']['random_state']\n",
+ ")\n",
+ "logreg.fit(X_train, y_train)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['models/model.joblib']"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "joblib.dump(logreg, config['train']['model_path'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Evaluate"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2019-06-16T21:21:55.875303Z",
+ "start_time": "2019-06-16T21:21:55.864724Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from src.report.visualization import plot_confusion_matrix"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2019-06-16T21:21:56.090756Z",
+ "start_time": "2019-06-16T21:21:56.086966Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Get X and Y\n",
+ "\n",
+ "y_test = test_dataset.loc[:, 'target'].values.astype('int32')\n",
+ "X_test = test_dataset.drop('target', axis=1).values.astype('float32')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2019-06-16T21:21:56.270245Z",
+ "start_time": "2019-06-16T21:21:56.265054Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "prediction = logreg.predict(X_test)\n",
+ "cm = confusion_matrix(prediction, y_test)\n",
+ "f1 = f1_score(y_true = y_test, y_pred = prediction, average='macro')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2019-06-16T21:21:56.493617Z",
+ "start_time": "2019-06-16T21:21:56.489929Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0.9305555555555555"
+ ]
+ },
+ "execution_count": 22,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# f1 score value\n",
+ "f1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Save metrics\n",
+ "metrics = {\n",
+ " 'f1': f1\n",
+ "}\n",
+ "\n",
+ "with open(config['reports']['metrics_file'], 'w') as mf:\n",
+ " json.dump(\n",
+ " obj=metrics,\n",
+ " fp=mf,\n",
+ " indent=4\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2019-06-16T21:21:56.966279Z",
+ "start_time": "2019-06-16T21:21:56.726149Z"
+ }
+ },
+ "outputs": [
+ {
+ "ename": "NameError",
+ "evalue": "name 'np' is not defined",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m--------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mcm_plot\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mplot_confusion_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_names\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnormalize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;32m~/course-ds-base/src/report/visualization.py\u001b[0m in \u001b[0;36mplot_confusion_matrix\u001b[0;34m(cm, target_names, title, cmap, normalize)\u001b[0m\n\u001b[1;32m 39\u001b[0m \"\"\"\n\u001b[1;32m 40\u001b[0m \u001b[0maccuracy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcm\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0mmisclass\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0maccuracy\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcmap\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mNameError\u001b[0m: name 'np' is not defined"
+ ]
+ }
+ ],
+ "source": [
+ "cm_plot = plot_confusion_matrix(cm, data.target_names, normalize=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "NameError",
+ "evalue": "name 'cm_plot' is not defined",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m--------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Save confusion matrix image\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mcm_plot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msavefig\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'reports'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'confusion_matrix_image'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;31mNameError\u001b[0m: name 'cm_plot' is not defined"
+ ]
+ }
+ ],
+ "source": [
+ "# Save confusion matrix image\n",
+ "cm_plot.savefig(config['reports']['confusion_matrix_image'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.5"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": true
+ },
+ "varInspector": {
+ "cols": {
+ "lenName": 16,
+ "lenType": 16,
+ "lenVar": 40
+ },
+ "kernels_config": {
+ "python": {
+ "delete_cmd_postfix": "",
+ "delete_cmd_prefix": "del ",
+ "library": "var_list.py",
+ "varRefreshCmd": "print(var_dic_list())"
+ },
+ "r": {
+ "delete_cmd_postfix": ") ",
+ "delete_cmd_prefix": "rm(",
+ "library": "var_list.r",
+ "varRefreshCmd": "cat(var_dic_list()) "
+ }
+ },
+ "types_to_exclude": [
+ "module",
+ "function",
+ "builtin_function_or_method",
+ "instance",
+ "_Feature"
+ ],
+ "window_display": false
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/src/report/visualization.py b/src/report/visualization.py
new file mode 100644
index 00000000..32952a43
--- /dev/null
+++ b/src/report/visualization.py
@@ -0,0 +1,74 @@
+import itertools
+import matplotlib.pyplot as plt
+import numpy as np
+
+def plot_confusion_matrix(cm,
+ target_names,
+ title='Confusion matrix',
+ cmap=None,
+ normalize=True):
+ """
+ given a sklearn confusion matrix (cm), make a nice plot
+
+ Arguments
+ ---------
+ cm: confusion matrix from sklearn.metrics.confusion_matrix
+
+ target_names: given classification classes such as [0, 1, 2]
+ the class names, for example: ['high', 'medium', 'low']
+
+ title: the text to display at the top of the matrix
+
+ cmap: the gradient of the values displayed from matplotlib.pyplot.cm
+ see http://matplotlib.org/examples/color/colormaps_reference.html
+ plt.get_cmap('jet') or plt.cm.Blues
+
+ normalize: If False, plot the raw numbers
+ If True, plot the proportions
+
+ Usage
+ -----
+ plot_confusion_matrix(cm = cm, # confusion matrix created by
+ # sklearn.metrics.confusion_matrix
+ normalize = True, # show proportions
+ target_names = y_labels_vals, # list of names of the classes
+ title = best_estimator_name) # title of graph
+ Citiation
+ ---------
+ http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
+ """
+ accuracy = np.trace(cm) / float(np.sum(cm))
+ misclass = 1 - accuracy
+
+ if cmap is None:
+ cmap = plt.get_cmap('Blues')
+
+ plt.figure(figsize=(8, 6))
+ plt.imshow(cm, interpolation='nearest', cmap=cmap)
+ plt.title(title)
+ plt.colorbar()
+
+ if target_names is not None:
+ tick_marks = np.arange(len(target_names))
+ plt.xticks(tick_marks, target_names, rotation=45)
+ plt.yticks(tick_marks, target_names)
+
+ if normalize:
+ cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
+
+ thresh = cm.max() / 1.5 if normalize else cm.max() / 2
+ for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
+ if normalize:
+ plt.text(j, i, "{:0.4f}".format(cm[i, j]),
+ horizontalalignment="center",
+ color="white" if cm[i, j] > thresh else "black")
+ else:
+ plt.text(j, i, "{:,}".format(cm[i, j]),
+ horizontalalignment="center",
+ color="white" if cm[i, j] > thresh else "black")
+
+ plt.tight_layout()
+ plt.ylabel('True label')
+ plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
+
+ return plt.gcf()
\ No newline at end of file