百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 热门文章 > 正文

几行代码构建全功能对象检测模型,这位杜克大学学生做到了

bigegpt 2024-08-04 11:19 16 浏览

作者 | Alan Bi

译者 | 武明利 责编 | Carol

出品 | AI科技大本营(ID:rgznai100)

如今,机器学习和计算机视觉已成为一种热潮。我们都看过关于自动驾驶汽车和面部识别的新闻,可能会想象建立自己的计算机视觉模型有多酷。然而,进入这个领域并不总是那么容易,尤其是在没有很强的数学背景的情况下。如果你只想做一些小的实验,像PyTorch和TensorFlow这样的库可能会很枯燥。

在本教程中,作者提供了一种简单的方法,任何人都可以使用几行代码构建全功能的对象检测模型。更具体地说,我们将使用Detecto,这是一个在PyTorch之上构建的Python软件包,可简化该过程并向所有级别的程序员开放。

快速简单的例子

为了演示如何简单地使Detecto,让我们加载一个预先训练的模型,并对以下图像进行推断:

首先,使用pip下载Detecto软件包:

pip3 install detecto

然后,将上面的图像另存为“fruit.jpg”,并在与图像相同的文件夹中创建一个Python文件。在Python文件中,编写以下5行代码:

from detectoimport core, utils, visualize

image = utils.read_image('fruit.jpg')

model = core.Model

labels, boxes, scores = model.predict_top(image)

visualize.show_labeled_image(image, boxes, labels)

运行此文件后(如果你的计算机上没有启用CUDA的GPU,可能会花费几秒钟;稍后再进行介绍),你应该会看到类似下面的图:

作者仅用了5行代码就完成了所有工作,真的是太棒了。下面是我们每步中分别做的:

1)导入Detecto模块

2)读入图像

3)初始化预训练模型

4)在图像上生成最高预测

5)为预测绘图

绘制我们的预测

Detecto使用来自PyTorch模型动物园中的Faster R-CNN ResNet-50 FPN,它能够检测大约80种不同的物体,例如动物,车辆,厨房用具等。但是,如果你想要检测自定义对象,例如可口可乐与百事可乐罐,斑马与长颈鹿,该怎么办呢?

这时你会发现,在自定义数据集上训练探测器模型同样简单; 同样,你只需要5行代码,以及现有的数据集或花一些时间标记图像。

构建自定义数据集

在本教程中,作者将从头开始构建自己的数据集。建议你也这样做,但是如果你想跳过这一步,你可以在这里下载一个示例数据集(从斯坦福的Dog数据集修改)。

对于我们的数据集,我们将训练我们的模型来检测来自RoboSub竞赛的水下外星人,蝙蝠和女巫,如下所示:

理想情况下,每个类至少需要100张图像。好在每张图像中可以有多个对象,所以理论上,如果每张图像包含你想要检测的每类对象,那么你可以总共获得100张图像。另外,如果你有视频素材,Detico可以轻松地将这些视频素材分割成可用于数据集的图像:

from detecto.utilsimport split_video

split_video('video.mp4','frames/', step_size=4)

上面的代码在“video.mp4”中每第4帧拍摄一次,并将其另存为JPEG文件存在“frames”文件夹中。

生成训练数据集后,应该具有一个类似于以下内容的文件夹:

images/

| image0.jpg

| image1.jpg

| image2.jpg

| ...

如果需要的话,你还可以使用另一个文件夹,其中包含一组验证图像。

现在是耗时的部分:标记。Detecto支持PASCAL VOC格式,其中具有XML文件,其中包含图像中每个对象的标签和位置数据。要创建这些XML文件,可以使用开源LabelImg工具,如下所示:

pip3 install labelImg # Download LabelImg using pip

labelImg # Launch the application

现在,你应该会看到一个弹出窗口。单击左侧“打开目录”按钮,然后选择想要标记的图像文件夹。如果一切正常,你应该会看到类似以下内容:

要绘制边界框,请单击左侧菜单栏中的图标(或使用键盘快捷键“w”)。然后,你可以在对象周围拖动一个框并编写/选择标签:

标记完图像后,请使用CTRL+S或CMD+S保存XML文件(为简便起见,你可以使用自动填充的默认文件位置和名称)。要标记下一张图像,请单击“下一张图像”(或使用键盘快捷键“d”)。

整个数据集处理完毕之后,你的文件夹应如下所示:

images/

| image0.jpg

| image0.xml

| image1.jpg

| image1.xml

| ...

我们已经准备好开始训练我们的对象检测模型了!

访问GPU

首先,检查你的计算机是否具有启用CUDA的GPU。由于深度学习需要大量处理能力,因此在通常的CPU上进行训练可能会非常缓慢。值得庆幸的是,大多数现代深度学习框架(例如PyTorch和Tensorflow)都可以在GPU上运行,从而使处理速度更快。确保已经下载了PyTorch(如果你安装了Detecto,应该已经下载了),然后运行以下两行代码:

import torch

print(torch.cuda.is_available)

