百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 热门文章 > 正文

「AI实战」快速掌握TensorFlow(四):损失函数

bigegpt 2024-10-07 06:34 16 浏览



在前面的文章中,我们已经学习了TensorFlow激励函数的操作使用方法(见文章:快速掌握TensorFlow(三)),今天我们将继续学习TensorFlow。

本文主要是学习掌握TensorFlow的损失函数。

一、什么是损失函数

损失函数(loss function)是机器学习中非常重要的内容,它是度量模型输出值与目标值的差异,也就是作为评估模型效果的一种重要指标,损失函数越小,表明模型的鲁棒性就越好。

二、怎样使用损失函数

在TensorFlow中训练模型时,通过损失函数告诉TensorFlow预测结果相比目标结果是好还是坏。在多种情况下,我们会给出模型训练的样本数据和目标数据,损失函数即是比较预测值与给定的目标值之间的差异。

下面将介绍在TensorFlow中常用的损失函数。

1、回归模型的损失函数

首先讲解回归模型的损失函数,回归模型是预测连续因变量的。为方便介绍,先定义预测结果(-1至1的等差序列)、目标结果(目标值为0),代码如下:

import tensorflow as tf
sess=tf.Session()
y_pred=tf.linspace(-1., 1., 100)
y_target=tf.constant(0.)

注意,在实际训练模型时,预测结果是模型输出的结果值,目标结果是样本提供的。

(1)L1正则损失函数(即绝对值损失函数)

L1正则损失函数是对预测值与目标值的差值求绝对值,公式如下:


在TensorFlow中调用方式如下:

loss_l1_vals=tf.abs(y_pred-y_target)
loss_l1_out=sess.run(loss_l1_vals)

L1正则损失函数在目标值附近不平滑,会导致模型不能很好地收敛。

(2)L2正则损失函数(即欧拉损失函数)

L2正则损失函数是预测值与目标值差值的平方和,公式如下:


当对L2取平均值,就变成均方误差(MSE, mean squared error),公式如下:


在TensorFlow中调用方式如下:

# L2损失
loss_l2_vals=tf.square(y_pred - y_target)
loss_l2_out=sess.run(loss_l2_vals)
# 均方误差
loss_mse_vals= tf.reduce.mean(tf.square(y_pred - y_target))
loss_mse_out = sess.run(loss_mse_vals)

L2正则损失函数在目标值附近有很好的曲度,离目标越近收敛越慢,是非常有用的损失函数。

L1、L2正则损失函数如下图所示:


(3)Pseudo-Huber 损失函数

Huber损失函数经常用于回归问题,它是分段函数,公式如下:


从这个公式可以看出当残差(预测值与目标值的差值,即y-f(x) )很小的时候,损失函数为L2范数,残差大的时候,为L1范数的线性函数。

Peseudo-Huber损失函数是Huber损失函数的连续、平滑估计,在目标附近连续,公式如下:


该公式依赖于参数delta,delta越大,则两边的线性部分越陡峭。

在TensorFlow中的调用方式如下:

delta=tf.constant(0.25)
loss_huber_vals = tf.mul(tf.square(delta), tf.sqrt(1. + tf.square(y_target – y_pred)/delta)) – 1.)
loss_huber_out = sess.run(loss_huber_vals)

L1、L2、Huber损失函数的对比图如下,其中Huber的delta取0.25、5两个值:


2、分类模型的损失函数

分类损失函数主要用于评估预测分类结果,重新定义预测值(-3至5的等差序列)和目标值(目标值为1),如下:

y_pred=tf.linspace(-3., 5., 100)
y_target=tf.constant(1.)
y_targets=tf.fill([100, ], 1.)

(1)Hinge损失函数

Hinge损失常用于二分类问题,主要用来评估向量机算法,但有时也用来评估神经网络算法,公式如下:


在TensorFlow中的调用方式如下:

loss_hinge_vals = tf.maximum(0., 1. – tf.mul(y_target, y_pred))
loss_hinge_out = sess.run(loss_hinge_vals)

上面的代码中,目标值为1,当预测值离1越近,则损失函数越小,如下图:


(2)两类交叉熵(Cross-entropy)损失函数

交叉熵来自于信息论,是分类问题中使用广泛的损失函数。交叉熵刻画了两个概率分布之间的距离,当两个概率分布越接近时,它们的交叉熵也就越小,给定两个概率分布p和q,则距离如下:


对于两类问题,当一个概率p=y,则另一个概率q=1-y,因此代入化简后的公式如下:


