Android配置tensorflow lite

按照官方网站的指导在项目的模块的构建文件build.gradle中配置中增加如下配置:

implementation 'org.tensorflow:tensorflow-lite:2.7.0'
    implementation 'org.tensorflow:tensorflow-lite-gpu:2.7.0'
    implementation 'org.tensorflow:tensorflow-lite-support:0.1.0'
    implementation 'org.tensorflow:tensorflow-lite-metadata:0.1.0'
android{
   aaptOptions {
        noCompress "tflite"
    }
  defaultConfig {
        ndk {
            abiFilters 'armeabi-v7a', 'arm64-v8a'
        }
    }
 }

导入模型资源资源

创建将文《关于将Tesorflow的SavedModel模型转换成tflite模型》创建的模型model.tflite,导入到Android项目的assets目录中。

定义模型基本配置类BaseModelConfig

/**
 * 定义模型的基本配置类
 */
public abstract class BaseModelConfig{
    //每通道处理的字节数
    var numBytesPerChannel:Int = 0
    //定义批处理的个数
    var dimBatchSize:Int = 0
    //定义像素个数
    var dimPixelSize:Int = 0
    //定义图片的宽度
    var dimImgWidth:Int = 0
    //定义图片的高度
    var dimImgHeight:Int = 0
    //定义平均差
    var imageMean=0
    //定义图片的标准差
    var imageSTD:Float = 0.0F
    //定义模型的名称
    lateinit var modelName:String

    constructor() : super() {
        setConfigs()
    }
    /**
     * 将像素值转换成ByteBuffer
     * 增加图片的值
     */
    public abstract fun addImgValue(buffer: ByteBuffer,pixel:Int)

    /**
     * 配置
     */
    public abstract fun setConfigs()
}

定义FloatSavedModelConfig类

class FloatSavedModelConfig: BaseModelConfig() {
    public override fun setConfigs() {
        modelName="model.tflite"
        numBytesPerChannel = 4
        dimBatchSize = 1
        dimPixelSize = 1
        dimImgWidth = 28
        dimImgHeight = 28
        imageMean = 0
        imageSTD = 255.0f
    }

    override fun addImgValue(imgData: ByteBuffer, pixel: Int) {
        imgData.putFloat(((pixel  and 0xFF) - imageMean) / imageSTD)
    }
}

创建配置模型参数的工厂类

object ModelConfigFactory {
    const val FLOAT_SAVED_MODEL = "float_saved_model"
    const val QUANT_SAVED_MODEL = "quant_saved_model"

    fun getModelConfig(model: String): BaseModelConfig? =
        when(model) {
            FLOAT_SAVED_MODEL-> FloatSavedModelConfig()
            QUANT_SAVED_MODEL-> QuantSavedModelConfig()
            else->null
        }
}

定义图像分类器

class ImageClassifier {
    private val TAG = "FashionMNIST"
    private val RESULTS_TO_SHOW = 3

    lateinit var mTFLite: Interpreter

    lateinit var mModelPath:String
    var mNumBytesPerChannel = 0

    var mDimBatchSize = 0
    var mDimPixelSize = 0

    var mDimImgWidth = 0
    var mDimImgHeight = 0

    lateinit var mModelConfig:BaseModelConfig

    //定义标签检测的二维数组1x10
    val mLabelProbArray = Array(1) {
        FloatArray(
            10
        )
    }
    val labels = arrayListOf("T恤","裤子","帽头衫","连衣裙","外套","凉鞋","衬衫","运动鞋","包","靴子")

    //定义检测结果保持到优先队列中
    var mSortedLabels = PriorityQueue<Map.Entry<String, Float>>(
                        RESULTS_TO_SHOW) {
            o1, o2 -> o1?.value!!.compareTo(o2?.value!!)
    }

