百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 热门文章 > 正文

【Python机器学习系列】建立梯度提升模型预测心脏疾病

bigegpt 2025-05-27 12:49 19 浏览

这是Python机器学习系列原创文章,我的第204篇原创文章。

一、引言

对于表格数据,一套完整的机器学习建模流程如下:

针对不同的数据集,有些步骤不适用即不需要做,其中橘红色框为必要步骤,由于数据质量较高,本文有些步骤跳过了,跳过的步骤将单独出文章总结!同时欢迎大家关注翻看我之前的一些相关文章。


GradientBoostingClassifier是一种基于梯度提升算法的分类器,它是scikit-learn库中的一个类。梯度提升是一种集成学习方法,通过组合多个弱学习器(通常是决策树,梯度提升决策树GBDT)来构建一个更强大的分类器。梯度提升模型的基本思想是利用梯度下降来最小化损失函数,以逐步优化模型的预测能力。在每一轮迭代中,模型会计算当前模型对样本的预测值与实际值之间的残差,然后使用一个新的弱学习器来拟合这个残差。通过迭代地拟合残差,每个弱学习器都会以一定的学习率加入到模型中,最终得到一个强大的集成模型。
本文将实现基于心脏疾病数据集建立梯度提升模型对心脏疾病患者进行分类预测的完整过程。

二、实现过程

导入必要的库

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_curve
from sklearn.metrics import auc
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

1、准备数据

data = pd.read_csv(r'Dataset.csv')
df = pd.DataFrame(data)

df:

数据基本信息:

print(df.head())
print(df.info())
print(df.shape)
print(df.columns)
print(df.dtypes)
cat_cols = [col for col in df.columns if df[col].dtype == "object"] # 类别型变量名
num_cols = [col for col in df.columns if df[col].dtype != "object"] # 数值型变量名

2、提取特征变量和目标变量

target = 'target'
features = df.columns.drop(target)
print(data["target"].value_counts()) # 顺便查看一下样本是否平衡

3、数据集划分

# df = shuffle(df)
X_train, X_test, y_train, y_test = train_test_split(df[features], df[target], test_size=0.2, random_state=0)

4、模型的构建与训练

# 模型的构建与训练
model = GradientBoostingClassifier()
model.fit(X_train, y_train)

参数详解:

from sklearn.ensemble import GradientBoostingClassifier
# 全部参数
GradientBoostingClassifier(loss='log_loss', 
                            learning_rate=0.1, 
                            n_estimators=100, 
                            subsample=1.0, 
                            criterion='friedman_mse', 
                            min_samples_split=2, 
                            min_samples_leaf=1, 
                            min_weight_fraction_leaf=0.0, 
                            max_depth=3, 
                            min_impurity_decrease=0.0, 
                            init=None, 
                            random_state=None, 
                            max_features=None, 
                            verbose=0, 
                            max_leaf_nodes=None, 
                            warm_start=False, 
                            validation_fraction=0.1, 
                            n_iter_no_change=None, 
                            tol=0.0001, 
                            ccp_alpha=0.0)
  • loss:损失函数的类型。默认为deviance,表示使用对数似然损失函数进行分类。可以选择exponential,表示使用指数损失函数进行分类。
  • learning_rate:学习率,控制每个弱学习器的贡献。较小的学习率会使模型收敛得更慢,但可能会获得更好的性能。默认为0.1。
  • n_estimators:弱学习器(决策树)的数量。默认为100。
  • subsample:用于训练每个弱学习器的样本子集的比例。默认为1.0,表示使用全部样本。可以设置小于1.0的值来降低方差,防止过拟合。
  • criterion:决策树节点分裂的标准。默认为friedman_mse,表示使用Friedman均方误差作为分裂标准。可以选择mse,表示使用均方误差,或mae,表示使用平均绝对误差。
  • max_depth:决策树的最大深度。默认为3。增加深度可以增加模型的复杂度,但也容易导致过拟合。
  • min_samples_split:决策树节点分裂所需的最小样本数。默认为2。如果某个节点的样本数少于该值,则不会再进行分裂。
  • min_samples_leaf:叶节点所需的最小样本数。默认为1。如果叶节点的样本数少于该值,则不会进行进一步的分裂。
  • max_features:每个决策树节点考虑的特征数量。可以是整数、浮点数或字符串。默认为None,表示考虑所有特征。可以选择sqrt,表示考虑特征数量的平方根,或log2,表示考虑特征数量的对数。
  • random_state:随机种子。可以用于重现实验结果。
  • verbose:控制训练过程中的输出信息的详细程度。默认为0,表示不输出任何信息。较大的值会增加输出信息的数量。

5、模型的推理与评价

y_pred = model.predict(X_test)
y_scores = model.predict_proba(X_test)
acc = accuracy_score(y_test, y_pred) # 准确率acc
cm = confusion_matrix(y_test, y_pred) # 混淆矩阵
cr = classification_report(y_test, y_pred) # 分类报告
fpr, tpr, thresholds = roc_curve(y_test, y_scores[:, 1], pos_label=1) # 计算ROC曲线和AUC值,绘制ROC曲线
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show()

cm:

cr:

ROC:

三、小结

本文利用scikit-learn(一个常用的机器学习库)实现了基于心脏疾病数据集建立梯度提升模型对心脏疾病患者进行分类预测的完整过程。

