如何Keras自动编码器给极端罕见事件分类
bigegpt 2025-05-27 12:49 21 浏览
全文共7940字,预计学习时长30分钟或更长
本文将以一家造纸厂的生产为例,介绍如何使用自动编码器构建罕见事件分类器。
现实生活中罕见事件的数据集:
背景
1. 什么是极端罕见事件?
在罕见事件问题中,数据集是不平衡的。也就是说,正样本比负样本数量少。典型罕见事件问题的正样本数约占总数的5-10%。而在极端罕见的事件问题中,正样本数据只有不到1%。例如,本文使用的数据集里,这一比例只有约0.6%。
这种极端罕见的事件问题在现实世界中非常常见,例如,工厂中的机器故障或在网上点击购买时页面失踪。
对这些罕见事件进行分类非常有挑战性。近来,深度学习被广泛应用于分类中。然而正样本数太少不利于深度学习的应用。不论数据总量多大,深度学习的使用都会受制于阳性数据的数量。
2. 为什么要使用深度学习?
这个问题很合理。为什么不考虑使用其他机器学习方法呢?
答案很主观。我们总是可以采用某种机器学习方法来达到目的。为了使其成功,可以对负样本数据进行欠采样,以获得接近更平衡的数据集。由于只有0.6%的正样本数据,欠采样将会导致数据集大小约为原始数据集的1%。机器学习方法如SVM或Random Forest仍然适用于这种大小的数据集。然而,其准确性将受到限制。剩下约99%的数据中的信息将无法使用。
如果数据足够的话,深度学习或许更有效。它还能通过使用不同的体系结构实现模型改进的灵活性。因此,我们选择尝试使用深度学习的方法。
在本文中,我们将学习如何使用一个简单的全连接层自动编码器来构建罕见事件分类器。本文是为了演示如何使用自动编码器来实现极端罕见事件分类器的构建。用户可以自行探索自动编码器的不同架构和配置。
用自动编码器进行分类
用自动编码器分类类似于异常检测。在异常检测中,先学习正常过程的模式。任何不遵循此模式的都被归类为异常。对于罕见事件的二进制分类,可以采用类似的方法使用自动编码器。
1. 什么是自动编码器?
· 自动编码器由编码器和解码器两个模块组成。
· 编码器学习某一进程的隐含特性。这些特性通常在一个降低的维度中。
· 解码器可以根据这些隐含特性重新创建原始数据。
2. 如何使用自动编码器构建罕见事件分类?
· 将数据分为正标记和负标记两部分。
· 负标记的数据视为正常状态——无事件。
· 忽略正标记的数据,用负标记数据训练自动编码器。
· 所以重构误差的概率就很小。
· 然而,如果试图从稀有事件中重构数据,自动编码器就很难工作。
· 这就会造成在罕见事件中发生重构误差的概率比较高。
· 如此高的重构误差,并将其标记为罕见事件预测。
· 此过程与异常检测方法类似。
实际应用
1. 数据和问题
这是来自一家造纸厂关于纸张破损的二进制标记数据。纸张破损在造纸业是个很严重的问题。一起纸张破损就会造成数千美元损失,而造纸厂每天都会有数几起破损。这导致每年数百万美元的损失和工作风险。
因为生产过程本身的性质,很难检测到纸张破损。破损概率降低5%都能给厂家带来巨大的利益。
我们的数据包含15天内收集的约18,000行数据。y列包含两类标签,1表示纸张破损。其余列是预测,正标记样本约124例(约0.6%)。
2. 编码
导入所需的库。
%matplotlib inline import matplotlib.pyplot as plt import seaborn as sns import pandas as pd import numpy as np from pylab import rcParams import tensorflow as tf from keras.models import Model, load_model from keras.layers import Input, Dense from keras.callbacks import ModelCheckpoint, TensorBoard from keras import regularizers from sklearn.preprocessing import StandardScaler from sklearn.model_selection import train_test_split from sklearn.metrics import confusion_matrix, precision_recall_curve from sklearn.metrics import recall_score, classification_report, auc, roc_curve from sklearn.metrics import precision_recall_fscore_support, f1_score from numpy.random import seed seed(1) from tensorflow import set_random_seed set_random_seed(2) SEED = 123 #used to help randomly select the data points DATA_SPLIT_PCT = 0.2 rcParams['figure.figsize'] = 8, 6 LABELS = ["Normal","Break"]
注意,我们正在为可重现结果设置随机种子。
3. 数据处理
现在,读取并准备数据。
df = pd.read_csv("
data/processminer-rare-event-mts - data.csv")这个罕见事件问题的目标就是在纸张破损发生之前就及时做出预测。我们试着在破损发生前四分钟就要预测到。为了建立这个模型,把标签向上移动两行(相当于4分钟)。只要df.y=df.y.shift(-2)就行了。然而针对这个问题,需要做出如下改变:如果第n行是阳性的,
· 令行(n-2)和(n-1)等于1。这将帮助分类器学会提前最多4分钟预测。
· 删除第n行。因为分类器不需要学会在事件发生时做出预测。
为了这个复杂的变化需要用下面的UDF。
sign = lambda x: (1, -1)[x < 0] def curve_shift(df, shift_by): ''' This function will shift the binary labels in a dataframe. The curve shift will be with respect to the 1s. For example, if shift is -2, the following process will happen: if row n is labeled as 1, then - Make row (n+shift_by):(n+shift_by-1) = 1. - Remove row n. i.e. the labels will be shifted up to 2 rows up. Inputs: df A pandas dataframe with a binary labeled column. This labeled column should be named as 'y'. shift_by An integer denoting the number of rows to shift. Output df A dataframe with the binary labels shifted by shift. ''' vector = df['y'].copy() for s in range(abs(shift_by)): tmp = vector.shift(sign(shift_by)) tmp = tmp.fillna(0) vector += tmp labelcol = 'y' # Add vector to the df df.insert(loc=0, column=labelcol+'tmp', value=vector) # Remove the rows with labelcol == 1. df = df.drop(df[df[labelcol] == 1].index) # Drop labelcol and rename the tmp col as labelcol df = df.drop(labelcol, axis=1) df = df.rename(columns={labelcol+'tmp': labelcol}) # Make the labelcol binary df.loc[df[labelcol] > 0, labelcol] = 1 return df
现在,将数据分为训练集、有效集和测试集。然后迅速使用只有0的数据子集来训练自动编码器。
df_train, df_test = train_test_split(df, test_size=DATA_SPLIT_PCT, random_state=SEED) df_train, df_valid = train_test_split(df_train, test_size=DATA_SPLIT_PCT, random_state=SEED) df_train_0 = df_train.loc[df['y'] == 0] df_train_1 = df_train.loc[df['y'] == 1] df_train_0_x = df_train_0.drop(['y'], axis=1) df_train_1_x = df_train_1.drop(['y'], axis=1)
4. 标准化
自动编码器最好使用标准化数据(转换为Gaussian、均值0、方差1)。
scaler = StandardScaler().fit(df_train_0_x) df_train_0_x_rescaled = scaler.transform(df_train_0_x) df_valid_0_x_rescaled = scaler.transform(df_valid_0_x) df_valid_x_rescaled = scaler.transform(df_valid.drop(['y'], axis = 1)) df_test_0_x_rescaled = scaler.transform(df_test_0_x) df_test_x_rescaled = scaler.transform(df_test.drop(['y'], axis = 1))
由自动编码器构造的分类器
1. 初始化
首先,初始化自动编码器架构。先构建一个简单的自动编码器,稍后再探索更复杂的架构和配置。
nb_epoch = 100 batch_size = 128 input_dim = df_train_0_x_rescaled.shape[1] #num of predictor variables, encoding_dim = 32 hidden_dim = int(encoding_dim / 2) learning_rate = 1e-3 input_layer = Input(shape=(input_dim, )) encoder = Dense(encoding_dim, activation="tanh", activity_regularizer=regularizers.l1(learning_rate))(input_layer) encoder = Dense(hidden_dim, activation="relu")(encoder) decoder = Dense(hidden_dim, activation='tanh')(encoder) decoder = Dense(input_dim, activation='relu')(decoder) autoencoder = Model(inputs=input_layer, outputs=decoder)
2. 训练
训练模型并将其保存在文件中。保存训练过的模型会为将来的分析省很多时间。
autoencoder.compile(metrics=['accuracy'], loss='mean_squared_error', optimizer='adam') cp = ModelCheckpoint(filepath="autoencoder_classifier.h5", save_best_only=True, verbose=0) tb = TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True) history=autoencoder.fit(df_train_0_x_rescaled,df_train_0_x_rescaled, epochs=nb_epoch, batch_size=batch_size, shuffle=True, validation_data=(df_valid_0_x_rescaled, df_valid_0_x_rescaled), verbose=1, callbacks=[cp, tb]).history
3. 分类
接下来将展示如何利用自动编码器重构误差来构造罕见事件分类器。
如前所述,如果重构误差较大,将其归类为纸张破损。需要确定这个阈值。
使用验证集来确定阈值。
valid_x_predictions = autoencoder.predict(df_valid_x_rescaled) mse = np.mean(np.power(df_valid_x_rescaled - valid_x_predictions, 2), axis=1) error_df = pd.DataFrame({'Reconstruction_error': mse, 'True_class': df_valid['y']}) precision_rt, recall_rt, threshold_rt = precision_recall_curve(error_df.True_class, error_df.Reconstruction_error) plt.plot(threshold_rt, precision_rt[1:], label="Precision",linewidth=5) plt.plot(threshold_rt, recall_rt[1:], label="Recall",linewidth=5) plt.title('Precision and recall for different threshold values') plt.xlabel('Threshold') plt.ylabel('Precision/Recall') plt.legend() plt.show()
现在,对测试数据进行分类。
不要根据测试数据来估计分类阈值,这会导致过度拟合。
test_x_predictions = autoencoder.predict(df_test_x_rescaled) mse = np.mean(np.power(df_test_x_rescaled - test_x_predictions, 2), axis=1) error_df_test = pd.DataFrame({'Reconstruction_error': mse, 'True_class': df_test['y']})error_df_test = error_df_test.reset_index() threshold_fixed = 0.85 groups = error_df_test.groupby('True_class') fig, ax = plt.subplots() for name, group in groups: ax.plot(group.index, group.Reconstruction_error, marker='o', ms=3.5, linestyle='', label= "Break" if name == 1 else "Normal") ax.hlines(threshold_fixed, ax.get_xlim()[0], ax.get_xlim()[1], colors="r", zorder=100, label='Threshold') ax.legend() plt.title("Reconstruction error for different classes") plt.ylabel("Reconstruction error") plt.xlabel("Data point index") plt.show();
图4中,阈值线上方的橙色和蓝色圆点分别表示真阳性和假阳性。可以看到上面有很多假阳性的点。为了看得更清楚,可以看一个混淆矩阵。
pred_y = [1 if e > threshold_fixed else 0 for e in error_df.Reconstruction_error.values] conf_matrix = confusion_matrix(error_df.True_class, pred_y) plt.figure(figsize=(12, 12)) sns.heatmap(conf_matrix, xticklabels=LABELS, yticklabels=LABELS, annot=True, fmt="d"); plt.title("Confusion matrix") plt.ylabel('True class') plt.xlabel('Predicted class') plt.show()
在32次的破损中,我们预测到了9次。注意,其中包括提前两或四分钟的预测。这一概率约为28%,对造纸业来说是一个不错的召回率。假阳性率约为6.3%。这对造纸厂来说不是最好的结果,但也不坏。
该模型还可以进一步改进,降低假阳性率以提高召回率。观察如下AUC值后探讨改进方法。
ROC曲线和AUC(Area Under Curve)
false_pos_rate, true_pos_rate, thresholds = roc_curve(error_df.True_class, error_df.Reconstruction_error)roc_auc = auc(false_pos_rate, true_pos_rate,) plt.plot(false_pos_rate, true_pos_rate, linewidth=5, label='AUC = %0.3f'% roc_auc)plt.plot([0,1],[0,1], linewidth=5) plt.xlim([-0.01, 1]) plt.ylim([0, 1.01]) plt.legend(loc='lower right') plt.title('Receiver operating characteristic curve (ROC)') plt.ylabel('True Positive Rate') plt.xlabel('False Positive Rate') plt.show()
值得注意的是,这是一个(多变量的)时间序列数据。我们并未考虑数据中的时间信息/模式。
留言 点赞 关注
我们一起分享AI学习与发展的干货
欢迎关注全平台AI垂类自媒体 “读芯术”
相关推荐
- 方差分析简介(方差分析通俗理解)
-
介绍方差分析(ANOVA,AnalysisofVariance)是一种广泛使用的统计方法,用于比较两个或多个组之间的均值。单因素方差分析是方差分析的一种变体,旨在检测三个或更多分类组的均值是否存在...
- 正如404页面所预示,猴子正成为断网元凶--吧嗒吧嗒真好吃
-
吧嗒吧嗒,绘图:MakiNaro你可以通过加热、冰冻、水淹、模塑、甚至压溃压力来使网络光缆硬化。但用猴子显然是不行的。光缆那新挤压成型的塑料外皮太尼玛诱人了,无法阻挡一场试吃盛宴的举行。印度政府正...
- Python数据可视化:箱线图多种库画法
-
概念箱线图通过数据的四分位数来展示数据的分布情况。例如:数据的中心位置,数据间的离散程度,是否有异常值等。把数据从小到大进行排列并等分成四份,第一分位数(Q1),第二分位数(Q2)和第三分位数(Q3)...
- 多组独立(完全随机设计)样本秩和检验的SPSS操作教程及结果解读
-
作者/风仕在上一期,我们已经讲完了两组独立样本秩和检验的SPSS操作教程及结果解读,这期开始讲多组独立样本秩和检验,我们主要从多组独立样本秩和检验介绍、两组独立样本秩和检验使用条件及案例的SPSS操作...
- 方差分析 in R语言 and Excel(方差分析r语言例题)
-
今天来写一篇实际中比较实用的分析方法,方差分析。通过方差分析,我们可以确定组别之间的差异是否超出了由于随机因素引起的差异范围。方差分析分为单因素方差分析和多因素方差分析,这一篇先介绍一下单因素方差分析...
- 可视化:前端数据可视化插件大盘点 图表/图谱/地图/关系图
-
前端数据可视化插件大盘点图表/图谱/地图/关系图全有在大数据时代,很多时候我们需要在网页中显示数据统计报表,从而能很直观地了解数据的走向,开发人员很多时候需要使用图表来表现一些数据。随着Web技术的...
- matplotlib 必知的 15 个图(matplotlib各种图)
-
施工专题,我已完成20篇,施工系列几乎覆盖Python完整技术栈,目标只总结实践中最实用的东西,直击问题本质,快速帮助读者们入门和进阶:1我的施工计划2数字专题3字符串专题4列表专题5流程控制专题6编...
- R ggplot2常用图表绘制指南(ggplot2绘制折线图)
-
ggplot2是R语言中强大的数据可视化包,基于“图形语法”(GrammarofGraphics),通过分层方式构建图表。以下是常用图表命令的详细指南,涵盖基本语法、常见图表类型及示例,适合...
- Python数据可视化:从Pandas基础到Seaborn高级应用
-
数据可视化是数据分析中不可或缺的一环,它能帮助我们直观理解数据模式和趋势。本文将全面介绍Python中最常用的三种可视化方法。Pandas内置绘图功能Pandas基于Matplotlib提供了简洁的绘...
- Python 数据可视化常用命令备忘录
-
本文提供了一个全面的Python数据可视化备忘单,适用于探索性数据分析(EDA)。该备忘单涵盖了单变量分析、双变量分析、多变量分析、时间序列分析、文本数据分析、可视化定制以及保存与显示等内容。所...
- 统计图的种类(统计图的种类及特点图片)
-
统计图是利用几何图形或具体事物的形象和地图等形式来表现社会经济现象数量特征和数量关系的图形。以下是几种常见的统计图类型及其适用场景:1.条形图(BarChart)条形图是用矩形条的高度或长度来表示...
- 实测,大模型谁更懂数据可视化?(数据可视化和可视化分析的主要模型)
-
大家好,我是Ai学习的老章看论文时,经常看到漂亮的图表,很多不知道是用什么工具绘制的,或者很想复刻类似图表。实测,大模型LaTeX公式识别,出乎预料前文,我用Kimi、Qwen-3-235B...
- 通过AI提示词让Deepseek快速生成各种类型的图表制作
-
在数据分析和可视化领域,图表是传达信息的重要工具。然而,传统图表制作往往需要专业的软件和一定的技术知识。本文将介绍如何通过AI提示词,利用Deepseek快速生成各种类型的图表,包括柱状图、折线图、饼...
- 数据可视化:解析箱线图(box plot)
-
箱线图/盒须图(boxplot)是数据分布的图形表示,由五个摘要组成:最小值、第一四分位数(25th百分位数)、中位数、第三四分位数(75th百分位数)和最大值。箱子代表四分位距(IQR)。IQR是...
- [seaborn] seaborn学习笔记1-箱形图Boxplot
-
1箱形图Boxplot(代码下载)Boxplot可能是最常见的图形类型之一。它能够很好表示数据中的分布规律。箱型图方框的末尾显示了上下四分位数。极线显示最高和最低值,不包括异常值。seaborn中...
- 一周热门
- 最近发表
- 标签列表
-
- mybatiscollection (79)
- mqtt服务器 (88)
- keyerror (78)
- c#map (65)
- xftp6 (83)
- bt搜索 (75)
- c#var (76)
- xcode-select (66)
- mysql授权 (74)
- 下载测试 (70)
- linuxlink (65)
- pythonwget (67)
- androidinclude (65)
- libcrypto.so (74)
- linux安装minio (74)
- ubuntuunzip (67)
- vscode使用技巧 (83)
- secure-file-priv (67)
- vue阻止冒泡 (67)
- jquery跨域 (68)
- php写入文件 (73)
- kafkatools (66)
- mysql导出数据库 (66)
- jquery鼠标移入移出 (71)
- 取小数点后两位的函数 (73)