我们通常将深度学习任务中的神经网络定义为模型,这个模型即是算法的核心。MMEngine 抽象出了一个统一模型 BaseModel 以标准化训练、测试和其他过程。MMSegmentation 实现的所有模型都继承自 BaseModel
,并且在 MMSegmention 中,我们实现了前向传播并为语义分割算法添加了一些功能。
在 MMSegmentation 中,我们将网络架构抽象为分割器,它是一个包含网络所有组件的模型。我们已经实现了编码器解码器(EncoderDecoder)和级联编码器解码器(CascadeEncoderDecoder),它们通常由数据预处理器、骨干网络、解码头和辅助头组成。
数据预处理器是将数据复制到目标设备并将数据预处理为模型输入格式的部分。
主干网络是将图像转换为特征图的部分,例如没有最后全连接层的 ResNet-50。
颈部是连接主干网络和头的部分。它对主干网络生成的原始特征图进行一些改进或重新配置。例如 Feature Pyramid Network(FPN)。
解码头是将特征图转换为分割掩膜的部分,例如 PSPNet。
辅助头是一个可选组件,它将特征图转换为仅用于计算辅助损失的分割掩膜。
MMSegmentation 封装 BaseModel
并实现了 BaseSegmentor 类,主要提供 forward
、train_step
、val_step
和 test_step
接口。接下来将详细介绍这些接口。
forward
方法返回训练、验证、测试和简单推理过程的损失或预测。
该方法应接受三种模式:“tensor”、“predict” 和 “loss”:
nn.Module
相同。SegDataSample
列表中。字典
。注:SegDataSample 是 MMSegmentation 的数据结构接口,用作不同组件之间的接口。SegDataSample
实现了抽象数据元素 mmengine.structures.BaseDataElement
,请参阅 MMMEngine 中的 SegDataSample 文档和数据元素文档了解更多信息。
注意,此方法不处理在 train_step
方法中完成的反向传播或优化器更新。
参数:
metainfo
和 gt_sem_seg
等信息。默认值为 None。返回值:
dict
或 list
:
mode == "loss"
,则返回用于反向过程和日志记录的损失张量字典
。mode == "predict"
,则返回 SegDataSample
的列表
,推理结果将被递增地添加到传递给 forward 方法的 data_sample
参数中,每个 SegDataSeample
包含以下关键词:
PixelData
):语义分割的预测。PixelData
):标准化前语义分割的预测指标。mode == "tensor"
,则返回张量
或张量数组
的字典
以供自定义使用。我们在配置文档中简要描述了模型配置的字段,这里我们详细介绍 model.test_cfg
字段。model.test_cfg
用于控制前向行为,"predict"
模式下的 forward
方法可以在两种模式下运行:
whole_inference
:如果 cfg.model.test_cfg.mode == 'whole'
,则模型将使用完整图像进行推理。
whole_inference
模式的一个示例配置:
model = dict(
type='EncoderDecoder'
...
test_cfg=dict(mode='whole')
)
slide_inference
:如果 cfg.model.test_cfg.mode == ‘slide’
,则模型将通过滑动窗口进行推理。注意: 如果选择 slide
模式,还应指定 cfg.model.test_cfg.stride
和 cfg.model.test_cfg.crop_size
。
slide_inference
模式的一个示例配置:
model = dict(
type='EncoderDecoder'
...
test_cfg=dict(mode='slide', crop_size=256, stride=170)
)
train_step
方法调用 loss
模式的前向接口以获得损失字典
。BaseModel
类实现默认的模型训练过程,包括预处理、模型前向传播、损失计算、优化和反向传播。
参数:
inputs
和 data_samples
两个字段。注:OptimWrapper 提供了一个用于更新参数的通用接口,请参阅 MMMEngine 中的优化器封装文档了解更多信息。
返回值:
-Dict[str, torch.Tensor
]:用于记录日志的张量的字典
。
val_step
方法调用 predict
模式的前向接口并返回预测结果,预测结果将进一步被传递给评测器的进程接口和钩子的 after_val_inter
接口。
参数:
dict
or tuple
or list
) - 从数据集中采样的数据。在 MMSegmentation 中,数据字典包含 inputs
和 data_samples
两个字段。返回值:
list
- 给定数据的预测结果。BaseModel
中 test_step
与 val_step
的实现相同。
MMSegmentation 实现的 SegDataPreProcessor 继承自由 MMEngine 实现的 BaseDataPreprocessor,提供数据预处理和将数据复制到目标设备的功能。
Runner 在构建阶段将模型传送到指定的设备,而 SegDataPreProcessor 在 train_step
、val_step
和 test_step
中将数据传送到指定设备,之后处理后的数据将被进一步传递给模型。
SegDataPreProcessor
构造函数的参数:
数据将按如下方式处理:
pad_val
将输入填充到输入大小,并用定义的 seg_Pad_val
填充分割图。forward
方法的参数:
forward
方法的返回值:
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )