生成对抗网络GAN && 人脸图像生成练习
bigegpt 2025-06-23 14:57 2 浏览
介绍
在2016年的一个研讨会上,杨立昆称生成式对抗网络为“机器学习这二十年来最酷的想法”。
生成对抗网络(Generative Adversarial Network,GAN)是一种非监督学习方法,通过两个神经网络相互对抗的方式进行学习。这一方法于2014年由伊恩·古德费洛等人提出。
GAN包括一个生成网络和一个判别网络。生成网络从潜在空间中随机采样作为输入,并试图生成与训练集中真实样本相似的输出。判别网络的输入是真实样本或生成网络的输出,其目标是将生成网络的输出与真实样本区分开来。
生成网络的目标是尽可能欺骗判别网络,而判别网络则努力辨别真实样本和生成网络的输出。
两个网络相互对抗,不断调整参数,最终使判别网络难以判断生成网络的输出是否真实。
生成对抗网络通常用于生成逼真的图像,同时也可以用于生成影片、三维物体模型等领域。
尽管最初生成对抗网络是为了无监督学习而提出的,但已经证明对半监督学习、完全监督学习和强化学习也是有效的。
实现原理
GAN 可以被视为一场博弈。生成器的目标是最小化判断器区分真实数据 (x) 和生成数据 (G(z)) 的能力。另一方面,判断器最大化其进行区分的能力。
这个迭代过程一直持续到达到纳什均衡。当达到这个均衡时,G 和 D 都无法进一步提高。这种平衡导致 G 产生高度真实的数据,而 D 很难将这些数据与真实样本区分开来。
主要包括以下几个步骤:
1、定义生成器(Generator)和判别器(Discriminator)的网络结构:
生成器负责从随机噪声生成样本,判别器则负责区分生成器生成的样本和真实样本。
确定生成器和判别器的网络结构,包括层数、激活函数、优化器等。
2、准备数据集:
准备训练GAN所需的数据集,确保数据集包含足够多的真实样本用于训练。
3、定义损失函数:
GAN的损失函数由生成器和判别器的损失组成。
生成器的损失通常包括生成器生成的样本被判别为真实样本的概率。
判别器的损失包括将生成器生成的样本正确分类为生成样本或真实样本的概率。
4、训练GAN:
交替训练生成器和判别器,每次迭代中先训练判别器,然后训练生成器。
生成器生成样本,判别器评估生成器生成的样本和真实样本的区分度,根据评估结果更新生成器和判别器的参数。
5、评估生成器:
训练完成后,评估生成器生成的样本是否逼真,可以通过人工评估或其他评估指标来判断生成器的效果。
6、调优和改进:
根据评估结果调整网络结构、损失函数或训练策略,以改进生成器的生成效果。
人脸数据集
数据集包括:
- 50,000 张高分辨率人脸图像
- 图像为178 * 218 像素
数据预处理
数据预处理是为分析或机器学习任务准备数据的重要步骤。它涉及对原始数据进行转换和清理,使其适合进一步处理。
具体步骤有:
1. 数据采集
2. 数据标准化
为了确保一致性并提高模型收敛性,对图像的值进行标准化非常重要。在我们的例子中,我们将像素值标准化在 0 和 1 之间。我们还缩小了所有图像的尺寸,因此我们可以在我们的机器上训练模型。
3. 随机播放和批量创建
为了引入随机性并防止数据排序中出现任何潜在偏差,我们可以在创建批次之前对图像进行打乱。这确保图像的顺序不会影响学习过程。打乱后,我们可以继续创建 64 张图像的批次。
4. 批处理
为了有效地处理大型数据集,通常将它们分成较小的批次。在本例中,我们创建了 64 个图像的批次。这使我们能够在训练或分析期间以较小的部分提供数据,减少内存需求并有可能实现并行处理如果可供使用的话。
通过将数据分成批次,我们可以迭代地处理每个批次,而无需立即将整个数据集加载到内存中。在处理无法完全装入内存的大型数据集时,这种方法特别有用。
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Dense, BatchNormalization, Flatten, InputLayer, LeakyReLU, Conv2DTranspose, Reshape, Dropout
from tensorflow.keras import Model
import matplotlib.pyplot as plt
import os
import time
def load_and_preprocess_dataset(directory, image_size, batch_size, shuffle=True):
"""
function that reads the Dataset, Shuffle it, normalizes it and batches it
读取数据集、对其进行洗牌、标准化并对其进行批处理的函数
"""
dataset = tf.keras.utils.image_dataset_from_directory(
directory=directory,
image_size=image_size,
batch_size=batch_size,
shuffle=shuffle,
label_mode=None,
color_mode="rgb" # RGB Color
)
# # 将数据集归一化到0和1之间
dataset = dataset.map(lambda x: x / 255.0)
return dataset
TRAIN_DIR = "./images/"
IMAGE_SIZE = (96,80)
BATCH_SIZE = 64
# 加载训练数据
train_data = load_and_preprocess_dataset(TRAIN_DIR, IMAGE_SIZE, BATCH_SIZE, shuffle=True)
模型定义
判断器(Discriminator)
下一步是建立一个判别器。判别器是生成对抗网络(GAN)中至关重要的组成部分,其作用是区分真实样本和生成的(假)样本。以下是对判别器的简要概述:
- 输入:判别器接收来自生成器的图像或样本作为输入。在图像生成任务中,输入通常由具有特定尺寸的图像组成。
- 架构:判别器通常由处理输入并提取特征的层组成。常见的架构选择包括卷积层,然后是激活函数,如ReLU或LeakyReLU。这些层旨在从输入样本中捕获相关模式和判别信息。
- 输出:判别器的输出是一个概率分数,表示输入是真实的还是假的可能性。通常是一个介于0到1之间的单个标量值。接近1的值表明输入被分类为真实的,而接近0的值表明它被分类为假的。
判别器的目标是提高区分真实样本和生成样本的能力,从而促使生成器产生更为真实的输出。判别器和生成器以对抗方式进行训练,生成器的目标是欺骗判别器,而判别器的目标是准确分类样本。
class Discriminator (Model):
def __init__(self, input_shape, batch_size):
super(Discriminator, self).__init__()
self.__input_shape = input_shape
self.__batch_size = batch_size
self.__discriminator = tf.keras.Sequential(
[
InputLayer(input_shape=self.__input_shape, batch_size=self.__batch_size),
Conv2D(32,(3,3),padding="same", strides=2),
Dropout(0.1),
BatchNormalization(),
LeakyReLU(alpha=0.01),
Conv2D(64,(3,3),padding="same", strides=2),
Dropout(0.1),
BatchNormalization(),
LeakyReLU(alpha=0.01),
Conv2D(128,(3,3),padding="same", strides=2),
Dropout(0.1),
LeakyReLU(alpha=0.01),
Conv2D(128,(3,3),padding="same", strides=2),
Dropout(0.1),
LeakyReLU(alpha=0.01),
Flatten(),
Dropout(0.1),
LeakyReLU(alpha=0.01),
Dense (1, activation ="sigmoid")
]
)
def call (self, input) :
return self.__discriminator(input)
def get_model (self) :
return self.__discriminator
生成器(Generator)
生成器是生成对抗网络(GAN)中的核心组件,其任务是生成类似于真实数据的合成样本。以下是对生成器的简要描述:
- 输入:生成器通常接受随机噪声或潜在向量作为输入。这些向量通常是从概率分布(如均匀分布或高斯分布)中采样得到的。输入向量的大小和维度取决于具体问题和所需的输出。
- 架构:生成器由将输入噪声或潜在向量转换为有意义数据表示的层组成。常见选择包括全连接(密集)层或转置卷积层。这些层逐渐对输入进行上采样,并应用非线性操作以生成更高分辨率的输出。
- 输出:生成器的输出是一个合成样本,旨在类似于真实数据。在图像生成任务中,输出可以是具有特定尺寸的图像。生成器生成的样本应与真实样本处于同一数据分布内。
- 训练:生成器与判别器一起进行训练。其目标是生成合成样本,以欺骗判别器将其分类为真实样本。生成器的参数根据判别器的反馈,通过反向传播和优化算法(如随机梯度下降或Adam)进行更新。
生成器的目标是逐步提高生成真实且多样化样本的能力。它学习捕获真实数据中存在的潜在模式和结构,有效地从随机噪声中合成新样本。随着训练的进行,生成器变得更加熟练地生成类似于真实数据分布的样本。
class Generator (Model):
def __init__(self, batch_size):
super(Generator, self).__init__()
self.__generator = tf.keras.Sequential(
[
InputLayer(input_shape=(64,), batch_size=batch_size),
Dense (6*5*64, use_bias=False),
LeakyReLU(alpha=0.01),
Reshape ((6,5,64)),
Conv2DTranspose(filters=256, padding="same", kernel_size=(3,3), strides = 2),
BatchNormalization (),
LeakyReLU(alpha=0.01),
Conv2DTranspose(filters=128, padding="same", kernel_size=(3,3), strides = 2),
BatchNormalization (),
LeakyReLU(alpha=0.01),
Conv2DTranspose(filters=128, padding="same", kernel_size=(3,3), strides = 2),
BatchNormalization (),
LeakyReLU(alpha=0.01),
Conv2DTranspose(filters=64, padding="same", kernel_size=(3,3), strides = 2),
LeakyReLU(alpha=0.01),
Conv2D(3,(3,3),padding="same",activation="sigmoid")
]
)
def call (self, input) :
return self.__generator(input)
def get_model (self) :
return self.__generator
dis = Discriminator((96,80,3),None)
gen = Generator(batch_size=None)
训练
- 定义损失函数是指定生成器和判别器的损失函数。判别器的损失函数旨在正确分类真实样本和生成样本,而生成器的损失函数则旨在鼓励生成的样本被判别器分类为真实样本。
cross_entropy = tf.keras.losses.BinaryCrossentropy()
- 判断器损失函数计算了在训练过程中判断器的损失。该函数接受两个输入:“real_output”和“fake_output”,分别表示判断器对真实样本和生成样本的预测。这个函数使用交叉熵损失函数来比较判断器的预测和目标标签。对于真实样本,目标标签为1,表示样本是真实的;对于生成样本,目标标签为0,表示样本是伪造的。“real_loss”通过将真实输出的预测与形状与“real_output”相同的张量进行比较来计算。这个损失代表了判断器正确分类真实样本的能力。类似地,“fake_loss”通过将生成输出的预测与形状与“fake_output”相同的零张量进行比较来计算。这个损失代表了判断器正确分类生成样本的能力。最后,通过将“real_loss”和“fake_loss”相加得到“total_loss”。这个综合损失用于在训练过程中优化判断器的权重。最终,该函数返回“total_loss”,它代表了判断器的整体损失。
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
- 生成器损失函数用于计算生成对抗网络(GAN)中生成器在训练期间的损失。该函数接受“fake_output”作为输入,表示判断器对生成的(假)样本的预测。在该函数中,使用交叉熵损失函数将判断器对假样本的预测与形状与“fake_output”相同的张量进行比较。目标标签为1,表示生成器期望判断器将假样本误分类为真实样本。生成器损失“gen_loss”是通过将假输出的预测与相同形状的张量进行比较来计算的。这个损失反映了生成器生成可以欺骗判断器的样本的表现。最终,该函数以浮点数形式返回“gen_loss”,表示生成器的损失。
def generator_loss(fake_output):
gen_loss = cross_entropy(tf.ones_like(fake_output), fake_output)
return float (gen_loss)
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
checkpoint_dir = './training_checkpoints_v1'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=gen,
discriminator=dis)
ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=1)
if ckpt_manager.latest_checkpoint:
checkpoint.restore(ckpt_manager.latest_checkpoint)
EPOCHS = 200
noise_dim = 64
num_examples_to_generate = 64
# 创建了一个正态分布的随机种子
seed = tf.random.normal([num_examples_to_generate, noise_dim])
# 这个训练步骤中,首先生成器生成虚假图像,然后判别器分别对真实图像和虚假图像进行判断,并计算损失,最后通过梯度下降来更新生成器和判别器的参数。
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = gen(noise, training=True)
real_output = dis(images, training=True)
fake_output = dis(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, gen.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, dis.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, gen.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, dis.trainable_variables))
return (disc_loss, gen_loss)
def gen_and_show_image (model, epoch, noise) :
gen_image = model(noise, training=False)
plt.imshow(gen_image[0])
plt.axis('off')
plt.savefig('./generated/image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
total_generator_loss = 0
total_discriminator_loss = 0
for image_batch in dataset:
disc_loss , gen_loss = train_step(image_batch)
total_generator_loss = total_generator_loss + gen_loss
total_discriminator_loss = total_discriminator_loss + disc_loss
# Generate after the final epoch
gen_and_show_image(
gen,
epoch,
seed
)
checkpoint.save(file_prefix = checkpoint_prefix)
gen_epoch_loss = total_generator_loss / len(dataset)
disc_epoch_loss = total_discriminator_loss / len (dataset)
print ('epoch {} 的时间为 {} 秒,生成器损失 = {} --- 判别器损失 = {}'.format(epoch + 1, time.time()-start, gen_epoch_loss, disc_epoch_loss))
执行
train(train_data, EPOCHS)
epoch 200 的时间为 3.435598134994507 秒,生成器损失 = 3.0046656131744385 --- 判别器损失 = 0.5568044185638428
以下是训练的一些关键参数:
- 训练epoch数:200
- 每个时期的平均训练时间:约 37 秒
- 总训练时间:123 分钟(约 2 小时)
- 批量大小:64
- 使用的显卡:Nvidia RTX 3080(10 GB VRAM)
结果
该模型已经经过成功训练,能够利用我们提供的数据集生成一些人脸。下面的 GIF,展示了训练过程的演变:
然而,结果并不尽如人意,这是由于多种因素造成的:
- 模型规模: 模型的参数不够多
- 数据集质量: 数据集并非完全合理,我们需要解决这些问题并清理数据集
- 判别器学习: 判别器的学习效率太高,它学会区分假图像和真实图像的速度比生成器学会欺骗判别器的速度还要快。
相关推荐
- AI「自我复制」能力曝光,RepliBench警示:大模型正在学会伪造身份
-
科幻中AI自我复制失控场景,正成为现实世界严肃的研究课题。英国AISI推出RepliBench基准,分解并评估AI自主复制所需的四大核心能力。测试显示,当前AI尚不具备完全自主复制能力,但在获取资源...
- 【Python第三方库安装】介绍8种情况,这里最全看这里就够了!
-
**本图文作品主要解决CMD或pycharm终端下载安装第三方库可能出错的问题**本作品介绍了8种安装方法,这里最全的python第三方库安装教程,简单易上手,满满干货!希望大家能愉快地写代码,而不要...
- pyvips,一个神奇的 Python 库!(pythonvip视频)
-
大家好,今天为大家分享一个神奇的Python库-pyvips。在图像处理领域,高效和快速的图像处理工具对于开发者来说至关重要。pyvips是一个强大的Python库,基于libvips...
- mac 安装tesseract、pytesseract以及简单使用
-
一.tesseract-OCR的介绍1.tesseract-OCR是一个开源的OCR引擎,能识别100多种语言,专门用于对图片文字进行识别,并获取文本。但是它的缺点是对手写的识别能力比较差。2.用te...
- 实测o3/o4-mini:3分钟解决欧拉问题,OpenAI最强模型名副其实!
-
号称“OpenAI迄今为止最强模型”,o3/o4-mini真实能力究竟如何?就在发布后的几小时内,网友们的第一波实测已新鲜出炉。最强推理模型o3,即使遇上首位全职提示词工程师RileyGoodsid...
- 使用Python将图片转换为字符画并保存到文件
-
字符画(ASCIIArt)是将图片转换为由字符组成的艺术作品。利用Python,我们可以轻松实现图片转字符画的功能。本教程将带你一步步实现这个功能,并详细解释每一步的代码和实现原理。环境准备首先,你...
- 5分钟-python包管理器pip安装(python pip安装包)
-
pip是一个现代的,通用、普遍的Python包管理工具。提供了对Python包的查找、下载、安装、卸载的功能,是Python开发的基础。第一步:PC端打开网址:选择gz后缀的文件下载第二步:...
- 网络问题快速排查,你也能当好自己家的网络攻城狮
-
前面写了一篇关于网络基础和常见故障排查的,只列举了工具。没具体排查方式。这篇重点把几个常用工具的组合讲解一下。先有请今天的主角:nslookup及dig,traceroute,httping,teln...
- 终于把TCP/IP 协议讲的明明白白了,再也不怕被问三次握手了
-
文:涤生_Woo下周就开始和大家成体系的讲hadoop了,里面的每一个模块的技术细节我都会涉及到,希望大家会喜欢。当然了你也可以评论或者留言自己喜欢的技术,还是那句话,希望咱们一起进步。今天周五,讲讲...
- 记一次工控触摸屏故障的处理(工控触摸屏维修)
-
先说明一下,虽然我是自动化专业毕业,但已经很多年不从事现场一线的工控工作了。但自己在单位做的工作也牵涉到信息化与自动化的整合,所以平时也略有关注。上一周一个朋友接到一个活,一家光伏企业用于启动机组的触...
- 19、90秒快速“读懂”路由、交换命令行基础
-
命令行视图VRP分层的命令结构定义了很多命令行视图,每条命令只能在特定的视图中执行。本例介绍了常见的命令行视图。每个命令都注册在一个或多个命令视图下,用户只有先进入这个命令所在的视图,才能运行相应的命...
- 摄像头没图像的几个检查方法(摄像头没图像怎么修复)
-
背景描述:安防监控项目上,用户的摄像头运行了一段时间有部分摄像头不能进行预览,需要针对不能预览的摄像头进行排查,下面列出几个常见的排查方法。问题解决:一般情况为网络、供电、设备配置等情况。一,网络检查...
- 小谈:必需脂肪酸(必需脂肪酸主要包括)
-
必需脂肪酸是指机体生命活动必不可少,但机体自身又不能合成,必需由食物供给的多不饱和脂肪酸(PUFA)。必需脂肪酸主要包括两种,一种是ω-3系列的α-亚麻酸(18:3),一种是ω-6系列的亚油酸(18:...
- 期刊推荐:15本sci四区易发表的机械类期刊
-
虽然,Sci四区期刊相比收录在sci一区、二区、三区的期刊来说要求不是那么高,投稿起来也相对容易一些。但,sci四区所收录的期刊中每本期刊的投稿难易程度也是不一样的。为方便大家投稿,本文给大家推荐...
- be sick of 用法考察(be in lack of的用法)
-
besick表示病了,做谓语.本身是形容词,有多种意思.最通常的是:生病,恶心,呕吐,不适,晕,厌烦,无法忍受asickchild生病的孩子Hermother'sverysi...
- 一周热门
- 最近发表
-
- AI「自我复制」能力曝光,RepliBench警示:大模型正在学会伪造身份
- 【Python第三方库安装】介绍8种情况,这里最全看这里就够了!
- pyvips,一个神奇的 Python 库!(pythonvip视频)
- mac 安装tesseract、pytesseract以及简单使用
- 实测o3/o4-mini:3分钟解决欧拉问题,OpenAI最强模型名副其实!
- 使用Python将图片转换为字符画并保存到文件
- 5分钟-python包管理器pip安装(python pip安装包)
- 网络问题快速排查,你也能当好自己家的网络攻城狮
- 终于把TCP/IP 协议讲的明明白白了,再也不怕被问三次握手了
- 记一次工控触摸屏故障的处理(工控触摸屏维修)
- 标签列表
-
- mybatiscollection (79)
- mqtt服务器 (88)
- keyerror (78)
- c#map (65)
- resize函数 (64)
- xftp6 (83)
- bt搜索 (75)
- c#var (76)
- mybatis大于等于 (64)
- xcode-select (66)
- mysql授权 (74)
- 下载测试 (70)
- linuxlink (65)
- pythonwget (67)
- androidinclude (65)
- logstashinput (65)
- hadoop端口 (65)
- vue阻止冒泡 (67)
- oracle时间戳转换日期 (64)
- jquery跨域 (68)
- php写入文件 (73)
- kafkatools (66)
- mysql导出数据库 (66)
- jquery鼠标移入移出 (71)
- 取小数点后两位的函数 (73)