在TensorFlow中的调用方式如下:

loss_ce_vals = tf.mul(y_target, tf.log(y_pred)) – tf.mul((1. – y_target), tf.log(1. – y_pred))
loss_ce_out = sess.run(loss_ce_vals)

Cross-entropy损失函数主要应用在二分类问题上,预测值为概率值,取值范围为[0,1],损失函数图如下:


(3)Sigmoid交叉熵损失函数

与上面的两类交叉熵类似,只是将预测值y_pred值通过sigmoid函数进行转换,再计算交叉熵损失。在TensorFlow中有内置了该函数,调用方式如下:

loss_sce_vals=tf.nn.sigmoid_cross_entropy_with_logits(y_pred, y_targets)
loss_sce_out=sess.run(loss_sce_vals)

由于sigmoid函数会将输入值变小很多,从而平滑了预测值,使得sigmoid交叉熵在预测值离目标值比较远时,其损失的增长没有那么的陡峭。与两类交叉熵的比较图如下:


(4)加权交叉熵损失函数

加权交叉熵损失函数是Sigmoid交叉熵损失函数的加权,是对正目标的加权。假定权重为0.5,在TensorFlow中的调用方式如下:

weight = tf.constant(0.5)
loss_wce_vals = tf.nn.weighted_cross_entropy_with_logits(y)vals, y_targets, weight)
loss_wce_out = sess.run(loss_wce_vals)

(5)Softmax交叉熵损失函数

Softmax交叉熵损失函数是作用于非归一化的输出结果,只针对单个目标分类计算损失。

通过softmax函数将输出结果转化成概率分布,从而便于输入到交叉熵里面进行计算(交叉熵要求输入为概率),softmax定义如下:


结合前面的交叉熵定义公式,则Softmax交叉熵损失函数公式如下:


在TensorFlow中调用方式如下:

