神经网络与传统统计方法的简单对比
bigegpt 2024-11-10 08:26 66 浏览
传统的统计方法如OLS假设变量之间符合简单的线性关系或者高阶线性关系进行拟合(或函数逼近),然而,并不是所有关系都是简单的线性关系或者高阶线性关系,这时就需要借助神经网络 (neural network,NN)等方法来进行建模。神经网络可以在不需要知道函数关系具体形式的条件下近似各种函数关系。
预测模型
1. scikit-learn
下例使用scikit-learn 库中的 MLPRegressor 类,该类可用 DNN 进行回归估计。DNN 有时也被称为多层感知器(multi-layer perceptron,MLP)。从最终的MSE来看,结果并不完美,但是对一个配置简单的模型来说,效果已经非常不错了。
from sklearn.neural_network import MLPRegressor
# 生成样本数据
def f(x):
return 2 * x ** 2 - x ** 3 / 3
x = np.linspace(-2, 4, 25)
y = f(x)
# 实例化 MLPRegressor 对象
model = MLPRegressor(hidden_layer_sizes=3 * [256], learning_rate_init=0.03, max_iter=5000)
# 拟合或学习步骤。
model.fit(x.reshape(-1, 1), y)
# 预测步骤
y_ = model.predict(x.reshape(-1, 1))
MSE = ((y - y_) ** 2).mean()
MSE
# Out:
# 0.003216321978018745
样本和预测结果图
plt.figure(figsize=(10, 6))
plt.plot(x, y, 'ro', label='sample data')
plt.plot(x, y_, lw=3.0, label='dnn estimation')
plt.legend();
2. Keras
下一个示例使用了 Keras 深度学习软件包中的序列模型 Sequential,对该模型每轮进行100次迭代训练,重复5轮。每轮训练之后,我们将更新并绘制由神经网络预测的近似值。如图显示,随着每一轮训练的近似值的准确率逐渐提高,MSE值逐渐降低。与之前的模型相似,最终结果并不完美,但是鉴于模型的简单性,它还是不错的。
import tensorflow as tf
tf.random.set_seed(100)
from keras.layers import Dense
from keras.models import Sequential
# 实例化 Sequential 模型对象
model = Sequential()
# 添加采用整流线性单元(ReLU)激活函数的全连接层作为隐藏层
model.add(Dense(256, activation='relu', input_dim=1))
# 添加线性激活的输出层
model.add(Dense(1, activation='linear'))
# 编译模型对象
model.compile(loss='mse', optimizer='rmsprop')
# 原始样本数据图
plt.figure(figsize=(10, 6))
plt.plot(x, y, 'ro', label='sample data')
# 迭代训练指定次数
for _ in range(1, 6):
# 训练神经网络
model.fit(x, y, epochs=100, verbose=False)
# 预测近似值
y_ = model.predict(x)
# 计算当前的 MSE
MSE = ((y - y_.flatten()) ** 2).mean()
print(f'round={_} | MSE={MSE:.5f}')
# 绘制当前的近似结果
plt.plot(x, y_, '--', label=f'round={_}')
plt.legend();
# Out:
# round=1 | MSE=3.87256
# round=2 | MSE=0.92527
# round=3 | MSE=0.28527
# round=4 | MSE=0.13191
# round=5 | MSE=0.09568
从以上两个示例来看,相比OLS回归完美的复刻原有方程的系数,神经网络只能提供一个近似的预测,那么为什么还要使用神经网络呢?假设我们的数据不是通过预定义好的数学函数生成的,而是随机产生的特征和标签呢?下面我们再看一个例子,当然该示例仅用于说明,不具有实际意义。
# 随机生成测试数据
np.random.seed(0)
x = np.linspace(-1, 1)
y = np.random.random(len(x)) * 2 - 1
# 用不同的多次项OLS回归进行拟合
plt.figure(figsize=(10, 6))
plt.plot(x, y, 'ro', label='sample data')
for deg in [1, 5, 9, 11, 13, 15]:
reg = np.polyfit(x, y, deg=deg)
y_ = np.polyval(reg, x)
MSE = ((y - y_) ** 2).mean()
print(f'deg={deg:2d} | MSE={MSE:.5f}')
plt.plot(x, np.polyval(reg, x), label=f'deg={deg}')
plt.legend();
# Out:
# deg= 1 | MSE=0.28153
# deg= 5 | MSE=0.27331
# deg= 9 | MSE=0.25442
# deg=11 | MSE=0.23458
# deg=13 | MSE=0.22989
# deg=15 | MSE=0.21672
明显可见,OLS 回归的效果并不理想。OLS回归假设我们可以通过有限个(基于多项式的)基函数的组合来逼近目标函数,由于样本数据集是随机生成的,因此在这种情况下,OLS 回归效果不佳。下面我们用神经网络来试下。
model = Sequential()
model.add(Dense(256, activation='relu', input_dim=1))
# 此处添加3个隐藏层
for _ in range(3):
model.add(Dense(256, activation='relu'))
model.add(Dense(1, activation='linear'))
model.compile(loss='mse', optimizer='rmsprop')
# 显示神经网络架构以及可训练参数的数量
model.summary()
# Out:
# Model: "sequential_1"
# _________________________________________________________________
# Layer (type) Output Shape Param #
# =================================================================
# dense_2 (Dense) (None, 256) 512
#
# dense_3 (Dense) (None, 256) 65792
#
# dense_4 (Dense) (None, 256) 65792
#
# dense_5 (Dense) (None, 256) 65792
#
# dense_6 (Dense) (None, 1) 257
#
# =================================================================
# Total params: 198,145
# Trainable params: 198,145
# Non-trainable params: 0
# _________________________________________________________________
%%time
plt.figure(figsize=(10, 6))
plt.plot(x, y, 'ro', label='sample data')
for _ in range(1, 8):
model.fit(x, y, epochs=500, verbose=False)
y_ = model.predict(x)
MSE = ((y - y_.flatten()) ** 2).mean()
print(f'round={_} | MSE={MSE:.5f}')
plt.plot(x, y_, '--', label=f'round={_}')
plt.legend();
# Out:
# round=1 | MSE=0.13428
# round=2 | MSE=0.08515
# round=3 | MSE=0.05811
# round=4 | MSE=0.04389
# round=5 | MSE=0.03376
# round=6 | MSE=0.00722
# round=7 | MSE=0.00644
# CPU times: user 22.8 s, sys: 3.97 s, total: 26.8 s
# Wall time: 12.1 s
尽管预测结果并不完美,但预测结果明显好于OLS。神经网络架构有近200000个可训练的参数(权重),与OLS 回归(最多使用15+1个参数)相比,这提供了相对较高的灵活性。
分类任务
神经网络也可以很容易地用于分类任务。考虑以下基于 Keras 实现神经网络分类,二元特征数据和二元标签数据是随机生成的。建模方面的主要调整是将输出层的激活函数从linear更改为sigmoid。虽然分类效果并不完美,但是也达到了很高的准确率。
# 创建随机特征数据和标签数据
f = 5
n = 10
np.random.seed(124812)
x = np.random.randint(0, 2, (n, f))
y = np.random.randint(0, 2, n)
model = Sequential()
model.add(Dense(256, activation='relu', input_dim=f))
# 输出层的激活函数为 sigmoid
model.add(Dense(1, activation='sigmoid'))
# 损失函数为 binary_crossentropy
model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['acc'])
model.fit(x, y, epochs=50, verbose=False)
y_ = np.where(model.predict(x).flatten() > 0.5, 1, 0)
# 预测值与标签数据的比较结果
y == y_
# Out:
# array([ True, True, True, True, True, True, True, False, True, True])
# 绘制每轮训练的损失函数和准确率值
res = pd.DataFrame(model.history.history)
res.plot(figsize=(10, 6));
由以上示例说明,对比传统统计方法,神经网络的一些基本特征:
- 问题无关性
在给定一组特征值的情况下,神经网络方法的性能与需要预测或者分类的具体标签值是无关的。而统计方法(比如OLS 回归)可能对较小的一组问题表现良好,对其他问题则表现不太好或根本没有效果。
2. 增量学习
给定一个用来度量成功的目标,神经网络中的最佳权重是基于随机初始化和增量改进而逐步学习得到的。这些增量改进是在考虑预测值和样本标签值之间的差异后,通过神经网络反向传播权重更新来实现的。
3. 通用函数逼近器
有严格的数学定理表明神经网络(即使只有一个隐藏层)几乎可以逼近任何函数。
相关推荐
- 悠悠万事,吃饭为大(悠悠万事吃饭为大,什么意思)
-
新媒体编辑:杜岷赵蕾初审:程秀娟审核:汤小俊审签:周星...
- 高铁扒门事件升级版!婚宴上‘冲喜’老人团:我们抢的是社会资源
-
凌晨两点改方案时,突然收到婚庆团队发来的视频——胶东某酒店宴会厅,三个穿大红棉袄的中年妇女跟敢死队似的往前冲,眼瞅着就要扑到新娘的高额钻石项链上。要不是门口小伙及时阻拦,这婚礼造型团队熬了三个月的方案...
- 微服务架构实战:商家管理后台与sso设计,SSO客户端设计
-
SSO客户端设计下面通过模块merchant-security对SSO客户端安全认证部分的实现进行封装,以便各个接入SSO的客户端应用进行引用。安全认证的项目管理配置SSO客户端安全认证的项目管理使...
- 还在为 Spring Boot 配置类加载机制困惑?一文为你彻底解惑
-
在当今微服务架构盛行、项目复杂度不断攀升的开发环境下,SpringBoot作为Java后端开发的主流框架,无疑是我们手中的得力武器。然而,当我们在享受其自动配置带来的便捷时,是否曾被配置类加载...
- Seata源码—6.Seata AT模式的数据源代理二
-
大纲1.Seata的Resource资源接口源码2.Seata数据源连接池代理的实现源码3.Client向Server发起注册RM的源码4.Client向Server注册RM时的交互源码5.数据源连接...
- 30分钟了解K8S(30分钟了解微积分)
-
微服务演进方向o面向分布式设计(Distribution):容器、微服务、API驱动的开发;o面向配置设计(Configuration):一个镜像,多个环境配置;o面向韧性设计(Resista...
- SpringBoot条件化配置(@Conditional)全面解析与实战指南
-
一、条件化配置基础概念1.1什么是条件化配置条件化配置是Spring框架提供的一种基于特定条件来决定是否注册Bean或加载配置的机制。在SpringBoot中,这一机制通过@Conditional...
- 一招解决所有依赖冲突(克服依赖)
-
背景介绍最近遇到了这样一个问题,我们有一个jar包common-tool,作为基础工具包,被各个项目在引用。突然某一天发现日志很多报错。一看是NoSuchMethodError,意思是Dis...
- 你读过Mybatis的源码?说说它用到了几种设计模式
-
学习设计模式时,很多人都有类似的困扰——明明概念背得滚瓜烂熟,一到写代码就完全想不起来怎么用。就像学了一堆游泳技巧,却从没下过水实践,很难真正掌握。其实理解一个知识点,就像看立体模型,单角度观察总...
- golang对接阿里云私有Bucket上传图片、授权访问图片
-
1、为什么要设置私有bucket公共读写:互联网上任何用户都可以对该Bucket内的文件进行访问,并且向该Bucket写入数据。这有可能造成您数据的外泄以及费用激增,若被人恶意写入违法信息还可...
- spring中的资源的加载(spring加载原理)
-
最近在网上看到有人问@ContextConfiguration("classpath:/bean.xml")中除了classpath这种还有其他的写法么,看他的意思是想从本地文件...
- Android资源使用(android资源文件)
-
Android资源管理机制在Android的开发中,需要使用到各式各样的资源,这些资源往往是一些静态资源,比如位图,颜色,布局定义,用户界面使用到的字符串,动画等。这些资源统统放在项目的res/独立子...
- 如何深度理解mybatis?(如何深度理解康乐服务质量管理的5个维度)
-
深度自定义mybatis回顾mybatis的操作的核心步骤编写核心类SqlSessionFacotryBuild进行解析配置文件深度分析解析SqlSessionFacotryBuild干的核心工作编写...
- @Autowired与@Resource原理知识点详解
-
springIOCAOP的不多做赘述了,说下IOC:SpringIOC解决的是对象管理和对象依赖的问题,IOC容器可以理解为一个对象工厂,我们都把该对象交给工厂,工厂管理这些对象的创建以及依赖关系...
- java的redis连接工具篇(java redis client)
-
在Java里,有不少用于连接Redis的工具,下面为你介绍一些主流的工具及其特点:JedisJedis是Redis官方推荐的Java连接工具,它提供了全面的Redis命令支持,且...
- 一周热门
- 最近发表
- 标签列表
-
- mybatiscollection (79)
- mqtt服务器 (88)
- keyerror (78)
- c#map (65)
- resize函数 (64)
- xftp6 (83)
- bt搜索 (75)
- c#var (76)
- mybatis大于等于 (64)
- xcode-select (66)
- mysql授权 (74)
- 下载测试 (70)
- linuxlink (65)
- pythonwget (67)
- androidinclude (65)
- logstashinput (65)
- hadoop端口 (65)
- vue阻止冒泡 (67)
- oracle时间戳转换日期 (64)
- jquery跨域 (68)
- php写入文件 (73)
- kafkatools (66)
- mysql导出数据库 (66)
- jquery鼠标移入移出 (71)
- 取小数点后两位的函数 (73)