百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 热门文章 > 正文

一文上手最新Tensorflow2.0系列|“tf.data”API 使用

bigegpt 2024-08-08 12:06 2 浏览

除了GPU和TPU等硬件加速设备以外,一个高效的数据输入管道也可以很大程度的提升模型性能,减少模型训练所需要的时间。数据输入管道本质是一个ELT(Extract、Transform和Load)过程:

  • Extract:从硬盘中读取数据(可以是本地的也可以是云端的)。
  • Transform:数据的预处理(例如数据清洗、格式转换等)。
  • Load:将处理好的数据加载到计算设备(例如CPU、GPU以及TPU等)。

数据输入管道一般使用CPU来执行ELT过程,GPU等其他硬件加速设备则负责模型的训练,ELT过程和模型的训练并行执行,从而提高模型训练的效率。另外ELT过程的各个步骤也都可以进行相应的优化,例如并行的读取数据以及并行的处理数据等。在TensorFlow中我们可以使用“tf.data”API来构建这样的数据输入管道。

我们首先下载实验中需要用的图像数据集(下载地址“https://storage.googleapis.com/download.tensorflow.org/example_images/”,百度网盘地址“https://pan.baidu.com/s/16fvNOBvKyGVa8yCB5mDUOQ”)。

该数据集是一个花朵图片的数据集,将下载下来的数据解压后如图所示,除了一个License文件以外主要是五个分别存放着对应类别花朵图片的文件夹。其中“daisy(雏菊)”文件夹中有633张图片,“dandelion(蒲公英)”文件夹中有898张图片,“roses(玫瑰)”文件夹中有641张图片,“sunflowers(向日葵)”文件夹中有699张图片,“tulips(郁金香)”文件夹中有799张图片。


接下来我们开始实现代码,首先我们导入需要使用的包:

import tensorflow as tf
import pathlib

pathlib提供了一组用于处理文件系统路径的类。导入需要的包后,可以先检查一下TensorFlow的版本:

print(tf.__version__)

首先获取所有图片样本文件的路径:

# 获取当前路径
data_root = pathlib.Path.cwd()
# 获取指定目录下的文件路径(返回是一个列表,每一个元素是一个PosixPath对象)
all_image_paths = list(data_root.glob('*/*/*'))
print(type(all_image_paths[0]))
# 将PosixPath对象转为字符串
all_image_paths = [str(path) for path in all_image_paths]
print(all_image_paths[0])
print(data_root)

输出结果如图所示:

接下来我们需要统计图片的类别,并给每一个类别分配一个类标:

# 获取图片类别的名称,即存放样本图片的五个文件夹的名称
label_names = sorted(item.name for item in data_root.glob('*/*/') if item.is_dir())
# 将类别名称转为数值型的类标
label_to_index = dict((name, index) for index, name in enumerate(label_names))
# 获取所有图片的类标
all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
 
print(label_to_index)
print("First 10 labels indices: ", all_image_labels[:2])
print("First 10 labels indices: ", all_image_paths[:2])

输出结果如图所示,daisy(雏菊)、dandelion(蒲公英)、roses(玫瑰)、sunflowers(向日葵)、tulips(郁金香)的类标分别为0、1、2、3、4、5。


处理完类标之后,我们接下来需要对图片本身做一些处理,这里我们定义一个函数用来加载和预处理图片数据:

def load_and_preprocess_image(path):
 # 读取图片
 image = tf.io.read_file(path)
 # 将jpeg格式的图片解码,得到一个张量(三维的矩阵)
 image = tf.image.decode_jpeg(image, channels=3)
 # 由于数据集中每张图片的大小不一样,统一调整为192*192
 image = tf.image.resize(image, [192, 192])
 # 对每个像素点的RGB值做归一化处理
 image /= 255.0
 
 return image

完成对类标和图像数据的预处理之后,我们使用“tf.data.Dataset”来构建和管理数据集:

# 构建图片路径的“dataset”
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
# 使用AUTOTUNE自动调节管道参数
AUTOTUNE = tf.data.experimental.AUTOTUNE
# 构建图片数据的“dataset”
image_ds = path_ds.map(load_and_preprocess_image,
num_parallel_calls=AUTOTUNE)
# 构建类标数据的“dataset”
label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))
# 将图片和类标压缩为(图片,类标)对
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
 
print(image_ds)
print(label_ds)
print(image_label_ds)

输出结果:


