在数据增强、蒸馏剪枝下ERNIE3.0模型性能提高

项目链接: aistudio.baidu.com/aistudio/pr…

以CBLUE数据会集医疗搜索检索词目的分类为例:

本项目首要讲解了数据增强和数据蒸馏的计划,并在后面章节进行效果展现,成果预览:

模型 ACC Precision Recall F1 average_of_acc_and_f1
ERNIE 3.0 Base 0.80255 0.9317147 0.908284 0.919850 0.86120
ERNIE 3.0 Base+数据增强 0.7979539 0.901004 0.92899 0.91478 0.8563
ERNIE 3.0 Base+取舍保存比0.5 0.79846 0.951257 0.89497 0.92225 0.8603
ERNIE 3.0 Base +取舍保存比2/3 0.8092071 0.9415384 0.905325 0.923076 0.86614

gensim装置最新版本:pip install gensim

tqdm装置:pip install tqdm

LAC装置最新版本:pip install lac


Gensim库介绍

Gensim是在做自然语言处理时较为常常用到的一个东西库,首要用来以无监督的方式从原始的非结构化文本傍边来学习到文本躲藏层的主题向量表达。

首要包括TF-IDF,LSA,LDA,word2vec,doc2vec等多种模型。

Tqdm

是一个快速,可扩展的Python进展条,能够在 Python 长循环中增加一个进展提示信息,用户只需求封装任意的迭代器 tqdm(iterator)。目的为了程序显现的美观

中文词法剖析-LAC

LAC是一个联合的词法剖析模型,全体性地完成中文分词、词性标示、专名辨认使命。LAC既能够以为是Lexical Analysis of Chinese的首字母缩写,也能够以为是LAC Analyzes Chinese的递归缩写。

LAC根据一个堆叠的双向GRU结构,在长文本上精确复刻了百度AI敞开平台上的词法剖析算法。效果方面,分词、词性、专名辨认的全体精确率95.5%;独自评价专名辨认使命,F值87.1%(精确90.3,召回85.4%),总体略优于敞开平台版本。在效果优化的基础上,LAC的模型简洁高效,内存开销不到100M,而速度则比百度AI敞开平台提高了57%

LAC链接:www.paddlepaddle.org.cn/modelbasede…

!pip install –upgrade paddlenlp !pip install gensim !pip install tqdm !pip install lac

2.数据增强计划介绍

数据增强东西供给4种增强战略:遮盖、删去、同词性词替换、词向量近义词替换

在数据增强、蒸馏剪枝下ERNIE3.0分类模型性能提升

!unzip ERNIE-.zip -d ./ERNIE #增加ERNIE东西包

假如程序报错:
能够发现提示有一个.ipynb_checkpoints的文件。但当我去对应的文件夹找时底子看不到这个文件,所以猜想是一个躲藏文件。所以经过终端进入对应的目录:输入cd coco进入对应目录,输入ls -a显现所有文件。然后输入rm -rf .ipynb_checkpoints删去该文件。再次输入ls -a查看文件是否被删去。

下载词表,词表有1.7G会花点时刻。下面以情感剖析数据样例展现demo,看看数据增强的效果。

!wget -q --no-check-certificate http://bj.bcebos.com/wenxin-models/vec2.txt

python data_aug.py “输入文件夹的目录” “输出文件夹的目录”

  • data_aug.py脚本传参阐明
shell输入:
    python data_aug.py -h
shell输出:
    usage: data_aug.py [-h] [-n AUG_TIMES] [-c COLUMN_NUMBER] [-u UNK]
                       [-t TRUNCATE] [-r POS_REPLACE] [-w W2V_REPLACE]
                       [-e ERNIE_REPLACE] [--unk_token UNK_TOKEN]
                       input output
    main
    positional arguments:
      input                                                #原始待增强数据文件地点文件夹,带label的,一个或多个文本列
      output                                               #输出文件途径
    optional arguments:
      -h, --help            show this help message and exit
      -n AUG_TIMES, --aug_times AUG_TIMES                  #数据集数目放大n倍,output行数为inputn+1-c COLUMN_NUMBER, --column_number COLUMN_NUMBER      #明文文件中所要增强列的列序号,多列用逗号切割,如:1,2
      -u UNK, --unk UNK                                    #unk 增强战略的概率
      -t TRUNCATE, --truncate TRUNCATE                     #truncate 增强战略的概率
      -r POS_REPLACE, --pos_replace POS_REPLACE            #pos_replace 增强战略的概率
      -w W2V_REPLACE, --w2v_replace W2V_REPLACE            #w2v_replace 增强战略的概率
      --unk_token UNK_TOKEN                    

分类问题中:引荐运用前三种即可,w2v词向量近义词替换能够不必,花费时刻太长。

!python data_aug.py --unk 0.25 --truncate 0.25 --pos 0.5 --w2v 0 ./train ./output
demo成果展现:
机器 背面 好像 被 撕 了 张 什么 标签 , 残 胶 还在 。 可是 又 看 不 出 是 什么 标签 不见 了 , 该 有 的 都 在 , 怪	0
机器 背面 好像 被 撕 了 张 什么 标签 , 胶 还在 。 可是 又 看 不 出 是 什么 标签 不见 了 , 该 有 的 都 在 , 怪	0
机器 背面 了 张 什么 标签 , 残 胶 还在 。 可是 又 看 不 出 是 什么 标签  了 , 该在 , 怪	0
呵呵 , 尽管 表皮 看上去 不错 很 精美 , 可是 我 仍是 能 看得出来 是 盗 的 。 可是 里边 的 内容 真 的 不错 , 我 妈 爱 看 , 我自己 也 学 着 找 一些 穴道 。	0
呵呵 , 尽管 表皮 看上去 不错 很 精美 , 可是 我 仍是 能 看得出来 是 盗 的 。 可是 里边 的 内容 真 的 不错 , 我妈 爱 看 , 我自己 也 学 着 找 一些 穴道 	0
呵呵 , 尽管 表皮 看上去 不错 很 精美 , 可是 我 还 能 看得出来 是 盗。 可是 里边 的 内容 真 的 不错 , 我 妈 爱 看 ,学 着 找 	0
尽管 表皮 看上去 不错 很 精美 , 可是 我 仍是 能 看得出来 是 盗 的 。 可是 里边 的 内容 真 的 不错 , 我 妈 爱 看 , 我自己 也 学 着 找 一些 穴道 。	0
 表皮 看上去 不错 很 精美 , 可是 我 仍是 能 看得出来 是 盗 的 。 可是 里边 的 内容 真 的 不错 , 我 妈 爱 看 , 我自己 也 学 着 找 一些 穴道 。	0
地舆 方位 佳 , 在 市中心 。 酒店 服务 好 、 早餐 种类 丰厚 。 我 住 的 商务 数码 房 电脑 宽带 速度 满足 , 房间 还算 洁净 , 离 湖南路小吃街 近 。	1
地舆 方位 佳 , 在 市中心 。 酒店 服务 好 、 早餐 种类 丰厚 。 我 住 的 商务 数码 房 电脑 宽带 速度 满足 , 房间 还算 洁净 , 离 湖南路小吃街 近。。	1
地舆 方位 佳 , 在 市中心 。 酒店 服务 好 、 早餐 种类 丰厚 。 我 住 的 商务 数码 房 电脑 宽带 速度 满足 , 机器 还算 洁净 , 离 湖南路小吃街 近 。	1
地舆 方位 佳 , 在 市中心 。 酒店 服务 好 、 早餐 种类 丰厚 。 我 住 的 商务 数码 房 电脑 宽带 速度 满足 , 房间 还算 洁净 , 离 湖南路小吃街 近 。	1
地舆 方位 佳 , 在 市中心 。 酒店 服务 好 、 早餐 种类 丰厚 。 我 住 的 商务 数码 房 电脑 宽
我 看 是 书 的 还 能够 , 可是 我 订 的 书 迟迟 还 到 能 半个月 , 都 没有 收到 打电话 也 没

2.0 补充nlpcda一键中文数据增强东西(NLP Chinese Data Augmentation )