如果打印True,那你可以跳到下一部分。如果显示False,不要担心。请按照以下步骤创建Google Colaboratory笔记本,这是一个在线编码环境,带有免费可用的GPU。对于本教程,你将只在Google Drive文件夹中工作,而不是在计算机上工作。

1)登录到Google Drive

2)创建一个名为“Detecto Tutorial”的文件夹并导航到该文件夹

3)将你的训练图像(和/或验证图像)上传到此文件夹

4)右键单击,转到“更多”,然后单击“Google Colaboratory”:

你现在应该看到这样的界面:

5)根据需要给笔记本起个名字,然后转到“编辑”->“笔记本设置”->“硬件加速器”,然后选择“GPU”

6)输入以下代码以“装入”你的云端硬盘,将目录更改为当前文件夹,然后安装Detecto:

import os

from google.colabimport drive

drive.mount('/content/drive')

os.chdir('/content/drive/My Drive/Detecto Tutorial')

!pip install detecto

为了确保一切正常,你可以创建一个新的代码单元,然后输入!ls以检查你是否处于正确的目录中。

训练自定义模型

最后,我们现在可以在自定义数据集上训练模型了。如前所述,这是容易的部分。它只需要4行代码:

from detectoimport core, utils, visualize

dataset = core.Dataset('images/')

model = core.Model(['alien','bat','witch'])

model.fit(dataset)

让我们再次分解一下我们每行代码所做的工作:

1、导入的Detecto模块

2、从“images”文件夹(包含我们的JPEG和XML文件)创建了一个数据集

3、初始化模型检测自定义对象(外星人,蝙蝠和女巫)

4、在数据集上训练我们的模型

根据数据集的大小,这可能需要10分钟到1个小时以上的时间来运行,因此请确保你的程序在完成上述语句后不会立即退出(例如:你使用的是Jupyter / Colab笔记本,它在活动时保留状态)。

使用训练好的模型

现在你已经有了训练好的模型,让我们在一些图像上对其进行测试。要从文件路径读取图像,可以使用detecto.utils模块中的read_image函数(也可以使用上面创建的数据集中的图像):

# Specify the path to your image

image = utils.read_image('images/image0.jpg')

predictions = model.predict(image)

# predictions format: (labels, boxes, scores)

labels, boxes, scores = predictions

# ['alien', 'bat', 'bat']

print(labels)

# xmin ymin xmax ymax

# tensor([[ 569.2125, 203.6702, 1003.4383, 658.1044],

# [ 276.2478, 144.0074, 579.6044, 508.7444],

# [ 277.2929, 162.6719, 627.9399, 511.9841]])

print(boxes)

# tensor([0.9952, 0.9837, 0.5153])

print(scores)

正像你看到的,模型的预测方法返回一个由3个元素组成的元组:标签,方框和分数。

在上面的示例中,此模型在坐标[569、204、1003、658](框[0])处,

预测了一个外星人(标签[0]),其置信度为0.995(得分[0])。

根据这些预测,我们可以使用detecto.visualize模块绘制结果。例如:

visualize.show_labeled_image(image, boxes, labels)

将上面的代码与收到的图像和预测一起运行将产生如下所示的内容:

如果你有一个视频,你可以在它上面运行对象检测:

visualize.detect_video(model,'input.mp4','output.avi')

这将获取一个名为“input.mp4”的视频文件,并根据给定模型的预测结果生成一个“output.avi”文件。如果你使用VLC或其他视频播放器打开此文件,应该会看到一些希望看到的结果!

最后,你可以从文件中保存和加载模型,从而可以保存进度并稍后返回:

model.save('model_weights.pth')

# ... Later ...

model = core.Model.load('model_weights.pth', ['alien','bat','witch'])

高级用法

你会发现Detecto不仅限于5行代码。举例来说,这个模型没有你希望的那么好。我们可以尝试通过使用Torchvision转换来扩展我们的数据集并定义一个自定义数据加载器来提高其性能:

from torchvisionimport transforms

augmentations = transforms.Compose([

transforms.ToPILImage(),

transforms.RandomHorizontalFlip(0.5),

transforms.ColorJitter(saturation=0.5),

transforms.ToTensor(),

utils.normalize_transform(),

])

dataset = core.Dataset('images/', transform=augmentations)

loader = core.DataLoader(dataset, batch_size=2, shuffle=True)

此代码对数据集中的图像应用了随机的水平翻转和饱和效果,从而增加了数据的多样性。然后,我们使用batch_size = 2定义一个数据加载对象;我们将其传递给model.fit而不是Dataset,这样来告诉我们的模型是对2张图像进行批量训练,而不是默认的1张。

如果你之前创建了单独的验证数据集,那么现在是在训练期间加载它的时候了。通过提供验证数据集,fit方法将返回每个时期的损失列表,如果verbose = True,则会在训练过程中将其打印出来。以下代码块演示了这一点,并自定义了其他几个训练参数:

import matplotlib.pyplotas plt

val_dataset = core.Dataset('validation_images/')

losses = model.fit(loader, val_dataset, epochs=10, learning_rate=0.001,

lr_step_size=5, verbose=True)

plt.plot(losses)