在代码中,我们使用了“from_tensor_slices”方法使用张量的切片元素构建“dataset”,“tf.data.Dataset”类还提供了“from_tensor”直接使用单个张量来构建“dataset”,以及可以使用生成器生成的元素来构建“dataset”的“from_generator”方法。

我们还使用了“tf.data.Dataset”的“map”方法,该方法允许我们自己定义一个函数,将原数据集中的元素依次经过该函数处理,并将处理后的数据作为新的数据集,处理前和处理后的数据顺序不变。例如这里我们自己定义了一个“load_and_preprocess_image”函数,将“path_ds”中的图片路径转换成了经过预处理的图像数据,并保存在了“image_ds”中。

最后我们使用“tf.data.Dataset”的“zip”方法将图片数据和类标数据压缩成“(图片,类标)”对,其结构如图所示。我们可视化一下数据集中的部分数据:

import matplotlib.pyplot as plt
 
plt.figure(figsize=(8,8))
for n,image_label in enumerate(image_label_ds.take(4)):
 plt.subplot(2,2,n+1)
 plt.imshow(image_label[0])
 plt.grid(False)
 plt.xticks([])
 plt.yticks([])
 plt.xlabel(image_label[1])

结果如图所示:


接下来我们用创建的dataset训练一个分类模型,为了简单,我们直接使用“tf.keras.applications”包中训练好的模型,并将其迁移到我们的花朵分类任务上来。这里我们使用“MobileNetV2”模型。

# 下载的模型在用户根目录下,具体位置:“~/.keras/models/
mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_192_no_top.h5”
mobile_net = tf.keras.applications.MobileNetV2(input_shape=(192, 192, 3),
include_top=False)
# 禁止训练更新“MobileNetV2”模型的参数
mobile_net.trainable = False

当我们执行代码后,训练好的“MobileNetV2”模型会被下载到本地,该模型是在ImageNet数据集上训练的。因为我们是想把该训练好的模型迁移到我们的花朵分类问题中来,所以我们设置该模型的参数不可训练和更新。

接下来我们打乱一下数据集,以及定义好训练过程中每个“batch”的大小。

# 使用Dataset类的shuffle方法打乱数据集
image_count = len(all_image_paths)
ds = image_label_ds.shuffle(buffer_size=image_count)
# 让数据集重复多次
ds = ds.repeat()
# 设置每个batch的大小
BATCH_SIZE = 32
ds = ds.batch(BATCH_SIZE)
# 通过“prefetch”方法让模型的训练和每个batch数据集的加载并行
ds = ds.prefetch(buffer_size=AUTOTUNE)

在代码中,使用“tf.data.Dataset”类的“shuffle”方法将数据集进行打乱。“repeat”方法让数据集可以重复获取,通常情况下如果我们一个“epoch”只对完整的数据集训练一遍的话,可以不需要设置“repeat”。“repeat”方法可以设置参数,例如“ds.repeat(2)”是让数据集可以重复获取两遍,即一个训练回合(epoch)中我们可以使用两遍数据集,不加参数的话,则默认可以无限次重复获取数据集。

由于“MobileNetV2”模型接收的输入数据是归一化后范围在[-1,1]之间的数据,我们在代码中对数据进行了一次归一化处理后,其范围在[0,1]之间,因此我们需要将我们的数据映射到[-1,1]之间。

def change_range(image,label):
 return 2*image-1, label
# 使用“map”方法对dataset进行处理
keras_ds = ds.map(change_range)

接下来我们定义模型,由于预训练好的“MobileNetV2”返回的数据维度为“(32,6,6,1280)”,其中32是一个“batch”的大小,“6,6”代表输出的特征图的大小为“6X6”,1280代表该层使用了1280个卷积核。为了适应我们的分类任务,我们需要在“MobileNetV2”返回数据的基础上再增加两层网络层。

model = tf.keras.Sequential([
 mobile_net,
 tf.keras.layers.GlobalAveragePooling2D(),
 tf.keras.layers.Dense(len(label_names))])

全局平均池化(GAP,Global Average Pooling)将每一个特征图求平均,将该平均值作为该特征图池化后的结果,因此经过该操作后数据的维度变为了(32,1280)。由于我们的花朵分类任务是一个5分类的任务,因此我们再使用一个全连接(Dense),将维度变为(32,5)。

接着我们编译一下模型,同时指定使用的优化器和损失函数:

model.compile(optimizer=tf.keras.optimizers.Adam(),
 loss='sparse_categorical_crossentropy',
 metrics=["accuracy"])
model.summary()
“model.summary()”可以输出模型各层的参数概况,如图所示:


