Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 234 additions & 0 deletions 06-tutorial-classification/6.4-train-with-dataset-and-model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# 使用数据集、模型进行训练和评估

## 训练
训练使用官方自带的例子: https://github.com/apache/incubator-mxnet/blob/master/example/image-classification/train_cifar10.py

### 设置数据集
数据集挂载在 `/workspace/mnt/data/` 目录下,所以可以使用如下参数设置:

训练集 `data-train`: `/workspace/mnt/data/train/cache/data.rec`

测试集 `data-val`: `/workspace/mnt/data/val/cache/data.rec`

训练样本数量 `num-examples`: `wc /workspace/mnt/data/train/cache/data.idx|awk '{print $1}'`

训练类别 `num-classes`: `wc /workspace/mnt/data/train/labels.csv|awk '{print $1}'`

加载数据集命令参数: `python train_cifar10.py --data-train=${data-train} --data-val=${data-val} --num-examples=${num-examples} --num-classes=${num-classes}`

### 设置模型
模型相关文件挂载在 `/workspace/mnt/model/` 目录下, `train_cifar10.py` 中没有加载模型结构文件的处理,需要稍作修改:

加个 symbol 参数: `parser.add_argument('--symbol', type=str, help='the symbol file')`

加载 symbol: `sym = mx.symbol.load(args.symbol)`

这样就可以用 `python train_cifar10.py --symbol /workspace/mnt/model/symbol.json` 加载模型结构文件了。

目录 `/workspace/mnt/model/state` 用于存放模型参数,可以用 `--model-prefix=/workspace/mnt/model/state/train01` 以生成模型参数在此目录下(注意容量有上限)。同时也可以供后续训练进行 finetune。

一个常用的做法是,在 portal 上将最后生成的 epoch 记录在模型属性里(`load-epoch`),供下次读取继续: `jq .load-epoch /workspace/mnt/model/attrs`

加载模型命令参数: `python train_cifar10.py --symbol /workspace/mnt/model/symbol.json --model-prefix=/workspace/mnt/model/state/train01 --load-epoch=${load-epoch}`

### 完整训练主代码
原读取数据和模型部分已注释掉,改为从命令参数读取:

```python
import os
import argparse
import logging
logging.basicConfig(level=logging.DEBUG)
from common import find_mxnet, data, fit
from common.util import download_file
import mxnet as mx

def download_cifar10():
data_dir="data"
fnames = (os.path.join(data_dir, "cifar10_train.rec"),
os.path.join(data_dir, "cifar10_val.rec"))
download_file('http://data.mxnet.io/data/cifar10/cifar10_val.rec', fnames[1])
download_file('http://data.mxnet.io/data/cifar10/cifar10_train.rec', fnames[0])
return fnames

if __name__ == '__main__':
# download data
#(train_fname, val_fname) = download_cifar10()

# parse args
parser = argparse.ArgumentParser(description="train cifar10",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
fit.add_fit_args(parser)
data.add_data_args(parser)
data.add_data_aug_args(parser)
data.set_data_aug_level(parser, 2)

parser.add_argument('--symbol', type=str, help='the symbol file')

parser.set_defaults(
# network
network = 'resnet',
num_layers = 110,
# data
#data_train = train_fname,
#data_val = val_fname,
num_classes = 10,
num_examples = 50000,
image_shape = '3,28,28',
pad_size = 4,
# train
batch_size = 128,
num_epochs = 300,
lr = .05,
lr_step_epochs = '200,250',
)
args = parser.parse_args()

# load network
#from importlib import import_module
#net = import_module('symbols.'+args.network)
#sym = net.get_symbol(**vars(args))
sym = mx.symbol.load(args.symbol)

# train
fit.fit(args, sym, data.get_rec_iter)
```

命令参数示例:
```
python train_cifar10.py \
--gpus=0 \
--num-epochs=30 \
--data-train=/workspace/mnt/data/train/cache/data.rec \
--data-val=/workspace/mnt/data/val/cache/data.rec \
--num-examples=$(wc /workspace/mnt/data/train/cache/data.idx|awk '{print $1}') \
--num-classes=$(wc /workspace/mnt/data/train/labels.csv|awk '{print $1}') \
--symbol /workspace/mnt/model/symbol.json \
--model-prefix=/workspace/mnt/model/state/train01 \
--load-epoch=$(jq .load-epoch /workspace/mnt/model/attrs)
```

