MMFlow разделяет метод оценки потока flow_estimator
на encoder
и decoder
. Этот учебник посвящён тому, как добавлять новые компоненты.
mmflow/models/encoders/my_model.py
.from mmcv.runner import BaseModule
from ..builder import ENCODERS
@ENCODERS.register_module()
class MyModel(BaseModule):
def __init__(self, arg1, arg2):
pass
def forward(self, x): # должен возвращать кортеж
pass
def init_weights(self, pretrained=None):
pass
mmflow/models/encoders/__init__.py
.from .my_model import MyModel
mmflow/models/decoders/my_decoder.py
.Вы можете написать новую голову, наследуя от BaseModule
из MMCV, и переписать методы forward(self, x)
, forward_train
и forward_test
.
У нас есть унифицированный интерфейс для инициализации весов в MMCV. Вы можете использовать init_cfg
, чтобы указать функцию инициализации и аргументы, или переписать init_weigths
, если вы предпочитаете индивидуальную инициализацию.
from ..builder import DECODERS
@DECODERS.register_module()
class MyDecoder(BaseModule):
def __init__(self, arg1, arg2):
pass
def forward(self, *args):
pass
# опционально
def init_weights(self):
pass
def forward_train(self, *args, flow_gt):
flow_pred = self.forward(*args)
return self.losses(flow_pred, flow_gt)
def forward_test(self,*args, img_metas):
flow_pred = self.forward(*args)
return self.get_flow(flow_pred, img_metas)
losses
— это функция потерь для вычисления потерь между выходом модели и целью, get_flow
реализован в BaseDecoder
для восстановления формы потока в исходную форму входных изображений.
mmflow/models/decoders/__init__.py
from .my_decoder import MyDecoder
mmflow/models/flow_estimators/my_estimator.py
Вы можете написать новый оценщик потока, наследуя от FlowEstimator
, например PWC-Net, и реализовать forward_train
и forward_test
from ..builder import FLOW_ESTIMATORS
from .base import FlowEstimator
@FLOW_ESTIMATORS.register_module()
class MyEstimator(FlowEstimator):
def __init__(self, arg1, arg2):
pass
def forward_train(self, imgs):
pass
def forward_test(self, imgs):
pass
mmflow/models/flow_estimator/__init__.py
from .my_estimator import MyEstimator
Мы устанавливаем тип модуля как MyEstimator
.
model = dict(
type='MyEstimator',
encoder=dict(
type='MyModel',
arg1=xxx,
arg2=xxx),
decoder=dict(
type='MyDecoder',
arg1=xxx,
arg2=xxx))
Предположим, вы хотите добавить новую потерю как MyLoss
для оценки потока. Чтобы добавить новую функцию потерь, пользователи должны реализовать её в mmflow/models/losses/my_loss.py
.
import torch
import torch.nn as nn
from mmflow.models import LOSSES
def my_loss(pred, target):
pass
@LOSSES.register_module()
class MyLoss(nn.Module):
def __init__(self, arg1):
super(MyLoss, self).__init__()
def forward(self, output, target):
return my_loss(output, target)
Затем пользователи должны добавить его в mmflow/models/losses/__init__.py
.
from .my_loss import MyLoss, my_loss
Чтобы использовать его, измените поле flow_loss
в модели.
flow_loss=dict(type='MyLoss', use_target_weight=False)
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )