一. 说明
1. tensorflow是什么:
2. 为什么要使用tensorflow在手机端进行物体检测:
二. 准备工作
1. 下载tensorflow项目(地址:https://github.com/tensorflow/tensorflow)
2. python环境(我是用的是python3.5的)
3. tensorflow安装(有两个版本,cpu版,gpu版,显然gpu训练模型时要快很多,这个视电脑配置而定吧,一般安装anaconda,直接执行:conda install tensorflow安装即可,我使用了tensorflow-gpu版本,需要nvidia显卡支持,命令:conda install tensorflow-gpu)
4. Android Studio (我使用了2.3.3版本)
5. 下载 libtensorflow_inference.so 以及 libandroid_tensorflow_inference_java.jar文件,(这两个文件可以使用源码进行编译生成),链接:https://pan.baidu.com/s/1tN_nNqfy6JC272J17VaWTg 密码:boat
三. 训练自定义的tensorflow模型
1. 准备数据集
tensorflow_master/tensorflow/examples/image_retraining/data/train 中
2. 准备预训练模型
训练模型需要用到imagenet预训练权重,4个文件(classify_image_graph_def.pb,imagenet_2012_challenge_label_map_proto.pbtxt,imagenet_synset_to_human_label_map.txt,inception-2015-12-05.tgz),下载链接:链接:https://pan.baidu.com/s/1JlDbYy4NHD7qy3Or5lDtSg 密码:i3jo
提前下载拷贝至 model文件夹下,没有该文件夹请自行新建,否则会自动下载很慢的
3. 开始训练
cd 进入tensorflow_master\tensorflow\examples\image_retraining文件夹:
python retrain.py --bottleneck_dir bottleneck --how_many_training_steps 4000 --model_dir model/ --output_graph output_graph.pb --output_labels output_labels.txt --image_dir data/train/
output_graph.pb 以及 output_labels.txt
4. 上一步骤中生成的模型不能直接放置到Android中,需要一步转化:官方的解释:
To use v3 Inception model, strip the DecodeJpeg Op from your retrained
// model first:
cd 进入tensorflow_master\tensorflow\python\tools文件夹,将上步中生成的 output_graph.pb 文件复制到改目录下,执行命令:
python strip_unused.py --input_graph=output_graph.pb --output_graph=output.pb --input_node_names="Mul" --output_node_names="final_result" --input_binary=true
即可在改目录下生成 output.pb 文件。
四. 整合Android项目
1. 新建项目后,在\app\src\main目录下 新建assets以及jniLibs两个目录,将之前生成的 output.pb 以及 output_labels.txt文件拷贝至assets文件夹下
2. 在jniLibs文件夹下新建armeabi-v7a 文件夹,将 libtensorflow_inference.so 拷贝至 jniLibs\armeabi-v7a 文件夹下
3. 将libandroid_tensorflow_inference_java.jar 添加至项目中,不会的直接搜索 Android Studio添加jar。
4. 新建一个类(Classifier.Java):
import android.graphics.Bitmap;
import android.graphics.RectF;
import java.util.List;
* Created by amitshekhar on 06/03/17.
* Generic interface for interacting with different recognition engines.
public interface Classifier {
* An immutable result returned by a Classifier describing what was recognized.
public class Recognition {
* A unique identifier for what has been recognized. Specific to the class, not the instance of
* the object.
private final String id;
* Display name for the recognition.
private final String title;
* A sortable score for how good the recognition is relative to others. Higher should be better.
private final Float confidence;
* Optional location within the source image for the location of the recognized object.
private RectF location;
public Recognition(
final String id, final String title, final Float confidence, final RectF location) {
this.id = id;
this.title = title;
this.confidence = confidence;
this.location = location;
public String getId() {
return id;
public String getTitle() {
return title;
public Float getConfidence() {
return confidence;
public RectF getLocation() {
return new RectF(location);
public void setLocation(RectF location) {
this.location = location;
public String toString() {
String resultString = "";
if (id != null) {
resultString += "[" + id + "] ";
if (title != null) {
resultString += title + " ";
if (confidence != null) {
resultString += String.format("(%.1f%%) ", confidence * 100.0f);
if (location != null) {
resultString += location + " ";
return resultString.trim();
List<Recognition> recognizeImage(Bitmap bitmap);
void enableStatLogging(final boolean debug);
String getStatString();
void close();
5. 新建识别实现类 (
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.support.v4.os.TraceCompat;
import android.util.Log;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Vector;
import www.demo04.com.util.tensorflow.Classifier;
* Created by amitshekhar on 06/03/17.
* A classifier specialized to label images using TensorFlow.
public class TensorFlowImageClassifier implements Classifier {
private static final String TAG = "ImageClassifier";
// Only return this many results with at least this confidence.
private static final int MAX_RESULTS = 2;
private static final float THRESHOLD = 0.1f;
// Config values.
private String inputName;
private String outputName;
private int inputSize;
private int imageMean;
private float imageStd;
// Pre-allocated buffers.
private Vector<String> labels = new Vector<String>();
private int[] intValues;
private float[] floatValues;
private float[] outputs;
private String[] outputNames;
private TensorFlowInferenceInterface inferenceInterface;
private boolean runStats = false;
private TensorFlowImageClassifier() {
* Initializes a native TensorFlow session for classifying images.
* @param assetManager The asset manager to be used to load assets.
* @param modelFilename The filepath of the model GraphDef protocol buffer.
* @param labelFilename The filepath of label file for classes.
* @param inputSize The input size. A square image of inputSize x inputSize is assumed.
* @param imageMean The assumed mean of the image values.
* @param imageStd The assumed std of the image values.
* @param inputName The label of the image input node.
* @param outputName The label of the output node.
* @throws IOException
public static Classifier create(
AssetManager assetManager,
String modelFilename,
String labelFilename,
int inputSize,
int imageMean,
float imageStd,
String inputName,
String outputName)
throws IOException {
TensorFlowImageClassifier c = new TensorFlowImageClassifier();
c.inputName = inputName;
c.outputName = outputName;
// Read the label names into memory.
// TODO(andrewharp): make this handle non-assets.
String actualFilename = labelFilename.split("file:///android_asset/")[1];
Log.i(TAG, "Reading labels from: " + actualFilename);
BufferedReader br = null;
br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
String line;
while ((line = br.readLine()) != null) {
c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
// The shape of the output is [N, NUM_CLASSES], where N is the batch size.
int numClasses =
(int) c.inferenceInterface.graph().operation(outputName).output(0).shape().size(1);
Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);
// Ideally, inputSize could have been retrieved from the shape of the input operation. Alas,
// the placeholder node for input in the graphdef typically used does not specify a shape, so it
// must be passed in as a parameter.
c.inputSize = inputSize;
c.imageMean = imageMean;
c.imageStd = imageStd;
// Pre-allocate buffers.
c.outputNames = new String[]{outputName};
c.intValues = new int[inputSize * inputSize];
c.floatValues = new float[inputSize * inputSize * 3];
c.outputs = new float[numClasses];
/*if(c.inferenceInterface != null && c.inferenceInterface.graph() != null && c.inferenceInterface.graph().operations()!=null){
Iterator<Operation> operations = c.inferenceInterface.graph().operations();
Log.e("operation : ",""+operations.next().name());
return c;
public List<Recognition> recognizeImage(final Bitmap bitmap) {
// Log this method so that it can be analyzed with systrace.
// Preprocess the image data from 0-255 int to normalized float based
// on the provided parameters.
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
for (int i = 0; i < intValues.length; ++i) {
final int val = intValues[i];
floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd;
floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd;
floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd;
// Copy the input data into TensorFlow.
inputName, floatValues, new long[]{1, inputSize, inputSize, 3});
// Run the inference call.
inferenceInterface.run(outputNames, runStats);
// Copy the output Tensor back into the output array.
inferenceInterface.fetch(outputName, outputs);
// Find the best classifications.
PriorityQueue<Recognition> pq =
new PriorityQueue<Recognition>(
new Comparator<Recognition>() {
public int compare(Recognition lhs, Recognition rhs) {
// Intentionally reversed to put high confidence at the head of the queue.
return Float.compare(rhs.getConfidence(), lhs.getConfidence());
for (int i = 0; i < outputs.length; ++i) {
if (outputs[i] > THRESHOLD) {
new Recognition(
"" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null));
final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
for (int i = 0; i < recognitionsSize; ++i) {
TraceCompat.endSection(); // "recognizeImage"
return recognitions;
public void enableStatLogging(boolean debug) {
runStats = debug;
public String getStatString() {
return inferenceInterface.getStatString();
public void close() {
5. 在识别的Activity中
private static final int INPUT_SIZE = 299;
private static final int IMAGE_MEAN = 128;
private static final float IMAGE_STD = 128;
private static final String INPUT_NAME = "Mul";
private static final String OUTPUT_NAME = "final_result";
private static final String MODEL_FILE = "file:///android_asset/output.pb";
private static final String LABEL_FILE ="file:///android_asset/output_labels.txt";
添加初始化 tensorflow 方法:
private void initTensorFlowAndLoadModel() {
executor.execute(new Runnable() {
public void run() {
try {
classifier = TensorFlowImageClassifier.create(
} catch (final Exception e) {
throw new RuntimeException("Error initializing TensorFlow!", e);
这里的图片官方说法是使用299 * 299的,其他规格大小试了几个都有问题,有的大了,有的提示不是2048的倍数,总之不想一直纠结,可以将图片裁剪一下,一句话代码:
rightBitmap = Bitmap.createScaledBitmap(rightBitmap, 299, 299, true);
final List<Classifier.Recognition> results = classifier.recognizeImage(rightBitmap);
返回的 results 是一个List集合,存放有预测物体的名称,以及预测的准确率
可以发现比之前的opencv 准确多了。