百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 热门文章 > 正文

C++ 中的卷积神经网络 (CNN)

bigegpt 2024-08-07 17:36 10 浏览

有很多卷积神经网络文章解释了 CNN 是什么以及它的用途是什么,而本文将用 C++ 编写一个 CNN 和一个名为 mlpack 的库来对MNIST数据集进行分类。

你们可能会问为什么 C++ 在 Python 中很容易使用大量库,你们现在可能已经看到一些特斯拉汽车,这些类型的系统需要从它们的环境中进行实时推理,而 Python 非常适合原型设计,但不提供实时当使用它部署如此庞大的模型时会更新。

一、mlpack的含义

它是一个用 C++ 编写的机器学习库,它利用其他一些底层库来提供快速且可扩展的尖端机器学习和深度学习方法。

二、MINST数据集

我们要使用的数据包含在一个 CSV 文件中,由 0 到 9 的数字图像组成,其中列包含标签,行包含特征,但是当我们要将数据加载到矩阵中时,数据将被转置,并且提到哪个特征的标签也将被加载,所以我们需要注意这一点。

#include <mlpack/core.hpp>
#include <mlpack/core/data/split_data.hpp>
#include <mlpack/methods/ann/layer/layer.hpp>
#include <mlpack/methods/ann/ffn.hpp>
#include <ensmallen.hpp>  /* The numerical optimization library that mlpack uses */ 

using namespace mlpack;
using namespace mlpack::ann;

// Namespace for the armadillo library(linear algebra library).
using namespace arma;
using namespace std;

// Namespace for ensmallen.
using namespace ens;

然后我们将声明一个辅助函数将模型输出转换为行矩阵,以匹配我们加载的行矩阵形式的标签。

arma::Row<size_t> getLabels(arma::mat predOut)
{
  arma::Row<size_t> predLabels(predOut.n_cols);
  for(arma::uword i = 0; i < predOut.n_cols; ++i)
  {
    predLabels(i) = predOut.col(i).index_mat() + 1;
  }
  return predLabels;
}

在这一部分下面,代码将出现在main函数中,但它的编写并不是为了让代码易于解释。现在我们将声明一些我们需要的明显训练参数,将解释那些突出的参数。

constexpr double RATIO = 0.1; // ratio to divide the data in train and val set.
constexpr int MAX_ITERATIONS = 0; // set to zero to allow infinite iterations.
constexpr double STEP_SIZE = 1.2e-3;// step size for Adam optimizer.
constexpr int BATCH_SIZE = 50;
constexpr size_t EPOCH = 2;

mat tempDataset;
data::Load("train.csv", tempDataset, true);

mat tempTest;
data::Load("test.csv", test, true);

参数 MAX_ITERATIONS 设置为 0,因为这允许我们在一个 epoch 中无限迭代,以便在训练阶段后期使用提前停止。作为旁注,当此参数未设置为 0 时,也可以使用提前停止。

让我们处理和删除描述每一行中包含的内容的列,如我在数据部分所述,并为训练、验证和测试集的标签和特征创建一个单独的矩阵。

mat dataset = tempDataset.submat(0, 1, tempDataset.n_rows - 1, tempDataset.n_cols - 1);

mat test = tempTest.submat(0, 1, tempTest.n_rows - 1, tempTest.n_cols - 1);

mat train, valid;
data::Split(dataset, train, valid, RATIO);

const mat trainX = train.submat(1, 0, train.n_rows - 1, train.n_cols - 1);
const mat validX = valid.submat(1, 0, valid.n_rows - 1, valid.n_cols - 1);
const mat testX = test.submat(1, 0, test.n_rows - 1, test.n_cols - 1);
const mat trainY = train.row(0) + 1;
const mat validY = valid.row(0) + 1;
const mat testY = test.row(0) + 1;

我们将使用负对数似然损失,在 mlpack 库中,它的标签从 1 而不是 0 开始,因此我们在标签中添加了 1。

三、卷积框架

现在让我们看一下我们将要定义的简单卷积架构。

