大模型训练的奥秘:从一个字母的奇妙转变开始
bigegpt 2024-11-10 08:26 3 浏览
为了更好地理解模型训练的过程,我们将通过一个简单的例子来说明。假设我们有一个模型,其输入是小写字母(如'a'、'b'、'c'等),目标是输出对应的大写字母(如'A'、'B'、'C'等)。我们将逐步演示如何创建、训练和使用这个模型,并解释其中涉及的关键概念。
为了更加易于理解,假设我们正在设计一个魔法小精灵,这本书能够识别小写字母,并将它们变成大写字母。我们用一个小精灵来代表我们的模型,这个小精灵需要学会如何将小写字母转换成大写字母。
mapping = {'a': 'A', 'b': 'B', 'c': 'C', ..., 'z': 'Z'}
这里我们创建了一个字典 mapping,它把小写字母映射到对应的大写字母。例如,'a' 映射到 'A','b' 映射到 'B' 等等。
2.模型构建
我们需要训练一个小精灵,让它学会如何把小写字母变成大写字母。为此,我们给小精灵设计了一个学习计划,包括三个阶段:输入阶段、学习阶段和输出阶段。
import tensorflow as tf
from tensorflow import keras
model = keras.models.Sequential([
keras.layers.Embedding(input_dim=26, output_dim=10), # 输入层
keras.layers.Dense(64, activation='relu'), # 学习层
keras.layers.Dense(26, activation='softmax') # 输出层
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
- 输入层(Embedding):小精灵需要先理解每个字母的含义。在这里,我们用一个 Embedding 层来表示每个字母。我们有26个小写字母,每个字母用一个10维的向量表示(output_dim=10)。
- 比如,'a' 用 [1, 0, 0, ...] 表示,'b' 用 [0, 1, 0, ...] 表示,等等。
- 学习层(Dense):小精灵需要通过学习来理解每个字母的含义。这里我们使用一个 Dense 层,它有64个神经元,使用 ReLU 激活函数。
- 这个层的作用是让小精灵学会如何从输入的向量中提取有用的特征。
- 输出层(Dense):最后,小精灵需要决定哪个大写字母是最有可能的答案。这里我们再次使用一个 Dense 层,这次有26个神经元,使用 Softmax 激活函数。
- Softmax 函数的作用是将每个大写字母的概率输出为一个0到1之间的数,并且所有概率之和为1。
3.数据预处理
为了让小精灵更好地学习,我们需要将输入数据转换成它能够理解的形式。
inputs = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
outputs = [ord(mapping[input]) - ord('A') for input in inputs]
input_indices = [ord(char) - ord('a') for char in inputs]
- 我们有一些输入数据 inputs,即小写字母。
- outputs 是对应的输出,即大写字母的索引。
- input_indices 是输入数据的索引,表示每个字母在字母表中的位置(如 'a' 对应 0,'b' 对应 1,等等)。
4.模型训练
现在我们需要让小精灵学习如何将小写字母转换成大写字母。我们会给它看一些例子,让它通过反复练习来学习。
X_train = np.array(input_indices)
y_train = np.array(outputs)
history = model.fit(X_train, y_train, epochs=100, batch_size=1, verbose=1)
- X_train 是输入数据,即小写字母的索引。
- y_train 是输出数据,即大写字母的索引。
- epochs 参数表示训练轮数,即小精灵需要学习多少次。
- batch_size 参数表示每次训练时看多少个字母。
5.模型评估与使用
小精灵学完后,我们需要检查它是否真的学会了。我们还会用它来做一些预测,看看它能否正确地将小写字母转换成大写字母。
test_inputs = ['k', 'l', 'm']
test_outputs = [ord(mapping[char]) - ord('A') for char in test_inputs]
test_input_indices = [ord(char) - ord('a') for char in test_inputs]
X_test = np.array(test_input_indices)
y_test = np.array(test_outputs)
loss, accuracy = model.evaluate(X_test, y_test)
print(f"Test Loss: {loss}, Test Accuracy: {accuracy}")
predictions = model.predict(X_test)
predicted_indices = np.argmax(predictions, axis=1)
predicted_outputs = [chr(index + ord('A')) for index in predicted_indices]
print("Predicted Outputs:", predicted_outputs)
- test_inputs 是一些新的测试数据,小精灵之前没有见过。
- test_outputs 是对应的正确答案。
- evaluate 方法用来评估模型在测试数据上的表现。
- predict 方法用来预测新的数据。
- np.argmax 函数用来找出预测结果中概率最高的那个字母的索引。
- 最后,我们将索引转换成对应的字母,并打印出来。
通过这个简单的例子,我们展示了如何创建、训练和使用一个模型来将小写字母转换为大写字母。在这个过程中,我们涉及到了数据准备、模型构建、数据预处理、模型训练、模型参数以及模型评估与使用等关键步骤。
这个例子虽然简单,但包含了深度学习模型训练的基本流程。希望这个例子能够帮助你更好地理解模型训练的过程及其背后的概念。
5.模型的参数
我们的模型训练完成后,一般会说这个模型会有多少参数,这个参数又是怎么回事呢?在这个例子中,模型的参数主要是权重(weights)和偏置(biases),它们存储在模型的每一层中。我们可以通过具体的代码和解释来展示这些参数的大小。
模型参数及其大小
1.Embedding 层
在 Embedding 层中,参数是一个嵌入矩阵,它的大小是 input_dim × output_dim。
embedding_layer = keras.layers.Embedding(input_dim=26, output_dim=10)
- input_dim 是输入词汇表的大小,这里是26个字母。
- output_dim 是每个字母嵌入到的向量维度,这里是10。
因此,Embedding 层的参数矩阵大小为 26 × 10。这意味着模型需要存储260个权重(每个字母对应一个10维向量)。
2.Dense 层
在 Dense 层中,参数包括权重矩阵和偏置向量。
hidden_layer = keras.layers.Dense(64, activation='relu')
- Dense 层的输入维度是由前一层的输出维度决定的,这里是10。
- Dense 层的输出维度是64。
因此,Dense 层的权重矩阵大小为 10 × 64,并且还有一个长度为64的偏置向量。
3.输出层
在输出层中,参数同样包括权重矩阵和偏置向量。
output_layer = keras.layers.Dense(26, activation='softmax')
- 输出层的输入维度是64(由前一层的输出维度决定)。
- 输出层的输出维度是26(对应26个字母)。
因此,输出层的权重矩阵大小为 64 × 26,并且还有一个长度为26的偏置向量。
查看模型参数大小
我们可以使用 TensorFlow 和 Keras 提供的方法来查看模型的参数大小。
# 查看模型的总参数数量
model.summary()
输出示例:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) (None, None, 10) 260
_________________________________________________________________
dense (Dense) (None, 64) 640
_________________________________________________________________
dense_1 (Dense) (None, 26) 1690
=================================================================
Total params: 2,590
Trainable params: 2,590
Non-trainable params: 0
_________________________________________________________________
解释模型参数大小
- Embedding 层:参数大小为260(26×10)。
- Hidden Dense 层:参数大小为640(10×64权重 + 64偏置)。
- Output Dense 层:参数大小为1690(64×26权重 + 26偏置)。
因此,整个模型的总参数数量为2590。
查看模型参数的具体值
我们还可以查看模型的具体参数值,以便更深入地理解模型是如何工作的。
# 查看模型的权重和偏置
weights = model.get_weights()
print("Embedding Weights:")
print(weights[0])
print("\nHidden Layer Weights:")
print(weights[1])
print("Hidden Layer Biases:")
print(weights[2])
print("\nOutput Layer Weights:")
print(weights[3])
print("Output Layer Biases:")
print(weights[4])
通过上述步骤,我们可以清楚地了解到模型的参数及其大小。在这个例子中,模型的主要参数包括:
- Embedding 层:260个参数(26×10)。
- Hidden Dense 层:640个参数(10×64权重 + 64偏置)。
- Output Dense 层:1690个参数(64×26权重 + 26偏置)。
总参数数量为2590。这些参数在训练过程中会被不断调整,以使模型能够更好地从输入数据中学习并预测正确的输出。希望这个解释能够帮助你更好地理解模型参数及其大小。
- 上一篇:掌握深度学习,数据不足也能进行图像分类
- 下一篇:专为初学者设计——最小的神经网络
相关推荐
- 悠悠万事,吃饭为大(悠悠万事吃饭为大,什么意思)
-
新媒体编辑:杜岷赵蕾初审:程秀娟审核:汤小俊审签:周星...
- 高铁扒门事件升级版!婚宴上‘冲喜’老人团:我们抢的是社会资源
-
凌晨两点改方案时,突然收到婚庆团队发来的视频——胶东某酒店宴会厅,三个穿大红棉袄的中年妇女跟敢死队似的往前冲,眼瞅着就要扑到新娘的高额钻石项链上。要不是门口小伙及时阻拦,这婚礼造型团队熬了三个月的方案...
- 微服务架构实战:商家管理后台与sso设计,SSO客户端设计
-
SSO客户端设计下面通过模块merchant-security对SSO客户端安全认证部分的实现进行封装,以便各个接入SSO的客户端应用进行引用。安全认证的项目管理配置SSO客户端安全认证的项目管理使...
- 还在为 Spring Boot 配置类加载机制困惑?一文为你彻底解惑
-
在当今微服务架构盛行、项目复杂度不断攀升的开发环境下,SpringBoot作为Java后端开发的主流框架,无疑是我们手中的得力武器。然而,当我们在享受其自动配置带来的便捷时,是否曾被配置类加载...
- Seata源码—6.Seata AT模式的数据源代理二
-
大纲1.Seata的Resource资源接口源码2.Seata数据源连接池代理的实现源码3.Client向Server发起注册RM的源码4.Client向Server注册RM时的交互源码5.数据源连接...
- 30分钟了解K8S(30分钟了解微积分)
-
微服务演进方向o面向分布式设计(Distribution):容器、微服务、API驱动的开发;o面向配置设计(Configuration):一个镜像,多个环境配置;o面向韧性设计(Resista...
- SpringBoot条件化配置(@Conditional)全面解析与实战指南
-
一、条件化配置基础概念1.1什么是条件化配置条件化配置是Spring框架提供的一种基于特定条件来决定是否注册Bean或加载配置的机制。在SpringBoot中,这一机制通过@Conditional...
- 一招解决所有依赖冲突(克服依赖)
-
背景介绍最近遇到了这样一个问题,我们有一个jar包common-tool,作为基础工具包,被各个项目在引用。突然某一天发现日志很多报错。一看是NoSuchMethodError,意思是Dis...
- 你读过Mybatis的源码?说说它用到了几种设计模式
-
学习设计模式时,很多人都有类似的困扰——明明概念背得滚瓜烂熟,一到写代码就完全想不起来怎么用。就像学了一堆游泳技巧,却从没下过水实践,很难真正掌握。其实理解一个知识点,就像看立体模型,单角度观察总...
- golang对接阿里云私有Bucket上传图片、授权访问图片
-
1、为什么要设置私有bucket公共读写:互联网上任何用户都可以对该Bucket内的文件进行访问,并且向该Bucket写入数据。这有可能造成您数据的外泄以及费用激增,若被人恶意写入违法信息还可...
- spring中的资源的加载(spring加载原理)
-
最近在网上看到有人问@ContextConfiguration("classpath:/bean.xml")中除了classpath这种还有其他的写法么,看他的意思是想从本地文件...
- Android资源使用(android资源文件)
-
Android资源管理机制在Android的开发中,需要使用到各式各样的资源,这些资源往往是一些静态资源,比如位图,颜色,布局定义,用户界面使用到的字符串,动画等。这些资源统统放在项目的res/独立子...
- 如何深度理解mybatis?(如何深度理解康乐服务质量管理的5个维度)
-
深度自定义mybatis回顾mybatis的操作的核心步骤编写核心类SqlSessionFacotryBuild进行解析配置文件深度分析解析SqlSessionFacotryBuild干的核心工作编写...
- @Autowired与@Resource原理知识点详解
-
springIOCAOP的不多做赘述了,说下IOC:SpringIOC解决的是对象管理和对象依赖的问题,IOC容器可以理解为一个对象工厂,我们都把该对象交给工厂,工厂管理这些对象的创建以及依赖关系...
- java的redis连接工具篇(java redis client)
-
在Java里,有不少用于连接Redis的工具,下面为你介绍一些主流的工具及其特点:JedisJedis是Redis官方推荐的Java连接工具,它提供了全面的Redis命令支持,且...
- 一周热门
- 最近发表
- 标签列表
-
- mybatiscollection (79)
- mqtt服务器 (88)
- keyerror (78)
- c#map (65)
- resize函数 (64)
- xftp6 (83)
- bt搜索 (75)
- c#var (76)
- mybatis大于等于 (64)
- xcode-select (66)
- mysql授权 (74)
- 下载测试 (70)
- linuxlink (65)
- pythonwget (67)
- androidinclude (65)
- logstashinput (65)
- hadoop端口 (65)
- vue阻止冒泡 (67)
- oracle时间戳转换日期 (64)
- jquery跨域 (68)
- php写入文件 (73)
- kafkatools (66)
- mysql导出数据库 (66)
- jquery鼠标移入移出 (71)
- 取小数点后两位的函数 (73)