百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 热门文章 > 正文

使用pytorch构建GAN示例 pytorch gather_nd

bigegpt 2024-10-12 05:07 10 浏览

生成对抗网络(GAN)是一对相互学习的学习引擎。一个有用的类比是把伪造者和专家放在一起,两者都学会了超越对方。伪造者更擅长伪造(生成),而专家更擅长鉴别(鉴别)赝品。

生成和判别网络已经存在了一段时间,它们彼此相互映射产生了一些非常有趣的结果,比教导生成器如何生产更好的东西。

下面的Python代码显示了一个非常基本的GAN,它的目的是显示一种方法来配对两个对手。有一些变体,可以随意尝试,将其作为工作原型使用。

我用了pytorch,这是python里的torch的重建。它使创建您自己的机器学习(ML)应用程序非常容易。

示例代码 - 学习正态分布的1D GAN

导入库

import torch

import torch.nn as nn

import torch.optim as optim

from torch.distributions.normal import Normal

定义我们将学习的分布

我们想学习平均3.0和标准差0.4的正态分布。数据将保存为长度为30的数组。

(数组不会被排序,即打印原始数据看起来会很混乱,这只能学习分布,而不是概率分布曲线!)。

data_mean = 3.0

data_stddev = 0.4

Series_Length = 30

定义生成器网络

取20个随机输入并生成一个分布(如上所述),定义了隐层神经元数。

g_input_size = 20

g_hidden_size = 150

g_output_size = Series_Length

定义判别器网络

输出是一个值

  • True(1.0)匹配所需的发行版
  • False(0.0)与分布不匹配

更改隐藏大小以查看效果

d_input_size = Series_Length

d_hidden_size = 75

d_output_size = 1

定义如何向流程发送数据

训练分为两个阶段,分别训练判别器和生成器。判别器比生成器“更好”似乎很重要。有时用于训练判别器的批次比用于生成器的批次多。在这种情况下,我们在判别器训练中投入了更多的内容。

注意,真正的训练使用超过5000个epochs。

d_minibatch_size = 15

g_minibatch_size = 10

num_epochs = 5000

print_interval = 1000

设置学习利率

学习率需要多做这些实验。太小会收敛过慢,太大,我们可能会在一个解周围振荡。

d_learning_rate = 3e-3

g_learning_rate = 8e-3

定义两个函数来返回提供真实样本和一些随机噪声的函数。真实样本训练判别器,随机噪声馈送生成器。

制作signal generator函数的本地副本

def get_real_sampler(mu, sigma):

dist = Normal( mu, sigma )

return lambda m, n: dist.sample( (m, n) ).requires_grad_()

def get_noise_sampler():

return lambda m, n: torch.rand(m, n).requires_grad_() # Uniform-dist data into generator, _NOT_ Gaussian

actual_data = get_real_sampler( data_mean, data_stddev )

noise_data = get_noise_sampler()

生成器

重要的是生成器可以输出匹配的装置。小心使用sigmoid之类的东西,输出0..1。用mean 2.0是学不到东西的!

这是一个非常简单的4层网络,接收噪声并产生输出。

xfer是transfer函数

class Generator(nn.Module):

def __init__(self, input_size, hidden_size, output_size):

super(Generator, self).__init__()

self.map1 = nn.Linear(input_size, hidden_size)

self.map2 = nn.Linear(hidden_size, hidden_size)

self.map3 = nn.Linear(hidden_size, output_size)

self.xfer = torch.nn.SELU()

def forward(self, x):

x = self.xfer( self.map1(x) )

x = self.xfer( self.map2(x) )

return self.xfer( self.map3( x ) )

请注意,最后一层应限制为0..1,这使我们可以在损失函数中进行更多选择。

判别器

这个网络是一个经典的多层感知器——真的没什么特别的。它根据所学习的函数返回true/false。

class Discriminator(nn.Module):

def __init__(self, input_size, hidden_size, output_size):

super(Discriminator, self).__init__()

self.map1 = nn.Linear(input_size, hidden_size)

self.map2 = nn.Linear(hidden_size, hidden_size)

self.map3 = nn.Linear(hidden_size, output_size)

self.elu = torch.nn.ELU()

def forward(self, x):

x = self.elu(self.map1(x))

x = self.elu(self.map2(x))

