1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/rWySp2020-mmtracking

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Клонировать/Скачать
customize_sot_model.md 6.4 КБ
Копировать Редактировать Web IDE Исходные данные Просмотреть построчно История
gitlife-traslator Отправлено 30.11.2024 19:32 8c2b4a3

Настройка моделей SOT

Мы в основном классифицируем компоненты модели на 4 типа.

  • backbone: обычно сеть FCN для извлечения карт признаков, например, ResNet, MobileNet.
  • neck: компонент между backbones и heads, например, ChannelMapper, FPN.
  • head: компонент для конкретных задач, например, отслеживание предсказания bbox.
  • loss: компонент в head для расчёта потерь, например, FocalLoss, L1Loss.

Добавление новых backbones

Здесь мы покажем, как разрабатывать новые компоненты на примере MobileNet.

1. Определите новый backbone (например, MobileNet)

Создайте новый файл mmtrack/models/backbones/mobilenet.py.

import torch.nn as nn

from mmdet.models.builder import BACKBONES


@BACKBONES.register_module()
class MobileNet(nn.Module):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, x):  # должен возвращать кортеж
        pass

    def init_weights(self, pretrained=None):
        pass

2. Импортируйте модуль

Вы можете либо добавить следующую строку в mmtrack/models/backbones/__init__.py:

from .mobilenet import MobileNet

либо альтернативно добавить:

custom_imports = dict(
    imports=['mmtrack.models.backbones.mobilenet'],
    allow_failed_imports=False)

в конфигурационный файл, чтобы избежать изменения исходного кода.

3. Используйте backbone в конфигурационном файле

model = dict(
    ...
    backbone=dict(
        type='MobileNet',
        arg1=xxx,
        arg2=xxx),
    ...

Добавление нового neck

1. Определение шеи (например, MyFPN)

Создайте новый файл mmtrack/models/necks/my_fpn.py.

from mmdet.models.builder import NECKS

@NECKS.register_module()
class MyFPN(nn.Module):

    def __init__(self,
                in_channels,
                out_channels,
                num_outs,
                start_level=0,
                end_level=-1,
                add_extra_convs=False):
        pass

    def forward(self, inputs):
        # реализация игнорируется
        pass

2. Импорт модуля

Вы можете либо добавить следующую строку в mmtrack/models/necks/__init__.py,

from .my_fpn import MyFPN

либо альтернативно добавить:

custom_imports = dict(
    imports=['mmtrack.models.necks.my_fpn.py'],
    allow_failed_ imports=False)

в конфигурационный файл и избежать изменения исходного кода.

3. Модификация конфигурационного файла

neck=dict(
    type='MyFPN',
    in_channels=[256, 512, 1024, 2048],
    out_channels=256,
    num_outs=5)

Добавление новой головы

1. Определение головы (например, MyHead)

Создайте новый файл mmtrack/models/track_heads/my_head.py.

from mmdet.models import HEADS

@HEADS.register_module()
class MyHead(nn.Module):

    def __init__(self,
                arg1,
                arg2):
        pass

    def forward(self, inputs):
        # реализация игнорируется
        pass

2. Импорт модуля

Вы можете либо добавить следующую строку в mmtrack/models/track_heads/__init__.py,

from .my_head import MyHead

либо альтернативно добавить:

custom_imports = dict(
    imports=['mmtrack.models.track_heads.my_head.py'],
    allow_failed_imports=False)

в конфигурационный файл и избежать изменения исходного кода.

3. Модификация конфигурационного файла

head=dict(
    type='MyHead',
    arg1=xxx,
    arg2=xxx)

Добавление новой потери

Предположим, вы хотите добавить новую потерю как MyLoss, для регрессии ограничивающего прямоугольника. Чтобы добавить новую функцию потерь, пользователи должны реализовать её в mmtrack/models/losses/my_loss.py. Декоратор weighted_loss позволяет сделать так, чтобы потеря была взвешена для каждого элемента.

import torch
import torch.nn as nn

from mmdet.models import LOSSES, weighted_loss

@weighted_loss
def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss

@LOSSES.register_module()
class MyLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
``` ```
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
    assert reduction_override in (None, 'none', 'mean', 'sum')
    reduction = (
        reduction_override if reduction_override else self.reduction)
    loss_bbox = self.loss_weight * my_loss(
        pred, target, weight, reduction=reduction, avg_factor=avg_factor)
    return loss_bbox

Затем пользователям необходимо добавить это в mmtrack/models/losses/__init__.py.

from .my_loss import MyLoss, my_loss

Или же можно добавить

custom_imports=dict(
    imports=['mmtrack.models.losses.my_loss'])

в файл конфигурации и достичь той же цели.

Чтобы использовать его, измените поле loss_xxx. Поскольку MyLoss предназначен для регрессии, вам необходимо изменить поле loss_bbox в заголовке.

loss_bbox=dict(type='MyLoss', loss_weight=1.0))

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://api.gitlife.ru/oschina-mirror/rWySp2020-mmtracking.git
git@api.gitlife.ru:oschina-mirror/rWySp2020-mmtracking.git
oschina-mirror
rWySp2020-mmtracking
rWySp2020-mmtracking
master