plt.show

损失的结果图应或多或少地减少:

为了更具有灵活性和对模型的控制,你可以完全绕过Detecto。你可以根据需要随意调整model.get_internal_model方法返回使用的基础模型。

结论

在本教程中,作者展示了计算机视觉和对象检测不需要具有挑战性。你所需要的是一点时间和耐心来处理标记的数集。

如果你对进一步探索感兴趣的话,请查看Detecto on GitHub或访问文档以获取更多教程和用例!

原文:https://hackernoon.com/build-a-custom-trained-object-detection-model-with-5-lines-of-code-y08n33vi

本文为 CSDN 翻译,转载请注明来源出处。

相关推荐

当Frida来“敲”门(frida是什么)

0x1渗透测试瓶颈目前,碰到越来越多的大客户都会将核心资产业务集中在统一的APP上,或者对自己比较重要的APP,如自己的主业务,办公APP进行加壳,流量加密,投入了很多精力在移动端的防护上。而现在挖...

服务端性能测试实战3-性能测试脚本开发

前言在前面的两篇文章中,我们分别介绍了性能测试的理论知识以及性能测试计划制定,本篇文章将重点介绍性能测试脚本开发。脚本开发将分为两个阶段:阶段一:了解各个接口的入参、出参,使用Python代码模拟前端...

Springboot整合Apache Ftpserver拓展功能及业务讲解(三)

今日分享每天分享技术实战干货,技术在于积累和收藏,希望可以帮助到您,同时也希望获得您的支持和关注。架构开源地址:https://gitee.com/msxyspringboot整合Ftpserver参...

Linux和Windows下:Python Crypto模块安装方式区别

一、Linux环境下:fromCrypto.SignatureimportPKCS1_v1_5如果导包报错:ImportError:Nomodulenamed'Crypt...

Python 3 加密简介(python des加密解密)

Python3的标准库中是没多少用来解决加密的,不过却有用于处理哈希的库。在这里我们会对其进行一个简单的介绍,但重点会放在两个第三方的软件包:PyCrypto和cryptography上,我...

怎样从零开始编译一个魔兽世界开源服务端Windows

第二章:编译和安装我是艾西,上期我们讲述到编译一个魔兽世界开源服务端环境准备,那么今天跟大家聊聊怎么编译和安装我们直接进入正题(上一章没有看到的小伙伴可以点我主页查看)编译服务端:在D盘新建一个文件夹...

附1-Conda部署安装及基本使用(conda安装教程)

Windows环境安装安装介质下载下载地址:https://www.anaconda.com/products/individual安装Anaconda安装时,选择自定义安装,选择自定义安装路径:配置...

如何配置全世界最小的 MySQL 服务器

配置全世界最小的MySQL服务器——如何在一块IntelEdison为控制板上安装一个MySQL服务器。介绍在我最近的一篇博文中,物联网,消息以及MySQL,我展示了如果Partic...

如何使用Github Action来自动化编译PolarDB-PG数据库

随着PolarDB在国产数据库领域荣膺桂冠并持续获得广泛认可,越来越多的学生和技术爱好者开始关注并涉足这款由阿里巴巴集团倾力打造且性能卓越的关系型云原生数据库。有很多同学想要上手尝试,却卡在了编译数据...

面向NDK开发者的Android 7.0变更(ndk android.mk)

订阅Google官方微信公众号:谷歌开发者。与谷歌一起创造未来!受Android平台其他改进的影响,为了方便加载本机代码,AndroidM和N中的动态链接器对编写整洁且跨平台兼容的本机...

信创改造--人大金仓(Kingbase)数据库安装、备份恢复的问题纪要

问题一:在安装KingbaseES时,安装用户对于安装路径需有“读”、“写”、“执行”的权限。在Linux系统中,需要以非root用户执行安装程序,且该用户要有标准的home目录,您可...

OpenSSH 安全漏洞,修补操作一手掌握

1.漏洞概述近日,国家信息安全漏洞库(CNNVD)收到关于OpenSSH安全漏洞(CNNVD-202407-017、CVE-2024-6387)情况的报送。攻击者可以利用该漏洞在无需认证的情况下,通...

Linux:lsof命令详解(linux lsof命令详解)

介绍欢迎来到这篇博客。在这篇博客中,我们将学习Unix/Linux系统上的lsof命令行工具。命令行工具是您使用CLI(命令行界面)而不是GUI(图形用户界面)运行的程序或工具。lsoflsof代表&...

幻隐说固态第一期:固态硬盘接口类别

前排声明所有信息来源于网络收集,如有错误请评论区指出更正。废话不多说,目前固态硬盘接口按速度由慢到快分有这几类:SATA、mSATA、SATAExpress、PCI-E、m.2、u.2。下面我们来...

新品轰炸 影驰SSD多款产品登Computex

分享泡泡网SSD固态硬盘频道6月6日台北电脑展作为全球第二、亚洲最大的3C/IT产业链专业展,吸引了众多IT厂商和全球各地媒体的热烈关注,全球存储新势力—影驰,也积极参与其中,为广大玩家朋友带来了...