一键中文数据增强东西,支撑:

1.随机实体替换 2.近义词 3.近义近音字替换 4.随机字删去(内部细节:数字时刻日期片段,内容不会删) 5.NER类 BIO 数据增强 6.随机置换邻近的字:研表究明,汉字序顺并不定一影响文字的阅览理解<<是乱序的 7.中文等价字替换(1 一 壹 ①,2 二 贰 ②) 8.翻译互转完成的增强 9.运用simbert做生成式相似句生成

参阅链接: 一键中文数据增强包 ; NLP数据增强、bert数据增强、EDA:pip install nlpcda nlpcda一键中文数据增强东西

3.数据蒸馏技能

ERNIE数据蒸馏三步

Step 1. 运用ERNIE模型对输入标示数据对进行fine-tune,得到Teacher Model

Step 2. 运用ERNIE Service对以下无监督数据进行猜测:

  • 用户供给的大规模无标示数据,需与标示数据同源
  • 对标示数据进行数据增强,具体增强战略
  • 对无标示数据和数据增强数据进行必定份额混合

Step 3. 运用步骤2的数据练习出Student Model

数据增强

目前选用三种数据增强战略战略,关于不必的使命能够特定的份额混合。三种数据增强战略包括:

增加噪声:对原始样本中的词,以必定的概率(如0.1)替换为”UNK”标签

同词性词替换:对原始样本中的所有词,以必定的概率(如0.1)替换为本数据集钟随机一个同词性的词

N-sampling:从原始样本中,随机选取方位截取长度为m的片段作为新的样本,其中片段的长度m为0到原始样本长度之间的随机值

在数据增强、蒸馏剪枝下ERNIE3.0分类模型性能提升
模型取舍,根据 PaddleNLP 的 Trainer API 发布供给了模型裁剪 API。裁剪 API 支撑用户对 ERNIE 等Transformers 类下游使命微调模型进行裁剪。

具体效果在下一节展现,先装置好paddleslim库

4.根据ERNIR3.0文本模型微调

加载已有数据集:CBLUE数据会集医疗搜索检索词目的分类(练习)

数据集定义: 以揭露数据集CBLUE数据会集医疗搜索检索词目的分类(KUAKE-QIC)使命为示例,在练习集上进行模型微调,并在开发集上运用精确率Accuracy评价模型体现。

数据集默许为:默许为”cblue”。

save_dir:保存练习模型的目录;默许保存在当时目录checkpoint文件夹下。

dataset:练习数据集;默许为”cblue”。

dataset_dir:本地数据集途径,数据集途径中应包含train.txt,dev.txt和label.txt文件;默许为None。

task_name:练习数据集;默许为”KUAKE-QIC”。

max_seq_length:ERNIE模型运用的最大序列长度,最大不能超过512, 若呈现显存缺乏,请适当调低这一参数;默许为128。

model_name:挑选预练习模型;默许为”ernie-3.0-base-zh”。

device: 选用什么设备进行练习,可选cpu、gpu、xpu、npu。如运用gpu练习,可运用参数gpus指定GPU卡号。

batch_size:批处理巨细,请结合显存情况进行调整,若呈现显存缺乏,请适当调低这一参数;默许为32。

learning_rate:Fine-tune的最大学习率;默许为6e-5。

weight_decay:控制正则项力度的参数,用于防止过拟合,默许为0.01。

early_stop:挑选是否运用早停法(EarlyStopping);默许为False。

early_stop_nums:在设定的早停练习次序内,模型在开发集上体现不再上升,练习停止;默许为4。 epochs: 练习次序,默许为100。

warmup:是否运用学习率warmup战略;默许为False。

warmup_proportion:学习率warmup战略的份额数,假如设为0.1,则学习率会在前10%steps数从0慢慢增长到learning_rate, 而后再缓慢衰减;默许为0.1。

logging_steps: 日志打印的间隔steps数,默许5。

init_from_ckpt: 模型初始checkpoint参数地址,默许None。

seed:随机种子,默许为3。