    /**
     * 配置参数
     */
    private fun initConfig(config: BaseModelConfig) {
        mModelConfig = config
        mNumBytesPerChannel = config.numBytesPerChannel
        mDimBatchSize = config.dimBatchSize
        mDimPixelSize = config.dimPixelSize
        mDimImgWidth = config.dimImgWidth
        mDimImgHeight = config.dimImgHeight
        mModelPath = config.modelName
    }

    constructor(modelConfig: String, activity: Activity) {
        // 初始化分类器的相关参数
        initConfig(ModelConfigFactory.getModelConfig(modelConfig)!!)

        // 使用配置参数初始化翻译器
        mTFLite = Interpreter(loadModelFile(activity)!!)
    }

    /**
     * 在Assets中的模型文件映射到内存中
     * */
    private fun loadModelFile(activity: Activity): MappedByteBuffer? {
        val fileDescriptor = activity.assets.openFd(mModelPath)
        val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
        val fileChannel = inputStream.channel
        val startOffset = fileDescriptor.startOffset
        val declaredLength = fileDescriptor.declaredLength
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
    }

    /**
     * 将图片数据写入到ByteBuffer,加载到内存中
     * */
    protected fun convertBitmapToByteBuffer(bitmap: Bitmap?): ByteBuffer {
        val intValues = IntArray(mDimImgWidth * mDimImgHeight)
        //调整要处理的图片为28x28
        var tmp = scaleBitmap(bitmap)
        //将图片二值化
        tmp = binarized(tmp)

        //将二值化的图片加载到内存中
        tmp.getPixels(intValues,
            0, tmp.width, 0, 0, tmp.width, tmp.height
        )
        val imgData = ByteBuffer.allocateDirect(
            mNumBytesPerChannel * mDimBatchSize * mDimImgWidth * mDimImgHeight * mDimPixelSize
        )
        imgData.order(ByteOrder.nativeOrder())
        imgData.rewind()

        //将图片转换成像素实数数据
        var pixel = 0
        for (i in 0 until mDimImgWidth) {
            for (j in 0 until mDimImgHeight) {
                var value = intValues[pixel++]
                mModelConfig.addImgValue(imgData, value)
            }
        }
        return imgData
    }

    /**
     * 将图片二值化处理
     * 转换成二值图像
     * @param bmp
     * @return
     */
    fun binarized(bmp: Bitmap): Bitmap {
        val width = bmp.width
        val height = bmp.height
        val pixels = IntArray(width * height)
        //将图片的像素加载到数组中
        bmp.getPixels(pixels, 0, width, 0, 0, width, height)
        var alpha = 0xFF shl 24
        for (i in 0 until height) {
            for (j in 0 until width) {
                val grey = pixels[width * i + j]
                // 分离三原色
                alpha = grey and -0x1000000 shr 24
                var red = grey and 0x00FF0000 shr 16
                var green = grey and 0x0000FF00 shr 8
                var blue = grey and 0x000000FF
                val tmp = 180
                red = if (red > tmp) 255 else 0
                blue = if (blue > tmp) 255 else 0
                green = if (green > tmp) 255 else 0
                pixels[width * i + j] = alpha shl 24 or (red shl 16) or (green shl 8) or blue
                if (pixels[width * i + j] == -1) {
                    pixels[width * i + j] = -1
                } else {
                    pixels[width * i + j] = -16777216
                }
            }
        }
        // 新建图片
        val newBmp = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)
        // 设置图片数据
        newBmp.setPixels(pixels, 0, width, 0, 0, width, height)
        return newBmp
    }

    /**
     * 将图片调整到规定的大小28x28
     */
    fun scaleBitmap(bmp: Bitmap?): Bitmap {
        return Bitmap.createScaledBitmap(bmp!!, mDimImgWidth, mDimImgHeight, true)
    }

    /**
     * 分类处理
     */
    fun doClassify(bitmap: Bitmap?): String? {
        // 将Bitmap图片转换成TFLite翻译器的可读的ByteBuffer
        val imgData = convertBitmapToByteBuffer(bitmap)

        // do run interpreter
        val startTime = System.nanoTime()
        mTFLite.run(imgData, mLabelProbArray)
        val endTime = System.nanoTime()
        Log.i(TAG, String.format(
                "运行识别的时间: %f ms",
                (endTime - startTime).toFloat() / 1000000.0f
            )
        )

        // 生成并返回结果
        return printTopKLabels()
    }

    /**
     * 打印检测排序在前几位的标签,并作为结果显示在UI界面中。
     */
    fun printTopKLabels(): String? {
        for (i in 0..9) {
            mSortedLabels.add(
                AbstractMap.SimpleEntry(
                    labels[i],
                    mLabelProbArray[0][i]
                )
            )
            if (mSortedLabels.size > RESULTS_TO_SHOW) {
                mSortedLabels.poll()
            }
        }
        val textToShow = StringBuffer()
        val size = mSortedLabels.size
        for (i in 0 until size) {
            val label = mSortedLabels.poll()
            textToShow.insert(0, String.format("\n%s   %4.8f", label.key, label.value))
        }
        return textToShow.toString()
    }

}

