首先找到包含Trainer这个类的train.py文件,然后在这个类前面定义数据集:
| 
					 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33  | 
						#引入以下注释 from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.data.datasets.coco import load_coco_json import pycocotools #声明类别,尽量保持 CLASS_NAMES =["__background__","name_1","name_2"...] # 数据集路径 DATASET_ROOT = '/home/Yourdatadir' ANN_ROOT = os.path.join(DATASET_ROOT, 'COCOformat') TRAIN_PATH = os.path.join(DATASET_ROOT, 'JPEGImages') VAL_PATH = os.path.join(DATASET_ROOT, 'JPEGImages') TRAIN_JSON = os.path.join(ANN_ROOT, 'train.json') #VAL_JSON = os.path.join(ANN_ROOT, 'val.json') VAL_JSON = os.path.join(ANN_ROOT, 'test.json') def plain_register_dataset():     #训练集     DatasetCatalog.register("coco_my_train", lambda: load_coco_json(TRAIN_JSON, TRAIN_PATH))     MetadataCatalog.get("coco_my_train").set(thing_classes=CLASS_NAMES,  # 可以选择开启,但是不能显示中文,这里需要注意,中文的话最好关闭                                                     evaluator_type='coco', # 指定评估方式                                                     json_file=TRAIN_JSON,                                                     image_root=TRAIN_PATH)     #DatasetCatalog.register("coco_my_val", lambda: load_coco_json(VAL_JSON, VAL_PATH, "coco_2017_val"))     #验证/测试集     DatasetCatalog.register("coco_my_val", lambda: load_coco_json(VAL_JSON, VAL_PATH))     MetadataCatalog.get("coco_my_val").set(thing_classes=CLASS_NAMES, # 可以选择开启,但是不能显示中文,这里需要注意,中文的话最好关闭                                                 evaluator_type='coco', # 指定评估方式                                                 json_file=VAL_JSON,                                                 image_root=VAL_PATH)  | 
					
这里的__background__记得留着。
上面就把数据集定义好了,接下来要设置后面训练的数据集。找到该文件中的setup方法,加入:
| 
					 1 2 3  | 
						cfg.DATASETS.TRAIN = ("coco_my_train",) # 训练数据集名称 cfg.DATASETS.TEST = ("coco_my_val",) cfg.MODEL.RETINANET.NUM_CLASSES = 81  | 
					
需要注意的是,这三行需要在cfg.freeze()这句之前加入。
最后,在开始训练之前,执行上面我们定义的plain_register_dataset方法。
| 
					 1 2 3 4  | 
						def start_train(args):     plain_register_dataset()     cfg = setup(args)  | 
					
这样就可以了。
本文最后更新于2022年4月5日,已超过 1 年没有更新,如果文章内容或图片资源失效,请留言反馈,我们会及时处理,谢谢!
		
马春杰杰