MaskRCNN代码结构
一、数据集加载
1 | dataset_train = pmr.datasets(args.dataset, args.data_dir, "train2017", train=True) # 导入训练数据集 |
pmr.datasets
函数是一个中间函数,根据输入的数据集类型不同,调用不同的数据集类来加载数据,以coco数据集为例。
数据集初始时先加载注释文件,使用COCO工具包将注释数据读取出来。
然后对数据进行检查,根据checked_.txt文件内容判断当前数据是否已经被检查过,如果未被检查,则调用GeneralizedDataset类中的check_dataset函数,检查图像的宽高比是否符合要求。
二、网络模型
网络模型定义主程序如下:
1 | model = pmr.maskrcnn_resnet50(True, num_classes).to(device) |
2.1 基础网络:ResNet50
1 | backbone = ResBackbone('resnet50', pretrained_backbone) |
- 实例化ResNet网络,使用指定的ResNet类型和预训练权重;
- 冻结ResNet网络中除layer2、layer3和layer4以外的所有层的参数,使这些参数在训练时不更新;
- 将ResNet网络中前8层作为子模块加入到
ModuleDict
中,以便后续网络中使用; - 定义一个特征金字塔,将输入特征图通道从2048转换为256,并对特征金字塔的两个卷积层权重和偏置初始化。
前向传播时,将输入的特征图依次通过 ResNet 的前8层卷积层,然后经过两个卷积层进行特征金字塔的处理,最终输出处理后的特征图
2.2 MaskRCNN网络本体
1 | model = MaskRCNN(backbone, num_classes) |
评论
TwikooGitalk