diff --git a/06-tutorial-classification/6.4-train-with-dataset-and-model.md b/06-tutorial-classification/6.4-train-with-dataset-and-model.md new file mode 100644 index 0000000..373ba53 --- /dev/null +++ b/06-tutorial-classification/6.4-train-with-dataset-and-model.md @@ -0,0 +1,234 @@ +# 使用数据集、模型进行训练和评估 + +## 训练 +训练使用官方自带的例子: https://github.com/apache/incubator-mxnet/blob/master/example/image-classification/train_cifar10.py + +### 设置数据集 +数据集挂载在 `/workspace/mnt/data/` 目录下,所以可以使用如下参数设置: + +训练集 `data-train`: `/workspace/mnt/data/train/cache/data.rec` + +测试集 `data-val`: `/workspace/mnt/data/val/cache/data.rec` + +训练样本数量 `num-examples`: `wc /workspace/mnt/data/train/cache/data.idx|awk '{print $1}'` + +训练类别 `num-classes`: `wc /workspace/mnt/data/train/labels.csv|awk '{print $1}'` + +加载数据集命令参数: `python train_cifar10.py --data-train=${data-train} --data-val=${data-val} --num-examples=${num-examples} --num-classes=${num-classes}` + +### 设置模型 +模型相关文件挂载在 `/workspace/mnt/model/` 目录下, `train_cifar10.py` 中没有加载模型结构文件的处理,需要稍作修改: + +加个 symbol 参数: `parser.add_argument('--symbol', type=str, help='the symbol file')` + +加载 symbol: `sym = mx.symbol.load(args.symbol)` + +这样就可以用 `python train_cifar10.py --symbol /workspace/mnt/model/symbol.json` 加载模型结构文件了。 + +目录 `/workspace/mnt/model/state` 用于存放模型参数,可以用 `--model-prefix=/workspace/mnt/model/state/train01` 以生成模型参数在此目录下(注意容量有上限)。同时也可以供后续训练进行 finetune。 + +一个常用的做法是,在 portal 上将最后生成的 epoch 记录在模型属性里(`load-epoch`),供下次读取继续: `jq .load-epoch /workspace/mnt/model/attrs` + +加载模型命令参数: `python train_cifar10.py --symbol /workspace/mnt/model/symbol.json --model-prefix=/workspace/mnt/model/state/train01 --load-epoch=${load-epoch}` + +### 完整训练主代码 +原读取数据和模型部分已注释掉,改为从命令参数读取: + +```python +import os +import argparse +import logging +logging.basicConfig(level=logging.DEBUG) +from common import find_mxnet, data, fit +from common.util import download_file +import mxnet as mx + +def download_cifar10(): + data_dir="data" + fnames = (os.path.join(data_dir, "cifar10_train.rec"), + os.path.join(data_dir, "cifar10_val.rec")) + download_file('http://data.mxnet.io/data/cifar10/cifar10_val.rec', fnames[1]) + download_file('http://data.mxnet.io/data/cifar10/cifar10_train.rec', fnames[0]) + return fnames + +if __name__ == '__main__': + # download data + #(train_fname, val_fname) = download_cifar10() + + # parse args + parser = argparse.ArgumentParser(description="train cifar10", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + fit.add_fit_args(parser) + data.add_data_args(parser) + data.add_data_aug_args(parser) + data.set_data_aug_level(parser, 2) + + parser.add_argument('--symbol', type=str, help='the symbol file') + + parser.set_defaults( + # network + network = 'resnet', + num_layers = 110, + # data + #data_train = train_fname, + #data_val = val_fname, + num_classes = 10, + num_examples = 50000, + image_shape = '3,28,28', + pad_size = 4, + # train + batch_size = 128, + num_epochs = 300, + lr = .05, + lr_step_epochs = '200,250', + ) + args = parser.parse_args() + + # load network + #from importlib import import_module + #net = import_module('symbols.'+args.network) + #sym = net.get_symbol(**vars(args)) + sym = mx.symbol.load(args.symbol) + + # train + fit.fit(args, sym, data.get_rec_iter) +``` + +命令参数示例: +``` +python train_cifar10.py \ + --gpus=0 \ + --num-epochs=30 \ + --data-train=/workspace/mnt/data/train/cache/data.rec \ + --data-val=/workspace/mnt/data/val/cache/data.rec \ + --num-examples=$(wc /workspace/mnt/data/train/cache/data.idx|awk '{print $1}') \ + --num-classes=$(wc /workspace/mnt/data/train/labels.csv|awk '{print $1}') \ + --symbol /workspace/mnt/model/symbol.json \ + --model-prefix=/workspace/mnt/model/state/train01 \ + --load-epoch=$(jq .load-epoch /workspace/mnt/model/attrs) +``` + +## 评估 +评估使用官方自带的例子:https://github.com/apache/incubator-mxnet/blob/master/example/image-classification/score.py + +### 设置数据集 +和训练类似。 + +### 设置模型 +如果是训练好的模型,其生成的参数权重在 `/workspace/mnt/model/state` 目录下。 + +``` +# ls /workspace/mnt/model/state +train01-0001.params train01-0002.params train01-0003.params train01-symbol.json +``` + +评估时加载对应的模型即可。 + +### 完整评估主代码 +原读取数据和模型部分已注释掉,改为从命令参数读取: + +``` +import argparse +from common import modelzoo, find_mxnet +import mxnet as mx +import time +import os +import logging + +def score(model, data_val, metrics, gpus, batch_size, rgb_mean=None, mean_img=None, + image_shape='3,224,224', data_nthreads=4, label_name='softmax_label', max_num_examples=None, + model_prefix=None, load_epoch=None): + # create data iterator + data_shape = tuple([int(i) for i in image_shape.split(',')]) + if mean_img is not None: + mean_args = {'mean_img':mean_img} + elif rgb_mean is not None: + rgb_mean = [float(i) for i in rgb_mean.split(',')] + mean_args = {'mean_r':rgb_mean[0], 'mean_g':rgb_mean[1], + 'mean_b':rgb_mean[2]} + + data = mx.io.ImageRecordIter( + path_imgrec = data_val, + label_width = 1, + preprocess_threads = data_nthreads, + batch_size = batch_size, + data_shape = data_shape, + label_name = label_name, + rand_crop = False, + rand_mirror = False, + **mean_args) + + sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, load_epoch) + #if isinstance(model, str): + # # download model + # dir_path = os.path.dirname(os.path.realpath(__file__)) + # (prefix, epoch) = modelzoo.download_model( + # model, os.path.join(dir_path, 'model')) + # sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) + #elif isinstance(model, tuple) or isinstance(model, list): + # assert len(model) == 3 + # (sym, arg_params, aux_params) = model + #else: + # raise TypeError('model type [%s] is not supported' % str(type(model))) + + # create module + if gpus == '': + devs = mx.cpu() + else: + devs = [mx.gpu(int(i)) for i in gpus.split(',')] + + mod = mx.mod.Module(symbol=sym, context=devs, label_names=[label_name,]) + mod.bind(for_training=False, + data_shapes=data.provide_data, + label_shapes=data.provide_label) + mod.set_params(arg_params, aux_params) + if not isinstance(metrics, list): + metrics = [metrics,] + tic = time.time() + num = 0 + for batch in data: + mod.forward(batch, is_train=False) + for m in metrics: + mod.update_metric(m, batch.label) + num += batch_size + if max_num_examples is not None and num > max_num_examples: + break + return (num / (time.time() - tic), ) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='score a model on a dataset') + parser.add_argument('--model', type=str, help = 'the model name.') + parser.add_argument('--gpus', type=str, default='0') + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--rgb-mean', type=str, default='0,0,0') + parser.add_argument('--data-val', type=str, required=True) + parser.add_argument('--image-shape', type=str, default='3,224,224') + parser.add_argument('--data-nthreads', type=int, default=4, + help='number of threads for data decoding') + parser.add_argument('--model-prefix', type=str, required=True) + parser.add_argument('--load-epoch', type=int, required=True) + args = parser.parse_args() + + logger = logging.getLogger() + logger.setLevel(logging.DEBUG) + + metrics = [mx.metric.create('acc'), + mx.metric.create('top_k_accuracy', top_k = 5)] + + (speed,) = score(metrics = metrics, **vars(args)) + logging.info('Finished with %f images per second', speed) + + for m in metrics: + logging.info(m.get()) +``` + +命令参数示例: +``` +python score.py \ + --gpus=0 \ + --data-val=/workspace/mnt/data/val/cache/data.rec \ + --load-epoch=1 \ + --model-prefix=/workspace/mnt/model/state/train01 \ + --image-shape=3,32,32 +```