Android配置tensorflow lite
按照官方网站的指导在项目的模块的构建文件build.gradle中配置中增加如下配置:
implementation 'org.tensorflow:tensorflow-lite:2.7.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.7.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.1.0'
implementation 'org.tensorflow:tensorflow-lite-metadata:0.1.0'
android{
aaptOptions {
noCompress "tflite"
}
defaultConfig {
ndk {
abiFilters 'armeabi-v7a', 'arm64-v8a'
}
}
}
导入模型资源资源
创建将文《关于将Tesorflow的SavedModel模型转换成tflite模型》创建的模型model.tflite,导入到Android项目的assets目录中。
定义模型基本配置类BaseModelConfig
/**
* 定义模型的基本配置类
*/
public abstract class BaseModelConfig{
//每通道处理的字节数
var numBytesPerChannel:Int = 0
//定义批处理的个数
var dimBatchSize:Int = 0
//定义像素个数
var dimPixelSize:Int = 0
//定义图片的宽度
var dimImgWidth:Int = 0
//定义图片的高度
var dimImgHeight:Int = 0
//定义平均差
var imageMean=0
//定义图片的标准差
var imageSTD:Float = 0.0F
//定义模型的名称
lateinit var modelName:String
constructor() : super() {
setConfigs()
}
/**
* 将像素值转换成ByteBuffer
* 增加图片的值
*/
public abstract fun addImgValue(buffer: ByteBuffer,pixel:Int)
/**
* 配置
*/
public abstract fun setConfigs()
}
定义FloatSavedModelConfig类
class FloatSavedModelConfig: BaseModelConfig() {
public override fun setConfigs() {
modelName="model.tflite"
numBytesPerChannel = 4
dimBatchSize = 1
dimPixelSize = 1
dimImgWidth = 28
dimImgHeight = 28
imageMean = 0
imageSTD = 255.0f
}
override fun addImgValue(imgData: ByteBuffer, pixel: Int) {
imgData.putFloat(((pixel and 0xFF) - imageMean) / imageSTD)
}
}
创建配置模型参数的工厂类
object ModelConfigFactory {
const val FLOAT_SAVED_MODEL = "float_saved_model"
const val QUANT_SAVED_MODEL = "quant_saved_model"
fun getModelConfig(model: String): BaseModelConfig? =
when(model) {
FLOAT_SAVED_MODEL-> FloatSavedModelConfig()
QUANT_SAVED_MODEL-> QuantSavedModelConfig()
else->null
}
}
定义图像分类器
class ImageClassifier {
private val TAG = "FashionMNIST"
private val RESULTS_TO_SHOW = 3
lateinit var mTFLite: Interpreter
lateinit var mModelPath:String
var mNumBytesPerChannel = 0
var mDimBatchSize = 0
var mDimPixelSize = 0
var mDimImgWidth = 0
var mDimImgHeight = 0
lateinit var mModelConfig:BaseModelConfig
//定义标签检测的二维数组1x10
val mLabelProbArray = Array(1) {
FloatArray(
10
)
}
val labels = arrayListOf("T恤","裤子","帽头衫","连衣裙","外套","凉鞋","衬衫","运动鞋","包","靴子")
//定义检测结果保持到优先队列中
var mSortedLabels = PriorityQueue<Map.Entry<String, Float>>(
RESULTS_TO_SHOW) {
o1, o2 -> o1?.value!!.compareTo(o2?.value!!)
}
/**
* 配置参数
*/
private fun initConfig(config: BaseModelConfig) {
mModelConfig = config
mNumBytesPerChannel = config.numBytesPerChannel
mDimBatchSize = config.dimBatchSize
mDimPixelSize = config.dimPixelSize
mDimImgWidth = config.dimImgWidth
mDimImgHeight = config.dimImgHeight
mModelPath = config.modelName
}
constructor(modelConfig: String, activity: Activity) {
// 初始化分类器的相关参数
initConfig(ModelConfigFactory.getModelConfig(modelConfig)!!)
// 使用配置参数初始化翻译器
mTFLite = Interpreter(loadModelFile(activity)!!)
}
/**
* 在Assets中的模型文件映射到内存中
* */
private fun loadModelFile(activity: Activity): MappedByteBuffer? {
val fileDescriptor = activity.assets.openFd(mModelPath)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}
/**
* 将图片数据写入到ByteBuffer,加载到内存中
* */
protected fun convertBitmapToByteBuffer(bitmap: Bitmap?): ByteBuffer {
val intValues = IntArray(mDimImgWidth * mDimImgHeight)
//调整要处理的图片为28x28
var tmp = scaleBitmap(bitmap)
//将图片二值化
tmp = binarized(tmp)
//将二值化的图片加载到内存中
tmp.getPixels(intValues,
0, tmp.width, 0, 0, tmp.width, tmp.height
)
val imgData = ByteBuffer.allocateDirect(
mNumBytesPerChannel * mDimBatchSize * mDimImgWidth * mDimImgHeight * mDimPixelSize
)
imgData.order(ByteOrder.nativeOrder())
imgData.rewind()
//将图片转换成像素实数数据
var pixel = 0
for (i in 0 until mDimImgWidth) {
for (j in 0 until mDimImgHeight) {
var value = intValues[pixel++]
mModelConfig.addImgValue(imgData, value)
}
}
return imgData
}
/**
* 将图片二值化处理
* 转换成二值图像
* @param bmp
* @return
*/
fun binarized(bmp: Bitmap): Bitmap {
val width = bmp.width
val height = bmp.height
val pixels = IntArray(width * height)
//将图片的像素加载到数组中
bmp.getPixels(pixels, 0, width, 0, 0, width, height)
var alpha = 0xFF shl 24
for (i in 0 until height) {
for (j in 0 until width) {
val grey = pixels[width * i + j]
// 分离三原色
alpha = grey and -0x1000000 shr 24
var red = grey and 0x00FF0000 shr 16
var green = grey and 0x0000FF00 shr 8
var blue = grey and 0x000000FF
val tmp = 180
red = if (red > tmp) 255 else 0
blue = if (blue > tmp) 255 else 0
green = if (green > tmp) 255 else 0
pixels[width * i + j] = alpha shl 24 or (red shl 16) or (green shl 8) or blue
if (pixels[width * i + j] == -1) {
pixels[width * i + j] = -1
} else {
pixels[width * i + j] = -16777216
}
}
}
// 新建图片
val newBmp = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)
// 设置图片数据
newBmp.setPixels(pixels, 0, width, 0, 0, width, height)
return newBmp
}
/**
* 将图片调整到规定的大小28x28
*/
fun scaleBitmap(bmp: Bitmap?): Bitmap {
return Bitmap.createScaledBitmap(bmp!!, mDimImgWidth, mDimImgHeight, true)
}
/**
* 分类处理
*/
fun doClassify(bitmap: Bitmap?): String? {
// 将Bitmap图片转换成TFLite翻译器的可读的ByteBuffer
val imgData = convertBitmapToByteBuffer(bitmap)
// do run interpreter
val startTime = System.nanoTime()
mTFLite.run(imgData, mLabelProbArray)
val endTime = System.nanoTime()
Log.i(TAG, String.format(
"运行识别的时间: %f ms",
(endTime - startTime).toFloat() / 1000000.0f
)
)
// 生成并返回结果
return printTopKLabels()
}
/**
* 打印检测排序在前几位的标签,并作为结果显示在UI界面中。
*/
fun printTopKLabels(): String? {
for (i in 0..9) {
mSortedLabels.add(
AbstractMap.SimpleEntry(
labels[i],
mLabelProbArray[0][i]
)
)
if (mSortedLabels.size > RESULTS_TO_SHOW) {
mSortedLabels.poll()
}
}
val textToShow = StringBuffer()
val size = mSortedLabels.size
for (i in 0 until size) {
val label = mSortedLabels.poll()
textToShow.insert(0, String.format("\n%s %4.8f", label.key, label.value))
}
return textToShow.toString()
}
}
定义主活动MainActivity
在主活动中,主要处理如下操作:
(1)从图库中选择图片
(2)利用图像分类器检测图片中的内容,判断是FashionMnist数据集的哪种标签
(3)将检测的结果在移动终端的GUI界面中显示出来。
class MainActivity : AppCompatActivity() {
private lateinit var binding: ActivityMainBinding
val RequestCameraCode = 1
val TAG = "FashionMNIST"
companion object{
var mIsFloat = true
}
private var bitmap: Bitmap? = null
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
//生成视图绑定对象
binding = ActivityMainBinding.inflate(layoutInflater)
//设置视图的根视图
setContentView(binding.root)
binding.imageView.setOnClickListener {
val intent = Intent()
intent.type = "image/*"
intent.action = Intent.ACTION_GET_CONTENT
startActivityForResult(intent,RequestCameraCode)
}
val spinnerAdapter = ArrayAdapter<String>(this,android.R.layout.simple_spinner_item,getChoices())
binding.typeSpinner.adapter = spinnerAdapter
binding.typeSpinner.onItemSelectedListener = object : OnItemSelectedListener {
override fun onItemSelected(
parent: AdapterView<*>?,
view: View,
position: Int,
id: Long
) {
mIsFloat = position == 0
}
override fun onNothingSelected(parent: AdapterView<*>?) {}
}
}
override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
super.onActivityResult(requestCode, resultCode, data)
if(resultCode == RESULT_OK && requestCode == RequestCameraCode){
val uri = data?.data
try{
//从图库中读取图片
var bitmap = BitmapFactory.decodeStream(contentResolver.openInputStream(uri!!))
//在图像视图ImageView中显示图片
binding.imageView.setImageBitmap(bitmap)
//判断模型类型
val config = when(mIsFloat){
true->ModelConfigFactory.FLOAT_SAVED_MODEL
else->ModelConfigFactory.QUANT_SAVED_MODEL
}
//根据模型类型创建图像识别器
val classifier = ImageClassifier(config,this)
//检测并判断图像的类别
val result = classifier.doClassify(bitmap)
binding.labelTxt.text = result
binding.tipTxt.visibility = View.GONE
}catch(e: FileNotFoundException){
Log.d(TAG,"没有找到指定的图像文件")
}catch(e: IOException){
Log.e(TAG,"初始化图像识别器失败")
}
}
}
/**
* 返回可用模型的名称
*/
private fun getChoices()= resources.getStringArray(R.array.model_names)
}
参考文献
李锡涵等 《简明的Tensorflow 2》人民邮电出版社 北京 P91-P96