faster rcnn 源码(1)——GeneralizedRCNN

faster rcnn 源码(1)——GeneralizedRCNN

faster rcnn 继承于GeneralizedRCNN

关于 GeneralizedRCNN 类,其间有4个重要的接口:

  1. transform : 首要是标准化和把图片缩放到固定大小,后续说明
  2. backbone :一般是VGG、ResNet、MobileNet 等网络
  3. rpn:经过rpn生成proposals 和 proposal_losses
  4. roi_heads:roi pooling + 分类
class GeneralizedRCNN(nn.Module):
    """
    Main class for Generalized R-CNN.
    Args:
        backbone (nn.Module):
        rpn (nn.Module):
        roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
            detections / masks from it.
        transform (nn.Module): performs the data transformation from the inputs to feed into
            the model
    """
    def __init__(self, backbone, rpn, roi_heads, transform):
        super(GeneralizedRCNN, self).__init__()
        self.transform = transform
        self.backbone = backbone
        self.rpn = rpn
        self.roi_heads = roi_head

前向传播

def forward(self, images, targets=None):
    # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
    """
    Args:
        images (list[Tensor]): images to be processed
        targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
    Returns:
        result (list[BoxList] or dict[Tensor]): the output from the model.
            During training, it returns a dict[Tensor] which contains the losses.
            During testing, it returns list[BoxList] contains additional fields
            like `scores`, `labels` and `mask` (for Mask R-CNN models).
    """

images 参数:(list[Tensor(C,H,W)*Batch_size])

targets 可选参数:传递gt,是一个列表,列表中的每个元素都是一个字典,字典中包含了与图画中的真实方针相关的信息。"boxes":一个张量(Tensor),包含了真实方针框的坐标信息。通常是一个形状为 [N, 4] 的张量,其间 N 是方针框的数量,每行表明一个方针框的坐标信息,通常是左上角和右下角的坐标。 其他键值对:可能还包含其他与方针相关的信息,比如类别标签、切割掩码等。

输出成果 :(list[BoxList] or dict[Tensor])

在练习过程中,模型回来一个字典,其间包含了损失信息, Dict[str, Tensor],键是各种损失名称 loss_name,值是损失张量。

在测验过程中,模型回来一个列表,其间包含了检测成果的信息。每个元素是一个字典,表明一张图画的检测成果。这个字典包的key为:检测框的置信度 scores、类别标签 labels、切割掩码 mask 等。因而,每个元素的类型是 List[Dict[str, Tensor]]

original_image_sizes: List[Tuple[int, int]] = []
for img in images:
    val = img.shape[-2:]
    assert len(val) == 2
    original_image_sizes.append((val[0], val[1])) # 记载改换前original_images_sizes
images, targets = self.transform(images, targets)
# transfrom的界说为class GeneralizedRCNNTransform(nn.Module),对images和target都进行resize等操作 
# 这儿transform回来的images是ImageList类型(Tensors:tensor,image_sizes:List)

这儿 transform 首要包含标准化和将图画缩放到固定大小

需求说明的是,把缩放后的图画输入网络,那么网络输出的检测框也是在缩放后的图画上的。但是实践中咱们需求的是在原始图画的检测框,为了对应起来,所以需求记载改换前original_images_sizes。

进入首要网络流程

faster rcnn 源码(1)——GeneralizedRCNN

features = self.backbone(images.tensors)

将transform 后的图画进入backbone(一般包含VGG,ResNet,MobileNet等网络) 提取特征

if isinstance(features, torch.Tensor):
    features = OrderedDict([('0', features)])

类型查看和转化,假如 featurestorch.Tensor 类型的实例,将 features 变量转换为一个字典类型的目标,其间包含一个键值对,键是字符串 '0',值是原始的 features 变量。

proposals, proposal_losses = self.rpn(images, features, targets)

经过 rpn 模块生成proposals和 proposal_losses

proposals 是 rpn 生成的 bbox ,type:List[tensor(n,4)*Batch_size],并且按置信度降序排序 proposal_losses 是 rpn 阶段的loss,包含置信度 loss 和 bbox回归 loss

roi_heads包含roi pooling + 分类

detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)

接着进入 roi_heads 模块,对候选区域(proposals)进行进一步处理,以发生终究的检测成果。

detctions : 每个元素代表一张输入图画进 roi_heads 处理后的检测成果。每个检测成果通常包含以下信息boxes,labels,scores.回来List[dict[Tensor]*batch_size],key为boxes,labels,scores

detector_losses :roi_head阶段的loss包含:class分类 loss 和 bbox回归 loss

detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

后续处理,经 postprocess 模块(进行 NMS,同时将 box 经过 original_images_size映射回原图,即transform阶段Resize的逆操作)

losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
if torch.jit.is_scripting():
    if not self._has_warned:
        warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
        self._has_warned = True
    return losses, detections
else:
    return self.eager_outputs(losses, detections)

依据依据train和test阶段不同回来不同值,练习阶段回来 losses,测验阶段回来 detections