FFN<NegativeLogLikelihood<>, RandomInitialization> model;

model.Add<Convolution<>>(1,  // Number of input activation maps.
                         6,  // Number of output activation maps.
                         5,  // Filter width.
                         5,  // Filter height.
                         1,  // Stride along width.
                         1,  // Stride along height.
                         0,  // Padding width.
                         0,  // Padding height.
                         28, // Input width.
                         28  // Input height.
);

model.Add<ReLULayer<>>();

model.Add<MaxPooling<>>(2, // Width of field.
                        2, // Height of field.
                        2, // Stride along width.
                        2, // Stride along height.
                        true);

model.Add<Convolution<>>(6, // Number of input activation maps.
                        16, // Number of output activation maps.
                         5, // Filter width.
                         5, // Filter height.
                         1, // Stride along width.
                         1, // Stride along height.
                         0, // Padding width.
                         0, // Padding height.
                        12, // Input width.
                        12  // Input height.
);

model.Add<ReLULayer<>>();
                         
model.Add<MaxPooling<>>(2, 2, 2, 2, true);
                         
model.Add<Linear<>>(16 * 4 * 4, 10);
                         
model.Add<LogSoftMax<>>();

其他细节的展示:

ens::Adam optimizer(
   STEP_SIZE,  // Step size of the optimizer.
   BATCH_SIZE, // Batch size. Number of data points that are used in each iteration.
   0.9,        // Exponential decay rate for the first moment estimates.
   0.999, // Exponential decay rate for the weighted infinity norm estimates.
   1e-8,  // Value used to initialise the mean squared gradient parameter.
   MAX_ITERATIONS, // Max number of iterations.
   1e-8, // Tolerance.
   true);

model.Train(trainX,
           trainY,
           optimizer,
           ens::PrintLoss(),
           ens::ProgressBar(),
           ens::EarlyStopAtMinLoss(EPOCH),
           ens::EarlyStopAtMinLoss(
               [&](const arma::mat& /* param */)
               {
                 double validationLoss = model.Evaluate(validX, validY);
                 std::cout << "Validation loss: " << validationLoss
                     << "." << std::endl;
                 return validationLoss;
               }));

正如你们可以看到在验证准确性上使用 EarlyStopAtMinLoss,这就是将参数 MAX_ITERATIONS 设置为 0 以让我们定义无限迭代的原因。

mat predOut;
model.Predict(trainX, predOut);
arma::Row<size_t> predLabels = getLabels(predOut);
double trainAccuracy = arma::accu(predLabels == trainY) / ( double )trainY.n_elem * 100;
model.Predict(validX, predOut);
predLabels = getLabels(predOut);
double validAccuracy = arma::accu(predLabels == validY) / ( double )validY.n_elem * 100;
std::cout << "Accuracy: train = " << trainAccuracy << "%,"<< "\t valid = " << validAccuracy << "%" << std::endl;

mat testPredOut;
model.Predict(testX,testPredOut);
arma::Row<size_t> testPred = getLabels(testPredOut)
double testAccuracy = arma::accu(testPredOut == testY) /( double )trainY.n_elem * 100;
std::cout<<"Test Accuracy = "<< testAccuracy;

相关推荐

方差分析简介(方差分析通俗理解)

介绍方差分析(ANOVA,AnalysisofVariance)是一种广泛使用的统计方法,用于比较两个或多个组之间的均值。单因素方差分析是方差分析的一种变体,旨在检测三个或更多分类组的均值是否存在...

正如404页面所预示,猴子正成为断网元凶--吧嗒吧嗒真好吃

吧嗒吧嗒,绘图:MakiNaro你可以通过加热、冰冻、水淹、模塑、甚至压溃压力来使网络光缆硬化。但用猴子显然是不行的。光缆那新挤压成型的塑料外皮太尼玛诱人了,无法阻挡一场试吃盛宴的举行。印度政府正...

Python数据可视化:箱线图多种库画法