return torch.sigmoid( self.map3(x) )

创建两个网络

G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)

D = Discriminator(input_size=d_input_size, hidden_size=d_hidden_size, output_size=d_output_size)

设置学习规则:

  • 损失函数
  • 每个网络的优化器

在这里您可以自由选择。

criterion = nn.BCELoss()

d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate )

g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate )

实际数据的训练函数

应该学习真实数据=> 1.0

def train_D_on_actual() :

real_data = actual_data( d_minibatch_size, d_input_size )

real_decision = D( real_data )

real_error = criterion( real_decision, torch.ones( d_minibatch_size, 1 )) # ones = true

real_error.backward()

生成器数据的训练函数

学习将生成的数据识别为假=> 0.0

def train_D_on_generated() :

noise = noise_data( d_minibatch_size, g_input_size )

fake_data = G( noise )

fake_decision = D( fake_data )

fake_error = criterion( fake_decision, torch.zeros( d_minibatch_size, 1 )) # zeros = fake

fake_error.backward()

判别器的训练函数

假设生成器产生完美的数据(即判别器返回1.0)。然后学习如何根据判别器的实际输出来改善发生器的输出。

这是GAN的关键部分:通过两个网络传递错误,但只更新生成器权重。

def train_G():

noise = noise_data( g_minibatch_size, g_input_size )

fake_data = G( noise )

fake_decision = D( fake_data )

error = criterion( fake_decision, torch.ones( g_minibatch_size, 1 ) )

error.backward()

return error.item(), fake_data

算法

算法的工作原理如下:

第1步是普通的批量学习,如果删除了其余代码,您将拥有一个可以识别所需分发的网络

  • 训练鉴别器就像你训练任何网络一样
  • 使用真假(生成)样本来学习

第2步是GAN差值

  • 训练生成器生产,但不要将输出与良好的样品进行比较
  • 通过判别器提供样本生成的输出以发现假的
  • 通过判别器和生成器反向传播误差

因此,让我们考虑可能的情况(在所有情况下,仅在步骤2中更新生成器参数)

Discrimator perfect,Generator Perfect Generator生成一个标识为1.0的样本。Error是0.0,没有学习

Discrimator perfect,Generator Rubbish Generator产生的噪声被识别为0.0。Error是1.0,传播错误并且生成器学习

Discrimator rubbish,Generator Perfect Generator生成标识为0.0的样本。Error是1.0,传播错误,发生器不会学到太多,因为error 将被鉴别器吸收

Discrimator rubbish,Generator Rubbish Generator生成标识为0.5的样本。Error为0.5,误差传播到鉴别器和发生器中的梯度将意味着误差在两者之间共享并且学习发生

此步骤可能很慢 - 具体取决于可用的计算能力

losses = []

for epoch in range(num_epochs):

D.zero_grad()

train_D_on_actual()

train_D_on_generated()

d_optimizer.step()

G.zero_grad()

loss,generated = train_G()

g_optimizer.step()

losses.append( loss )

if( epoch % print_interval) == (print_interval-1) :

print( "Epoch %6d. Loss %5.3f" % ( epoch+1, loss ) )

print( "Training complete" )

显示结果

训练完毕后,我们将生成一些样本并绘制它们。很容易看到我们有一个很好的正态分布。这一步完全是可选的,但很高兴看到我们实际工作的内容。

import matplotlib.pyplot as plt

def draw( data ) :

plt.figure()

d = data.tolist() if isinstance(data, torch.Tensor ) else data

plt.plot( d )

plt.show()

d = torch.empty( generated.size(0), 53 )

for i in range( 0, d.size(0) ) :

d[i] = torch.histc( generated[i], min=0, max=5, bins=53 )

draw( d.t() )

相关推荐

方差分析简介(方差分析通俗理解)

介绍方差分析(ANOVA,AnalysisofVariance)是一种广泛使用的统计方法,用于比较两个或多个组之间的均值。单因素方差分析是方差分析的一种变体,旨在检测三个或更多分类组的均值是否存在...

正如404页面所预示,猴子正成为断网元凶--吧嗒吧嗒真好吃

吧嗒吧嗒,绘图:MakiNaro你可以通过加热、冰冻、水淹、模塑、甚至压溃压力来使网络光缆硬化。但用猴子显然是不行的。光缆那新挤压成型的塑料外皮太尼玛诱人了,无法阻挡一场试吃盛宴的举行。印度政府正...

