本文为稀土技能社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!
老规矩,发文先上图。
不知什么时候起,我被读者惯出这个毛病。或许说,是我培养出读者这个好习惯。你评论问一句,我就来一篇万字长文。往前翻翻,最近几篇都是这样。我只写咱们想看的。
之前给咱们安利过,不论什么渠道都能调用AI能力,是拿着YOLOv8举的比如。我感谢上面那些朋友,他们能提出这些问题,说明是真的有需求,并且动手操作了。你写的文章有人看,这叫被知己欣赏。
因而,我打算详细拆解YOLO
的导出,并以tflite
格局的生成、导入,以及在移动端的详细代码运用为例,给上面的问题一个答案。一起,告知咱们,渠道仅仅一个环境,了解原理可融会贯通,透穿渠道。
一、模型转化
故事开端了。
我媳妇网购东西后,喜欢比价。她买俩吹风机,一个叫①号,一个叫②号,到货了也每天去渠道看看详情页。所以,我给她做了一款Android客户端,让她一扫描就进详情页,不用再到订单里去找。
这儿边选用了YOLOv8
的方针检测技能,先练习生成.pt
权重文件,再导出为.tflite
模型文件,终究放入Android项目中完成检测功能。下面来看看作用,检测速度是毫秒级别的。
标示和练习办法,参考之前文章《YOLOv8运用教程》,这是前置常识。今日的要点,从生成的.pt
权重文件开端。
先验证一下我的best_num.pt
文件,它会去辨认test
文件下图片里的①号、②号。
from ultralytics import YOLO
model = YOLO("best_num.pt")
model.predict(source="test", save=True, save_txt=True)
1.1 pt文件到tflite文件
好,没问题。下面转化为.tflite
格局的文件。
from ultralytics import YOLO
model = YOLO('best_num.pt')
model.export(format='tflite')
留意,我可点运转了啊!接下来你会看到好长时刻的转圈。由于好多库没有装置。
如读者反馈,这个进程或许会报错,报错的原因得看详细的过错信息。大多数过错和环境抵触有关,比如你原来有个2.0,此刻它自动去装置个3.2,或许就会发生过错。因而,我建议你整一个全新的虚拟环境去做。别怕费事,给每个项目配一个专属空间,会减少许多不必要的费事。
假如是装置好了,那仍是很快的。只需要16.5秒。
咱们从日志中能够看到,它阅历了一番曲折的转化。
1.2 格局转化的道路
起先,它是一个PyTorch
模型的.pt
文件,名称叫best_num.pt
。然后,它被转化为onnx
格局的best_num.onnx
。
ONNX
的全称是Open Neural Network Exchange
(开放式神经网络交流)。它是由微软、Facebook、IBM等科技公司在2017年一起建议的一种机制,能够完成不同深度学习框架(如PyTorch、TensorFlow、Caffe2等)模型之间的相互转化。因而,onnx
格局是必经之路。
随后,又用onnx2tf
东西,以命令行的办法将ONNX
模型转化为TensorFlow SavedModel
格局,并以best_num_saved_model
文件夹保存。这个格局是序列化TensorFlow
模型用的。
紧接着,发动TensorFlow Lite
的导出进程,将TensorFlow SavedModel
模型转化为best_num_float32.tflite
的TensorFlow Lite
格局。
呜呼呀!我这5.9MB
的.pt
文件,终究竟然被转为11.6MB
的.tflite
。这显然不行,在App里太大了!
别的,我看到saved_model
文件下有许多.tflite
文件。它们的姓名还带着数字:float32.tflite
、float16.tflite
……这是什么状况?
-
float32.tflite:全精度模型。参数都以32位浮点数(float32)存储。精度高,运转速度相对会慢。
-
float16.tflite:半精度模型。参数都以16位浮点数(float16)存储。巨细是全精度模型的一半,运转速度会快一些。
在Android终端设备上的推理,咱们需要的是更快,而非特别准确。由于假如要求准确,不考虑时刻,我上传到服务端去处理好不好。
说的很有道理,咱们能够经过量化来改善巨细和速度问题。只需要加一个参数model.export(……, int8=True)
,再运转一下。
ONNX: export success 2.2s, saved as 'best_num.onnx' (11.6 MB)
TensorFlow SavedModel: running 'onnx2tf -i "best_num.onnx" -o "best_num_saved_model" -nuo --verbosity info -oiqt -qt per-tensor'
TensorFlow SavedModel: export success 212.0s, saved as 'best_num_saved_model' (38.6 MB)
TensorFlow Lite: starting export with tensorflow 2.13.0...
TensorFlow Lite: export success 0.0s, saved as 'best_num_saved_model\best_num_int8.tflite' (3.0 MB)
Export complete (213.7s)
这次耗时长,用了213.7秒,终究导出模型的巨细为3.0MB
。我满足了,这个巨细放到app才合适。
新增的文件如下:
它给出一个best_num_int8.tflite
作为最优选择,这是什么状况?
1.3 模型的量化
它叫int8量化模型。此模型被量化时,它将浮点数值映射到8-bit
的整数范围,并保存了映射关系。当模型进行推理时,这些整数能够被从头解说为挨近原始的浮点数值。
量化技能,能在减小模型巨细和进步履行速度的一起,依然保持相对高的精度。
转化成功喽。
下面咱们就来拆解它,了解如何剖析,咱们给它传什么数据,以及它又会返给咱们怎样的成果!
二、模型文件剖析
上面的best_num_int8.tflite
模型是咱们自己练习并转化的。因而,咱们了解它的结构和收支参数。
现在换一个故事,有人给了一个xxx.tflite
,让你去调用。此刻你该如何做呢?其实,用代码就能够剖析出来许多有用的信息。
以下操作,用Python
和Android
都能够完成。鉴于Python
简练,所以先用它快速演示作用,后边咱们还会用Android
再做一遍。
假定best_num_int8.tflite
便是那个xxx.tflite
文件,咱们用代码来将它阅读一下。
2.1 Interpreter解说器
关于tflite
文件的解析,TensorFlow提供了一个Interpreter
类。
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path='xxx.tflite')
interpreter.allocate_tensors()
print(interpreter.get_tensor_details())
经过interpreter
的get_tensor_details
,能够获取整个网络结构的信息。
那么,你获取到这些结构信息,有什么用呢?
2.2 网络结构与层
诶,它界说了整个数据流通的格局和操作。这就相当于一套操作过程,讲述了在哪个过程会把什么做怎样的处理。咱们把模型比作馒头出产流水线机器,那榜首步是放入面粉,随后对面粉加水,和面,揉面,拉长,切割,终究产出馒头。
因而,一旦咱们了解了这套馒头机的流程。咱们就能清楚地知道,在机器入口该依照怎样的频率倒入多少数的面粉,然后在出口能收成什么形状、多少重量的馒头。
其实这与把图片传给模型,它告知你里边有什么物体类似,都是一个加工收拾的进程。
假如把模型做的简略一些,能够是下面这样:
尽管它不智能,但我信任这有助于你更好地理解结构。
除了用代码读取,许多网站也能够阅读模型文件的结构。比如这个网站 netron.app/ 。
这说明这些文件都是揭露可读的,并没有什么特殊加密。
弱水三千,只取一瓢。模型百层,我只关心输入、输出(如同面粉与馒头)。倘若读出它们,倒也不难,一行代码可成。
# 获取输入层
input_details = interpreter.get_input_details()
# 获取输出层
output_details = interpreter.get_output_details()
打印一下看看:
input_details:
[{'name': 'serving_default_images:0',
'index': 0,
'shape': array([ 1, 640, 640, 3]),
'shape_signature': array([ 1, 640, 640, 3]),
'dtype': numpy.float32,
'quantization': (0.0, 0),
'quantization_parameters': {'scales': array([], dtype=float32),
'zero_points': array([], dtype=int32),
'quantized_dimension': 0},
'sparsity_parameters': {}}]
output_details:
[{'name': 'PartitionedCall:0',
'index': 410,
'shape': array([ 1, 6, 8400]),
'shape_signature': array([ 1, 6, 8400]),
'dtype': numpy.float32,
'quantization': (0.0, 0),
'quantization_parameters': {'scales': array([], dtype=float32),
'zero_points': array([], dtype=int32),
'quantized_dimension': 0},
'sparsity_parameters': {}}]
这儿边,英语单词翻译过来便是解说。比如sparsity parameters是有关稀疏性参数的信息。这个比如中,值都为空,表明没有进行稀疏性优化。
2.3 形状和数据类型
今日咱们只重视2项:
-
shape:描述该层数据的形状。输入层的形状是[1, 640, 640, 3],这意味着该层接纳一个四维数组,榜首维是样本数(批次巨细,几张图片),后边三个维度别离表明一张图片的高度、宽度和颜色通道数。关于输入层数据形式的安排,一般不难。难的是对输出层的剖析和处理,后边一章咱们会要点拆解。
-
dtype:表明该层数据的数据类型。咱们看到两个层的dtype都是numpy.float32,也便是32位的浮点数。
好了,tflite
文件咱们剖析完了。我也知道,你略微有点懂,但还不至于全懂。看完后边,或许会有所改善。
三、用Python调用模型文件
下面,咱们先把输入层的数据安排好,然后调用文件试试看。
咱们这个模型的输入层的形状是[1, 640, 640, 3]
。
其实打开是这样:
图片列表 | 图片数据 |
---|---|
图片1 | 640*640个像素点,每个点用(R,G,B)3种色值表明 |
图片2 | 640*640个像素点,每个点用(R,G,B)3种色值表明 |
图片…… | 640*640个像素点,每个点用(R,G,B)3种色值表明 |
咱们先疏忽多个图片,只考虑一张图片的状况,这样能简略些。
来读这么一张图片。
3.1 图片的读取
咱们读取一下它的数值:
import cv2
image = cv2.imread('num.jpg')
print(image.shape)
cv2.imread
会把一张图片读取成矩阵数据。image.shape
是数据的形状,由3部分构成:图片高/矩阵行数,图片宽/矩阵列数,颜色通道数。
这张图片的shape
打印出来是(2162, 2883, 3)
。这表明图片尺寸为28832162,通道数为3。
咱们假如在它的外面套一层,加一个[]
,它就能够变成(1, 2162, 2883, 3)
。可是,现在咱们首要要把它的尺寸变为(640, 640, 3)
,由于输入层的格局是(1, 640, 640, 3)
。
到这儿,你或许有点质疑,这640
是怎样来的,谁规定的?我用960
行不行?
兄弟,不行的。你没法用小麦粒充任面粉往馒头机里放。
3.2 输入数据的预处理
真想要追根溯源,得说你当初用YOLOv8
练习时,只履行了一句model.train(data="num.yaml", epochs=80)
,并没有做其他设置。而未设置的,会走一个默许装备,这个装备在Lib\site-packages\ultralytics\cfg
下,姓名叫default.yaml
,里边就有一个imgsz
便是640
。一个值,表明640640是正方形,两个值能够设置宽与高。
看我文章,跟听书似的,能涨不少周边常识。
咱们要检测的图片,或许来自摄像头,或许来自用户上传,这个咱们不能约束。咱们要做的是将图片修改成640640
。
def pre_img(image):
height, width, _ = image.shape
# 等比例缩放
if height > width:
new_height = 640
new_width = int(640 * width / height)
else:
new_width = 640
new_height = int(640 * height / width)
image_resized = cv2.resize(image, (new_width, new_height))
# 创立一个640*640的白色布景图画
background = np.ones((640, 640, 3), dtype=np.uint8) * 255
# 将缩放后的图画粘贴到布景图画的中心位置
start_x = (640 - new_width) // 2
start_y = (640 - new_height) // 2
background[start_y:start_y+new_height, start_x:start_x+new_width] = image_resized
return background
resize_image = pre_img(image)
为了凑一幅640640
的图画,咱们选用的处理办法是:不论图片巨细,先让它顶着边扩大或许缩小到640640
的框里,然后布景设为白色。
此刻打印resize_image.shape
就看到了久违的(640, 640, 3)
。
留意,要开端调用模型了!
3.3 履行推理
调用很简略,代码加注释,保证你一看就会!
# 单张图片数据转为浮点型
input_image_f32 = resize_image.astype(dtype=np.float32)/ 255
# 外面包一层[]组成[1, 640, 640, 3]
input_data = np.expand_dims(input_image_f32, axis=0)
# 将input_data数据塞给输入层,从索引找到
input_index = input_details[0]['index'] # 输入的索引
interpreter.set_tensor(input_index, input_data)
# 跑一跑
interpreter.invoke()
# 将输出层的数据拿出来,从索引确定输出层
output_index = output_details[0]['index'] # 输出层的索引
detect_scores = interpreter.get_tensor(output_index)
print(detect_scores.shape, detect_scores)
终究来数据了,便是那个detect_scores
。
3.4 输出数据剖析
detect_scores.shape:
(1, 6, 8400)
detect_scores:
array([[[7.9783527e-03, 2.5762582e-02, 3.6012750e-02, ...,
8.1423753e-01, 8.3901447e-01, 9.1082019e-01],
...
1.8597868e-03, 1.8911671e-03, 1.9312450e-03]]], dtype=float32)
输出数据的形状是(1, 6, 8400)
。
看输入数据的形状,有经验的老CV师傅,尚且能猜到是图片数据。但现在看这个输出数据的形状,就真的需要你对YOLOv8
算法略微了解才行喽。
我给咱们解说一下,这些维度都代表什么。
解说之前,得再往回倒历史,YOLO
是You Only Look Once的简称。这种算法,只需要在图上扫一遍就够了。由于有的算法,需要对图片扫描多遍才干完成方针检测。
所以,YOLO
会设置一个最小网格作为根本单位,划分出十分多的大巨细小的框。然后检测这些框里边是否有方针,以及是某种物体类型的可行性。
我只练习标示了①②两类方针,所以分类数量是2。
下面就容易理解这个模型的输出啦。
3.4.1 输出层格局解析
维度数值 | 解说 |
---|---|
1 | 图片批次巨细,有几张图片。1代表一张 |
8400 | 一张图中划分出的8400个小区域 |
6 | 6个数代表 (中心点x, 中心点y, 宽度w, 高度h, 分类1的得分, 分类2的得分) |
咱们依然只重视一张图片,并且把数据处理一下。
# 降维 (1, 6, 8400) -> (6, 8400)
detect_score = np.squeeze(detect_scores)
# 转化 (6, 8400) -> (8400, 6)
output_data = np.transpose(detect_score)
打印一个数据看看print(output_data[0])
,输出为:
[0.01971355 0.01480704 0.04122782 0.03146162 0.00014795 0.0001379 ]
这是8400个框中第1个框的数据,6位数便是上面表格里对应的6个含义。
我想画一下这些框。可是能够幻想,画面必定就糊了。咱们这样,只画出类别概率大于某个数值的框。
# 前4个是矩形框 x, y, width, height
boxes = output_data[:, :4]
# 后2个是①的概率,②的概率
scores = output_data[:, 4:]
# 核算每个鸿沟框最高的得分
max_scores = np.max(scores, axis=1)
# 找到满足必定准确率的框【修改点在这儿】
keep = max_scores >= 0.6
# 得到符合条件的鸿沟框和得分
filtered_boxes, filtered_scores = boxes[keep], scores[keep]
rimge = resize_image.copy()
(height, width) = rimge.shape[:2]
for i, box in enumerate(filtered_boxes):
x,y,w,h = box
# x,y 是中心点的坐标,并且是占宽高的百分比
x1,y1 = (x-w/2)*width, (y-w/2)*height
x2,y2 = (x+w/2)*width, (y+w/2)*height
cv2.rectangle(rimge, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 1)
下图是我画出概率大于0.01和0.601的框,能够看出差异仍是挺明显的。
如同咱们已经从输出数据,找到了方针和位置。
等会儿……如同还有一个问题,框的重复状况比较严重。发生的原因便是前面说的8400个框。
3.4.2 NMS非极大值按捺
看下图,这3个区域,都是合格的网格,并且也都检测到了方针。你不能说它们谁有错!
这可……怎样办?
此刻,你再看开篇那张图,有读者说“是不是NMS不包含?”。我说他们真的有需求,并且用心看了是有原因的。NMS
全称是Non-Maximum Suppression,换成中国话便是“非极大值按捺”。
浅显来讲,便是扫除同类弱者,因而叫非极大值按捺。比如IT界要选出各个开发言语的代表人物,来了1000多口儿,300多Java,600多PHP。咱们一比照,啊,都是干Java的,都搞多并发,留一个最好的,剩下的多并发走人。那边有两个人一比照,你是Java,我是PHP,咱们是两类人,没抵触,都留下。终究,必定就剩下最具有代表性的人了。
咱们选用哪些个框的方案,也是相同的道理。技能完成上,就用到了IoU
。不是I LOVE U
啊,是IoU
。全称是intersection over union,便是……你甭管叫啥。我告知你怎样处理,上代码。
def iou(box1, box2):
# 核算交集区域的坐标
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
# 核算交集区域的面积
inter_area = max(0, x2 - x1 + 1) * max(0, y2 - y1 + 1)
# 核算两个鸿沟框的面积
box1_area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1)
box2_area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1)
# 核算IoU
iou = inter_area / float(box1_area + box2_area - inter_area)
return iou
这个iou
办法的输出,完成的是两个矩形框的交集除以并集。先求出box1
、box2
的面积,再求出box1
和box2
重合的面积。终究,用重合面积除以两个框合占的面积。
来一个图就理解啦。
其实便是重合度,0表明不重合,1表明完全相同,0.6表明堆叠60%。
那下面,咱们就去郁闷……不是,去按捺非极大值就行啦。
# 按捺非极大值办法
def non_max_suppression(boxes, scores, threshold=0.8):
# 创立一个用于存储保存的鸿沟框的列表
keep = []
# 对得分进行排序
order = scores.argsort()[::-1]
# 循环直到一切鸿沟框都被检查
while order.size > 0:
# 将当前最大得分的鸿沟框添加到keep中
i = order[0]
keep.append(i)
# 核算剩下鸿沟框与当前鸿沟框的IoU
ious = np.array([iou(boxes[i], boxes[j]) for j in order[1:]])
# 找到与当前鸿沟框IoU小于阈值的鸿沟框
inds = np.where(ious <= threshold)[0]
# 更新order,只保存那些与当前鸿沟框IoU小于阈值的鸿沟框
order = order[inds + 1]
return keep
# 核算每个鸿沟框的最高得分
max_scores = np.max(filtered_scores, axis=1)
# 进行处理
keep = non_max_suppression(filtered_boxes, max_scores)
# 终究留下的候选框
final_boxes = filtered_boxes[keep]
final_scores = filtered_scores[keep]
# 方针的索引
indexs = np.argmax(final_scores, axis=1)
上面代码,将这些高质量的候选框,先依照得分进行排序,然后拿最高分跟其他候选比照。但凡重合度高的,去掉,重合率低的,保存。这个操作就完成了一山不容二虎。
3.5 呈现终究成果
咱们把之前画框的代码,略微改动一下。
for i, box in enumerate(final_boxes):
……
color_v = (0, 0, 255) if indexs[i] == 0 else (255, 0, 0)
cv2.rectangle(rimge, (int(x1), int(y1)), (int(x2), int(y2)), color_v, 2)
加了一个判别,假如是第①个类别用红色,第②个类别用蓝色。
运转作用如下:
怎样样,咱们用.tflite
格局完成了方针检测。这与在PyTorch
下的.pt
文件是相同的作用。
那个读者问,是不是不包含nms
?兄弟,有许多成熟的类库能够一句话调用。可是,退一万步讲,就算咱用原生代码自己去写一套,也没有多少行代码。
所以我讲原理很重要,渠道仅仅一个媒介。
下面,咱们就前往Android的世界,再去完成这一套流程。
四、用Android调用模型文件
首要声明,存在比我下面讲的,还要简略的完成办法。这个我是知道的。
比如多导入以下两个包,能够很方便地处理关于模型加载,图画与数据转化,乃至NMS
的问题。那样,没几行代码。
implementation 'org.tensorflow:tensorflow-lite-support:0.3.0'
implementation 'org.tensorflow:tensorflow-lite-task-vision:0.3.0'
可是,我吹了牛了,我说原理能够不受渠道约束。因而,我只导入根本的tensorflow-lite
包,用来加载tflite
文件。其他全用Java代码来写(Kotlin也相同)。
implementation 'org.tensorflow:tensorflow-lite:2.5.0'
4.1 加载模型并推理
首要,build.gradle
导入上面最根本的tensorflow-lite
包。然后,将咱们的best_num_int8.tflite
文件,拷贝到assets
文件下。
我的文件结构如下所示:
其间,DetectTool.java
是我自己写的一个检测东西类,担任加载tflite
模型,处理图片的缩放,以及剖析模型输出层的数据。NonMaxSuppression.java
也是自己手敲的一个处理非极大值按捺的算法类。
首要,加载tflite文件。
import org.tensorflow.lite.Interpreter;
public class DetectTool {
// 从Assets下加载.tflite文件
private static MappedByteBuffer loadModelFile(Context context, String fileName) throws IOException {
AssetFileDescriptor fileDescriptor = context.getAssets().openFd(fileName);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
// 构建Interpreter,这是tflite文件的解说器
public static Interpreter getInterpreter(Context context){
Interpreter.Options options = new Interpreter.Options();
options.setNumThreads(4);
Interpreter interpreter = null;
try {
interpreter = new Interpreter(loadModelFile(context, "best_num_int8.tflite"), options);
} catch (IOException e) {
throw new RuntimeException("Error loading model file.", e);
}
return interpreter;
}
}
留意,履行这一步时,需要在build.gradle
中装备不要紧缩.tflite
文件(默许是紧缩的)。
android {
// 新增:不要紧缩tflite文件
aaptOptions {
noCompress "tflite"
}
此刻,你就能够在Activity
中运用Interpreter
了。
// 构建解说器
Interpreter interpreter = DetectTool.getInterpreter(this);
// 将要处理的Bitmap图画缩放为640640
Bitmap resize_bitmap = resizeBitmap(bitmap, 640);
// 转化为输入层(1, 640, 640, 3)结构的float数组
float[][][][] input_arr = bitmapToFloatArray(resize_bitmap);
// 构建一个空的输出结构
float[][][] outArray = new float[1][6][8400];
// 运转解说器,input_arr是输入,它会将成果写到outArray中
interpreter.run(input_arr, outArray);
你依然能够用interpreter
的各种get办法获取输入输出的层信息。可是,依据前面咱们已经了解了它的结构,因而现在能够直接构建对应的结构。
4.2 输入预处理详解
其间,resizeBitmap
办法与bitmapToFloatArray
办法是自己写的。
resizeBitmap
用于图片尺寸缩放。
public static Bitmap resizeBitmap(Bitmap source, int maxSize) {
int outWidth;
int outHeight;
int inWidth = source.getWidth();
int inHeight = source.getHeight();
if(inWidth > inHeight){
outWidth = maxSize;
outHeight = (inHeight * maxSize) / inWidth;
} else {
outHeight = maxSize;
outWidth = (inWidth * maxSize) / inHeight;
}
Bitmap resizedBitmap = Bitmap.createScaledBitmap(source, outWidth, outHeight, false);
Bitmap outputImage = Bitmap.createBitmap(maxSize, maxSize, Bitmap.Config.ARGB_8888);
Canvas canvas = new Canvas(outputImage);
canvas.drawColor(Color.WHITE);
int left = (maxSize - outWidth) / 2;
int top = (maxSize - outHeight) / 2;
canvas.drawBitmap(resizedBitmap, left, top, null);
return outputImage;
bitmapToFloatArray
是构建输入层的数据格局。
public static float[][][][] bitmapToFloatArray(Bitmap bitmap) {
int height = bitmap.getHeight();
int width = bitmap.getWidth();
// 初始化一个float数组
float[][][][] result = new float[1][height][width][3];
for (int i = 0; i < height; ++i) {
for (int j = 0; j < width; ++j) {
// 获取像素值
int pixel = bitmap.getPixel(j, i);
// 将RGB值别离并进行标准化(假定你需要将颜色值标准化到0-1之间)
result[0][i][j][0] = ((pixel >> 16) & 0xFF) / 255.0f;
result[0][i][j][1] = ((pixel >> 8) & 0xFF) / 255.0f;
result[0][i][j][2] = (pixel & 0xFF) / 255.0f;
}
}
return result;
}
Bitmap
是图片,能够是一张本地图片文件,也能够是从相机的预览回调传来的每一帧图画。
只需经过interpreter.run(input_arr, outArray)
后,outArray
中就有了成果数据,它的形状便是咱们熟悉的那个(1, 6, 8400)
。
用python
时,咱们全程是手写算法。在Java
中,相同能够做到。
4.3 输出数据的处理
// 取出(1, 6, 8400)中的(6, 8400)
float[][] matrix_2d = outArray[0];
// (6, 8400)变为(8400, 6)
float[][] outputMatrix = new float[8400][6];
for (int i = 0; i < 8400; i++) {
for (int j = 0; j < 6; j++) {
outputMatrix[i][j] = matrix_2d[j][i];
}
}
float threshold = 0.6f; // 类别准确率挑选
float non_max = 0.8f; // nms非极大值按捺
ArrayList<float[]> boxes = new ArrayList<>();
ArrayList<Float> maxScores = new ArrayList();
for (float[] detection : outputMatrix) {
// 6位数中的后两位是两类的相信度
float[] score = Arrays.copyOfRange(detection, 4, 6);
float maxValue = score[0];
float maxIndex = 0;
for(int i=1; i < score.length;i++){
if(score[i] > maxValue){ // 找出最大的一项
maxValue = score[i];
maxIndex = i;
}
}
if (maxValue >= threshold) { // 假如相信度超过60%则记录
detection[4] = maxIndex;
detection[5] = maxValue;
boxes.add(detection); // 挑选后的框
maxScores.add(maxValue); // 挑选后的准确率
}
}
这段完成和python
差异很大。由于原生Java
代码在处理矩阵上根本全赖循环。它不像python
能够一句话获取矩阵的横向平均值、竖向最大值。
因而,我将那6位数中的detection[4]
设置为最大值的分类索引,detection[5]
存储最大值的分值。
到这儿,咱们就获取到了分类概率大于60%的一切备选框。这时相同会出现框重复的状况。需要做一个NMS
。
public class NonMaxSuppression {
public static float iou(float[] box1, float[] box2) {
float x1 = Math.max(box1[0], box2[0]);
float y1 = Math.max(box1[1], box2[1]);
float x2 = Math.min(box1[2], box2[2]);
float y2 = Math.min(box1[3], box2[3]);
float interArea = Math.max(0, x2 - x1 + 1) * Math.max(0, y2 - y1 + 1);
float box1Area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1);
float box2Area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1);
return interArea / (box1Area + box2Area - interArea);
}
public static List<float[]> nonMaxSuppression(List<float[]> boxes, List<Float> scores, float threshold){
List<float[]> result = new ArrayList<>();
while (!boxes.isEmpty()) {
int bestScoreIdx = scores.indexOf(Collections.max(scores));
float[] bestBox = boxes.get(bestScoreIdx);
result.add(bestBox);
boxes.remove(bestScoreIdx);
scores.remove(bestScoreIdx);
List<float[]> newBoxes = new ArrayList<>();
List<Float> newScores = new ArrayList<>();
for (int i = 0; i < boxes.size(); i++) {
if (iou(bestBox, boxes.get(i)) < threshold) {
newBoxes.add(boxes.get(i));
newScores.add(scores.get(i));
}
}
boxes = newBoxes;
scores = newScores;
}
return result;
}
}
iou
的核算简直和python
的处理相同。nonMaxSuppression
则依据Java
语法特性,变化了一些。
可是,原理是不变。都是先依照分数排名,然后疏忽和高分重合度高的,收录重合率低的。
终究的result
是终究成果,它是一个列表,每个子项里边6个数,别离是:中心点x、中心点y、框的宽width、框的高height、属于哪一类class_index、相信概率值。
便是这样,Android也成功完成了。你在程序里调用就能够。
五、小结
5.1 源码分享
我已经将python代码、Java两个类,以及我的pt、tflite文件,还有测验图片,上传到Github上了。期望得到咱们的指导 github.com/hlwgy/yolo2… 。
不仅仅Android能够调用,其他官网支撑的渠道相同能够用。期望本文能处理你的一些困惑,哪怕仅仅是让你有些感触,我也是不白写啊。写的太多了,的编辑器都卡爆了。
本篇文章的代码有点多,能读到这儿的,都是好哥们。你能够持续提问,我会持续写文章答复。
我觉得AI
技能并不难,并且它距离现实生活也不远。期望咱们能一起去探究并应用它。
我是@TF男孩,一个从事人工智能的程序员。