首先找到包含Trainer
这个类的train.py
文件,然后在这个类前面定义数据集:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
#引入以下注释 from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.data.datasets.coco import load_coco_json import pycocotools #声明类别,尽量保持 CLASS_NAMES =["__background__","name_1","name_2"...] # 数据集路径 DATASET_ROOT = '/home/Yourdatadir' ANN_ROOT = os.path.join(DATASET_ROOT, 'COCOformat') TRAIN_PATH = os.path.join(DATASET_ROOT, 'JPEGImages') VAL_PATH = os.path.join(DATASET_ROOT, 'JPEGImages') TRAIN_JSON = os.path.join(ANN_ROOT, 'train.json') #VAL_JSON = os.path.join(ANN_ROOT, 'val.json') VAL_JSON = os.path.join(ANN_ROOT, 'test.json') def plain_register_dataset(): #训练集 DatasetCatalog.register("coco_my_train", lambda: load_coco_json(TRAIN_JSON, TRAIN_PATH)) MetadataCatalog.get("coco_my_train").set(thing_classes=CLASS_NAMES, # 可以选择开启,但是不能显示中文,这里需要注意,中文的话最好关闭 evaluator_type='coco', # 指定评估方式 json_file=TRAIN_JSON, image_root=TRAIN_PATH) #DatasetCatalog.register("coco_my_val", lambda: load_coco_json(VAL_JSON, VAL_PATH, "coco_2017_val")) #验证/测试集 DatasetCatalog.register("coco_my_val", lambda: load_coco_json(VAL_JSON, VAL_PATH)) MetadataCatalog.get("coco_my_val").set(thing_classes=CLASS_NAMES, # 可以选择开启,但是不能显示中文,这里需要注意,中文的话最好关闭 evaluator_type='coco', # 指定评估方式 json_file=VAL_JSON, image_root=VAL_PATH) |
这里的__background__
记得留着。
上面就把数据集定义好了,接下来要设置后面训练的数据集。找到该文件中的setup
方法,加入:
1 2 3 |
cfg.DATASETS.TRAIN = ("coco_my_train",) # 训练数据集名称 cfg.DATASETS.TEST = ("coco_my_val",) cfg.MODEL.RETINANET.NUM_CLASSES = 81 |
需要注意的是,这三行需要在cfg.freeze()
这句之前加入。
最后,在开始训练之前,执行上面我们定义的plain_register_dataset
方法。
1 2 3 4 |
def start_train(args): plain_register_dataset() cfg = setup(args) |
这样就可以了。
本文最后更新于2022年4月5日,已超过 1 年没有更新,如果文章内容或图片资源失效,请留言反馈,我们会及时处理,谢谢!
最新评论
这模板不错啊,收藏了
看看可不可用
还有macapp.org.cn macwk.cn
大佬,IOS17.1能用吗?
没安装桌面的时候就有网,安了就没有了
可是安装好了没有网啊,怎么办大佬
您好,我这边需要跟您沟通下亚马逊云科技文章合作事宜,您看可以加个微信,详聊一下嘛
感谢分享。。。