y_pred=tf.constant([[1., -3., 10.]]
y_target=tf.constant([[0.1, 0.02, 0.88]])
loss_sce_vals=tf.nn.softmax_cross_entropy_with_logits(y_pred, y_target)
loss_sce_out=sess.run(loss_sce_vals)

用于回归相关的损失函数,对比图如下:


3、总结

下面对各种损失函数进行一个总结,如下表所示:


在实际使用中,对于回归问题经常会使用MSE均方误差(L2取平均)计算损失,对于分类问题经常会使用Sigmoid交叉熵损失函数。

大家在使用时,还要根据实际的场景、具体的模型,选择使用的损失函数,希望本文对你有帮助。

接下来的“快速掌握TensorFlow”系列文章,还会有更多讲解TensorFlow的精彩内容,敬请期待。

欢迎关注本人的微信公众号“大数据与人工智能Lab”(BigdataAILab),获取更多信息

推荐相关阅读

  • 【AI实战】快速掌握TensorFlow(一):基本操作
  • 【AI实战】快速掌握TensorFlow(二):计算图、会话
  • 【AI实战】快速掌握TensorFlow(三):激励函数
  • 【AI实战】快速掌握TensorFlow(四):损失函数
  • 【AI实战】搭建基础环境
  • 【AI实战】训练第一个模型
  • 【AI实战】编写人脸识别程序
  • 【AI实战】动手训练目标检测模型(SSD篇)
  • 【AI实战】动手训练目标检测模型(YOLO篇)
  • 【精华整理】CNN进化史
  • 大话卷积神经网络(CNN)
  • 大话循环神经网络(RNN)
  • 大话深度残差网络(DRN)
  • 大话深度信念网络(DBN)
  • 大话CNN经典模型:LeNet
  • 大话CNN经典模型:AlexNet
  • 大话CNN经典模型:VGGNet
  • 大话CNN经典模型:GoogLeNet
  • 大话目标检测经典模型:RCNN、Fast RCNN、Faster RCNN
  • 大话目标检测经典模型:Mask R-CNN
  • 27种深度学习经典模型
  • 浅说“迁移学习”
  • 什么是“强化学习”
  • AlphaGo算法原理浅析
  • 大数据究竟有多少个V
  • Apache Hadoop 2.8 完全分布式集群搭建超详细教程
  • Apache Hive 2.1.1 安装配置超详细教程
  • Apache HBase 1.2.6 完全分布式集群搭建超详细教程
  • 离线安装Cloudera Manager 5和CDH5(最新版5.13.0)超详细教程

关注本人公众号“大数据与人工智能Lab”(BigdataAILab),获取更多信息

K码农提供了很多不同领域技术,包含人工智能,android,ios ,前端,后端,大数据,云计算,区块链,物联网等大量的技术:http://kmanong.top

相关推荐

当Frida来“敲”门(frida是什么)

0x1渗透测试瓶颈目前,碰到越来越多的大客户都会将核心资产业务集中在统一的APP上,或者对自己比较重要的APP,如自己的主业务,办公APP进行加壳,流量加密,投入了很多精力在移动端的防护上。而现在挖...

服务端性能测试实战3-性能测试脚本开发

前言在前面的两篇文章中,我们分别介绍了性能测试的理论知识以及性能测试计划制定,本篇文章将重点介绍性能测试脚本开发。脚本开发将分为两个阶段:阶段一:了解各个接口的入参、出参,使用Python代码模拟前端...

Springboot整合Apache Ftpserver拓展功能及业务讲解(三)

今日分享每天分享技术实战干货,技术在于积累和收藏,希望可以帮助到您,同时也希望获得您的支持和关注。架构开源地址:https://gitee.com/msxyspringboot整合Ftpserver参...

Linux和Windows下:Python Crypto模块安装方式区别

一、Linux环境下:fromCrypto.SignatureimportPKCS1_v1_5如果导包报错:ImportError:Nomodulenamed'Crypt...

Python 3 加密简介(python des加密解密)

Python3的标准库中是没多少用来解决加密的,不过却有用于处理哈希的库。在这里我们会对其进行一个简单的介绍,但重点会放在两个第三方的软件包:PyCrypto和cryptography上,我...

怎样从零开始编译一个魔兽世界开源服务端Windows

第二章:编译和安装我是艾西,上期我们讲述到编译一个魔兽世界开源服务端环境准备,那么今天跟大家聊聊怎么编译和安装我们直接进入正题(上一章没有看到的小伙伴可以点我主页查看)编译服务端:在D盘新建一个文件夹...

附1-Conda部署安装及基本使用(conda安装教程)

Windows环境安装安装介质下载下载地址:https://www.anaconda.com/products/individual安装Anaconda安装时,选择自定义安装,选择自定义安装路径:配置...

如何配置全世界最小的 MySQL 服务器

配置全世界最小的MySQL服务器——如何在一块IntelEdison为控制板上安装一个MySQL服务器。介绍在我最近的一篇博文中,物联网,消息以及MySQL,我展示了如果Partic...

如何使用Github Action来自动化编译PolarDB-PG数据库

随着PolarDB在国产数据库领域荣膺桂冠并持续获得广泛认可,越来越多的学生和技术爱好者开始关注并涉足这款由阿里巴巴集团倾力打造且性能卓越的关系型云原生数据库。有很多同学想要上手尝试,却卡在了编译数据...

面向NDK开发者的Android 7.0变更(ndk android.mk)

订阅Google官方微信公众号:谷歌开发者。与谷歌一起创造未来!受Android平台其他改进的影响,为了方便加载本机代码,AndroidM和N中的动态链接器对编写整洁且跨平台兼容的本机...

信创改造--人大金仓(Kingbase)数据库安装、备份恢复的问题纪要

问题一:在安装KingbaseES时,安装用户对于安装路径需有“读”、“写”、“执行”的权限。在Linux系统中,需要以非root用户执行安装程序,且该用户要有标准的home目录,您可...

OpenSSH 安全漏洞,修补操作一手掌握

1.漏洞概述近日,国家信息安全漏洞库(CNNVD)收到关于OpenSSH安全漏洞(CNNVD-202407-017、CVE-2024-6387)情况的报送。攻击者可以利用该漏洞在无需认证的情况下,通...

Linux:lsof命令详解(linux lsof命令详解)

介绍欢迎来到这篇博客。在这篇博客中,我们将学习Unix/Linux系统上的lsof命令行工具。命令行工具是您使用CLI(命令行界面)而不是GUI(图形用户界面)运行的程序或工具。lsoflsof代表&...

幻隐说固态第一期:固态硬盘接口类别

前排声明所有信息来源于网络收集,如有错误请评论区指出更正。废话不多说,目前固态硬盘接口按速度由慢到快分有这几类:SATA、mSATA、SATAExpress、PCI-E、m.2、u.2。下面我们来...

新品轰炸 影驰SSD多款产品登Computex

分享泡泡网SSD固态硬盘频道6月6日台北电脑展作为全球第二、亚洲最大的3C/IT产业链专业展,吸引了众多IT厂商和全球各地媒体的热烈关注,全球存储新势力—影驰,也积极参与其中,为广大玩家朋友带来了...