菜鸟:简单神经网络train and test详解(双层)
bigegpt 2024-10-07 06:34 16 浏览
简单神经网络train and test详解(双层)
【 The latest data : 2018/05/01 】Yuchen
1. NN模型如下
神经网络整体架构内容可参考之前的云笔记《06_神经网络整体架构》
http://note.youdao.com/noteshare?id=2c27bbf6625d75e4173d9fcbeea5e8c1&sub=7F4BC70112524F9289531EC6AE435E14
其中,
n是指的样本数
Mnist数据集 784是28×28×1 灰度图 channel = 1
wb是指的权重参数
输出的是10分类的得分值,也可以接softmax分类器
out是L2层和输出层之间的关系
256 128 10是指的神经元数量
2. 构造参数
函数构造
3. Code
1. 网络模型架构搭建
导入相应数据
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import input_data
mnist = input_data.read_data_sets('data/', one_hot=True)
network topologies
# 网络拓扑 network topologies
# layer中神经元数量
n_hidden_1 =256
n_hidden_2 =128
# 输入数据的像素点 28x28x1
n_input =784
# 10分类
n_classes =10
input and output
x = tf.placeholder("float",[None,n_input])
y = tf.placeholder("float",[None,n_classes])
network parameters
# network parameters
# 方差
stddev =0.1
# random_normal 高斯初始化
weights ={
'w1': tf.Variable(tf.random_normal([n_input,n_hidden_1],stddev=stddev)),
'w2': tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2],stddev=stddev)),
'out': tf.Variable(tf.random_normal([n_hidden_2,n_classes],stddev=stddev))
}
# 对于 b 零值初始化也可以
biases ={
'b1': tf.Variable(tf.random_normal([n_hidden_1])),
'b2': tf.Variable(tf.random_normal([n_hidden_2])),
'out': tf.Variable(tf.random_normal([n_classes]))
}
print("Network Ready")
output
NetworkReady
可以看到网络模型架构搭建成功
2.训练网络模型
定义前向传播函数
# 定义前向传播函数
def multilayer_perceptron(_X, _weights, _biases):
# 之所以加 sigmoid 是因为每一个 hidden layer 都有一个非线性函数
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(_X, _weights['w1']), _biases['b1']))
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, _weights['w2']), _biases['b2']))
return(tf.matmul(layer_2, _weights['out'])+ _biases['out'])
反向传播
(1)将前向传播预测值
# prediction
pred = multilayer_perceptron(x, weights, biases)
(2)定义损失函数
# 首先定义损失函数 softmax_cross_entropy_with_logits 交叉熵函数
# 交叉熵函数的输入有 pred : 网络的预测值 (前向传播的结果)
# y : 实际的label值
# 将两参数的一系列的比较结果,除以 batch 求平均之后的 loss 返回给 cost 损失值
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
(3)梯度下降最优化
optm = tf.train.GradientDescentOptimizer(learning_rate =0.001).minimize(cost)
(4)精确值
具体解释详见上一篇笔记《06_迭代完成逻辑回归模型》
corr = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accr = tf.reduce_mean(tf.cast(corr,"float"))
(5)初始化
# initializer
init = tf.global_variables_initializer()
print("Function ready")
output
Function ready
可以看出传播中的参数和优化模型搭建成功
3. Train and Test
training_epochs =20
# 每次 iteration 的样本
batch_size =100
# 每四个 epoch 打印一次结果
display_step =4
# lanch the graph
sess = tf.Session()
sess.run(init)
# optimize
for epoch in range(training_epochs):
# 初始,平均 loss = 0
avg_cost =0
total_batch =int(mnist.train.num_examples/batch_size)
# iteration
for i in range(total_batch):
# 通过 next_batch 返回相应的 batch_xs,batch_ys
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
feeds ={x: batch_xs, y: batch_ys}
sess.run(optm, feed_dict = feeds)
avg_cost += sess.run(cost, feed_dict = feeds)
avg_cost = avg_cost / total_batch
# display
if(epoch+1)% display_step ==0:
print("Epoch: %03d/%03d cost: %.9f "%(epoch, training_epochs, avg_cost))
feeds ={x: batch_xs, y: batch_ys}
train_acc = sess.run(accr, feed_dict = feeds)
print("train accuracy: %.3f"%(train_acc))
feeds ={x: mnist.test.images, y: mnist.test.labels}
test_acc = sess.run(accr, feed_dict = feeds)
print("test accuracy: %.3f"%(test_acc))
print("optimization finished")
output
Epoch:003/020 cost:2.273774184
train accuracy:0.250
test accuracy:0.197
Epoch:007/020 cost:2.240329206
train accuracy:0.270
test accuracy:0.311
Epoch:011/020 cost:2.203503076
train accuracy:0.370
test accuracy:0.404
Epoch:015/020 cost:2.161286944
train accuracy:0.490
test accuracy:0.492
Epoch:019/020 cost:2.111541148
train accuracy:0.410
test accuracy:0.534
optimization finished
20个batch每个batch 100个样本,每隔4个batch打印一次
处理器:Intel Core i5-6200U CPU @ 2.30GHz 2.04GHz
04 epoch:train+test, cost_time: 25’40”
08 epoch:train+test, cost_time: 50’29”
12 epoch:train+test, cost_time: 74’42”
16 epoch:train+test, cost_time: 98’63”
20 epoch:train+test, cost_time: 121’49”
想要更完整代码或者跟作者交流,请留言头条号。加入我们的社群。
相关推荐
- 当Frida来“敲”门(frida是什么)
-
0x1渗透测试瓶颈目前,碰到越来越多的大客户都会将核心资产业务集中在统一的APP上,或者对自己比较重要的APP,如自己的主业务,办公APP进行加壳,流量加密,投入了很多精力在移动端的防护上。而现在挖...
- 服务端性能测试实战3-性能测试脚本开发
-
前言在前面的两篇文章中,我们分别介绍了性能测试的理论知识以及性能测试计划制定,本篇文章将重点介绍性能测试脚本开发。脚本开发将分为两个阶段:阶段一:了解各个接口的入参、出参,使用Python代码模拟前端...
- Springboot整合Apache Ftpserver拓展功能及业务讲解(三)
-
今日分享每天分享技术实战干货,技术在于积累和收藏,希望可以帮助到您,同时也希望获得您的支持和关注。架构开源地址:https://gitee.com/msxyspringboot整合Ftpserver参...
- Linux和Windows下:Python Crypto模块安装方式区别
-
一、Linux环境下:fromCrypto.SignatureimportPKCS1_v1_5如果导包报错:ImportError:Nomodulenamed'Crypt...
- Python 3 加密简介(python des加密解密)
-
Python3的标准库中是没多少用来解决加密的,不过却有用于处理哈希的库。在这里我们会对其进行一个简单的介绍,但重点会放在两个第三方的软件包:PyCrypto和cryptography上,我...
- 怎样从零开始编译一个魔兽世界开源服务端Windows
-
第二章:编译和安装我是艾西,上期我们讲述到编译一个魔兽世界开源服务端环境准备,那么今天跟大家聊聊怎么编译和安装我们直接进入正题(上一章没有看到的小伙伴可以点我主页查看)编译服务端:在D盘新建一个文件夹...
- 附1-Conda部署安装及基本使用(conda安装教程)
-
Windows环境安装安装介质下载下载地址:https://www.anaconda.com/products/individual安装Anaconda安装时,选择自定义安装,选择自定义安装路径:配置...
- 如何配置全世界最小的 MySQL 服务器
-
配置全世界最小的MySQL服务器——如何在一块IntelEdison为控制板上安装一个MySQL服务器。介绍在我最近的一篇博文中,物联网,消息以及MySQL,我展示了如果Partic...
- 如何使用Github Action来自动化编译PolarDB-PG数据库
-
随着PolarDB在国产数据库领域荣膺桂冠并持续获得广泛认可,越来越多的学生和技术爱好者开始关注并涉足这款由阿里巴巴集团倾力打造且性能卓越的关系型云原生数据库。有很多同学想要上手尝试,却卡在了编译数据...
- 面向NDK开发者的Android 7.0变更(ndk android.mk)
-
订阅Google官方微信公众号:谷歌开发者。与谷歌一起创造未来!受Android平台其他改进的影响,为了方便加载本机代码,AndroidM和N中的动态链接器对编写整洁且跨平台兼容的本机...
- 信创改造--人大金仓(Kingbase)数据库安装、备份恢复的问题纪要
-
问题一:在安装KingbaseES时,安装用户对于安装路径需有“读”、“写”、“执行”的权限。在Linux系统中,需要以非root用户执行安装程序,且该用户要有标准的home目录,您可...
- OpenSSH 安全漏洞,修补操作一手掌握
-
1.漏洞概述近日,国家信息安全漏洞库(CNNVD)收到关于OpenSSH安全漏洞(CNNVD-202407-017、CVE-2024-6387)情况的报送。攻击者可以利用该漏洞在无需认证的情况下,通...
- Linux:lsof命令详解(linux lsof命令详解)
-
介绍欢迎来到这篇博客。在这篇博客中,我们将学习Unix/Linux系统上的lsof命令行工具。命令行工具是您使用CLI(命令行界面)而不是GUI(图形用户界面)运行的程序或工具。lsoflsof代表&...
- 幻隐说固态第一期:固态硬盘接口类别
-
前排声明所有信息来源于网络收集,如有错误请评论区指出更正。废话不多说,目前固态硬盘接口按速度由慢到快分有这几类:SATA、mSATA、SATAExpress、PCI-E、m.2、u.2。下面我们来...
- 新品轰炸 影驰SSD多款产品登Computex
-
分享泡泡网SSD固态硬盘频道6月6日台北电脑展作为全球第二、亚洲最大的3C/IT产业链专业展,吸引了众多IT厂商和全球各地媒体的热烈关注,全球存储新势力—影驰,也积极参与其中,为广大玩家朋友带来了...
- 一周热门
- 最近发表
-
- 当Frida来“敲”门(frida是什么)
- 服务端性能测试实战3-性能测试脚本开发
- Springboot整合Apache Ftpserver拓展功能及业务讲解(三)
- Linux和Windows下:Python Crypto模块安装方式区别
- Python 3 加密简介(python des加密解密)
- 怎样从零开始编译一个魔兽世界开源服务端Windows
- 附1-Conda部署安装及基本使用(conda安装教程)
- 如何配置全世界最小的 MySQL 服务器
- 如何使用Github Action来自动化编译PolarDB-PG数据库
- 面向NDK开发者的Android 7.0变更(ndk android.mk)
- 标签列表
-
- mybatiscollection (79)
- mqtt服务器 (88)
- keyerror (78)
- c#map (65)
- resize函数 (64)
- xftp6 (83)
- bt搜索 (75)
- c#var (76)
- mybatis大于等于 (64)
- xcode-select (66)
- mysql授权 (74)
- 下载测试 (70)
- linuxlink (65)
- pythonwget (67)
- androidinclude (65)
- libcrypto.so (74)
- logstashinput (65)
- hadoop端口 (65)
- vue阻止冒泡 (67)
- jquery跨域 (68)
- php写入文件 (73)
- kafkatools (66)
- mysql导出数据库 (66)
- jquery鼠标移入移出 (71)
- 取小数点后两位的函数 (73)