#修改后的练习文件train_new2.py ,首要运用了paddlenlp.metrics.glue的AccuracyAndF1:精确率及F1-score,可用于GLUE中的MRPC 和QQP使命
#不过吐槽一下:    return (acc,precision,recall,f1,(acc + f1) / 2,) 最终一个目标竟然是加权均匀.....
!python train_new2.py --warmup --early_stop --epochs 10 --save_dir "./checkpoint2" --batch_size 16 --model_name ernie-3.0-base-zh

练习成果部分展现:

[2022-08-16 19:58:36,834] [    INFO] - global step 1280, epoch: 3, batch: 412, loss: 0.23292, acc: 0.87106, speed: 16.54 step/s
[2022-08-16 19:58:37,392] [    INFO] - global step 1290, epoch: 3, batch: 422, loss: 0.22339, acc: 0.87130, speed: 17.94 step/s
[2022-08-16 19:58:37,960] [    INFO] - global step 1300, epoch: 3, batch: 432, loss: 0.22791, acc: 0.87182, speed: 17.68 step/s
(acc, precision, recall, f1, average_of_acc_and_f1):(0.8025575447570332, 0.9317147192716236, 0.908284023668639, 0.9198501872659175, 0.8612038660114754)

[2022-08-16 20:01:36,060] [ INFO] – Early stop! [2022-08-16 20:01:36,060] [ INFO] – Save best accuracy text classification model in ./checkpoint2

4.1 加载自定义数据集(并经过数据增强练习)

从本地文件创立数据集

运用本地数据集来练习咱们的文本分类模型,本项目支撑运用固定格式本地数据集文件进行练习 假如需求对本地数据集进行数据标示,能够参阅文本分类使命doccano数据标示运用指南进行文本分类数据标示。[这个放到下个项目讲解]

本项目将以CBLUE数据会集医疗搜索检索词目的分类(KUAKE-QIC)使命为例进行介绍怎么加载本地固定格式数据集进行练习:

本地数据集目录结构如下:

data/
├── train.txt # 练习数据集文件
├── dev.txt # 开发数据集文件
├── label.txt # 分类标签文件
└── data.txt # 可选,待猜测数据文件

部分成果展现

[2022-08-16 23:43:18,093] [    INFO] - global step 2400, epoch: 2, batch: 234, loss: 0.60859, acc: 0.84437, speed: 19.27 step/s
(acc, precision, recall, f1, average_of_acc_and_f1):(0.7979539641943734, 0.9010043041606887, 0.9289940828402367, 0.9147851420247632, 0.8563695531095683)
[2022-08-16 23:43:24,522] [    INFO] - Save best F1 text classification model in ./checkpoint3
[2022-08-16 23:43:24,523] [    INFO] - best F1 performence has been updated: 0.91450 --> 0.91479

4.2 数据蒸馏

!unset CUDA_VISIBLE_DEVICES
!python -m paddle.distributed.launch --gpus "0" prune.py \
    --device "gpu" \
    --output_dir "./prune" \
    --per_device_train_batch_size 32 \
    --per_device_eval_batch_size 32 \
    --learning_rate 3e-5 \
    --num_train_epochs 5 \
    --logging_steps 10 \
    --save_steps 50 \
    --seed 3 \
    --dataset_dir "KUAKE_QIC" \
    --max_seq_length 128 \
    --params_dir "./checkpoint3" \
    --width_mult '0.5'

部分成果展现:

[2022-08-17 14:22:30,954] [    INFO] - width_mult: 0.5, eval loss: 0.63535, acc: 0.79847
(acc, precision, recall, f1, average_of_acc_and_f1):(0.7984654731457801, 0.9512578616352201, 0.8949704142011834, 0.9222560975609755, 0.8603607853533778)
[2022-08-17 14:22:35,870] [    INFO] - Save best F1 text classification model in ./prune/0.5
[2022-08-17 14:22:35,870] [    INFO] - best F1 performence has been updated: 0.92226 --> 0.92226
!unset CUDA_VISIBLE_DEVICES
!python -m paddle.distributed.launch --gpus "0" prune.py \
    --device "gpu" \
    --output_dir "./prune" \
    --per_device_train_batch_size 32 \
    --per_device_eval_batch_size 32 \
    --learning_rate 3e-5 \
    --num_train_epochs 5 \
    --logging_steps 10 \
    --save_steps 50 \
    --seed 3 \
    --dataset_dir "KUAKE_QIC" \
    --max_seq_length 128 \
    --params_dir "./checkpoint3" \
    --width_mult '2/3'
