处理不平衡机器学习数据时需要了解的技术
bigegpt 2024-11-10 08:26 2 浏览
我们在处理真实世界机器学习数据集时遇到的主要挑战之一是数据的比例不平衡。欺诈检测是这类数据的最好例子。在本文中,我们将使用kaggle中的信用卡欺诈检测数据集(www.kaggle.com/mlg-ulb/creditcardfraud)。
在全部数据中,欺诈事件不到1%。这种具有相当少的来自特定类的实例的数据称为不平衡数据。
采样技术
过采样
对少数类(数据集中实例较少的类)的数据进行复制,以增加少数类的比例。这种技术的一个主要问题是过度拟合。
from imblearn.over_sampling import RandomOverSampler
oversample = RandomOverSampler(sampling_strategy='minority')
X_over, y_over = oversample.fit_resample(X_train, y_train)
欠采样
对来自多数类(机器学习数据集中实例较多的类)的数据进行采样,以减少多数类的比例。这种技术的一个主要问题是信息的丢失。
from imblearn.over_sampling import RandomUnderSampler
undersample = RandomUnderSampler(sampling_strategy='majority')
X_over, y_over = oversample.fit_resample(X_train, y_train)
合成少数类过采样技术(SMOTE)
我们将使用过采样技术生成样本,但不是盲目复制。SMOTE遵循以下步骤来生成数据。
- 对于少数类中的每个样本x,选择k个最近邻构成Q{y0,y1,…yk}(k的默认值为5)。
- 对少数类样本进行线性插值,得到新的样本x’。
from imblearn.over_sampling import SMOTE
sm = SMOTE(random_state = 2)
X_train_res, y_train_res = sm.fit_sample(X_train, y_train)
结合使用SMOTE和欠采样可获得更好的结果。
集成学习技巧
集成学习技术在不平衡数据上表现良好。集成技术结合了多个分类器的结果,以提高单个分类器的性能。
随机森林
随机森林是一种减少决策树分类器方差的集成学习技术。随机森林从建立在采样数据上的多个决策树中获得最佳解。
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(n_estimators=100,
bootstrap = True,
max_features = 'sqrt')
model.fit(X_train,y_train)
y_pred2= model.predict(X_test)
我们检测到72个欺诈(总共98个欺诈)。因此,检测到欺诈的概率为0.734。
XGBoost
随机森林并行构建树。在boosting技术中,通过校正先前训练的树的误差来训练树。
import xgboost as xgb
alg = xgb.XGBClassifier(learning_rate=0.1, n_estimators=140, max_depth=5,min_child_weight=3, gamma=0.2, subsample=0.6, colsample_bytree=1.0,objective='binary:logistic', nthread=4, scale_pos_weight=1, seed=27)
alg.fit(X_train, y_train, eval_metric='auc')
y_pred = alg.predict(X_test)
y_score = alg.predict_proba(X_test)[:,1]
我们检测到74起欺诈/ 98起欺诈。因此,检测到欺诈的概率为0.755。
Light GBM
Light GBM可以提高XGBoost的性能。Light GBM内存高效且与大型数据集兼容。
import lightgbm as lgbm
lgbm_clf = lgbm.LGBMClassifier(boosting_type='gbdt',
class_weight=None,
colsample_bytree=0.5112837457460335,importance_type='split',
learning_rate=0.02, max_depth=7, metric='None',
min_child_samples=195, min_child_weight=0.01,
min_split_gain=0.0,
n_estimators=3000, n_jobs=4, num_leaves=44, objective=None,
random_state=42, reg_alpha=2, reg_lambda=10, silent=True,
subsample=0.8137506311449016, subsample_for_bin=200000,
subsample_freq=0)
lgbm_clf.fit(X_train, y_train)
y_pred1 = lgbm_clf.predict(X_test)
y_score1 = lgbm_clf.predict_proba(X_test)[:,1]
我们检测到76起欺诈/ 98起欺诈。因此,检测到欺诈的概率为0.775。
深度学习技术
自编码器
自编码器尝试重建给定的输入。自编码器用于降维和深度异常检测。这些深度学习技术也可以应用于图像和视频。
我们将仅以正常交易来训练我们的自编码器。每当遇到欺诈检测时,自编码器都无法重建它。
autoencoder = tf.keras.models.Sequential([
tf.keras.layers.Dense(input_dim, activation='relu', input_shape=(input_dim, )),
tf.keras.layers.GaussianNoise(),
tf.keras.layers.Dense(latent_dim, activation='relu'),
tf.keras.layers.Dense(input_dim, activation='relu')
])
autoencoder.compile(optimizer='adam',
loss='mse',
metrics=['acc'])
autoencoder.summary()
现在,我们将训练自编码器,并观察正常交易和欺诈交易的重构情况。
X_test_transformed = pipeline.transform(X_test)
reconstructions = autoencoder.predict(X_test_transformed)
mse = np.mean(np.power(X_test - reconstructions, 2), axis=1)
欺诈交易的重构误差率很高。现在,我们需要设置将欺诈与正常交易区分开的阈值。
为了获得较高的精度值,我们可以采用较高的阈值,但是为了获得良好的召回率,我们需要降低它。
DevNet
Deviation Networks (DevNet)定义了高斯先验和基于Z分数的偏差损失,以使端到端神经异常评分学习器能够直接优化异常评分。
该网络中使用的损失函数为:
Lφ(x;Θ)=(1- y)| dev(x)| + y max(0,a-dev(x))
dev(x)=φ(x;Θ)?μR/σR
其中a是Z分数置信区间。
根据中心极限定理,我们可以得出结论:高斯分布拟合从网络获得的异常分数数据。我们将在实验中设置μ= 0和σ= 1,这有助于DevNet在不同机器学习数据集上实现稳定的检测性能。
对于所有正常交易(y = 0):
Lφ(x;Θ)=(1-0)| dev(x)| = | dev(x)|
对于所有欺诈交易(y = 1):
Lφ(x;Θ)= 1(max(0,a-dev(x)))= max(0,a-dev(x))
因此,偏差损失相当于将所有异常对象的异常评分与正常对象的异常评分进行统计显著性偏差。
网络的代码是:
def dev_network(input_shape):
x_input = Input(shape=input_shape) intermediate = Dense(1000
,activation='relu',
kernel_regularizer=regularizers.l2(0.01), name =
'hl1')(x_input)
intermediate = Dense(250, activation='relu',
kernel_regularizer=regularizers.l2(0.01), name =
'hl2')(intermediate)
intermediate = Dense(20, activation='relu',
kernel_regularizer=regularizers.l2(0.01), name =
'hl3')(intermediate)
intermediate = Dense(1, activation='linear', name = 'score')
(intermediate)
return Model(x_input, intermediate)
偏差损失的Pytyhon代码为:
def deviation_loss(y_true, y_pred):
confidence_margin = 5.
ref = K.variable(np.random.normal(loc = 0., scale= 1.0, size =
5000) , dtype='float32')
dev = (y_pred - K.mean(ref)) / K.std(ref)
inlier_loss = K.abs(dev)
outlier_loss = K.abs(K.maximum(confidence_margin - dev, 0.))
return K.mean((1 - y_true) * inlier_loss +
y_true * outlier_loss)
model = dev_network_d(input_shape)
model.compile(loss=deviation_loss, optimizer=rms)
指标
对于不平衡的机器学习数据,准确性不是一个好的度量标准。相反,我们可以考虑使用Recall和F1-score。
我们也可以从ROC曲线转换到Precision-Recall曲线。
ROC曲线介于真阳性率(召回率)与假阳性率之间。
精度对不平衡数据的变化更为敏感,因为负样本的数量相当高。
FPR = FP /(FP + TN)
Precision= TP /(TP + FP)
最后
- 为了从图像或视频相关数据中检测异常,首选深度学习。
- 与集成方法相比,深度学习中需要调整的参数更多。因此,理解模型对于深度学习的调整起着关键作用。
相关推荐
- 悠悠万事,吃饭为大(悠悠万事吃饭为大,什么意思)
-
新媒体编辑:杜岷赵蕾初审:程秀娟审核:汤小俊审签:周星...
- 高铁扒门事件升级版!婚宴上‘冲喜’老人团:我们抢的是社会资源
-
凌晨两点改方案时,突然收到婚庆团队发来的视频——胶东某酒店宴会厅,三个穿大红棉袄的中年妇女跟敢死队似的往前冲,眼瞅着就要扑到新娘的高额钻石项链上。要不是门口小伙及时阻拦,这婚礼造型团队熬了三个月的方案...
- 微服务架构实战:商家管理后台与sso设计,SSO客户端设计
-
SSO客户端设计下面通过模块merchant-security对SSO客户端安全认证部分的实现进行封装,以便各个接入SSO的客户端应用进行引用。安全认证的项目管理配置SSO客户端安全认证的项目管理使...
- 还在为 Spring Boot 配置类加载机制困惑?一文为你彻底解惑
-
在当今微服务架构盛行、项目复杂度不断攀升的开发环境下,SpringBoot作为Java后端开发的主流框架,无疑是我们手中的得力武器。然而,当我们在享受其自动配置带来的便捷时,是否曾被配置类加载...
- Seata源码—6.Seata AT模式的数据源代理二
-
大纲1.Seata的Resource资源接口源码2.Seata数据源连接池代理的实现源码3.Client向Server发起注册RM的源码4.Client向Server注册RM时的交互源码5.数据源连接...
- 30分钟了解K8S(30分钟了解微积分)
-
微服务演进方向o面向分布式设计(Distribution):容器、微服务、API驱动的开发;o面向配置设计(Configuration):一个镜像,多个环境配置;o面向韧性设计(Resista...
- SpringBoot条件化配置(@Conditional)全面解析与实战指南
-
一、条件化配置基础概念1.1什么是条件化配置条件化配置是Spring框架提供的一种基于特定条件来决定是否注册Bean或加载配置的机制。在SpringBoot中,这一机制通过@Conditional...
- 一招解决所有依赖冲突(克服依赖)
-
背景介绍最近遇到了这样一个问题,我们有一个jar包common-tool,作为基础工具包,被各个项目在引用。突然某一天发现日志很多报错。一看是NoSuchMethodError,意思是Dis...
- 你读过Mybatis的源码?说说它用到了几种设计模式
-
学习设计模式时,很多人都有类似的困扰——明明概念背得滚瓜烂熟,一到写代码就完全想不起来怎么用。就像学了一堆游泳技巧,却从没下过水实践,很难真正掌握。其实理解一个知识点,就像看立体模型,单角度观察总...
- golang对接阿里云私有Bucket上传图片、授权访问图片
-
1、为什么要设置私有bucket公共读写:互联网上任何用户都可以对该Bucket内的文件进行访问,并且向该Bucket写入数据。这有可能造成您数据的外泄以及费用激增,若被人恶意写入违法信息还可...
- spring中的资源的加载(spring加载原理)
-
最近在网上看到有人问@ContextConfiguration("classpath:/bean.xml")中除了classpath这种还有其他的写法么,看他的意思是想从本地文件...
- Android资源使用(android资源文件)
-
Android资源管理机制在Android的开发中,需要使用到各式各样的资源,这些资源往往是一些静态资源,比如位图,颜色,布局定义,用户界面使用到的字符串,动画等。这些资源统统放在项目的res/独立子...
- 如何深度理解mybatis?(如何深度理解康乐服务质量管理的5个维度)
-
深度自定义mybatis回顾mybatis的操作的核心步骤编写核心类SqlSessionFacotryBuild进行解析配置文件深度分析解析SqlSessionFacotryBuild干的核心工作编写...
- @Autowired与@Resource原理知识点详解
-
springIOCAOP的不多做赘述了,说下IOC:SpringIOC解决的是对象管理和对象依赖的问题,IOC容器可以理解为一个对象工厂,我们都把该对象交给工厂,工厂管理这些对象的创建以及依赖关系...
- java的redis连接工具篇(java redis client)
-
在Java里,有不少用于连接Redis的工具,下面为你介绍一些主流的工具及其特点:JedisJedis是Redis官方推荐的Java连接工具,它提供了全面的Redis命令支持,且...
- 一周热门
- 最近发表
- 标签列表
-
- mybatiscollection (79)
- mqtt服务器 (88)
- keyerror (78)
- c#map (65)
- resize函数 (64)
- xftp6 (83)
- bt搜索 (75)
- c#var (76)
- mybatis大于等于 (64)
- xcode-select (66)
- mysql授权 (74)
- 下载测试 (70)
- linuxlink (65)
- pythonwget (67)
- androidinclude (65)
- logstashinput (65)
- hadoop端口 (65)
- vue阻止冒泡 (67)
- oracle时间戳转换日期 (64)
- jquery跨域 (68)
- php写入文件 (73)
- kafkatools (66)
- mysql导出数据库 (66)
- jquery鼠标移入移出 (71)
- 取小数点后两位的函数 (73)