作者简介:

读研期间发表6篇SCI数据算法相关论文,目前在某研究院从事数据算法相关研究工作,结合自身科研实践经历不定期持续分享关于Python、数据分析、特征工程、机器学习、深度学习、人工智能系列基础知识与案例。需要数据源码的朋友关注gzh:数据杂坛,或点击原文链接,联系作者。

原文链接:

【Python机器学习系列】建立梯度提升模型预测心脏疾病(完整实现过程)

相关推荐

方差分析简介(方差分析通俗理解)

介绍方差分析(ANOVA,AnalysisofVariance)是一种广泛使用的统计方法,用于比较两个或多个组之间的均值。单因素方差分析是方差分析的一种变体,旨在检测三个或更多分类组的均值是否存在...

正如404页面所预示,猴子正成为断网元凶--吧嗒吧嗒真好吃

吧嗒吧嗒,绘图:MakiNaro你可以通过加热、冰冻、水淹、模塑、甚至压溃压力来使网络光缆硬化。但用猴子显然是不行的。光缆那新挤压成型的塑料外皮太尼玛诱人了,无法阻挡一场试吃盛宴的举行。印度政府正...

Python数据可视化:箱线图多种库画法

概念箱线图通过数据的四分位数来展示数据的分布情况。例如:数据的中心位置,数据间的离散程度,是否有异常值等。把数据从小到大进行排列并等分成四份,第一分位数(Q1),第二分位数(Q2)和第三分位数(Q3)...

多组独立(完全随机设计)样本秩和检验的SPSS操作教程及结果解读

作者/风仕在上一期,我们已经讲完了两组独立样本秩和检验的SPSS操作教程及结果解读,这期开始讲多组独立样本秩和检验,我们主要从多组独立样本秩和检验介绍、两组独立样本秩和检验使用条件及案例的SPSS操作...

方差分析 in R语言 and Excel(方差分析r语言例题)

今天来写一篇实际中比较实用的分析方法,方差分析。通过方差分析,我们可以确定组别之间的差异是否超出了由于随机因素引起的差异范围。方差分析分为单因素方差分析和多因素方差分析,这一篇先介绍一下单因素方差分析...

可视化:前端数据可视化插件大盘点 图表/图谱/地图/关系图

前端数据可视化插件大盘点图表/图谱/地图/关系图全有在大数据时代,很多时候我们需要在网页中显示数据统计报表,从而能很直观地了解数据的走向,开发人员很多时候需要使用图表来表现一些数据。随着Web技术的...

matplotlib 必知的 15 个图(matplotlib各种图)

施工专题,我已完成20篇,施工系列几乎覆盖Python完整技术栈,目标只总结实践中最实用的东西,直击问题本质,快速帮助读者们入门和进阶:1我的施工计划2数字专题3字符串专题4列表专题5流程控制专题6编...

R ggplot2常用图表绘制指南(ggplot2绘制折线图)

ggplot2是R语言中强大的数据可视化包,基于“图形语法”(GrammarofGraphics),通过分层方式构建图表。以下是常用图表命令的详细指南,涵盖基本语法、常见图表类型及示例,适合...

Python数据可视化:从Pandas基础到Seaborn高级应用

数据可视化是数据分析中不可或缺的一环,它能帮助我们直观理解数据模式和趋势。本文将全面介绍Python中最常用的三种可视化方法。Pandas内置绘图功能Pandas基于Matplotlib提供了简洁的绘...

Python 数据可视化常用命令备忘录

本文提供了一个全面的Python数据可视化备忘单,适用于探索性数据分析(EDA)。该备忘单涵盖了单变量分析、双变量分析、多变量分析、时间序列分析、文本数据分析、可视化定制以及保存与显示等内容。所...

统计图的种类(统计图的种类及特点图片)

统计图是利用几何图形或具体事物的形象和地图等形式来表现社会经济现象数量特征和数量关系的图形。以下是几种常见的统计图类型及其适用场景:1.条形图(BarChart)条形图是用矩形条的高度或长度来表示...

实测,大模型谁更懂数据可视化?(数据可视化和可视化分析的主要模型)

大家好,我是Ai学习的老章看论文时,经常看到漂亮的图表,很多不知道是用什么工具绘制的,或者很想复刻类似图表。实测,大模型LaTeX公式识别,出乎预料前文,我用Kimi、Qwen-3-235B...

通过AI提示词让Deepseek快速生成各种类型的图表制作

在数据分析和可视化领域,图表是传达信息的重要工具。然而,传统图表制作往往需要专业的软件和一定的技术知识。本文将介绍如何通过AI提示词,利用Deepseek快速生成各种类型的图表,包括柱状图、折线图、饼...

数据可视化:解析箱线图(box plot)

箱线图/盒须图(boxplot)是数据分布的图形表示,由五个摘要组成:最小值、第一四分位数(25th百分位数)、中位数、第三四分位数(75th百分位数)和最大值。箱子代表四分位距(IQR)。IQR是...

[seaborn] seaborn学习笔记1-箱形图Boxplot

1箱形图Boxplot(代码下载)Boxplot可能是最常见的图形类型之一。它能够很好表示数据中的分布规律。箱型图方框的末尾显示了上下四分位数。极线显示最高和最低值,不包括异常值。seaborn中...