百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 热门文章 > 正文

几行代码构建全功能对象检测模型,这位杜克大学学生做到了

bigegpt 2024-08-04 11:19 1 浏览

作者 | Alan Bi

译者 | 武明利 责编 | Carol

出品 | AI科技大本营(ID:rgznai100)

如今,机器学习和计算机视觉已成为一种热潮。我们都看过关于自动驾驶汽车和面部识别的新闻,可能会想象建立自己的计算机视觉模型有多酷。然而,进入这个领域并不总是那么容易,尤其是在没有很强的数学背景的情况下。如果你只想做一些小的实验,像PyTorch和TensorFlow这样的库可能会很枯燥。

在本教程中,作者提供了一种简单的方法,任何人都可以使用几行代码构建全功能的对象检测模型。更具体地说,我们将使用Detecto,这是一个在PyTorch之上构建的Python软件包,可简化该过程并向所有级别的程序员开放。

快速简单的例子

为了演示如何简单地使Detecto,让我们加载一个预先训练的模型,并对以下图像进行推断:

首先,使用pip下载Detecto软件包:

pip3 install detecto

然后,将上面的图像另存为“fruit.jpg”,并在与图像相同的文件夹中创建一个Python文件。在Python文件中,编写以下5行代码:

from detectoimport core, utils, visualize

image = utils.read_image('fruit.jpg')

model = core.Model

labels, boxes, scores = model.predict_top(image)

visualize.show_labeled_image(image, boxes, labels)

运行此文件后(如果你的计算机上没有启用CUDA的GPU,可能会花费几秒钟;稍后再进行介绍),你应该会看到类似下面的图:

作者仅用了5行代码就完成了所有工作,真的是太棒了。下面是我们每步中分别做的:

1)导入Detecto模块

2)读入图像

3)初始化预训练模型

4)在图像上生成最高预测

5)为预测绘图

绘制我们的预测

Detecto使用来自PyTorch模型动物园中的Faster R-CNN ResNet-50 FPN,它能够检测大约80种不同的物体,例如动物,车辆,厨房用具等。但是,如果你想要检测自定义对象,例如可口可乐与百事可乐罐,斑马与长颈鹿,该怎么办呢?

这时你会发现,在自定义数据集上训练探测器模型同样简单; 同样,你只需要5行代码,以及现有的数据集或花一些时间标记图像。

构建自定义数据集

在本教程中,作者将从头开始构建自己的数据集。建议你也这样做,但是如果你想跳过这一步,你可以在这里下载一个示例数据集(从斯坦福的Dog数据集修改)。

对于我们的数据集,我们将训练我们的模型来检测来自RoboSub竞赛的水下外星人,蝙蝠和女巫,如下所示:

理想情况下,每个类至少需要100张图像。好在每张图像中可以有多个对象,所以理论上,如果每张图像包含你想要检测的每类对象,那么你可以总共获得100张图像。另外,如果你有视频素材,Detico可以轻松地将这些视频素材分割成可用于数据集的图像:

from detecto.utilsimport split_video

split_video('video.mp4','frames/', step_size=4)

上面的代码在“video.mp4”中每第4帧拍摄一次,并将其另存为JPEG文件存在“frames”文件夹中。

生成训练数据集后,应该具有一个类似于以下内容的文件夹:

images/

| image0.jpg

| image1.jpg

| image2.jpg

| ...

如果需要的话,你还可以使用另一个文件夹,其中包含一组验证图像。

现在是耗时的部分:标记。Detecto支持PASCAL VOC格式,其中具有XML文件,其中包含图像中每个对象的标签和位置数据。要创建这些XML文件,可以使用开源LabelImg工具,如下所示:

pip3 install labelImg # Download LabelImg using pip

labelImg # Launch the application

现在,你应该会看到一个弹出窗口。单击左侧“打开目录”按钮,然后选择想要标记的图像文件夹。如果一切正常,你应该会看到类似以下内容:

要绘制边界框,请单击左侧菜单栏中的图标(或使用键盘快捷键“w”)。然后,你可以在对象周围拖动一个框并编写/选择标签:

标记完图像后,请使用CTRL+S或CMD+S保存XML文件(为简便起见,你可以使用自动填充的默认文件位置和名称)。要标记下一张图像,请单击“下一张图像”(或使用键盘快捷键“d”)。

整个数据集处理完毕之后,你的文件夹应如下所示:

images/

| image0.jpg

| image0.xml

| image1.jpg

| image1.xml

| ...

我们已经准备好开始训练我们的对象检测模型了!

访问GPU

首先,检查你的计算机是否具有启用CUDA的GPU。由于深度学习需要大量处理能力,因此在通常的CPU上进行训练可能会非常缓慢。值得庆幸的是,大多数现代深度学习框架(例如PyTorch和Tensorflow)都可以在GPU上运行,从而使处理速度更快。确保已经下载了PyTorch(如果你安装了Detecto,应该已经下载了),然后运行以下两行代码:

import torch

print(torch.cuda.is_available)

