diff --git a/python/caffe/io.py b/python/caffe/io.py index e1759beb587..28dda0beefd 100644 --- a/python/caffe/io.py +++ b/python/caffe/io.py @@ -256,7 +256,14 @@ def set_mean(self, in_, mean): if len(ms) != 3: raise ValueError('Mean shape invalid') if ms != self.inputs[in_][1:]: - raise ValueError('Mean shape incompatible with input shape.') + print(self.inputs[in_]) + in_shape = self.inputs[in_][1:] + m_min, m_max = mean.min(), mean.max() + normal_mean = (mean - m_min) / (m_max - m_min) + mean = resize_image(normal_mean.transpose((1,2,0)), + in_shape[1:]).transpose((2,0,1)) * \ + (m_max - m_min) + m_min + #raise ValueError('Mean shape incompatible with input shape.') self.mean[in_] = mean def set_input_scale(self, in_, scale): diff --git a/python/classify.py b/python/classify.py index 4544c51b4c2..17a672dfca3 100755 --- a/python/classify.py +++ b/python/classify.py @@ -5,6 +5,7 @@ By default it configures and runs the Caffe reference ImageNet model. """ import numpy as np +import pandas as pd import os import sys import argparse @@ -80,6 +81,17 @@ def main(argv): help="Order to permute input channels. The default converts " + "RGB -> BGR since BGR is the Caffe default by way of OpenCV." ) + parser.add_argument( + "--labels_file", + default=os.path.join(pycaffe_dir, + "../data/ilsvrc12/synset_words.txt"), + help="Readable label definition file." + ) + parser.add_argument( + "--print_results", + action='store_true', + help="Write output text to stdout rather than serializing to a file." + ) parser.add_argument( "--ext", default='jpg', @@ -93,6 +105,10 @@ def main(argv): mean, channel_swap = None, None if args.mean_file: mean = np.load(args.mean_file) + else: + # channel-wise mean + mean = np.array([104,117,123]) + if args.channel_swap: channel_swap = [int(s) for s in args.channel_swap.split(',')] @@ -126,12 +142,48 @@ def main(argv): # Classify. start = time.time() - predictions = classifier.predict(inputs, not args.center_only) + scores = classifier.predict(inputs, not args.center_only).flatten() print("Done in %.2f s." % (time.time() - start)) +# The script has been updated to support --print_results option. +# Ref - http://stackoverflow.com/questions/37265197/classify-py-is-not-taking-argument-print-results +# However, the labels format supported here has been modified, such that the file can have shorttext +# corresponding to the category classes instead of the general format. +# The commented part correspond to the general format of labels file which has mapping between +# synset_id and the text. + + if args.print_results: +# with open(args.labels_file) as f: +# labels_df = pd.DataFrame([ +# { +# 'synset_id': l.strip().split(' ')[0], +# 'name': ' '.join(l.strip().split(' ')[1:]).split(',')[0] +# } +# for l in f.readlines() +# ]) +# labels_df.synset_id = labels_df.synset_id.astype(np.int64) +# labels = labels_df.sort('synset_id')['name'].values + + labels_file = open(args.labels_file, 'r') + labels = labels_file.readlines() + + indices = (-scores).argsort()[:5] +# predictions = labels[indices] +# print(predictions) +# meta = [ +# (p, '%.5f' % scores[i]) +# for i, p in zip(indices, predictions) +# ] +# print meta + print("---------------------------------") + print("The top 5 predictions are") + for i in indices: + print('%.4f %s' % (scores[i] , labels[i].strip('\n'))) + print("---------------------------------") + # Save print("Saving results into %s" % args.output_file) - np.save(args.output_file, predictions) + np.save(args.output_file, scores) if __name__ == '__main__':