英文教程太难啃?这里有一份TensorFlow2.0中文教程
bigegpt 2024-11-10 08:26 3 浏览
今年 3 月份,谷歌在 Tensorflow Developer Summit 2019 大会上发布 TensorFlow 2.0 Alpha 版。作为当前最为流行的深度学习框架,2.0 Alpha 版的正式发布引人关注。近两个月,网上已经出现了大量 TensorFlow 2.0 英文教程。在此文章中,机器之心为大家推荐一个持续更新的中文教程,以便大家学习。
虽然,自 TensorFlow 2.0 发布以来,我们总是能够听到「TensorFlow 2.0 就是 keras」、「说的很好,但我用 PyTorch」类似的吐槽。但毋庸置疑,TensorFlow 依然是当前最主流的深度学习框架(感兴趣的读者可查看机器之心文章:2019 年,TensorFlow 被拉下马了吗?)。
整体而言,为了吸引用户,TensorFlow 2.0 从简单、强大、可扩展三个层面进行了重新设计。特别是在简单化方面,TensorFlow 2.0 提供更简化的 API、注重 Keras、结合了 Eager execution。
过去一段时间,机器之心为大家编译介绍了部分英文教程,例如:
- 如何在 TensorFlow 2.0 中构建强化学习智能体
- TensorFlow 2.0 到底怎么样?简单的图像分类任务探一探
此文章中,机器之心为大家推荐一个持续更新的中文教程,方便大家更系统的学习、使用 TensorFlow 2.0 :
- 知乎专栏地址:https://zhuanlan.zhihu.com/c_1091021863043624960
- Github 项目地址:https://github.com/czy36mengfei/tensorflow2_tutorials_chinese
该教程是 NLP 爱好者 Doit 在知乎上开的一个专栏,由作者从 TensorFlow2.0 官方教程的个人学习复现笔记整理而来。作者将此教程分为了三类:TensorFlow 2.0 基础教程、TensorFlow 2.0 深度学习实践、TensorFlow 2.0 基础网络结构。
以基础教程为例,作者整理了 Keras 快速入门教程、eager 模式、Autograph 等。目前为止,该中文教程已经包含 20 多篇文章,作者还在持续更新中,感兴趣的读者可以 follow。
以下是作者整理的「Keras 快速入门」教程内容。
Keras 快速入门
Keras 是一个用于构建和训练深度学习模型的高阶 API。它可用于快速设计原型、高级研究和生产。
keras 的 3 个优点: 方便用户使用、模块化和可组合、易于扩展
1. 导入 tf.keras
tensorflow2 推荐使用 keras 构建网络,常见的神经网络都包含在 keras.layer 中 (最新的 tf.keras 的版本可能和 keras 不同)
import tensorflow as tf from tensorflow.keras import layers print(tf.__version__) print(tf.keras.__version__)
2. 构建简单模型
2.1 模型堆叠
最常见的模型类型是层的堆叠:tf.keras.Sequential 模型
model = tf.keras.Sequential() model.add(layers.Dense(32, activation='relu')) model.add(layers.Dense(32, activation='relu')) model.add(layers.Dense(10, activation='softmax'))
2.2 网络配置
tf.keras.layers 中网络配置:
- activation:设置层的激活函数。此参数由内置函数的名称指定,或指定为可调用对象。默认情况下,系统不会应用任何激活函数。
- kernel_initializer 和 bias_initializer:创建层权重(核和偏差)的初始化方案。此参数是一个名称或可调用对象,默认为 "Glorot uniform" 初始化器。
- kernel_regularizer 和 bias_regularizer:应用层权重(核和偏差)的正则化方案,例如 L1 或 L2 正则化。默认情况下,系统不会应用正则化函数。
layers.Dense(32, activation='sigmoid') layers.Dense(32, activation=tf.sigmoid) layers.Dense(32, kernel_initializer='orthogonal') layers.Dense(32, kernel_initializer=tf.keras.initializers.glorot_normal) layers.Dense(32, kernel_regularizer=tf.keras.regularizers.l2(0.01)) layers.Dense(32, kernel_regularizer=tf.keras.regularizers.l1(0.01))
3. 训练和评估
3.1 设置训练流程
构建好模型后,通过调用 compile 方法配置该模型的学习流程:
model = tf.keras.Sequential() model.add(layers.Dense(32, activation='relu')) model.add(layers.Dense(32, activation='relu')) model.add(layers.Dense(10, activation='softmax')) model.compile(optimizer=tf.keras.optimizers.Adam(0.001), loss=tf.keras.losses.categorical_crossentropy, metrics=[tf.keras.metrics.categorical_accuracy])
3.2 输入 Numpy 数据
import numpy as np train_x = np.random.random((1000, 72)) train_y = np.random.random((1000, 10)) val_x = np.random.random((200, 72)) val_y = np.random.random((200, 10)) model.fit(train_x, train_y, epochs=10, batch_size=100, validation_data=(val_x, val_y))
3.3tf.data 输入数据
dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y)) dataset = dataset.batch(32) dataset = dataset.repeat() val_dataset = tf.data.Dataset.from_tensor_slices((val_x, val_y)) val_dataset = val_dataset.batch(32) val_dataset = val_dataset.repeat() model.fit(dataset, epochs=10, steps_per_epoch=30, validation_data=val_dataset, validation_steps=3)
3.4 评估与预测
test_x = np.random.random((1000, 72)) test_y = np.random.random((1000, 10)) model.evaluate(test_x, test_y, batch_size=32) test_data = tf.data.Dataset.from_tensor_slices((test_x, test_y)) test_data = test_data.batch(32).repeat() model.evaluate(test_data, steps=30) # predict result = model.predict(test_x, batch_size=32) print(result)
4. 构建高级模型
4.1 函数式 api
tf.keras.Sequential 模型是层的简单堆叠,无法表示任意模型。使用 Keras 函数式 API 可以构建复杂的模型拓扑,例如:
- 多输入模型,
- 多输出模型,
- 具有共享层的模型(同一层被调用多次),
- 具有非序列数据流的模型(例如,残差连接)。
使用函数式 API 构建的模型具有以下特征:
- 层实例可调用并返回张量。
- 输入张量和输出张量用于定义 tf.keras.Model 实例。
- 此模型的训练方式和 Sequential 模型一样。
input_x = tf.keras.Input(shape=(72,)) hidden1 = layers.Dense(32, activation='relu')(input_x) hidden2 = layers.Dense(16, activation='relu')(hidden1) pred = layers.Dense(10, activation='softmax')(hidden2) model = tf.keras.Model(inputs=input_x, outputs=pred) model.compile(optimizer=tf.keras.optimizers.Adam(0.001), loss=tf.keras.losses.categorical_crossentropy, metrics=['accuracy']) model.fit(train_x, train_y, batch_size=32, epochs=5)
4.2 模型子类化
通过对 tf.keras.Model 进行子类化并定义您自己的前向传播来构建完全可自定义的模型。在 init 方法中创建层并将它们设置为类实例的属性。在 call 方法中定义前向传播
class MyModel(tf.keras.Model): def __init__(self, num_classes=10): super(MyModel, self).__init__(name='my_model') self.num_classes = num_classes self.layer1 = layers.Dense(32, activation='relu') self.layer2 = layers.Dense(num_classes, activation='softmax') def call(self, inputs): h1 = self.layer1(inputs) out = self.layer2(h1) return out def compute_output_shape(self, input_shape): shape = tf.TensorShapej(input_shape).as_list() shape[-1] = self.num_classes return tf.TensorShape(shape) model = MyModel(num_classes=10) model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001), loss=tf.keras.losses.categorical_crossentropy, metrics=['accuracy']) model.fit(train_x, train_y, batch_size=16, epochs=5)
4.3 自定义层
通过对 tf.keras.layers.Layer 进行子类化并实现以下方法来创建自定义层:
- build:创建层的权重。使用 add_weight 方法添加权重。
- call:定义前向传播。
- compute_output_shape:指定在给定输入形状的情况下如何计算层的输出形状。或者,可以通过实现 get_config 方法和 from_config 类方法序列化层。
class MyLayer(layers.Layer): def __init__(self, output_dim, **kwargs): self.output_dim = output_dim super(MyLayer, self).__init__(**kwargs) def build(self, input_shape): shape = tf.TensorShape((input_shape[1], self.output_dim)) self.kernel = self.add_weight(name='kernel1', shape=shape, initializer='uniform', trainable=True) super(MyLayer, self).build(input_shape) def call(self, inputs): return tf.matmul(inputs, self.kernel) def compute_output_shape(self, input_shape): shape = tf.TensorShape(input_shape).as_list() shape[-1] = self.output_dim return tf.TensorShape(shape) def get_config(self): base_config = super(MyLayer, self).get_config() base_config['output_dim'] = self.output_dim return base_config @classmethod def from_config(cls, config): return cls(**config) model = tf.keras.Sequential( [ MyLayer(10), layers.Activation('softmax') ]) model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001), loss=tf.keras.losses.categorical_crossentropy, metrics=['accuracy']) model.fit(train_x, train_y, batch_size=16, epochs=5)
4.4 回调
callbacks = [ tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'), tf.keras.callbacks.TensorBoard(log_dir='./logs') ] model.fit(train_x, train_y, batch_size=16, epochs=5, callbacks=callbacks, validation_data=(val_x, val_y))
5 保持和恢复
5.1 权重保存
model = tf.keras.Sequential([ layers.Dense(64, activation='relu'), layers.Dense(10, activation='softmax')]) model.compile(optimizer=tf.keras.optimizers.Adam(0.001), loss='categorical_crossentropy', metrics=['accuracy']) model.save_weights('./weights/model') model.load_weights('./weights/model') model.save_weights('./model.h5') model.load_weights('./model.h5')
5.2 保存网络结构
# 序列化成json import json import pprint json_str = model.to_json() pprint.pprint(json.loads(json_str)) fresh_model = tf.keras.models.model_from_json(json_str) # 保持为yaml格式 #需要提前安装pyyaml yaml_str = model.to_yaml() print(yaml_str) fresh_model = tf.keras.models.model_from_yaml(yaml_str)
5.3 保存整个模型
model = tf.keras.Sequential([ layers.Dense(10, activation='softmax', input_shape=(72,)), layers.Dense(10, activation='softmax') ]) model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy']) model.fit(train_x, train_y, batch_size=32, epochs=5) model.save('all_model.h5') model = tf.keras.models.load_model('all_model.h5')
6. 将 keras 用于 Estimator
Estimator API 用于针对分布式环境训练模型。它适用于一些行业使用场景,例如用大型数据集进行分布式训练并导出模型以用于生产
model = tf.keras.Sequential([layers.Dense(10,activation='softmax'), layers.Dense(10,activation='softmax')]) model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001), loss='categorical_crossentropy', metrics=['accuracy']) estimator = tf.keras.estimator.model_to_estimator(model)
相关推荐
- 悠悠万事,吃饭为大(悠悠万事吃饭为大,什么意思)
-
新媒体编辑:杜岷赵蕾初审:程秀娟审核:汤小俊审签:周星...
- 高铁扒门事件升级版!婚宴上‘冲喜’老人团:我们抢的是社会资源
-
凌晨两点改方案时,突然收到婚庆团队发来的视频——胶东某酒店宴会厅,三个穿大红棉袄的中年妇女跟敢死队似的往前冲,眼瞅着就要扑到新娘的高额钻石项链上。要不是门口小伙及时阻拦,这婚礼造型团队熬了三个月的方案...
- 微服务架构实战:商家管理后台与sso设计,SSO客户端设计
-
SSO客户端设计下面通过模块merchant-security对SSO客户端安全认证部分的实现进行封装,以便各个接入SSO的客户端应用进行引用。安全认证的项目管理配置SSO客户端安全认证的项目管理使...
- 还在为 Spring Boot 配置类加载机制困惑?一文为你彻底解惑
-
在当今微服务架构盛行、项目复杂度不断攀升的开发环境下,SpringBoot作为Java后端开发的主流框架,无疑是我们手中的得力武器。然而,当我们在享受其自动配置带来的便捷时,是否曾被配置类加载...
- Seata源码—6.Seata AT模式的数据源代理二
-
大纲1.Seata的Resource资源接口源码2.Seata数据源连接池代理的实现源码3.Client向Server发起注册RM的源码4.Client向Server注册RM时的交互源码5.数据源连接...
- 30分钟了解K8S(30分钟了解微积分)
-
微服务演进方向o面向分布式设计(Distribution):容器、微服务、API驱动的开发;o面向配置设计(Configuration):一个镜像,多个环境配置;o面向韧性设计(Resista...
- SpringBoot条件化配置(@Conditional)全面解析与实战指南
-
一、条件化配置基础概念1.1什么是条件化配置条件化配置是Spring框架提供的一种基于特定条件来决定是否注册Bean或加载配置的机制。在SpringBoot中,这一机制通过@Conditional...
- 一招解决所有依赖冲突(克服依赖)
-
背景介绍最近遇到了这样一个问题,我们有一个jar包common-tool,作为基础工具包,被各个项目在引用。突然某一天发现日志很多报错。一看是NoSuchMethodError,意思是Dis...
- 你读过Mybatis的源码?说说它用到了几种设计模式
-
学习设计模式时,很多人都有类似的困扰——明明概念背得滚瓜烂熟,一到写代码就完全想不起来怎么用。就像学了一堆游泳技巧,却从没下过水实践,很难真正掌握。其实理解一个知识点,就像看立体模型,单角度观察总...
- golang对接阿里云私有Bucket上传图片、授权访问图片
-
1、为什么要设置私有bucket公共读写:互联网上任何用户都可以对该Bucket内的文件进行访问,并且向该Bucket写入数据。这有可能造成您数据的外泄以及费用激增,若被人恶意写入违法信息还可...
- spring中的资源的加载(spring加载原理)
-
最近在网上看到有人问@ContextConfiguration("classpath:/bean.xml")中除了classpath这种还有其他的写法么,看他的意思是想从本地文件...
- Android资源使用(android资源文件)
-
Android资源管理机制在Android的开发中,需要使用到各式各样的资源,这些资源往往是一些静态资源,比如位图,颜色,布局定义,用户界面使用到的字符串,动画等。这些资源统统放在项目的res/独立子...
- 如何深度理解mybatis?(如何深度理解康乐服务质量管理的5个维度)
-
深度自定义mybatis回顾mybatis的操作的核心步骤编写核心类SqlSessionFacotryBuild进行解析配置文件深度分析解析SqlSessionFacotryBuild干的核心工作编写...
- @Autowired与@Resource原理知识点详解
-
springIOCAOP的不多做赘述了,说下IOC:SpringIOC解决的是对象管理和对象依赖的问题,IOC容器可以理解为一个对象工厂,我们都把该对象交给工厂,工厂管理这些对象的创建以及依赖关系...
- java的redis连接工具篇(java redis client)
-
在Java里,有不少用于连接Redis的工具,下面为你介绍一些主流的工具及其特点:JedisJedis是Redis官方推荐的Java连接工具,它提供了全面的Redis命令支持,且...
- 一周热门
- 最近发表
- 标签列表
-
- mybatiscollection (79)
- mqtt服务器 (88)
- keyerror (78)
- c#map (65)
- resize函数 (64)
- xftp6 (83)
- bt搜索 (75)
- c#var (76)
- mybatis大于等于 (64)
- xcode-select (66)
- mysql授权 (74)
- 下载测试 (70)
- linuxlink (65)
- pythonwget (67)
- androidinclude (65)
- logstashinput (65)
- hadoop端口 (65)
- vue阻止冒泡 (67)
- oracle时间戳转换日期 (64)
- jquery跨域 (68)
- php写入文件 (73)
- kafkatools (66)
- mysql导出数据库 (66)
- jquery鼠标移入移出 (71)
- 取小数点后两位的函数 (73)