夸克APP端智能:文档关键点检测实践与应用
作者:顺达
最近夸克端智能小组在做端上的实时文档检测,即输入一张RGB图像,得到文档的四个角的关键点的坐标。整个pipelines属于关键点检测算法,因此最近对相关领域的论文进行阅读和进行了实验尝试。
将关键点检测算法按照不同模块进行拆分,可以分成以下几个部分,每个部分都有相关的方法可以进行优化:
- 图片处理:包括数据光学增强,变换,resize,crop等操作,扩充图片的多样性;
- 编码:指的是在训练中,如何将坐标转换成所需要的label,用于监督模型的输出;
- 网络模型:指的是网络结构,可以有backbone/FPN/detection head等部分组成;
- 解码:指的是如何将模型推理的结果转换成所需要的坐标形式,如笛卡尔坐标系下的坐标。
Related Works
关键点检测中主要有两条技术方案:
- 类似人脸检测,模型输出的结果tensor通过fc层,直接得到一维的向量,通常是归一化后关键点坐标值;
- 类似人体姿态估计,模型输出的结果tensor通过argmax等方式,获取heatmap中相应大的坐标,最后将此坐标恢复至原图坐标。
近年来,基于heatmap来进行关键点检测的方案居多,其主要原因是基于heatmap的效果要好于使用全连接层进行回归的方案。所以,我们采用的方案也是基于heatmap的,下面是近几年的一些相关论文工作。
DSNT
[1] Nibali A , He Z , Morgan S , et al. Numerical Coordinate Regression with Convolutional Neural Networks[J]. 2018.
思路
目前,在模型输出的heatmap到数值坐标的转换中,有两种方式:
- 通过对heatmap中取argmax,得到相应最大的点,以此来转换成数值坐标。此种方式具有较好的空间泛化性,但是由于在训练中argmax是不可导的,通常使用heatmap来逼近编码的高斯热例图,这会导致损失函数与最终评价指标的不一致。其次,在推理阶段,只使用到最大响应的坐标点来计算数值坐标,而在训练阶段,所有坐标点都对损失有贡献。第三,通过heatmap转换成数值坐标,是会存在理论误差下限的;
- 通过在heatmap后接fc层,转换成数值坐标。此种方法让梯度从数值坐标回传到input中,但是结果严重依赖与数据分布(例如在训练集中,一个物体一直出现在坐标;而在测试集中,这个物体出现在右边,这样就会导致预测错误)。其次,通过fc转换,丢失了heatmap的空间信息。
针对上述的两种方案,作者兼容了这两种方案的优点(端到端优化和保持空间泛化性),提出一种可微分的方式来得到数值坐标。
具体步骤
- 模型的输出1KH*W个heatmaps,其中K表示关键点的数量;
- 将每个通道的heatmap归一化,让其值都为非负且和为1,从而得到 norm_heatmap 。这么做的目的是,使用归一化后的heatmap保证了预测的坐标位于heatmap的空间范围之内。同时, norm_heatmap 也可以理解成二维离散概率密度函数;
- 生成 X 和 Y 矩阵,\(X_{(i,j)} = \frac{2j-(w+1)}{w}\), \(Y_{(i,j)} = \frac{2i-(h+1)}{h}\),分别表示x轴的索引和y轴的索引。可以理解成将图片的左上角缩放到 (-1,-1) 和右下角缩放到 (1,1) ;
- 将X 和 Y 矩阵分别与 norm_heatmap 点乘,从而得到最终的数值坐标。这么做的原因是, norm_heatmap 表示概率密度函数, X 矩阵表示索引,两者点成表示预测x的均值。通过均值来表示最终的预测的坐标,这样的好处是,a)可微分;b)理论误差下限小。
损失函数loss
dsnt模块的损失函数由Euclidean loss 和JS正则约束组成。前者用于回归坐标,后者用于约束生成的热力图更加接近高斯分布。
\(L_{euc}(u,p) = ||p-u||_2 \)
\(L_D(Z,p) =JS(p(c)||N(p,I)))\)
优点
- 整套模型是端到端训练的,损失函数与测试指标能对应;
- 理论误差下限很小;
- 引入 X 矩阵和 Y 矩阵,可以理解成引入先验,让模型的学习难度降低;
- 低分辨率的效果依然不错。
缺点
在实验中,发现当关键点位于图片边缘时,预测结果不好。
DARK
[1] Zhang F , Zhu X , Dai H , et al. Distribution-Aware Coordinate Representation for Human Pose Estimation[J]. 2019.
思路
作者发现将 heatmaps 解码结果,对生成最终数值坐标存在较大影响。因此研究了标准的坐标解码方式的不足,提出一种分布已知的解码方式和编码方式,来提高模型的最终效果。
标准的坐标解码过程是,获得模型的 heatmaps 后,通过argmax找到最大响应点 m 和第二大响应点 s ,以此来计算最终的响应点 p :
\(p=m+0.25\frac{s-m}{\left | s-m \right |_2} \)
这个公式意味着最大响应点向第二大响应点偏移0.25个像素,这么做的目的是补偿量化误差。然后把响应点映射回原图:
\( \hat{p} = \lambda p \)
这也说明, heatmap 中最大响应点并不是与原图的关键点精确对应,只是大概位置。
基于上面的痛点,作者基于分布已知的前提(高斯分布),提出新的解码方式,解决如何从 heatmap 中得到精确的位置,最小化量化误差。同时,提出了配套的编码方式
具体步骤
解码
假设输出的 heatmap 符合高斯分布,那么 heatmap 就可以用下面函数表示
其中\(\mu\)表示关键点映射到 heatmap 的位置。我们需要求\(\mu\)的位置,因此将函数g转换成最大似然函数
对\(P(\mu )\)进行泰勒展开
其中,m表示在热力图中最大响应的位置。而\(\mu\)在热力图对应的是极点,存在以下性质
结合上述公式,可以得到
因此,为了得到 heatmap 中\(\mu\)的位置,可以通过 heatmap 的一阶导数与二阶导数求得。这步的作用是通过数学的方法来说明该移动距离。
前面提及了假设输出的 heatmap 符合高斯分布,实际情况是不符合的,实际可能是多峰,因此需要对 heatmap 进行调制,让其尽量满足这个前提。具体做法是用高斯核函数来平滑 heatmap ,同时为了保证幅值一致,要进行归一化。
\({h}'=K\circledast h\)
\({h}'=\frac{{h}'-min({h}')}{max({h}')-min({h}')}*max(h) \)
综上所述,步骤是:
- 对 heatmap 使用高斯核来调制,并且缩放;
- 求一阶导数和二阶导数,来得到\(\mu\);
- 将\(\mu\)映射回原图。
编码
编码指的是将关键点映射到 heatmap 上,并且生成高斯分布。
之前工作的做法是现对坐标进行下采样,然后将点进行量化(floor,ceil,round),最后使用量化后的坐标生成高斯函数。
因为量化是不可导的,存在量化误差,因此,作者提出不进行量化,使用float来生成高斯函数,这样就能生成无偏 heatmap 。
UDP
[1] Huang J , Zhu Z , Guo F , et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation[J]. 2019.
思路
作者从数据处理和坐标表示下手,以此来提高性能。作者发现,目前的数据处理方式是存在偏差的,特别是flip时,会与原数据不对齐;其次坐标表征也存在统计误差。这两个问题共同导致结果存在偏差。因此提出了一种数据处理方式unbiased data processing,解决图像转换和坐标转换带来的误差。
具体步骤
Unbiased Coordinate System Transformation
在测试中,通常使用翻转后的\({k}'_{o,flip}\)与原始的\({k}'_o\)进行叠加,来得到最终的预测结果。但是\({k}'_o\)与\({\hat{k}}'_o\)并不一致,存在偏差。可以看到翻转后的 heatmap 不与原来的 heatmap 对齐,会产生误差,与分辨率有关。
因此作者建议使用 unit length 来代替图片长度:\(w=w^p-1\)。这样翻转后的 heatmap 就对齐了。
Unbiased Keypoint Format Transformation
无偏的关键点转换方式应该是可逆的,即\(k=Decoding(Enoding(k))\)。因此,作者提出了两种方式:
- Combined classification and regression format
借鉴了目标检测中anchor的方式,假设需要预测的关键点\(k=(m,n)\),则将其转换成如下。其中C表示关键点的位置范围,X和Y表示需要预测的offset。最终解码就是在热力图C上取到argmax,然后对X与Y的热力图上拿到对应位置的offset,最后进行相加得到数值坐标。
- Classification format
与DARK方式一致,即使用泰勒展开来逼近真实位置。
AID
[1] Huang J , Zhu Z , Huang G , et al. AID: Pushing the Performance Boundary of Human Pose Estimation with Information Dropping Augmentation[J]. 2020.
贡献点
对于关键点检测,外观信息与约束信息同样重要。而以往的工作通常是过拟合外观信息,而忽略了约束信息。因此,本文希望通过information drop,可理解成掩膜,来强调约束信息。约束信息有利于在该关键点被遮挡时,预测出其准确位置。
而以往工作没有使用到information drop的原因是,使用该数据增强手段后指标下降。作者就通过实验,发现information drop是有助于提高模型精度的,但需要修改响应的训练策略:
- 加倍训练次数;
- 先使用没有mask的来训练,得到一个比较好的模型后,再把mask手段加入继续训练。
RSN
[1] Cai Y , Wang Z , Luo Z , et al. Learning Delicate Local Representations for Multi-Person Pose Estimation[J]. 2020.
贡献点
本文是2019年coco关键点检测冠军的方案。其本文的主要思想是,最大程度聚合具有相同空间尺寸的特征,以此来获得丰富的局部信息,局部信息有利于产生更加准确的位置。因此提出了RSN网络,如下图一所示。从图来看,即融合不同感受野特征。RSN的输出包含了low-level准确的空间信息与high-level语义信息,空间信息有助于定位,语义信息有助于分类。但是这两类信息给最终预测带来的影响权重是不一致的,需要使用到PRM模块来平衡,RPM模块本质就是一个通道注意力和空间注意力模块。
Lite-HRNet
[1] Yu C , Xiao B , Gao C , et al. Lite-HRNet: A Lightweight High-Resolution Network[J]. 2021.
贡献点
本文出了一个高效的高分辨率网络,是HRNet的轻量化版本,通过将ShuffleNet中的shuffle block引入到HRNet中。同时发现shuffleNet中大量使用了pointwise convolution(11卷积),是计算瓶颈,因此引入contional channel weight来取代shuffle block中的11卷积。网络的整体结构如下图所示。在模型中一致保留高分辨率特征,并不断融合high-level特征。
在前面提及到的contional channel weight如下所示。左边是ShuffleNet中的shuffle block,右图是contional channel weight。可以看到,采用新模块来取代了1*1的卷积,实现跨stage信息交流与局部信息交流。其具体做法包含了Cross-resolution weight computation和Spatial weight computation。这两个模块的本质是注意力机制。
实验优化结果
模型结构
本次模型借鉴了CenterNet/RetiaFace/DBFace中的相关工作。本次的使用了dsnt的方案。主要原因是:需要运行在端上,实时性是首要考虑因素。dsnt在低分辨率的优势明显。
MobileNet v3使用small版本,FPN中使用 Nearest Upsample + conv + bn + Relu 来进行上采样。在训练时使用了 keypoints , mask 和 center 分支;而预测时,只使用到了 keypoints 分支。
优化策略
在本次实验中,使用到了以下几种优化策略:
- 使用 mask 与 center 分支来辅助学习。其中mask表示文档的掩膜,center表示文档的中心点;
- 使用deep Supervise。使用4倍下采样特征图与8倍下采样特征图来进行训练,使用相同的loss函数来监督这两层;
- dsnt中对边缘点的效果不佳,因此,对图片进行padding,让点不再位于图片边缘;
- 数据增强策略,除了常规的光学扰动增强外,还对图片进行random crop、random erase和random flip等操作;
- 进行loss函数尝试工作,对于关键点分支的loss,尝试过 euclidean loss , l1 loss , l2 loss 和 smoothl1 loss ,最终 smoothl1 loss 的效果最佳。
评价指标
- MSE
用于在训练中评价验证集的均方误差。
\(mse = \frac{\sum |d_i - \hat{d_i}|_2^2}{N} \)
- oks-mAP
oks用于评估预测与真实关键点之间的相似度,mAP的评估方式类似coco[0.5:0.05:0.95]的评价方式,这里取[0.99:0.001:0.999]。其中,oks进行一定变换,\(d_{p,i}\)表示点的欧式距离,\(S_p\)表示该四边形的面积。
\(oks_{p,i} =e^{-\frac{d_{p,i}^2}{2S_p}} \)
- 耗时
耗时指的是在红米8上,用MNN推理框架跑模型的平均时间。
实验结果
先构建一个baseline,baseline的模型为 moblieNet v3 + fpn + ssh module + keypoints 分支 + dsnt ,其中,都没有使用上述优化策略,使用4倍下采样特征图作为输出。
在v2版本中替换不同的loss函数。
此外,还尝试过其他无效的tricks:
- 辅助任务有利于提高模型的指标,因此还加入了edge的分支来辅助学习。实验下来,加入该分支反而损害模型的指标。可能原因是edge是利用gt关键点来生成的,可能某些edge并不是对应文档真正的边缘;
- 现阶段是预测文档的4个角,因此在增加4个点来进行预测,分别是4条边的中心点,所以模型一共预测8个关键点。实验结果显示,指标也下降了。
Demo 演示
总结
综上所述,在端上的文档关键点检测领域中,目前尝试下来,是基于heatmap+dsnt的方案较优,oks-mAP的指标有提升空间。但是,对比使用fc层进行回归坐标的方式,基于heatmap的方案中存在一个不足是:无法根据文档的约束信息,来预测图片外的关键点坐标。此方案的不足,会导致文档内容的缺失,摆正效果不佳的情况。因此,后续需要弥补此不足。
如何高效开发端智能算法?MNN 工作台 Python 调试详解
关注我们,每周 3 篇移动技术实践&干货给你思考!
低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。
持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。
转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。
- 上一篇
面试官:说下你对方法区演变过程和内部结构的理解
之前我们已经了解过“运行时数据区”的程序计数器、虚拟机栈、本地方法栈和堆空间,今天我们就来了解一下最后一个模块——方法区。 简介 创建对象时内存分配简图 《Java虚拟机规范》中明确说明:“尽管所有的方法区在逻辑上属于堆的一部分,但一些简单的实现可能不会选择去进行垃圾收集或者进行压缩。” 虽然 Java 虚拟机规范把方法区描述为堆的一个逻辑部分,但是它却有一个别名叫做 Non-Heap(非堆),目的应该是与 Java 堆区分开来。所以,方法区可以看作是一块独立于 Java 堆的内存空间。 方法区与 Java 堆一样,是各个线程共享的内存区域。方法区在 JVM 启动时就会被创建,并且它的实际的物理内存空间是可以不连续的,关闭 JVM 就会释放这个区域的内存。 永久代、元空间 《java虚拟机规范》对如何实现方法区,不做统一要求。例如:BEA JRockit/IBM J9 中不存在永久代的概念。而对于 HotSpot 来说,在 jdk7 及以前,习惯上把方法区的实现称为永久代,而从 jdk8 开始,使用元空间取代了永久代。 方法区是 Java 虚拟机规范中的概念,而永久代和元空间是 Hot...
- 下一篇
面试官:您能说说序列化和反序列化吗?是怎么实现的?什么场景下需要它?
序列化和反序列化是Java中最基础的知识点,也是很容易被大家遗忘的,虽然天天使用它,但并不一定都能清楚的说明白。 我相信很多小伙伴们掌握的也就几句概念、关键字(Serializable)而已,如果深究问一下序列化和反序列化是如何实现、使用场景等,就可能不知所措了。 在每次我作为面试官,考察Java基础时,通常都会问到序列化、反序列化的知识点,用以衡量其Java基础如何。 当被问及Java序列化是什么? 反序列化是什么? 什么场景下会用到? 如果不用它,会出现什么问题等,一般大家回答也就是几句简单的概念而已,有的工作好几年的应聘者甚至连概念都说不清楚,一脸闷逼。 本文就序列化和反序列化展开深入的探讨,当被别人问及时,不至于一脸闷逼、尴尬,或许会为你以后的求职面试中增加一点点筹码。 一、基本概念 1、什么是序列化和反序列化 序列化是指 将Java对象转换为字节序列的过程 ,而反序列化则是 将字节序列转换为Java对象的过程 。 Java对象序列化是将实现了 Serializable 接口的对象转换成一个字节序列,能够通过网络传输、文件存储等方式传输 ,传输过程中却...
相关文章
文章评论
共有0条评论来说两句吧...
文章二维码
点击排行
推荐阅读
最新文章
- Springboot2将连接池hikari替换为druid,体验最强大的数据库连接池
- MySQL8.0.19开启GTID主从同步CentOS8
- CentOS6,7,8上安装Nginx,支持https2.0的开启
- SpringBoot2整合Thymeleaf,官方推荐html解决方案
- CentOS关闭SELinux安全模块
- CentOS7设置SWAP分区,小内存服务器的救世主
- Docker安装Oracle12C,快速搭建Oracle学习环境
- Docker快速安装Oracle11G,搭建oracle11g学习环境
- CentOS7编译安装Gcc9.2.0,解决mysql等软件编译问题
- SpringBoot2整合MyBatis,连接MySql数据库做增删改查操作