最后我们使用“model.fit”训练模型:

model.fit(ds, epochs=1, steps_per_epoch=10)

这里参数“epochs”指定需要训练的回合数,“steps_per_epoch”代表每个回合要取多少个“batch”数据,通常“steps_per_epoch”的大小等于我们数据集的大小除以“batch”的大小后上取整。

相关推荐

【Docker 新手入门指南】第十章:Dockerfile

Dockerfile是Docker镜像构建的核心配置文件,通过预定义的指令集实现镜像的自动化构建。以下从核心概念、指令详解、最佳实践三方面展开说明,帮助你系统掌握Dockerfile的使用逻...

Windows下最简单的ESP8266_ROTS_ESP-IDF环境搭建与腾讯云SDK编译

前言其实也没啥可说的,只是我感觉ESP-IDF对新手来说很不友好,很容易踩坑,尤其是对业余DIY爱好者搭建环境非常困难,即使有官方文档,或者网上的其他文档,但是还是很容易踩坑,多研究,记住两点就行了,...

python虚拟环境迁移(python虚拟环境conda)

主机A的虚拟环境向主机B迁移。前提条件:主机A和主机B已经安装了virtualenv1.主机A操作如下虚拟环境目录:venv进入虚拟环境:sourcevenv/bin/active(1)记录虚拟环...

Python爬虫进阶教程(二):线程、协程

简介线程线程也叫轻量级进程,它是一个基本的CPU执行单元,也是程序执行过程中的最小单元,由线程ID、程序计数器、寄存器集合和堆栈共同组成。线程的引入减小了程序并发执行时的开销,提高了操作系统的并发性能...

基于网络安全的Docker逃逸(docker)

如何判断当前机器是否为Docker容器环境Metasploit中的checkcontainer模块、(判断是否为虚拟机,checkvm模块)搭配学习教程1.检查根目录下是否存在.dockerenv文...

Python编程语言被纳入浙江高考,小学生都开始学了

今年9月份开始的新学期,浙江省三到九年级信息技术课将同步替换新教材。其中,新初二将新增Python编程课程内容。新高一信息技术编程语言由VB替换为Python,大数据、人工智能、程序设计与算法按照教材...

CentOS 7下安装Python 3.10的完整过程

1.安装相应的编译工具yum-ygroupinstall"Developmenttools"yum-yinstallzlib-develbzip2-develope...

如何在Ubuntu 20.04上部署Odoo 14

Odoo是世界上最受欢迎的多合一商务软件。它提供了一系列业务应用程序,包括CRM,网站,电子商务,计费,会计,制造,仓库,项目管理,库存等等,所有这些都无缝集成在一起。Odoo可以通过几种不同的方式进...

Ubuntu 系统安装 PyTorch 全流程指南

当前环境:Ubuntu22.04,显卡为GeForceRTX3080Ti1、下载显卡驱动驱动网站:https://www.nvidia.com/en-us/drivers/根据自己的显卡型号和...

spark+python环境搭建(python 环境搭建)

最近项目需要用到spark大数据相关技术,周末有空spark环境搭起来...目标spark,python运行环境部署在linux服务器个人通过vscode开发通过远程python解释器执行代码准备...

centos7.9安装最新python-3.11.1(centos安装python环境)

centos7.9安装最新python-3.11.1centos7.9默认安装的是python-2.7.5版本,安全扫描时会有很多漏洞,比如:Python命令注入漏洞(CVE-2015-2010...

Linux系统下,五大步骤安装Python

一、下载Python包网上教程大多是通过官方地址进行下载Python的,但由于国内网络环境问题,会导致下载很慢,所以这里建议通过国内镜像进行下载例如:淘宝镜像http://npm.taobao.or...

centos7上安装python3(centos7安装python3.7.2一键脚本)

centos7上默认安装的是python2,要使用python3则需要自行下载源码编译安装。1.安装依赖yum-ygroupinstall"Developmenttools"...

利用本地数据通过微调方式训练 本地DeepSeek-R1 蒸馏模型

网络上相应的教程基本都基于LLaMA-Factory进行,本文章主要顺着相应的教程一步步实现大模型的微调和训练。训练环境:可自行定义,mac、linux或者window之类的均可以,本文以ma...

【法器篇】天啦噜,库崩了没备份(天啦噜是什么意思?)

背景数据库没有做备份,一天突然由于断电或其他原因导致无法启动了,且设置了innodb_force_recovery=6都无法启动,里面的数据怎么才能恢复出来?本例采用解析建表语句+表空间传输的方式进行...