一、数据集加载

1
2
3
dataset_train = pmr.datasets(args.dataset, args.data_dir, "train2017", train=True)  # 导入训练数据集
indices = torch.randperm(len(dataset_train)).tolist() # 生成一个随机的索引列表
d_train = torch.utils.data.Subset(dataset_train, indices) # 根据随机索引列表生成训练集的随机分布子集

pmr.datasets函数是一个中间函数,根据输入的数据集类型不同,调用不同的数据集类来加载数据,以coco数据集为例。

数据集初始时先加载注释文件,使用COCO工具包将注释数据读取出来。

然后对数据进行检查,根据checked_.txt文件内容判断当前数据是否已经被检查过,如果未被检查,则调用GeneralizedDataset类中的check_dataset函数,检查图像的宽高比是否符合要求。

二、网络模型

网络模型定义主程序如下:

1
model = pmr.maskrcnn_resnet50(True, num_classes).to(device)

2.1 基础网络:ResNet50

1
backbone = ResBackbone('resnet50', pretrained_backbone)
  1. 实例化ResNet网络,使用指定的ResNet类型和预训练权重;
  2. 冻结ResNet网络中除layer2、layer3和layer4以外的所有层的参数,使这些参数在训练时不更新;
  3. 将ResNet网络中前8层作为子模块加入到ModuleDict中,以便后续网络中使用;
  4. 定义一个特征金字塔,将输入特征图通道从2048转换为256,并对特征金字塔的两个卷积层权重和偏置初始化。

前向传播时,将输入的特征图依次通过 ResNet 的前8层卷积层,然后经过两个卷积层进行特征金字塔的处理,最终输出处理后的特征图

2.2 MaskRCNN网络本体

1
model = MaskRCNN(backbone, num_classes)