样本分布不平衡,机器学习准确率高又有什么用?
bigegpt 2025-03-05 13:34 13 浏览
前面无论是用全部变量还是筛选出的特征变量、无论如何十折交叉验证调参,获得的模型应用于测试集时虽然预测准确率能在90%以上,但与不基于任何信息的随机猜测相比,这个模型都是统计不显著的 (这一点可能意义也不大,样本不平衡时看模型整体准确性无意义)。一个原因应该是样本不平衡导致的。DLBCL组的样品数目约为FL组的3倍。不通过建模而只是盲猜结果为DLBCL即可获得75%的正确率。而FL组的预测准确率却很低。
而通常我们关注的是占少数的样本,如是否患病,我们更希望能尽量发现可能存在的疾病,提前采取措施。
因此如何处理非平衡样品是每一个算法应用于分类问题时都需要考虑的。
不平衡样本的模型构建中的影响主要体现在2个地方:
随机采样构建决策树时会有较大概率只拿到了样品多的分类,这些树将没有能力预测样品少的分类,从而构成无意义的决策树。
在决策树的每个分子节点所做的决策会倾向于整体分类纯度,因此样品少的分类对结果的贡献和影响少。
一般处理方式有下面4种:
Class weights: 样品少的类分类错误给予更高的罚分 (impose a heavier cost when errors are made in the minority class)
Down-sampling: 从样品多的类随机移除样品
Up-sampling: 在样品少的类随机复制样品 (randomly replicate instances in the minority class)
Synthetic minority sampling technique (SMOTE): 通过插值在样品少的类中合成填充样本
这些权重加权或采样技术对阈值依赖的评估指标如准确性等影响较大,它们相当于把决策阈值推向了ROC曲线中的”最优位置” (这在Boruta特征变量筛选部分有讲)。但这些权重加权或采样技术对ROC曲线通常影响不会太大。
基于模拟数据的样本不平衡处理
这里先通过一套模拟数据熟悉下处理流程,再应用于真实数据。采用caret包的twoClassSim函数生成包含20个有意义变量和10个噪音变量的数据集。该数据集包含5000个观察样品,分为两组,多数组和少数组的样品数目比例为50:1 (通过intercept参数控制)。
library(dplyr) # for data manipulation library(caret) # for model-building # install.packages("xts") # install.packages("quantmod") # wget https://cran.r-project.org/src/contrib/Archive/DMwR/DMwR_0.4.1.tar.gz # R CMD INSTALL DMwR_0.4.1.tar.gz library(DMwR) # for smote implementation # 或使用smotefamily代替 # library(smotefamily) # for smote implementation library(purrr) # for functional programming (map) library(pROC) # for AUC calculations set.seed(2969) imbal_train <- twoClassSim(5000, intercept = -25, linearVars = 20, noiseVars = 10) imbal_train$Class = ifelse(imbal_train$Class == "Class1", "Normal", "Disease") imbal_train$Class <- factor(imbal_train$Class, levels=c("Disease", "Normal")) imbal_test <- twoClassSim(5000, intercept = -25, linearVars = 20, noiseVars = 10) imbal_test$Class = ifelse(imbal_test$Class == "Class1", "Normal", "Disease") imbal_test$Class <- factor(imbal_test$Class, levels=c("Disease", "Normal")) prop.table(table(imbal_train$Class)) prop.table(table(imbal_test$Class))
样品构成
Disease Normal 0.0204 0.9796 Disease Normal 0.0252 0.9748
构建原始GBM模型
这里应用另外一种集成学习算法 (GBM, Gradient Boosting Machine)进行模型构建。GBM也是效果很好的集成学习算法,可以处理变量的互作和非线性关系。机器学习中常用的GBDT、XGBoost和LightGBM算法(或工具)都是基于梯度提升机(GBM)的算法思想。
先构建一个原始模型,重复5次10-折交叉验证寻找最优的模型超参数,采用AUC作为评估标准。这些概念如果不熟悉翻一下往期推文。
# Set up control function for training ctrl <- trainControl(method = "repeatedcv", number = 10, repeats = 5, summaryFunction = twoClassSummary, classProbs = TRUE) # Build a standard classifier using a gradient boosted machine set.seed(5627) orig_fit <- train(Class ~ ., data = imbal_train, method = "gbm", verbose = FALSE, metric = "ROC", trControl = ctrl) # Build custom AUC function to extract AUC # from the caret model object test_roc % test_roc(data = imbal_test) %>% auc()
AUC值为0.95,还是很不错的。
Setting levels: control = Disease, case = Normal Setting direction: controls > cases Area under the curve: 0.9538
从confusion matrix (预测结果采用默认阈值)来看,Disease的分类效果一般,准确率(敏感性)只有30.6%。不管是Normal还是Disease都倾向于预测为Normal,特异性低,这是因为样品不平衡导致的。而我们通常更希望尽早发现疾病的存在。
predictions_train <- predict(orig_fit, newdata=imbal_test) confusionMatrix(predictions_train, imbal_test$Class)
Confusion Matrix and Statistics Reference Prediction Disease Normal Disease 38 17 Normal 88 4857 Accuracy : 0.979 95% CI : (0.9746, 0.9828) No Information Rate : 0.9748 P-Value [Acc > NIR] : 0.02954 Kappa : 0.4109 Mcnemar's Test P-Value : 8.415e-12 Sensitivity : 0.3016 Specificity : 0.9965 Pos Pred Value : 0.6909 Neg Pred Value : 0.9822 Prevalence : 0.0252 Detection Rate : 0.0076 Detection Prevalence : 0.0110 Balanced Accuracy : 0.6490 'Positive' Class : Disease
采用权重分配或抽样方式处理样品不平衡问题
这里应用的GBM模型自身有一个参数weights可以用于设置样品的权重;caret在trainControl函数中提供了sampling参数可以进行up-sample和down-sample,或其它任何算法的采样方式(这里用的是smotefamily::SMOTE函数进行采样)。
# Create model weights (they sum to one) # 给每一个观察一个权重 class1_weight = (1/table(imbal_train$Class)[['Normal']]) * 0.5 class2_weight = (1/table(imbal_train$Class)[["Disease"]]) * 0.5 model_weights <- ifelse(imbal_train$Class == "Normal", class1_weight, class2_weight) # Use the same seed to ensure same cross-validation splits ctrl$seeds <- orig_fit$control$seeds # Build weighted model weighted_fit <- train(Class ~ ., data = imbal_train, method = "gbm", verbose = FALSE, weights = model_weights, metric = "ROC", trControl = ctrl) # Build down-sampled model ctrl$sampling <- "down" down_fit <- train(Class ~ ., data = imbal_train, method = "gbm", verbose = FALSE, metric = "ROC", trControl = ctrl) # Build up-sampled model ctrl$sampling <- "up" up_fit <- train(Class ~ ., data = imbal_train, method = "gbm", verbose = FALSE, metric = "ROC", trControl = ctrl) # Build smote model ctrl$sampling <- "smote" smote_fit <- train(Class ~ ., data = imbal_train, method = "gbm", verbose = FALSE, metric = "ROC", trControl = ctrl)
计算下每个模型的AUC值
# Examine results for test set model_list <- list(original = orig_fit, weighted = weighted_fit, down = down_fit, up = up_fit, SMOTE = smote_fit) model_list_roc % map(test_roc, data = imbal_test) model_list_roc %>% map(auc)
样品加权模型获得的AUC值最高,其次是up-sample, SMOTE, down-sample,结果都比original有提高。
Setting levels: control = Disease, case = Normal Setting direction: controls > cases Setting levels: control = Disease, case = Normal Setting direction: controls > cases Setting levels: control = Disease, case = Normal Setting direction: controls > cases Setting levels: control = Disease, case = Normal Setting direction: controls > cases Setting levels: control = Disease, case = Normal Setting direction: controls > cases $original Area under the curve: 0.9538 $weighted Area under the curve: 0.9793 $down Area under the curve: 0.9667 $up Area under the curve: 0.9778 $SMOTE Area under the curve: 0.9744
绘制下ROC曲线,查看下模型的具体效果展示。样品加权的模型优于其它所有模型,原始模型在假阳性率0-25%时效果差于其它模型。好的模型是在较低假阳性率时具有较高的真阳性率。
results_list_roc <- list(NA) num_mod <- 1 for(the_roc in model_list_roc){ results_list_roc[[num_mod]] <- data_frame(TPR = the_roc$sensitivities, FPR = 1 - the_roc$specificities, model = names(model_list)[num_mod]) num_mod <- num_mod + 1 } results_df_roc <- bind_rows(results_list_roc) results_df_roc$model <- factor(results_df_roc$model, levels=c("original", "down","SMOTE","up","weighted")) # Plot ROC curve for all 5 models custom_col <- c("#000000", "#009E73", "#0072B2", "#D55E00", "#CC79A7") ggplot(aes(x = FPR, y = TPR, group = model), data = results_df_roc) + geom_line(aes(color = model), size = 1) + scale_color_manual(values = custom_col) + geom_abline(intercept = 0, slope = 1, color = "gray", size = 1) + theme_bw(base_size = 18) + coord_fixed(1)
ggplot(aes(x = FPR, y = TPR, group = model), data = results_df_roc) + geom_line(aes(color = model), size = 1) + facet_wrap(vars(model)) + scale_color_manual(values = custom_col) + geom_abline(intercept = 0, slope = 1, color = "gray", size = 1) + theme_bw(base_size = 18) + coord_fixed(1)
加权后的模型,总预测准确率降低了一点,但Disease的预测准确性升高了2.47倍,70.63%。
predictions_train <- predict(weighted_fit, newdata=imbal_test) confusionMatrix(predictions_train, imbal_test$Class)
结果如下
Confusion Matrix and Statistics Reference Prediction Disease Normal Disease 89 83 Normal 37 4791 Accuracy : 0.976 95% CI : (0.9714, 0.9801) No Information Rate : 0.9748 P-Value [Acc > NIR] : 0.3137 Kappa : 0.5853 Mcnemar's Test P-Value : 3.992e-05 Sensitivity : 0.7063 Specificity : 0.9830 Pos Pred Value : 0.5174 Neg Pred Value : 0.9923 Prevalence : 0.0252 Detection Rate : 0.0178 Detection Prevalence : 0.0344 Balanced Accuracy : 0.8447 'Positive' Class : Disease
从这套测试数据来看,设置权重获得的模型效果是最好的。但这不是绝对的,应用于自己的数据时,需要都尝试一下,看看自己的数据更适合哪种方式。
未完待续......
相关推荐
- 方差分析简介(方差分析通俗理解)
-
介绍方差分析(ANOVA,AnalysisofVariance)是一种广泛使用的统计方法,用于比较两个或多个组之间的均值。单因素方差分析是方差分析的一种变体,旨在检测三个或更多分类组的均值是否存在...
- 正如404页面所预示,猴子正成为断网元凶--吧嗒吧嗒真好吃
-
吧嗒吧嗒,绘图:MakiNaro你可以通过加热、冰冻、水淹、模塑、甚至压溃压力来使网络光缆硬化。但用猴子显然是不行的。光缆那新挤压成型的塑料外皮太尼玛诱人了,无法阻挡一场试吃盛宴的举行。印度政府正...
- Python数据可视化:箱线图多种库画法
-
概念箱线图通过数据的四分位数来展示数据的分布情况。例如:数据的中心位置,数据间的离散程度,是否有异常值等。把数据从小到大进行排列并等分成四份,第一分位数(Q1),第二分位数(Q2)和第三分位数(Q3)...
- 多组独立(完全随机设计)样本秩和检验的SPSS操作教程及结果解读
-
作者/风仕在上一期,我们已经讲完了两组独立样本秩和检验的SPSS操作教程及结果解读,这期开始讲多组独立样本秩和检验,我们主要从多组独立样本秩和检验介绍、两组独立样本秩和检验使用条件及案例的SPSS操作...
- 方差分析 in R语言 and Excel(方差分析r语言例题)
-
今天来写一篇实际中比较实用的分析方法,方差分析。通过方差分析,我们可以确定组别之间的差异是否超出了由于随机因素引起的差异范围。方差分析分为单因素方差分析和多因素方差分析,这一篇先介绍一下单因素方差分析...
- 可视化:前端数据可视化插件大盘点 图表/图谱/地图/关系图
-
前端数据可视化插件大盘点图表/图谱/地图/关系图全有在大数据时代,很多时候我们需要在网页中显示数据统计报表,从而能很直观地了解数据的走向,开发人员很多时候需要使用图表来表现一些数据。随着Web技术的...
- matplotlib 必知的 15 个图(matplotlib各种图)
-
施工专题,我已完成20篇,施工系列几乎覆盖Python完整技术栈,目标只总结实践中最实用的东西,直击问题本质,快速帮助读者们入门和进阶:1我的施工计划2数字专题3字符串专题4列表专题5流程控制专题6编...
- R ggplot2常用图表绘制指南(ggplot2绘制折线图)
-
ggplot2是R语言中强大的数据可视化包,基于“图形语法”(GrammarofGraphics),通过分层方式构建图表。以下是常用图表命令的详细指南,涵盖基本语法、常见图表类型及示例,适合...
- Python数据可视化:从Pandas基础到Seaborn高级应用
-
数据可视化是数据分析中不可或缺的一环,它能帮助我们直观理解数据模式和趋势。本文将全面介绍Python中最常用的三种可视化方法。Pandas内置绘图功能Pandas基于Matplotlib提供了简洁的绘...
- Python 数据可视化常用命令备忘录
-
本文提供了一个全面的Python数据可视化备忘单,适用于探索性数据分析(EDA)。该备忘单涵盖了单变量分析、双变量分析、多变量分析、时间序列分析、文本数据分析、可视化定制以及保存与显示等内容。所...
- 统计图的种类(统计图的种类及特点图片)
-
统计图是利用几何图形或具体事物的形象和地图等形式来表现社会经济现象数量特征和数量关系的图形。以下是几种常见的统计图类型及其适用场景:1.条形图(BarChart)条形图是用矩形条的高度或长度来表示...
- 实测,大模型谁更懂数据可视化?(数据可视化和可视化分析的主要模型)
-
大家好,我是Ai学习的老章看论文时,经常看到漂亮的图表,很多不知道是用什么工具绘制的,或者很想复刻类似图表。实测,大模型LaTeX公式识别,出乎预料前文,我用Kimi、Qwen-3-235B...
- 通过AI提示词让Deepseek快速生成各种类型的图表制作
-
在数据分析和可视化领域,图表是传达信息的重要工具。然而,传统图表制作往往需要专业的软件和一定的技术知识。本文将介绍如何通过AI提示词,利用Deepseek快速生成各种类型的图表,包括柱状图、折线图、饼...
- 数据可视化:解析箱线图(box plot)
-
箱线图/盒须图(boxplot)是数据分布的图形表示,由五个摘要组成:最小值、第一四分位数(25th百分位数)、中位数、第三四分位数(75th百分位数)和最大值。箱子代表四分位距(IQR)。IQR是...
- [seaborn] seaborn学习笔记1-箱形图Boxplot
-
1箱形图Boxplot(代码下载)Boxplot可能是最常见的图形类型之一。它能够很好表示数据中的分布规律。箱型图方框的末尾显示了上下四分位数。极线显示最高和最低值,不包括异常值。seaborn中...
- 一周热门
- 最近发表
- 标签列表
-
- mybatiscollection (79)
- mqtt服务器 (88)
- keyerror (78)
- c#map (65)
- xftp6 (83)
- bt搜索 (75)
- c#var (76)
- xcode-select (66)
- mysql授权 (74)
- 下载测试 (70)
- linuxlink (65)
- pythonwget (67)
- androidinclude (65)
- libcrypto.so (74)
- linux安装minio (74)
- ubuntuunzip (67)
- vscode使用技巧 (83)
- secure-file-priv (67)
- vue阻止冒泡 (67)
- jquery跨域 (68)
- php写入文件 (73)
- kafkatools (66)
- mysql导出数据库 (66)
- jquery鼠标移入移出 (71)
- 取小数点后两位的函数 (73)