前言 本文评论了处理不平衡数据集和进步机器学习模型功能的各种技巧和战略,包括的一些技能包括重采样技能、代价灵敏学习、运用恰当的功能目标、集成办法和其他战略。
作者:Emine Bozku
来历:DeepHub IMBA
欢迎重视大众号CV技能攻略,专心于核算机视觉的技能总结、最新技能盯梢、经典论文解读、CV招聘信息。
不平衡数据集是指一个类中的示例数量与另一类中的示例数量明显不同的情况。例如在一个二元分类问题中,一个类只占总样本的一小部分,这被称为不平衡数据集。类不平衡会在构建机器学习模型时导致很多问题。
不平衡数据集的首要问题之一是模型或许会偏向大都类,从而导致猜测少量类的功能欠安。这是由于模型经过操练以最小化过错率,而且当大都类被过度代表时,模型倾向于更频繁地猜测大都类。这会导致更高的准确率得分,但少量类别得分较低。
另一个问题是,当模型露出于新的、看不见的数据时,它或许无法很好地泛化。这是由于该模型是在倾斜的数据集上操练的,或许无法处理测试数据中的不平衡。
在本文中,咱们将评论处理不平衡数据集和进步机器学习模型功能的各种技巧和战略。将包括的一些技能包括重采样技能、代价灵敏学习、运用恰当的功能目标、集成办法和其他战略。经过这些技巧,能够为不平衡的数据集构建有用的模型。
处理不平衡数据集的技巧
重采样技能是处理不平衡数据集的最流行办法之一。这些技能触及减少大都类中的示例数量或增加少量类中的示例数量。
欠采样能够从大都类中随机删除示例以减小其巨细并平衡数据集。这种技能简单易行,但会导致信息丢失,由于它会丢弃一些大都类示例。
过采样与欠采样相反,过采样随机仿制少量类中的示例以增加其巨细。这种技能或许会导致过度拟合,由于模型是在少量类的重复示例上操练的。
SMOTE是一种更高档的技能,它创立少量类的组成示例,而不是仿制现有示例。这种技能有助于在不引进重复项的情况下平衡数据集。
代价灵敏学习(Cost-sensitive learning)是另一种可用于处理不平衡数据集的技能。在这种办法中,不同的过错分类本钱被分配给不同的类别。这意味着与过错分类大都类示例比较,模型因过错分类少量类示例而遭到更严峻的惩罚。
在处理不平衡的数据集时,运用恰当的功能目标也很重要。准确性并不总是最好的目标,由于在处理不平衡的数据集时它或许会产生误导。相反,运用 AUC-ROC等目标能够更好地指示模型功能。
集成办法,例如 bagging 和 boosting,也能够有用地对不平衡数据集进行建模。这些办法结合了多个模型的猜测以进步全体功能。Bagging 触及独立操练多个模型并对它们的猜测进行均匀,而 boosting 触及按顺序操练多个模型,其间每个模型都企图纠正前一个模型的过错。
重采样技能、本钱灵敏学习、运用恰当的功能目标和集成办法是一些技巧和战略,能够协助处理不平衡的数据集并进步机器学习模型的功能。
在不平衡数据集上进步模型功能的战略
搜集更大都据是在不平衡数据集上进步模型功能的最直接战略之一。经过增加少量类中的示例数量,模型将有更多信息可供学习,而且不太或许偏向大都类。当少量类中的示例数量非常少时,此战略特别有用。
生成组成样本是另一种可用于进步模型功能的战略。组成样本是人工创立的样本,与少量类中的真实样本相似。这些样本能够运用 SMOTE等技能生成,该技能经过在现有示例之间进行插值来创立组成示例。生成组成样本有助于平衡数据集并为模型供给更多示例以供学习。
运用领域知识来重视重要样本也是一种可行的战略,经过辨认数据会集信息量最大的示例来进步模型功能。例如,假如咱们正在处理医学数据集,或许知道某些症状或实验室成果更能标明某种疾病。经过重视这些例子能够进步模型准确猜测少量类的能力。
最后能够运用反常检测等高档技能来辨认和重视少量类示例。这些技能可用于辨认与大都类不同且或许是少量类示例的示例。这能够经过辨认数据会集信息量最大的示例来协助进步模型功能。
在搜集更大都据、生成组成样本、运用领域知识专心于重要样本以及运用反常检测等先进技能是一些可用于进步模型在不平衡数据集上的功能的战略。这些战略能够协助平衡数据集,为模型供给更多示例以供学习,并辨认数据会集信息量最大的示例。
不平衡数据集的操练
这儿咱们运用信用卡诈骗分类的数据集演示处理不平衡数据的办法
import pandas as pd
import numpy as np
from sklearn.preprocessing import RobustScaler
from sklearn.linear\_model import LogisticRegression
from sklearn.model\_selection import train\_test\_split
from sklearn.metrics import accuracy\_score
from sklearn.metrics import confusion\_matrix, classification\_report,f1\_score,recall\_score,roc\_auc\_score, roc\_curve
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc,rcParams
import itertools
import warnings
warnings.filterwarnings\("ignore", category\=DeprecationWarning\)
warnings.filterwarnings\("ignore", category\=FutureWarning\)
warnings.filterwarnings\("ignore", category\=UserWarning\)
读取数据:
df \= pd.read\_csv\("creditcard.csv"\)
df.head\(\)
print\("Number of observations : " ,len\(df\)\)
print\("Number of variables : ", len\(df.columns\)\)
#Number of observations : 284807
#Number of variables : 31
检查数据集信息:
df.info\(\)
\<class 'pandas.core.frame.DataFrame'\>
RangeIndex: 284807 entries, 0 to 284806
Data columns \(total 31 columns\):
\# Column Non-Null Count Dtype
\--- \------ \-------------- \-----
0 Time 284807 non\-null float64
1 V1 284807 non\-null float64
2 V2 284807 non\-null float64
3 V3 284807 non\-null float64
4 V4 284807 non\-null float64
5 V5 284807 non\-null float64
6 V6 284807 non\-null float64
7 V7 284807 non\-null float64
8 V8 284807 non\-null float64
9 V9 284807 non\-null float64
10 V10 284807 non\-null float64
11 V11 284807 non\-null float64
12 V12 284807 non\-null float64
13 V13 284807 non\-null float64
14 V14 284807 non\-null float64
15 V15 284807 non\-null float64
16 V16 284807 non\-null float64
17 V17 284807 non\-null float64
18 V18 284807 non\-null float64
19 V19 284807 non\-null float64
20 V20 284807 non\-null float64
21 V21 284807 non\-null float64
22 V22 284807 non\-null float64
23 V23 284807 non\-null float64
24 V24 284807 non\-null float64
25 V25 284807 non\-null float64
26 V26 284807 non\-null float64
27 V27 284807 non\-null float64
28 V28 284807 non\-null float64
29 Amount 284807 non\-null float64
30 Class 284807 non\-null int64
dtypes: float64\(30\), int64\(1\)
memory usage: 67.4 MB
检查分类类别:
f,ax\=plt.subplots\(1,2,figsize\=\(18,8\)\)
df\['Class'\].value\_counts\(\).plot.pie\(explode\=\[0,0.1\],autopct\='\%1.1f\%\%',ax\=ax\[0\],shadow\=True\)
ax\[0\].set\_title\('dalm'\)
ax\[0\].set\_ylabel\(''\)
sns.countplot\('Class',data\=df,ax\=ax\[1\]\)
ax\[1\].set\_title\('Class'\)
plt.show\(\)
rob\_scaler \= RobustScaler\(\)
df\['Amount'\] \= rob\_scaler.fit\_transform\(df\['Amount'\].values.reshape\(\-1,1\)\)
df\['Time'\] \= rob\_scaler.fit\_transform\(df\['Time'\].values.reshape\(\-1,1\)\)
df.head\(\)
创立基类模型:
X \= df.drop\("Class", axis\=1\)
y \= df\["Class"\]
X\_train, X\_test, y\_train, y\_test \= train\_test\_split\(X, y, test\_size\=0.20, random\_state\=123456\)
model \= LogisticRegression\(random\_state\=123456\)
model.fit\(X\_train, y\_train\)
y\_pred \= model.predict\(X\_test\)
accuracy \= accuracy\_score\(y\_test, y\_pred\)
print\("Accuracy: \%.3f"\%\(accuracy\)\)
咱们创立的模型的准确率评分为0.999。咱们能够说咱们的模型很完美吗?混杂矩阵是一个用来描绘分类模型的真实值在测试数据上的功能的表。它包括4种不同的估量值和实际值的组合。
def plot\_confusion\_matrix\(cm, classes,
title\='Confusion matrix',
cmap\=plt.cm.Blues\):
plt.rcParams.update\(\{'font.size': 19\}\)
plt.imshow\(cm, interpolation\='nearest', cmap\=cmap\)
plt.title\(title,fontdict\=\{'size':'16'\}\)
plt.colorbar\(\)
tick\_marks \= np.arange\(len\(classes\)\)
plt.xticks\(tick\_marks, classes, rotation\=45,fontsize\=12,color\="blue"\)
plt.yticks\(tick\_marks, classes,fontsize\=12,color\="blue"\)
rc\('font', weight\='bold'\)
fmt \= '.1f'
thresh \= cm.max\(\)
for i, j in itertools.product\(range\(cm.shape\[0\]\), range\(cm.shape\[1\]\)\):
plt.text\(j, i, format\(cm\[i, j\], fmt\),
horizontalalignment\="center",
color\="red"\)
plt.ylabel\('True label',fontdict\=\{'size':'16'\}\)
plt.xlabel\('Predicted label',fontdict\=\{'size':'16'\}\)
plt.tight\_layout\(\)
plot\_confusion\_matrix\(confusion\_matrix\(y\_test, y\_pred\=y\_pred\), classes\=\['Non Fraud','Fraud'\],
title\='Confusion matrix'\)
- 非诈骗类共进行了56875次猜测,其间56870次(TP)正确,5次(FP)过错。
- 诈骗类共进行了87次猜测,其间31次(FN)过错,56次(TN)正确。
该模型能够猜测诈骗状况,准确率为0.99。但当检查混杂矩阵时,诈骗类的过错猜测率相当高。也就是说该模型正确地猜测了非诈骗类的概率为0.99。可是非诈骗类的观测值的数量高于诈骗类的观测值的数量,这拉搞了咱们对准确率的核算,而且咱们更加重视的是诈骗类的准确率,所以咱们需求一个目标来衡量它的功能。
选择正确的目标
在处理不平衡数据集时,选择正确的目标来评估模型的功能非常重要。传统目标,如准确性、准确度和召回率,或许不适用于不平衡的数据集,由于它们没有考虑数据中类别的散布。
常常用于不平衡数据集的一个目标是 F1 分数。F1 分数是准确率和召回率的调和均匀值,它供给了两个目标之间的平衡。核算如下:
F1 = 2 * (precision * recall) / (precision + recall)
另一个常常用于不平衡数据集的目标是 AUC-ROC。AUC-ROC 衡量模型区分正类和负类的能力。它是经过绘制不同分类阈值下的TPR与FPR来核算的。AUC-ROC 值的范围从 0.5(随机猜测)到 1.0(完美分类)。
print\(classification\_report\(y\_test, y\_pred\)\)
precision recall f1\-score support
0 1.00 1.00 1.00 56875
1 0.92 0.64 0.76 87
accuracy 1.00 56962
macro avg 0.96 0.82 0.88 56962
weighted avg 1.00 1.00 1.00 56962
回来对0(非诈骗)类的猜测有多少是正确的。检查混杂矩阵,56870 + 31 = 56901个非诈骗类猜测,其间56870个猜测正确。0类的精度值挨近1 (56870 / 56901)
回来对1 (诈骗)类的猜测有多少是正确的。检查混杂矩阵,5 + 56 = 61个诈骗类别猜测,其间56个被正确估量。0类的精度为0.92 (56 / 61),能够看到差别仍是很大的
过采样
经过仿制少量类样原本稳定数据集。
随机过采样:经过增加从少量集体中随机选择的样原本平衡数据集。假如数据集很小,能够运用这种技能。或许会导致过拟合。randomoverampler办法接受sampling_strategy参数,当sampling_strategy = ‘ minority ‘被调用时,它会增加minority类的数量,使其与majority类的数量相等。
咱们能够在这个参数中输入一个浮点值。例如,假定咱们的少量集体人数为1000人,大都集体人数为100人。假如咱们说sampling_strategy = 0.5,少量类将被增加到500
y\_train.value\_counts\(\)
0 227440
1 405
Name: Class, dtype: int64
from imblearn.over\_sampling import RandomOverSampler
oversample \= RandomOverSampler\(sampling\_strategy\='minority'\)
X\_randomover, y\_randomover \= oversample.fit\_resample\(X\_train, y\_train\)
采样后操练
model.fit\(X\_randomover, y\_randomover\)
y\_pred \= model.predict\(X\_test\)
plot\_confusion\_matrix\(confusion\_matrix\(y\_test, y\_pred\=y\_pred\), classes\=\['Non Fraud','Fraud'\],
title\='Confusion matrix'\)
运用随机过采样后,操练模型的精度值为0.97,出现了下降。可是从混杂矩阵来看,模型的诈骗类的正确估量率有所进步。
SMOTE 过采样:从少量集体中随机选取一个样本。然后,为这个样本找到k个最近的街坊。从k个最近的街坊中随机选取一个,将其与从少量类中随机选取的样本组合在特征空间中形成线段,形成组成样本。
from imblearn.over\_sampling import SMOTE
oversample = SMOTE\(\)
X\_smote, y\_smote = oversample.fit\_resample\(X\_train, y\_train\)
运用SMOTE后的数据操练
model.fit\(X\_smote, y\_smote\)
y\_pred = model.predict\(X\_test\)
accuracy = accuracy\_score\(y\_test, y\_pred\)
plot\_confusion\_matrix\(confusion\_matrix\(y\_test, y\_pred=y\_pred\), classes=\['Non Fraud','Fraud'\],
title='Confusion matrix'\)
能够看到与基线模型比较,诈骗的准确率有所进步,可是比随机过采样有所下降,这或许是数据集的原因,由于SMOTE采样会生成心的数据,所以并不合适一切的数据集。
总结
在这篇文章中,咱们评论了处理不平衡数据集和进步机器学习模型功能的各种技巧和战略。不平衡的数据集或许是机器学习中的一个常见问题,并或许导致在猜测少量类时体现欠安。
本文介绍了一些可用于平衡数据集的重采样技能,如欠采样、过采样和SMOTE。还评论了本钱灵敏学习和运用恰当的功能目标,如AUC-ROC,这能够供给更好的模型功能指示。
处理不平衡的数据集是具有挑战性的,但经过遵循本文评论的技巧和战略,能够树立有用的模型准确猜测少量集体。重要的是要记住最佳办法将取决于特定的数据集和问题,为了获得最佳成果,或许需求结合各种技能。因此,实验不同的技能并运用恰当的目标评估它们的功能是很重要的。
作者:Emine Bozku
欢迎重视大众号CV技能攻略,专心于核算机视觉的技能总结、最新技能盯梢、经典论文解读、CV招聘信息。
【技能文档】《从零建立pytorch模型教程》122页PDF下载
QQ沟通群:444129970。群内有大佬担任回答我们的日常学习、科研、代码问题。
模型布置沟通群:732145323。用于核算机视觉方面的模型布置、高功能核算、优化加速、技能学习等方面的沟通。
其它文章
U-Net在2022年相关研讨的论文引荐
用少于256KB内存完成边际操练,开支不到PyTorch千分之一
PyTorch 2.0 重磅发布:一行代码提速 30%
Hinton 最新研讨:神经网络的未来是前向-前向算法
聊聊核算机视觉入门
FRNet:上下文感知的特征强化模块
DAMO-YOLO | 超越一切YOLO,统筹模型速度与精度
《医学图画切割》综述,详述六大类100多个算法
怎么高效完成矩阵乘?万文长字带你从CUDA初学者的角度入门
近似乘法对卷积神经网络的影响
BT-Unet:医学图画切割的自监督学习框架
语义切割该怎么走下去?
轻量级模型规划与布置总结
从CVPR22出发,聊聊CAM是怎么激活咱们文章的热度!
入门必读系列(十六)经典CNN规划演变的要害总结:从VGGNet到EfficientNet
入门必读系列(十五)神经网络不work的原因总结
入门必读系列(十四)CV论文常见英语单词总结
入门必读系列(十三)高效阅览论文的办法
入门必读系列(十二)池化各要点与各办法总结
TensorRT教程(三)TensorRT的安装教程
TensorRT教程(一)初次介绍TensorRT
TensorRT教程(二)TensorRT进阶介绍
核算机视觉中的高效阅览论文的办法总结
核算机视觉中的神经网络可视化东西与项目
核算机视觉中的transformer模型立异思路总结
核算机视觉中的传统特征提取办法总结
核算机视觉中的数据预处理与模型操练技巧总结
核算机视觉中的图画标示东西总结
核算机视觉中的数据增强办法总结
核算机视觉中的注意力机制技能总结
核算机视觉中的特征金字塔技能总结
核算机视觉中的池化技能总结
核算机视觉中的高效阅览论文的办法总结
核算机视觉中的论文立异的常见思路总结
神经网络中的归一化办法总结
神经网络的初始化办法总结