概念箱线图通过数据的四分位数来展示数据的分布情况。例如:数据的中心位置,数据间的离散程度,是否有异常值等。把数据从小到大进行排列并等分成四份,第一分位数(Q1),第二分位数(Q2)和第三分位数(Q3)...

多组独立(完全随机设计)样本秩和检验的SPSS操作教程及结果解读

作者/风仕在上一期,我们已经讲完了两组独立样本秩和检验的SPSS操作教程及结果解读,这期开始讲多组独立样本秩和检验,我们主要从多组独立样本秩和检验介绍、两组独立样本秩和检验使用条件及案例的SPSS操作...

方差分析 in R语言 and Excel(方差分析r语言例题)

今天来写一篇实际中比较实用的分析方法,方差分析。通过方差分析,我们可以确定组别之间的差异是否超出了由于随机因素引起的差异范围。方差分析分为单因素方差分析和多因素方差分析,这一篇先介绍一下单因素方差分析...

可视化:前端数据可视化插件大盘点 图表/图谱/地图/关系图

前端数据可视化插件大盘点图表/图谱/地图/关系图全有在大数据时代,很多时候我们需要在网页中显示数据统计报表,从而能很直观地了解数据的走向,开发人员很多时候需要使用图表来表现一些数据。随着Web技术的...

matplotlib 必知的 15 个图(matplotlib各种图)

施工专题,我已完成20篇,施工系列几乎覆盖Python完整技术栈,目标只总结实践中最实用的东西,直击问题本质,快速帮助读者们入门和进阶:1我的施工计划2数字专题3字符串专题4列表专题5流程控制专题6编...

R ggplot2常用图表绘制指南(ggplot2绘制折线图)

ggplot2是R语言中强大的数据可视化包,基于“图形语法”(GrammarofGraphics),通过分层方式构建图表。以下是常用图表命令的详细指南,涵盖基本语法、常见图表类型及示例,适合...

Python数据可视化:从Pandas基础到Seaborn高级应用

数据可视化是数据分析中不可或缺的一环,它能帮助我们直观理解数据模式和趋势。本文将全面介绍Python中最常用的三种可视化方法。Pandas内置绘图功能Pandas基于Matplotlib提供了简洁的绘...

Python 数据可视化常用命令备忘录

本文提供了一个全面的Python数据可视化备忘单,适用于探索性数据分析(EDA)。该备忘单涵盖了单变量分析、双变量分析、多变量分析、时间序列分析、文本数据分析、可视化定制以及保存与显示等内容。所...

统计图的种类(统计图的种类及特点图片)

统计图是利用几何图形或具体事物的形象和地图等形式来表现社会经济现象数量特征和数量关系的图形。以下是几种常见的统计图类型及其适用场景:1.条形图(BarChart)条形图是用矩形条的高度或长度来表示...

实测,大模型谁更懂数据可视化?(数据可视化和可视化分析的主要模型)

大家好,我是Ai学习的老章看论文时,经常看到漂亮的图表,很多不知道是用什么工具绘制的,或者很想复刻类似图表。实测,大模型LaTeX公式识别,出乎预料前文,我用Kimi、Qwen-3-235B...

通过AI提示词让Deepseek快速生成各种类型的图表制作

在数据分析和可视化领域,图表是传达信息的重要工具。然而,传统图表制作往往需要专业的软件和一定的技术知识。本文将介绍如何通过AI提示词,利用Deepseek快速生成各种类型的图表,包括柱状图、折线图、饼...

数据可视化:解析箱线图(box plot)

箱线图/盒须图(boxplot)是数据分布的图形表示,由五个摘要组成:最小值、第一四分位数(25th百分位数)、中位数、第三四分位数(75th百分位数)和最大值。箱子代表四分位距(IQR)。IQR是...

[seaborn] seaborn学习笔记1-箱形图Boxplot

1箱形图Boxplot(代码下载)Boxplot可能是最常见的图形类型之一。它能够很好表示数据中的分布规律。箱型图方框的末尾显示了上下四分位数。极线显示最高和最低值,不包括异常值。seaborn中...