首先找到包含Trainer
这个类的train.py
文件,然后在这个类前面定义数据集:
#引入以下注释 from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.data.datasets.coco import load_coco_json import pycocotools #声明类别,尽量保持 CLASS_NAMES =["__background__","name_1","name_2"...] # 数据集路径 DATASET_ROOT = '/home/Yourdatadir' ANN_ROOT = os.path.join(DATASET_ROOT, 'COCOformat') TRAIN_PATH = os.path.join(DATASET_ROOT, 'JPEGImages') VAL_PATH = os.path.join(DATASET_ROOT, 'JPEGImages') TRAIN_JSON = os.path.join(ANN_ROOT, 'train.json') #VAL_JSON = os.path.join(ANN_ROOT, 'val.json') VAL_JSON = os.path.join(ANN_ROOT, 'test.json') def plain_register_dataset(): #训练集 DatasetCatalog.register("coco_my_train", lambda: load_coco_json(TRAIN_JSON, TRAIN_PATH)) MetadataCatalog.get("coco_my_train").set(thing_classes=CLASS_NAMES, # 可以选择开启,但是不能显示中文,这里需要注意,中文的话最好关闭 evaluator_type='coco', # 指定评估方式 json_file=TRAIN_JSON, image_root=TRAIN_PATH) #DatasetCatalog.register("coco_my_val", lambda: load_coco_json(VAL_JSON, VAL_PATH, "coco_2017_val")) #验证/测试集 DatasetCatalog.register("coco_my_val", lambda: load_coco_json(VAL_JSON, VAL_PATH)) MetadataCatalog.get("coco_my_val").set(thing_classes=CLASS_NAMES, # 可以选择开启,但是不能显示中文,这里需要注意,中文的话最好关闭 evaluator_type='coco', # 指定评估方式 json_file=VAL_JSON, image_root=VAL_PATH)
这里的__background__
记得留着。
上面就把数据集定义好了,接下来要设置后面训练的数据集。找到该文件中的setup
方法,加入:
cfg.DATASETS.TRAIN = ("coco_my_train",) # 训练数据集名称 cfg.DATASETS.TEST = ("coco_my_val",) cfg.MODEL.RETINANET.NUM_CLASSES = 81
需要注意的是,这三行需要在cfg.freeze()
这句之前加入。
最后,在开始训练之前,执行上面我们定义的plain_register_dataset
方法。
def start_train(args): plain_register_dataset() cfg = setup(args)
这样就可以了。
本文最后更新于2022年4月5日,已超过 1 年没有更新,如果文章内容或图片资源失效,请留言反馈,我们会及时处理,谢谢!