Python数据可视化:箱线图多种库画法

概念箱线图通过数据的四分位数来展示数据的分布情况。例如:数据的中心位置,数据间的离散程度,是否有异常值等。把数据从小到大进行排列并等分成四份,第一分位数(Q1),第二分位数(Q2)和第三分位数(Q3)...

多组独立(完全随机设计)样本秩和检验的SPSS操作教程及结果解读

作者/风仕在上一期,我们已经讲完了两组独立样本秩和检验的SPSS操作教程及结果解读,这期开始讲多组独立样本秩和检验,我们主要从多组独立样本秩和检验介绍、两组独立样本秩和检验使用条件及案例的SPSS操作...

方差分析 in R语言 and Excel(方差分析r语言例题)

今天来写一篇实际中比较实用的分析方法,方差分析。通过方差分析,我们可以确定组别之间的差异是否超出了由于随机因素引起的差异范围。方差分析分为单因素方差分析和多因素方差分析,这一篇先介绍一下单因素方差分析...

可视化:前端数据可视化插件大盘点 图表/图谱/地图/关系图

前端数据可视化插件大盘点图表/图谱/地图/关系图全有在大数据时代,很多时候我们需要在网页中显示数据统计报表,从而能很直观地了解数据的走向,开发人员很多时候需要使用图表来表现一些数据。随着Web技术的...

matplotlib 必知的 15 个图(matplotlib各种图)

施工专题,我已完成20篇,施工系列几乎覆盖Python完整技术栈,目标只总结实践中最实用的东西,直击问题本质,快速帮助读者们入门和进阶:1我的施工计划2数字专题3字符串专题4列表专题5流程控制专题6编...

R ggplot2常用图表绘制指南(ggplot2绘制折线图)

ggplot2是R语言中强大的数据可视化包,基于“图形语法”(GrammarofGraphics),通过分层方式构建图表。以下是常用图表命令的详细指南,涵盖基本语法、常见图表类型及示例,适合...

Python数据可视化:从Pandas基础到Seaborn高级应用

数据可视化是数据分析中不可或缺的一环,它能帮助我们直观理解数据模式和趋势。本文将全面介绍Python中最常用的三种可视化方法。Pandas内置绘图功能Pandas基于Matplotlib提供了简洁的绘...

Python 数据可视化常用命令备忘录

本文提供了一个全面的Python数据可视化备忘单,适用于探索性数据分析(EDA)。该备忘单涵盖了单变量分析、双变量分析、多变量分析、时间序列分析、文本数据分析、可视化定制以及保存与显示等内容。所...

统计图的种类(统计图的种类及特点图片)

统计图是利用几何图形或具体事物的形象和地图等形式来表现社会经济现象数量特征和数量关系的图形。以下是几种常见的统计图类型及其适用场景:1.条形图(BarChart)条形图是用矩形条的高度或长度来表示...

实测,大模型谁更懂数据可视化?(数据可视化和可视化分析的主要模型)

大家好,我是Ai学习的老章看论文时,经常看到漂亮的图表,很多不知道是用什么工具绘制的,或者很想复刻类似图表。实测,大模型LaTeX公式识别,出乎预料前文,我用Kimi、Qwen-3-235B...

通过AI提示词让Deepseek快速生成各种类型的图表制作

在数据分析和可视化领域,图表是传达信息的重要工具。然而,传统图表制作往往需要专业的软件和一定的技术知识。本文将介绍如何通过AI提示词,利用Deepseek快速生成各种类型的图表,包括柱状图、折线图、饼...

数据可视化:解析箱线图(box plot)

箱线图/盒须图(boxplot)是数据分布的图形表示,由五个摘要组成:最小值、第一四分位数(25th百分位数)、中位数、第三四分位数(75th百分位数)和最大值。箱子代表四分位距(IQR)。IQR是...

[seaborn] seaborn学习笔记1-箱形图Boxplot

1箱形图Boxplot(代码下载)Boxplot可能是最常见的图形类型之一。它能够很好表示数据中的分布规律。箱型图方框的末尾显示了上下四分位数。极线显示最高和最低值,不包括异常值。seaborn中...