faster rcnn 继承于GeneralizedRCNN
关于 GeneralizedRCNN 类,其间有4个重要的接口:
- transform : 首要是标准化和把图片缩放到固定大小,后续说明
- backbone :一般是VGG、ResNet、MobileNet 等网络
- rpn:经过rpn生成proposals 和 proposal_losses
- roi_heads:roi pooling + 分类
class GeneralizedRCNN(nn.Module):
"""
Main class for Generalized R-CNN.
Args:
backbone (nn.Module):
rpn (nn.Module):
roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
detections / masks from it.
transform (nn.Module): performs the data transformation from the inputs to feed into
the model
"""
def __init__(self, backbone, rpn, roi_heads, transform):
super(GeneralizedRCNN, self).__init__()
self.transform = transform
self.backbone = backbone
self.rpn = rpn
self.roi_heads = roi_head
前向传播
def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
"""
Args:
images (list[Tensor]): images to be processed
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
Returns:
result (list[BoxList] or dict[Tensor]): the output from the model.
During training, it returns a dict[Tensor] which contains the losses.
During testing, it returns list[BoxList] contains additional fields
like `scores`, `labels` and `mask` (for Mask R-CNN models).
"""
images 参数:(list[Tensor(C,H,W)*Batch_size])
targets 可选参数:传递gt,是一个列表,列表中的每个元素都是一个字典,字典中包含了与图画中的真实方针相关的信息。"boxes"
:一个张量(Tensor),包含了真实方针框的坐标信息。通常是一个形状为 [N, 4] 的张量,其间 N 是方针框的数量,每行表明一个方针框的坐标信息,通常是左上角和右下角的坐标。 其他键值对:可能还包含其他与方针相关的信息,比如类别标签、切割掩码等。
输出成果 :(list[BoxList] or dict[Tensor])
在练习过程中,模型回来一个字典,其间包含了损失信息, Dict[str, Tensor]
,键是各种损失名称
loss_name,值是损失张量。
在测验过程中,模型回来一个列表,其间包含了检测成果的信息。每个元素是一个字典,表明一张图画的检测成果。这个字典包的key为:检测框的置信度 scores
、类别标签 labels
、切割掩码 mask
等。因而,每个元素的类型是 List[Dict[str, Tensor]]
。
original_image_sizes: List[Tuple[int, int]] = []
for img in images:
val = img.shape[-2:]
assert len(val) == 2
original_image_sizes.append((val[0], val[1])) # 记载改换前original_images_sizes
images, targets = self.transform(images, targets)
# transfrom的界说为class GeneralizedRCNNTransform(nn.Module),对images和target都进行resize等操作
# 这儿transform回来的images是ImageList类型(Tensors:tensor,image_sizes:List)
这儿 transform 首要包含标准化和将图画缩放到固定大小
需求说明的是,把缩放后的图画输入网络,那么网络输出的检测框也是在缩放后的图画上的。但是实践中咱们需求的是在原始图画的检测框,为了对应起来,所以需求记载改换前original_images_sizes。
进入首要网络流程
features = self.backbone(images.tensors)
将transform 后的图画进入backbone(一般包含VGG,ResNet,MobileNet等网络) 提取特征
if isinstance(features, torch.Tensor):
features = OrderedDict([('0', features)])
类型查看和转化,假如 features
是 torch.Tensor
类型的实例,将 features
变量转换为一个字典类型的目标,其间包含一个键值对,键是字符串 '0'
,值是原始的 features
变量。
proposals, proposal_losses = self.rpn(images, features, targets)
经过 rpn 模块生成proposals和 proposal_losses
proposals 是 rpn 生成的 bbox ,type:List[tensor(n,4)*Batch_size]
,并且按置信度降序排序
proposal_losses 是 rpn 阶段的loss,包含置信度 loss 和 bbox回归 loss
roi_heads包含roi pooling + 分类
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
接着进入 roi_heads 模块,对候选区域(proposals)进行进一步处理,以发生终究的检测成果。
detctions : 每个元素代表一张输入图画进 roi_heads 处理后的检测成果。每个检测成果通常包含以下信息boxes
,labels
,scores
.回来List[dict[Tensor]*batch_size]
,key为boxes,labels,scores
detector_losses :roi_head阶段的loss包含:class分类 loss 和 bbox回归 loss
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
后续处理,经 postprocess 模块(进行 NMS,同时将 box 经过 original_images_size映射回原图,即transform阶段Resize的逆操作)
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
if torch.jit.is_scripting():
if not self._has_warned:
warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
self._has_warned = True
return losses, detections
else:
return self.eager_outputs(losses, detections)
依据依据train和test阶段不同回来不同值,练习阶段回来 losses,测验阶段回来 detections