环境路径:
configs/swin/mask-rcnn_swin-t-p4-w7_fpn_1x_coco.py
复用 base 文件中的变量
如果用户希望在当前配置中复用 base 文件中的变量,则可以通过使用 {{base.xxx}} 的方式来获取对应变量的拷贝
_base_ = './mask-rcnn_r50_fpn_1x_coco.py'
a = {{_base_.model}} # 变量 a 等于 _base_ 中定义的 model
通过使用字典来覆盖的方式,将里面的数据集改成 visdrone
官方文档
# 新配置继承了基本配置,并做了必要的修改
_base_ = '../mask_rcnn/mask-rcnn_r50-caffe_fpn_ms-poly-1x_coco.py'
# 我们还需要更改 head 中的 num_classes 以匹配数据集中的类别数
model = dict(
roi_head=dict(
bbox_head=dict(num_classes=1), mask_head=dict(num_classes=1)))
# 修改数据集相关配置
data_root = 'data/balloon/'
metainfo = {
'classes': ('balloon', ),
'palette': [
(220, 20, 60),
]
}
train_dataloader = dict(
batch_size=1,
dataset=dict(
data_root=data_root,
metainfo=metainfo,
ann_file='train/annotation_coco.json',
data_prefix=dict(img='train/')))
val_dataloader = dict(
dataset=dict(
data_root=data_root,
metainfo=metainfo,
ann_file='val/annotation_coco.json',
data_prefix=dict(img='val/')))
test_dataloader = val_dataloader
# 修改评价指标相关配置
val_evaluator = dict(ann_file=data_root + 'val/annotation_coco.json')
test_evaluator = val_evaluator
# 使用预训练的 Mask R-CNN 模型权重来做初始化,可以提高模型性能
load_from = 'https://download.openmmlab.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth'
- num_classes 以匹配数据集中的类别数
vidrone 数据集中有10个类
model = dict(
bbox_head=dict(num_classes=10))
- 修改数据集相关配置
data_root = '/home/dataset/dataset_visdrone/'
metainfo = {
'classes': ('pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor'),}
- 路径
train_dataloader = dict(
batch_size=8,
num_workers=8,
dataset=dict(
data_root=data_root,
metainfo=metainfo,
ann_file='VisDrone2019-DET-train/annotations/train.json',
data_prefix=dict(img='VisDrone2019-DET-train/images/')))
val_dataloader = dict(
batch_size=8,
num_workers=8,
dataset=dict(
data_root=data_root,
metainfo=metainfo,
ann_file='VisDrone2019-DET-val/annotations/val.json',
data_prefix=dict(img='VisDrone2019-DET-val/images/')))
test_dataloader = dict(
batch_size=8,
num_workers=8,
dataset=dict(
data_root=data_root,
metainfo=metainfo,
ann_file='VisDrone2019-DET-test-dev/annotations/test.json',
data_prefix=dict(img='VisDrone2019-DET-test-dev/images/')))
- 修改评价指标相关配置
val_evaluator = dict(ann_file=data_root + 'VisDrone2019-DET-val/annotations/val.json')
test_evaluator = dict(ann_file=data_root + 'VisDrone2019-DET-test-dev/annotations/test.json')
- 预训练模型
load_from = ‘https://download.openmmlab.com/mmdetection/v2.0/swin/mask_rcnn_swin-t-p4-w7_fpn_1x_coco/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth’
全部代码
_base_ = './mask-rcnn_swin-t-p4-w7_fpn_1x_coco.py'
# 我们还需要更改 head 中的 num_classes 以匹配数据集中的类别数
model = dict(
bbox_head=dict(
num_classes=10
)
)
# 修改数据集相关配置
data_root = '/home/dataset/dataset_visdrone/'
metainfo = {
'classes': ('pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor'),}
train_dataloader = dict(
batch_size=8,
num_workers=8,
dataset=dict(
data_root=data_root,
metainfo=metainfo,
ann_file='VisDrone2019-DET-train/annotations/train.json',
data_prefix=dict(img='VisDrone2019-DET-train/images/')))
val_dataloader = dict(
batch_size=8,
num_workers=8,
dataset=dict(
data_root=data_root,
metainfo=metainfo,
ann_file='VisDrone2019-DET-val/annotations/val.json',
data_prefix=dict(img='VisDrone2019-DET-val/images/')))
test_dataloader = dict(
batch_size=8,
num_workers=8,
dataset=dict(
data_root=data_root,
metainfo=metainfo,
ann_file='VisDrone2019-DET-test-dev/annotations/test.json',
data_prefix=dict(img='VisDrone2019-DET-test-dev/images/')))
# 修改评价指标相关配置
val_evaluator = dict(ann_file=data_root + 'VisDrone2019-DET-val/annotations/val.json')
test_evaluator = dict(ann_file=data_root + 'VisDrone2019-DET-test-dev/annotations/test.json')
load_from = ‘https://download.openmmlab.com/mmdetection/v2.0/swin/mask_rcnn_swin-t-p4-w7_fpn_1x_coco/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth’
详细参考下面连接
https://mmdetection.readthedocs.io/zh-cn/latest/user_guides/train.html#id7
训练代码
python tools/train.py <你的改写的py文件>