1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/open-mmlab-mmflow

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Клонировать/Скачать
4_new_modules.md 4.7 КБ
Копировать Редактировать Web IDE Исходные данные Просмотреть построчно История
gitlife-traslator Отправлено 29.11.2024 01:34 0fe9dae

Учебник 4: Добавление новых модулей

MMFlow разделяет метод оценки потока flow_estimator на encoder и decoder. Этот учебник посвящён тому, как добавлять новые компоненты.

Добавить новый encoder

  1. Создайте новый файл mmflow/models/encoders/my_model.py.
from mmcv.runner import BaseModule

from ..builder import ENCODERS

@ENCODERS.register_module()
class MyModel(BaseModule):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, x):  # должен возвращать кортеж
        pass

    def init_weights(self, pretrained=None):
        pass
  1. Импортируйте модуль в mmflow/models/encoders/__init__.py.
from .my_model import MyModel

Добавить новый decoder

  1. Создайте новый файл mmflow/models/decoders/my_decoder.py.

Вы можете написать новую голову, наследуя от BaseModule из MMCV, и переписать методы forward(self, x), forward_train и forward_test. У нас есть унифицированный интерфейс для инициализации весов в MMCV. Вы можете использовать init_cfg, чтобы указать функцию инициализации и аргументы, или переписать init_weigths, если вы предпочитаете индивидуальную инициализацию.

from ..builder import DECODERS


@DECODERS.register_module()
class MyDecoder(BaseModule):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, *args):
        pass

    # опционально
    def init_weights(self):
        pass

    def forward_train(self, *args, flow_gt):
        flow_pred = self.forward(*args)
        return self.losses(flow_pred, flow_gt)

    def forward_test(self,*args, img_metas):
        flow_pred = self.forward(*args)
        return self.get_flow(flow_pred, img_metas)

losses — это функция потерь для вычисления потерь между выходом модели и целью, get_flow реализован в BaseDecoder для восстановления формы потока в исходную форму входных изображений.

  1. Импортируйте модуль в mmflow/models/decoders/__init__.py
from .my_decoder import MyDecoder

Добавить новый flow_estimator

  1. Создайте новый файл mmflow/models/flow_estimators/my_estimator.py

Вы можете написать новый оценщик потока, наследуя от FlowEstimator, например PWC-Net, и реализовать forward_train и forward_test

from ..builder import FLOW_ESTIMATORS
from .base import FlowEstimator


@FLOW_ESTIMATORS.register_module()
class MyEstimator(FlowEstimator):

    def __init__(self, arg1, arg2):
        pass

    def forward_train(self, imgs):
        pass

    def forward_test(self, imgs):
        pass
  1. Импортируйте модуль в mmflow/models/flow_estimator/__init__.py
from .my_estimator import MyEstimator
  1. Используйте его в своём конфигурационном файле.

Мы устанавливаем тип модуля как MyEstimator.

model = dict(
    type='MyEstimator',
    encoder=dict(
        type='MyModel',
        arg1=xxx,
        arg2=xxx),
    decoder=dict(
        type='MyDecoder',
        arg1=xxx,
        arg2=xxx))

Добавить новую loss

Предположим, вы хотите добавить новую потерю как MyLoss для оценки потока. Чтобы добавить новую функцию потерь, пользователи должны реализовать её в mmflow/models/losses/my_loss.py.

import torch
import torch.nn as nn

from mmflow.models import LOSSES

def my_loss(pred, target):
    pass

@LOSSES.register_module()
class MyLoss(nn.Module):

    def __init__(self, arg1):
        super(MyLoss, self).__init__()


    def forward(self, output, target):
        return my_loss(output, target)

Затем пользователи должны добавить его в mmflow/models/losses/__init__.py.

from .my_loss import MyLoss, my_loss

Чтобы использовать его, измените поле flow_loss в модели.

flow_loss=dict(type='MyLoss', use_target_weight=False)

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://api.gitlife.ru/oschina-mirror/open-mmlab-mmflow.git
git@api.gitlife.ru:oschina-mirror/open-mmlab-mmflow.git
oschina-mirror
open-mmlab-mmflow
open-mmlab-mmflow
master