如果打印True,那你可以跳到下一部分。如果显示False,不要担心。请按照以下步骤创建Google Colaboratory笔记本,这是一个在线编码环境,带有免费可用的GPU。对于本教程,你将只在Google Drive文件夹中工作,而不是在计算机上工作。

1)登录到Google Drive

2)创建一个名为“Detecto Tutorial”的文件夹并导航到该文件夹

3)将你的训练图像(和/或验证图像)上传到此文件夹

4)右键单击,转到“更多”,然后单击“Google Colaboratory”:

你现在应该看到这样的界面:

5)根据需要给笔记本起个名字,然后转到“编辑”->“笔记本设置”->“硬件加速器”,然后选择“GPU”

6)输入以下代码以“装入”你的云端硬盘,将目录更改为当前文件夹,然后安装Detecto:

import os

from google.colabimport drive

drive.mount('/content/drive')

os.chdir('/content/drive/My Drive/Detecto Tutorial')

!pip install detecto

为了确保一切正常,你可以创建一个新的代码单元,然后输入!ls以检查你是否处于正确的目录中。

训练自定义模型

最后,我们现在可以在自定义数据集上训练模型了。如前所述,这是容易的部分。它只需要4行代码:

from detectoimport core, utils, visualize

dataset = core.Dataset('images/')

model = core.Model(['alien','bat','witch'])

model.fit(dataset)

让我们再次分解一下我们每行代码所做的工作:

1、导入的Detecto模块

2、从“images”文件夹(包含我们的JPEG和XML文件)创建了一个数据集

3、初始化模型检测自定义对象(外星人,蝙蝠和女巫)

4、在数据集上训练我们的模型

根据数据集的大小,这可能需要10分钟到1个小时以上的时间来运行,因此请确保你的程序在完成上述语句后不会立即退出(例如:你使用的是Jupyter / Colab笔记本,它在活动时保留状态)。

使用训练好的模型

现在你已经有了训练好的模型,让我们在一些图像上对其进行测试。要从文件路径读取图像,可以使用detecto.utils模块中的read_image函数(也可以使用上面创建的数据集中的图像):

# Specify the path to your image

image = utils.read_image('images/image0.jpg')

predictions = model.predict(image)

# predictions format: (labels, boxes, scores)

labels, boxes, scores = predictions

# ['alien', 'bat', 'bat']

print(labels)

# xmin ymin xmax ymax

# tensor([[ 569.2125, 203.6702, 1003.4383, 658.1044],

# [ 276.2478, 144.0074, 579.6044, 508.7444],

# [ 277.2929, 162.6719, 627.9399, 511.9841]])

print(boxes)

# tensor([0.9952, 0.9837, 0.5153])

print(scores)

正像你看到的,模型的预测方法返回一个由3个元素组成的元组:标签,方框和分数。

在上面的示例中,此模型在坐标[569、204、1003、658](框[0])处,

预测了一个外星人(标签[0]),其置信度为0.995(得分[0])。

根据这些预测,我们可以使用detecto.visualize模块绘制结果。例如:

visualize.show_labeled_image(image, boxes, labels)

将上面的代码与收到的图像和预测一起运行将产生如下所示的内容:

如果你有一个视频,你可以在它上面运行对象检测:

visualize.detect_video(model,'input.mp4','output.avi')

这将获取一个名为“input.mp4”的视频文件,并根据给定模型的预测结果生成一个“output.avi”文件。如果你使用VLC或其他视频播放器打开此文件,应该会看到一些希望看到的结果!

最后,你可以从文件中保存和加载模型,从而可以保存进度并稍后返回:

model.save('model_weights.pth')

# ... Later ...

model = core.Model.load('model_weights.pth', ['alien','bat','witch'])

高级用法

你会发现Detecto不仅限于5行代码。举例来说,这个模型没有你希望的那么好。我们可以尝试通过使用Torchvision转换来扩展我们的数据集并定义一个自定义数据加载器来提高其性能:

from torchvisionimport transforms

augmentations = transforms.Compose([

transforms.ToPILImage(),

transforms.RandomHorizontalFlip(0.5),

transforms.ColorJitter(saturation=0.5),

transforms.ToTensor(),

utils.normalize_transform(),

])

dataset = core.Dataset('images/', transform=augmentations)

loader = core.DataLoader(dataset, batch_size=2, shuffle=True)

此代码对数据集中的图像应用了随机的水平翻转和饱和效果,从而增加了数据的多样性。然后,我们使用batch_size = 2定义一个数据加载对象;我们将其传递给model.fit而不是Dataset,这样来告诉我们的模型是对2张图像进行批量训练,而不是默认的1张。

如果你之前创建了单独的验证数据集,那么现在是在训练期间加载它的时候了。通过提供验证数据集,fit方法将返回每个时期的损失列表,如果verbose = True,则会在训练过程中将其打印出来。以下代码块演示了这一点,并自定义了其他几个训练参数:

import matplotlib.pyplotas plt

val_dataset = core.Dataset('validation_images/')

losses = model.fit(loader, val_dataset, epochs=10, learning_rate=0.001,

lr_step_size=5, verbose=True)