2022-08-17 14:53:45,544] [    INFO] - global step 3070, epoch: 2, batch: 904, loss: 0.709566, speed: 9.93 step/s
[2022-08-17 14:53:46,550] [    INFO] - global step 3080, epoch: 2, batch: 914, loss: 0.607238, speed: 9.94 step/s
[2022-08-17 14:53:47,558] [    INFO] - global step 3090, epoch: 2, batch: 924, loss: 0.718484, speed: 9.93 step/s
[2022-08-17 14:53:48,563] [    INFO] - global step 3100, epoch: 2, batch: 934, loss: 0.546288, speed: 9.95 step/s
[2022-08-17 14:53:50,206] [    INFO] - teacher model, eval loss: 0.66438, acc: 0.80358
[2022-08-17 14:53:50,207] [    INFO] - eval done total : 1.6434180736541748 s
[2022-08-17 14:53:53,568] [    INFO] - width_mult: 0.6666666666666666, eval loss: 0.60219, acc: 0.80921
(acc, precision, recall, f1, average_of_acc_and_f1):(0.8092071611253197, 0.9415384615384615, 0.9053254437869822, 0.923076923076923, 0.8661420421011213)
[2022-08-17 14:53:58,489] [    INFO] - Save best F1 text classification model in ./prune/0.6666666666666666
[2022-08-17 14:53:58,489] [    INFO] - best F1 performence has been updated: 0.92308 --> 0.92308

4.3 模型猜测

输入待猜测数据和数据标签对照列表,模型猜测数据对应的标签

运用默许数据进行猜测:

#也能够挑选运用本地数据文件data/data.txt进行猜测:
!python predict.py --params_path ./checkpoint3/ --dataset_dir ./KUAKE_QIC --device "cpu"
黑苦荞茶的成效与效果及食用方法 成效效果
交界痣会凸起吗 疾病表述
查看是否能怀孕挂什么科 就医主张
鱼油怎么吃咬破吃仍是直接咽下去 其他
幼儿挑食的生理原因是 病因剖析
!python predict.py \
    --device "cpu" \
    --dataset_dir ./KUAKE_QIC \
    --params_path "./prune/0.5" \

5.总结

本项目首要讲解了数据增强和数据蒸馏的计划,并在后面章节进行效果展现,现在进行汇总

模型 ACC Precision Recall F1 average_of_acc_and_f1
ERNIE 3.0 Base 0.80255 0.9317147 0.908284 0.919850 0.86120
ERNIE 3.0 Base+数据增强 0.7979539 0.901004 0.92899 0.91478 0.8563
ERNIE 3.0 Base+取舍保存比0.5 0.79846 0.951257 0.89497 0.92225 0.8603
ERNIE 3.0 Base +取舍保存比2/3 0.8092071 0.9415384 0.905325 0.923076 0.86614

剖析可得,

  • 首要数据增强后导致性能部分下降部分和预期的原因: 随机mask、删去会产生过多噪声样本影响成果,引荐只运用同义词替换,本次样本数据量满足,且ERNIE性能本就优越,数据增强对成果提高在较大样本集能够忽略。

  • 其次,能够看到经过数据蒸馏后,模型性能改变不大,甚至在取舍1/3之后,性能有小幅度提高

本次首要对分类模型参加数据增强、数据蒸馏,已经对性能目标进行细化,不只是ACC,个人比较关注F1情况,并作为保存模型依据。

展望: 后续将完善动态图和静态图转化部分,让蒸馏下来模型能够继续线上加载运用;其次将会考虑小样本学习在分类模型使用情况;最终将完成模型交融环节提高性能,并做可解释性剖析。

自己博客:blog.csdn.net/sinat_39620…