## 评估
评估使用官方自带的例子:https://github.com/apache/incubator-mxnet/blob/master/example/image-classification/score.py

### 设置数据集
和训练类似。

### 设置模型
如果是训练好的模型,其生成的参数权重在 `/workspace/mnt/model/state` 目录下。

```
# ls /workspace/mnt/model/state
train01-0001.params train01-0002.params train01-0003.params train01-symbol.json
```

评估时加载对应的模型即可。

### 完整评估主代码
原读取数据和模型部分已注释掉,改为从命令参数读取:

```
import argparse
from common import modelzoo, find_mxnet
import mxnet as mx
import time
import os
import logging

def score(model, data_val, metrics, gpus, batch_size, rgb_mean=None, mean_img=None,
image_shape='3,224,224', data_nthreads=4, label_name='softmax_label', max_num_examples=None,
model_prefix=None, load_epoch=None):
# create data iterator
data_shape = tuple([int(i) for i in image_shape.split(',')])
if mean_img is not None:
mean_args = {'mean_img':mean_img}
elif rgb_mean is not None:
rgb_mean = [float(i) for i in rgb_mean.split(',')]
mean_args = {'mean_r':rgb_mean[0], 'mean_g':rgb_mean[1],
'mean_b':rgb_mean[2]}

data = mx.io.ImageRecordIter(
path_imgrec = data_val,
label_width = 1,
preprocess_threads = data_nthreads,
batch_size = batch_size,
data_shape = data_shape,
label_name = label_name,
rand_crop = False,
rand_mirror = False,
**mean_args)

sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, load_epoch)
#if isinstance(model, str):
# # download model
# dir_path = os.path.dirname(os.path.realpath(__file__))
# (prefix, epoch) = modelzoo.download_model(
# model, os.path.join(dir_path, 'model'))
# sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
#elif isinstance(model, tuple) or isinstance(model, list):
# assert len(model) == 3
# (sym, arg_params, aux_params) = model
#else:
# raise TypeError('model type [%s] is not supported' % str(type(model)))

# create module
if gpus == '':
devs = mx.cpu()
else:
devs = [mx.gpu(int(i)) for i in gpus.split(',')]

mod = mx.mod.Module(symbol=sym, context=devs, label_names=[label_name,])
mod.bind(for_training=False,
data_shapes=data.provide_data,
label_shapes=data.provide_label)
mod.set_params(arg_params, aux_params)
if not isinstance(metrics, list):
metrics = [metrics,]
tic = time.time()
num = 0
for batch in data:
mod.forward(batch, is_train=False)
for m in metrics:
mod.update_metric(m, batch.label)
num += batch_size
if max_num_examples is not None and num > max_num_examples:
break
return (num / (time.time() - tic), )


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='score a model on a dataset')
parser.add_argument('--model', type=str, help = 'the model name.')
parser.add_argument('--gpus', type=str, default='0')
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--rgb-mean', type=str, default='0,0,0')
parser.add_argument('--data-val', type=str, required=True)
parser.add_argument('--image-shape', type=str, default='3,224,224')
parser.add_argument('--data-nthreads', type=int, default=4,
help='number of threads for data decoding')
parser.add_argument('--model-prefix', type=str, required=True)
parser.add_argument('--load-epoch', type=int, required=True)
args = parser.parse_args()

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

metrics = [mx.metric.create('acc'),
mx.metric.create('top_k_accuracy', top_k = 5)]

(speed,) = score(metrics = metrics, **vars(args))
logging.info('Finished with %f images per second', speed)

for m in metrics:
logging.info(m.get())
```

命令参数示例:
```
python score.py \
--gpus=0 \
--data-val=/workspace/mnt/data/val/cache/data.rec \
--load-epoch=1 \
--model-prefix=/workspace/mnt/model/state/train01 \
--image-shape=3,32,32
```