plt.plot(losses)

plt.show

损失的结果图应或多或少地减少:

为了更具有灵活性和对模型的控制,你可以完全绕过Detecto。你可以根据需要随意调整model.get_internal_model方法返回使用的基础模型。

结论

在本教程中,作者展示了计算机视觉和对象检测不需要具有挑战性。你所需要的是一点时间和耐心来处理标记的数集。

如果你对进一步探索感兴趣的话,请查看Detecto on GitHub或访问文档以获取更多教程和用例!

原文:https://hackernoon.com/build-a-custom-trained-object-detection-model-with-5-lines-of-code-y08n33vi

本文为 CSDN 翻译,转载请注明来源出处。

相关推荐

得物可观测平台架构升级:基于GreptimeDB的全新监控体系实践

一、摘要在前端可观测分析场景中,需要实时观测并处理多地、多环境的运行情况,以保障Web应用和移动端的可用性与性能。传统方案往往依赖代理Agent→消息队列→流计算引擎→OLAP存储...

warm-flow新春版:网关直连和流程图重构

本期主要解决了网关直连和流程图重构,可以自此之后可支持各种复杂的网关混合、多网关直连使用。-新增Ruoyi-Vue-Plus优秀开源集成案例更新日志[feat]导入、导出和保存等新增json格式支持...

扣子空间体验报告

在数字化时代,智能工具的应用正不断拓展到我们工作和生活的各个角落。从任务规划到项目执行,再到任务管理,作者深入探讨了这款工具在不同场景下的表现和潜力。通过具体的应用实例,文章展示了扣子空间如何帮助用户...

spider-flow:开源的可视化方式定义爬虫方案

spider-flow简介spider-flow是一个爬虫平台,以可视化推拽方式定义爬取流程,无需代码即可实现一个爬虫服务。spider-flow特性支持css选择器、正则提取支持JSON/XML格式...

solon-flow 你好世界!

solon-flow是一个基础级的流处理引擎(可用于业务规则、决策处理、计算编排、流程审批等......)。提供有“开放式”驱动定制支持,像jdbc有mysql或pgsql等驱动,可...

新一代开源爬虫平台:SpiderFlow

SpiderFlow:新一代爬虫平台,以图形化方式定义爬虫流程,不写代码即可完成爬虫。-精选真开源,释放新价值。概览Spider-Flow是一个开源的、面向所有用户的Web端爬虫构建平台,它使用Ja...

通过 SQL 训练机器学习模型的引擎

关注薪资待遇的同学应该知道,机器学习相关的岗位工资普遍偏高啊。同时随着各种通用机器学习框架的出现,机器学习的门槛也在逐渐降低,训练一个简单的机器学习模型变得不那么难。但是不得不承认对于一些数据相关的工...

鼠须管输入法rime for Mac

鼠须管输入法forMac是一款十分新颖的跨平台输入法软件,全名是中州韵输入法引擎,鼠须管输入法mac版不仅仅是一个输入法,而是一个输入法算法框架。Rime的基础架构十分精良,一套算法支持了拼音、...

Go语言 1.20 版本正式发布:新版详细介绍

Go1.20简介最新的Go版本1.20在Go1.19发布六个月后发布。它的大部分更改都在工具链、运行时和库的实现中。一如既往,该版本保持了Go1的兼容性承诺。我们期望几乎所...

iOS 10平台SpriteKit新特性之Tile Maps(上)

简介苹果公司在WWDC2016大会上向人们展示了一大批新的好东西。其中之一就是SpriteKitTileEditor。这款工具易于上手,而且看起来速度特别快。在本教程中,你将了解关于TileE...

程序员简历例句—范例Java、Python、C++模板

个人简介通用简介:有良好的代码风格,通过添加注释提高代码可读性,注重代码质量,研读过XXX,XXX等多个开源项目源码从而学习增强代码的健壮性与扩展性。具备良好的代码编程习惯及文档编写能力,参与多个高...

Telerik UI for iOS Q3 2015正式发布

近日,TelerikUIforiOS正式发布了Q32015。新版本新增对XCode7、Swift2.0和iOS9的支持,同时还新增了对数轴、不连续的日期时间轴等;改进TKDataPoin...

ios使用ijkplayer+nginx进行视频直播

上两节,我们讲到使用nginx和ngixn的rtmp模块搭建直播的服务器,接着我们讲解了在Android使用ijkplayer来作为我们的视频直播播放器,整个过程中,需要注意的就是ijlplayer编...

IOS技术分享|iOS快速生成开发文档(一)

前言对于开发人员而言,文档的作用不言而喻。文档不仅可以提高软件开发效率,还能便于以后的软件开发、使用和维护。本文主要讲述Objective-C快速生成开发文档工具appledoc。简介apple...

macOS下配置VS Code C++开发环境

本文介绍在苹果macOS操作系统下,配置VisualStudioCode的C/C++开发环境的过程,本环境使用Clang/LLVM编译器和调试器。一、前置条件本文默认前置条件是,您的开发设备已...