本文是一个面对初学者的,用tensorflow2.0来实现识别手写数字的实践文章,可以帮助初学者快速入门。识别手写数字就像是学习编程语言里的Hello world一样简单和著名,是程序员的最爱,也是进入人工智能里的第一步,方便后面更深入地了解tensorflow编程。
之前写过安装tensorflow2.0及运行一个对抗神经网络的文章。需要看的朋友可以查看之前的文章:
https://www.toutiao.com/i6915983671610638859/
https://www.toutiao.com/i6916094716174139907/
样本的数字如下:
先给大家看一下tensorflow2.0识别手写数字运行的结果:
正确率达到98%。
实现介绍
1、导入 TensorFlow:
# 安装 TensorFlow
import tensorflow as tf
2、载入并准备好 MNIST 数据集,tensorflow里面可以直接导入数据。将样本从整数转换为浮点数:
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
3、查看训练数据的形状:
x_train.shape
输出为:
60000张28*28像素的图像。
如果想看看数组,可以一个训练数据看一下
x_train[0]
是28*28的二维数组。如下
上面的不是很直观,我们可以画一个数字出来。
4、先看一下数字形状
用matplotlib库来画形状
import matplotlib.pyplot as plt
fig, ax = plt.subplots(nrows=5, ncols=5, sharex='all', sharey='all')
ax = ax.flatten()
for i in range(25):
img = x_train[i]
ax[i].set_title(y_train[i])
ax[i].imshow(img, cmap='Greys', interpolation='nearest')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()
输出为:
上面的数据为它真实的数据,图片里面的为手写的数字。
如果要看单个数字,可以:
plt.figure()
plt.imshow(x_train[0])
plt.show()
输出为:
5、将模型的各层堆叠起来,以搭建 tf.keras.Sequential 模型。为训练选择优化器和损失函数:
这个模型的输入的图片的形状为28*28的长度,全连接到128的全连接层上,有relu函数进行激活,为了防止过拟合,下采样的概率不0.2,最后用softmax再映射到10个分类上。用adam优化器,交叉熵损失函数,指标用准确度。
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
6、训练并验证模型:
model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test, verbose=2)
现在,这个照片分类器的准确度已经达到 98%。