DETR架构的内部工作方式分析

这是一个facebook的目标检测transformer (detr)的完整指南。
介绍
detection transformer (detr)是facebook研究团队巧妙地利用了transformer 架构开发的一个目标检测模型。在这篇文章中,我将通过分析detr架构的内部工作方式来帮助提供一些关于它的直觉。
下面,我将解释一些结构,但是如果你只是想了解如何使用模型,可以直接跳到代码部分。
结构
detr模型由一个预训练的cnn骨干(如resnet)组成,它产生一组低维特征集。这些特征被格式化为一个特征集合并添加位置编码,输入一个由transformer组成的编码器和解码器中,和原始的transformer论文中描述的encoder-decoder的使用方式非常的类似。解码器的输出然后被送入固定数量的预测头,这些预测头由预定义数量的前馈网络组成。每个预测头的输出都包含一个类预测和一个预测框。损失是通过计算二分匹配损失来计算的。
该模型做出了预定义数量的预测,并且每个预测都是并行计算的。
cnn主干
假设我们的输入图像,有三个输入通道。cnn backbone由一个(预训练过的)cnn(通常是resnet)组成,我们用它来生成c个具有宽度w和高度h的低维特征(在实践中,我们设置c=2048, w=w₀/32和h=h₀/32)。
这留给我们的是c个二维特征,由于我们将把这些特征传递给一个transformer,每个特征必须允许编码器将每个特征处理为一个序列的方式重新格式化。这是通过将特征矩阵扁平化为h⋅w向量,然后将每个向量连接起来来实现的。
扁平化的卷积特征再加上空间位置编码,位置编码既可以学习,也可以预定义。
the transformer
transformer几乎与原始的编码器-解码器架构完全相同。不同之处在于,每个解码器层并行解码n个(预定义的数目)目标。该模型还学习了一组n个目标的查询,这些查询是(类似于编码器)学习出来的位置编码。
目标查询
下图描述了n=20个学习出来的目标查询(称为prediction slots)如何聚焦于一张图像的不同区域。
“我们观察到,在不同的操作模式下,每个slot 都会学习特定的区域和框大小。“ —— detr的作者
理解目标查询的直观方法是想象每个目标查询都是一个人。每个人都可以通过注意力来查看图像的某个区域。一个目标查询总是会问图像中心是什么,另一个总是会问左下角是什么,以此类推。
使用pytorch实现简单的detr
import torchimport torch.nn as nnfrom torchvision.models import resnet50class simpledetr(nn.module):minimal example of the detection transformer model with learned positional embedding def __init__(self, num_classes, hidden_dim, num_heads,             num_enc_layers, num_dec_layers):    super(simpledetr, self).__init__()    self.num_classes = num_classes    self.hidden_dim = hidden_dim    self.num_heads = num_heads    self.num_enc_layers = num_enc_layers    self.num_dec_layers = num_dec_layers    # cnn backbone    self.backbone = nn.sequential(         *list(resnet50(pretrained=true).children())[:-2])    self.conv = nn.conv2d(2048, hidden_dim, 1)    # transformer    self.transformer = nn.transformer(hidden_dim, num_heads,         num_enc_layers, num_dec_layers)    # prediction heads    self.to_classes = nn.linear(hidden_dim, num_classes+1)    self.to_bbox = nn.linear(hidden_dim, 4)    # positional encodings    self.object_query = nn.parameter(torch.rand(100, hidden_dim))    self.row_embed = nn.parameter(torch.rand(50, hidden_dim // 2)    self.col_embed = nn.parameter(torch.rand(50, hidden_dim // 2))                                   def forward(self, x):    x = self.backbone(x)    h = self.conv(x)    h, w = h.shape[-2:]    pos_enc = torch.cat([          self.col_embed[:w].unsqueeze(0).repeat(h,1,1),          self.row_embed[:h].unsqueeze(1).repeat(1,w,1)],       dim=-1).flatten(0,1).unsqueeze(1)    h = self.transformer(pos_enc + h.flatten(2).permute(2,0,1),    self.object_query.unsqueeze(1))    class_pred = self.to_classes(h)    bbox_pred = self.to_bbox(h).sigmoid()        return class_pred, bbox_pred  
二分匹配损失 (optional)
让为预测的集合,其中是包括了预测类别(可以是空类别)和包围框的二元组,其中上划线表示框的中心点, 和表示框的宽和高。
设y为ground truth集合。假设y和ŷ之间的损失为l,每一个yᵢ和ŷᵢ之间的损失为lᵢ。由于我们是在集合的层次上工作,损失l必须是排列不变的,这意味着无论我们如何排序预测,我们都将得到相同的损失。因此,我们想找到一个排列,它将预测的索引映射到ground truth目标的索引上。在数学上,我们求解:
计算的过程称为寻找最优的二元匹配。这可以用匈牙利算法找到。但为了找到最优匹配,我们需要实际定义一个损失函数,计算和之间的匹配成本。
回想一下,我们的预测包含一个边界框和一个类。现在让我们假设类预测实际上是一个类集合上的概率分布。那么第i个预测的总损失将是类预测产生的损失和边界框预测产生的损失之和。作者在http://arxiv.org/abs/1906.05909中将这种损失定义为边界框损失和类预测概率的差异:
其中,是的argmax,是是来自包围框的预测的损失,如果,则表示匹配损失为0。
框损失的计算为预测值与ground truth的l₁损失和的giou损失的线性组合。同样,如果你想象两个不相交的框,那么框的错误将不会提供任何有意义的上下文(我们可以从下面的框损失的定义中看到)。
其中,λᵢₒᵤ和是超参数。注意,这个和也是面积和距离产生的误差的组合。为什么会这样呢?
可以把上面的等式看作是与预测相关联的总损失,其中面积误差的重要性是λᵢₒᵤ,距离误差的重要性是。
现在我们来定义giou损失函数。定义如下:
由于我们从已知的已知类的数目来预测类,那么类预测就是一个分类问题,因此我们可以使用交叉熵损失来计算类预测误差。我们将损失函数定义为每n个预测损失的总和:
为目标检测使用detr
在这里,你可以学习如何加载预训练的detr模型,以便使用pytorch进行目标检测。
加载模型
首先导入需要的模块。
# import required modulesimport torchfrom torchvision import transforms as t import requests # for loading images from webfrom pil import image # for viewing imagesimport matplotlib.pyplot as plt  
下面的代码用resnet50作为cnn骨干从torch hub加载预训练的模型。其他主干请参见detr github:https://github.com/facebookresearch/detr
detr = torch.hub.load('facebookresearch/detr',                      'detr_resnet50',                       pretrained=true)  
加载一张图像
要从web加载图像,我们使用requests库:
url = 'https://www.tempetourism.com/wp-content/uploads/postino-downtown-tempe-2.jpg' # sample imageimage = image.open(requests.get(url, stream=true).raw) plt.imshow(image)plt.show()
设置目标检测的pipeline
为了将图像输入到模型中,我们需要将pil图像转换为张量,这是通过使用torchvision的transforms库来完成的。
transform = t.compose([t.resize(800),                       t.totensor(),                       t.normalize([0.485, 0.456, 0.406],                                  [0.229, 0.224, 0.225])])  
上面的变换调整了图像的大小,将pil图像进行转换,并用均值-标准差对图像进行归一化。其中[0.485,0.456,0.406]为各颜色通道的均值,[0.229,0.224,0.225]为各颜色通道的标准差。
我们装载的模型是预先在coco dataset上训练的,有91个类,还有一个表示空类(没有目标)的附加类。我们用下面的代码手动定义每个标签:
classes = ['n/a', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic-light', 'fire-hydrant', 'n/a', 'stop-sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'n/a', 'backpack', 'umbrella', 'n/a', 'n/a', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports-ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'n/a', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot-dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'n/a', 'dining table', 'n/a','n/a', 'toilet', 'n/a', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell-phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'n/a', 'book', 'clock', 'vase', 'scissors', 'teddy-bear', 'hair-dryer', 'toothbrush']  
如果我们想输出不同颜色的边框,我们可以手动定义我们想要的rgb格式的颜色
colors = [    [0.000, 0.447, 0.741],     [0.850, 0.325, 0.098],     [0.929, 0.694, 0.125],    [0.494, 0.184, 0.556],    [0.466, 0.674, 0.188],    [0.301, 0.745, 0.933]  ]  
格式化输出
我们还需要重新格式化模型的输出。给定一个转换后的图像,模型将输出一个字典,包含100个预测类的概率和100个预测边框。
每个包围框的形式为(x, y, w, h),其中(x,y)为包围框的中心(包围框是单位正方形[0,1]×[0,1]), w, h为包围框的宽度和高度。因此,我们需要将边界框输出转换为初始和最终坐标,并重新缩放框以适应图像的实际大小。
下面的函数返回边界框端点:
# get coordinates (x0, y0, x1, y0) from model output (x, y, w, h)def get_box_coords(boxes):    x, y, w, h = boxes.unbind(1)    x0, y0 = (x - 0.5 * w), (y - 0.5 * h)    x1, y1 = (x + 0.5 * w), (y + 0.5 * h)    box = [x0, y0, x1, y1]    return torch.stack(box, dim=1)  
我们还需要缩放了框的大小。下面的函数为我们做了这些:
# scale box from [0,1]x[0,1] to [0, width]x[0, height]def scale_boxes(output_box, width, height):    box_coords = get_box_coords(output_box)    scale_tensor = torch.tensor(                 [width, height, width, height]).to(                 torch.cuda.current_device())    return box_coords * scale_tensor  
现在我们需要一个函数来封装我们的目标检测pipeline。下面的detect函数为我们完成了这项工作。
# object detection pipelinedef detect(im, model, transform):    device = torch.cuda.current_device()    width = im.size[0]    height = im.size[1]       # mean-std normalize the input image (batch-size: 1)    img = transform(im).unsqueeze(0)    img = img.to(device)        # demo model only support by default images with aspect ratio    between 0.5 and 2    assert img.shape[-2] <= 1600 and img.shape[-1]  0.85       # convert boxes from [0; 1] to image scales    bboxes_scaled = scale_boxes(outputs['pred_boxes'][0, keep], width, height)    return probas[keep], bboxes_scaled  
现在,我们需要做的是运行以下程序来获得我们想要的输出:
probs, bboxes = detect(image, detr, transform)  
绘制结果
现在我们有了检测到的目标,我们可以使用一个简单的函数来可视化它们。
# plot predicted bounding boxesdef plot_results(pil_img, prob, boxes,labels=true):    plt.figure(figsize=(16,10))    plt.imshow(pil_img)    ax = plt.gca()        for prob, (x0, y0, x1, y1), color in zip(prob, boxes.tolist(),   colors * 100):        ax.add_patch(plt.rectangle((x0, y0), x1 - x0, y1 - y0,               fill=false, color=color, linewidth=2))        cl = prob.argmax()        text = f'{classes[cl]}: {prob[cl]:0.2f}'        if labels:            ax.text(x0, y0, text, fontsize=15,                bbox=dict(facecolor=color, alpha=0.75))    plt.axis('off')    plt.show()  
现在可以可视化结果:
plot_results(image, probs, bboxes, labels=true)


人工智能怎样加强人类健康管理
交通运输部:港口长期性能监测传感器国家重点研发项目启动
采样电阻在精度上与取样电阻有什么区别
智能门锁行业发展进入井喷期,距离盈利风口的到来只差时机
新能源汽车市场竞争激烈,国内新能源车企造车迫近伪窗口期?
DETR架构的内部工作方式分析
BG822CX芯片的工作原理及应用设计
一款新型模块化运行数据监测器
关于未来精密流形环的应用和发展趋势的分析
脑机接口安全两难,血管支架另辟蹊径
天合储能电芯开发路径:3减1补1智造
小米投资“芯来科技”促进国产RISC-V架构产品应用于更多的物联网设备
负反馈放大电路自激振荡相关的问题
网络适配器是什么
苹果利用拆解机器人进行设备的拆解回收
骨感耳机最好的品牌,不伤听力不堵耳朵的骨传导耳机
洪震:发表了《显示技术的新驱动力》的主题演讲
肖特基二极管的结构原理及作用
瑞萨G2L系列核心板-RZ/G2L处理器简介
DDR4设计规则及DDR4的PCB布线指南