定义主活动MainActivity

在主活动中,主要处理如下操作:
(1)从图库中选择图片
(2)利用图像分类器检测图片中的内容,判断是FashionMnist数据集的哪种标签
(3)将检测的结果在移动终端的GUI界面中显示出来。

class MainActivity : AppCompatActivity() {
    private lateinit var binding: ActivityMainBinding
    val RequestCameraCode = 1
    val TAG = "FashionMNIST"
    companion object{
        var mIsFloat = true
    }
    private var bitmap: Bitmap? = null
    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)

        //生成视图绑定对象
        binding = ActivityMainBinding.inflate(layoutInflater)
        //设置视图的根视图
        setContentView(binding.root)

        binding.imageView.setOnClickListener {
            val intent = Intent()
            intent.type = "image/*"
            intent.action = Intent.ACTION_GET_CONTENT
            startActivityForResult(intent,RequestCameraCode)
        }

        val spinnerAdapter = ArrayAdapter<String>(this,android.R.layout.simple_spinner_item,getChoices())
        binding.typeSpinner.adapter = spinnerAdapter
        
        binding.typeSpinner.onItemSelectedListener = object : OnItemSelectedListener {
            override fun onItemSelected(
                parent: AdapterView<*>?,
                view: View,
                position: Int,
                id: Long
            ) {
                mIsFloat = position == 0
            }

            override fun onNothingSelected(parent: AdapterView<*>?) {}
        }
    }

    override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
        super.onActivityResult(requestCode, resultCode, data)
        if(resultCode == RESULT_OK && requestCode == RequestCameraCode){
            val uri = data?.data
            try{
                //从图库中读取图片
                var bitmap = BitmapFactory.decodeStream(contentResolver.openInputStream(uri!!))
                //在图像视图ImageView中显示图片
                binding.imageView.setImageBitmap(bitmap)
                //判断模型类型
                val config = when(mIsFloat){
                    true->ModelConfigFactory.FLOAT_SAVED_MODEL
                    else->ModelConfigFactory.QUANT_SAVED_MODEL
                }
                //根据模型类型创建图像识别器
                val classifier = ImageClassifier(config,this)
                //检测并判断图像的类别
                val result = classifier.doClassify(bitmap)
                binding.labelTxt.text = result
                binding.tipTxt.visibility = View.GONE
            }catch(e: FileNotFoundException){
                Log.d(TAG,"没有找到指定的图像文件")
            }catch(e: IOException){
                Log.e(TAG,"初始化图像识别器失败")
            }

        }
    }
    /**
     * 返回可用模型的名称
     */
    private fun getChoices()= resources.getStringArray(R.array.model_names)

}

参考文献

李锡涵等 《简明的Tensorflow 2》人民邮电出版社 北京 P91-P96