博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
mmdetection源码笔记(四):训练模型之train_detector()的解读
阅读量:3905 次
发布时间:2019-05-23

本文共 3815 字,大约阅读时间需要 12 分钟。

引言

之前在写mmdetection源码的解读过程时,觉得train_detector()这部分很重要,对于理解整个的训练过程应该时起着非常大的理解作用。

然后最近研究工作一直在看和修改mmdetection的其他模块的代码这一块。感觉train_detector()这块内容其实也不是特别重要来着,可能就是一个加强理解的过程。这次还是花了点时间,大致的看了一下,顺便加上自己的一些理解,解释了一下整个过程,如果有错的话,希望各路大佬指出,互相学习哈。

train_detector()

下面的代码出现在tools/train.py中,也是main函数的结尾,也就是说,我们训练的时候,到这就是真正的开始训练了。

train_detector(        model,        datasets,        cfg,        distributed=distributed,        validate=args.validate,        logger=logger)

那到底怎么训练的呢?

下面代码是train_detector()函数的定义,在mmdet/api/train.py文件中

def train_detector(model,                   dataset,                   cfg,                   distributed=False,                   validate=False,                   logger=None):    if logger is None:        logger = get_root_logger(cfg.log_level)    # start training    if distributed:        _dist_train(model, dataset, cfg, validate=validate)    else:        _non_dist_train(model, dataset, cfg, validate=validate)

上面的开始训练过程分分布式训练和非分布式训练两种方法,我们只说分布式训练,同样下面代码是_dist_train()的定义,也在mmdet/api/train.py中

def _dist_train(model, dataset, cfg, validate=False):    # prepare data loaders    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]    data_loaders = [        build_dataloader(            ds, cfg.data.imgs_per_gpu, cfg.data.workers_per_gpu, dist=True)        for ds in dataset    ]    # put model on gpus    model = MMDistributedDataParallel(model.cuda())    # build runner 用来为pytorch训练用的类,该类在mmcv/mmcv/runner/runner.py中    optimizer = build_optimizer(model, cfg.optimizer)    # Optimizer 是用来更新和计算影响模型训练和模型输出的网络参数,使其逼近或达到最优值,从而最小化(或最大化)损失函数E(x)    # 这种算法使用各参数的梯度值来最小化或最大化损失函数E(x)。最常用的一阶优化算法是梯度下降。    runner = Runner(model, batch_processor, optimizer, cfg.work_dir,                    cfg.log_level)    # fp16 setting   用来提速的    fp16_cfg = cfg.get('fp16', None)    if fp16_cfg is not None:        optimizer_config = Fp16OptimizerHook(**cfg.optimizer_config,                                             **fp16_cfg)    else:        optimizer_config = DistOptimizerHook(**cfg.optimizer_config)    # register hooks hooks 用来查看中间变量的    # hook的作用是,当反传时,除了完成原有的反传,额外多完成一些任务。你可以定义一个中间变量的hook,将它的grad值打印出来,当然你也可以定义一个全局列表,将每次的grad值添加到里面去。    # 下面的hooks也是一样的,具体pytorch中hooks的作用,可以参考下方链接    runner.register_training_hooks(cfg.lr_config, optimizer_config,                                   cfg.checkpoint_config, cfg.log_config)    runner.register_hook(DistSamplerSeedHook())    # register eval hooks    if validate:        val_dataset_cfg = cfg.data.val        eval_cfg = cfg.get('evaluation', {
}) if isinstance(model.module, RPN): # TODO: implement recall hooks for other datasets runner.register_hook( CocoDistEvalRecallHook(val_dataset_cfg, **eval_cfg)) else: dataset_type = DATASETS.get(val_dataset_cfg.type) if issubclass(dataset_type, datasets.CocoDataset): runner.register_hook( CocoDistEvalmAPHook(val_dataset_cfg, **eval_cfg)) else: runner.register_hook( DistEvalmAPHook(val_dataset_cfg, **eval_cfg)) if cfg.resume_from: # 从resume_from(checkpoint)重新开始训练?? # (resume_from的作用我猜的,可以自己细看这部分的代码) runner.resume(cfg.resume_from) elif cfg.load_from: # 加载 checkpoint,继续训练 runner.load_checkpoint(cfg.load_from) runner.run(data_loaders, cfg.workflow, cfg.total_epochs) # 开始训练

上面代码,还出现了一个类runner,这个类的作用呢,就是用来更好的训练pytorch模型的。

简单的说,就是用runner这个类来操控安排训练过程中的各个环节。 这个操控包括,要在module中获取中间变量啊,或者加载和保存检查点,或者启动训练、启动测试、或者初始化权重,本身这个函数是不能改变这个网络模型的各个部分的,也就是说,我们要真正修改backbone、或者FPN啊,或者分类回归的具体实现,跟这个类无关。
也就是说,你只要把你定义好的网络模型结构,加载好的数据集,你要的优化器等,扔给runner,他就会来帮你跑模型。
runner这个类定义在mmcv/mmcv/runner/runner.py中,里面好多方法,想要了解的可以自己慢慢去看。

所以train_detection()这一部分的作用,其实就是帮我们把之前设计好的网络结构,数据集等,扔给runner,然后就行了,具体怎么跑呢,不需要太转牛角尖,毕竟太黑盒了。

如果以上理解有误,请指出,互相学习哈!

转载地址:http://kdxen.baihongyu.com/

你可能感兴趣的文章
CENTOS 6.5 配置YUM安装NGINX
查看>>
#ifdef DEBUG的理解
查看>>
Linux 任务控制的几个技巧( &, [ctrl]-z, jobs, fg, bg, kill)
查看>>
FASTCGI与CGI解释器的区别,及其工作原理
查看>>
Nginx+FastCGI运行原理
查看>>
centos7 安装 桌面 desktop
查看>>
pycharm搭建python开发环境
查看>>
使用virtualenv搭建独立的Python环境
查看>>
Flask + Gunicorn + Nginx 部署
查看>>
又见KeepAlive HTTP TCP KeepAlive 区别
查看>>
linux服务器出现大量CLOSE_WAIT状态的连接
查看>>
大规模Nginx平台化实践,京东能提供哪些参考经验?
查看>>
linux下python开发环境之一——安装python
查看>>
网络错误定位案例 ICMP host *** unreachable - admin prohibited
查看>>
SaltStack使用教程(一):安装并简单配置使用
查看>>
NGINX 1.9.1 新特性:套接字端口共享
查看>>
pip win10 升级问题
查看>>
安装python 以及pip
查看>>
多版本Python共存,以及pip对应
查看>>
Windows下Anaconda的安装和简单使用
查看>>