Мы в основном классифицируем компоненты модели на 5 типов.
Создать новый файл mmtrack/models/mot/trackers/my_tracker.py
.
Реализовать BaseTracker
, который предоставляет основные API для поддержания треков на видео. Мы рекомендуем наследовать новый трекер от него. Пользователи могут обратиться к документации BaseTracker
для получения подробной информации.
from mmtrack.models import TRACKERS
from .base_tracker import BaseTracker
@TRACKERS.register_module()
class MyTracker(BaseTracker):
def __init__(self,
arg1,
arg2,
*args,
**kwargs):
super().__init__(*args, **kwargs)
pass
def track(self, inputs):
# реализация игнорируется
pass
Можно добавить следующую строку в mmtrack/models/mot/trackers/__init__.py
:
from .my_tracker import MyTracker
или альтернативно добавить:
custom_imports = dict(
imports=['mmtrack.models.mot.trackers.my_tracker.py'],
allow_failed_imports=False)
в конфигурационный файл и избежать изменения исходного кода.
tracker=dict(
type='MyTracker',
arg1=xxx,
arg2=xxx)
Пожалуйста, обратитесь к учебнику по mmdetection для разработки нового детектора.
Создать новый файл mmtrack/models/motion/my_flownet.py
.
Вы можете наследовать модель движения от nn.Module
, если это модуль глубокого обучения, и от object
, если нет.
from ..builder import MOTION
@MOTION.register_module()
class MyFlowNet(nn.Module):
def __init__(self,
arg1,
arg2):
pass
def forward(self, inputs):
# реализация игнорируется
pass
Можно добавить следующую строку в mmtrack/models/motion/__init__.py
:
from .my_flownet import MyFlowNet
или альтернативно добавить:
custom_imports = dict(
imports=['mmtrack.models.motion.my_flownet.py'],
allow_failed_imports=False)
в конфигурационный файл и избежать изменения исходного кода.
motion=dict(
type='MyFlowNet',
arg1=xxx,
arg2=xxx)
Создать новый файл mmtrack/models/reid/my_reid.py
.
from ..builder import REID
@REID.register_module()
class MyReID(nn.Module):
def __init__(self,
arg1,
arg2):
pass
def forward(self, inputs):
# реализация игнорируется
pass
Можно добавить следующую строку в mmtrack/models/reid/__init__.py
:
from .my_reid import MyReID
или альтернативно добавить:
custom_imports = dict(
imports=['mmtrack.models.reid.my_reid.py'],
allow_failed_imports=False)
в конфигурационный файл и избежать изменения исходного кода.
reid=dict(
type='MyReID',
arg1=xxx,
arg2=xxx)
Создать новый файл mmtrack/models/track_heads/my_head.py
.
from mmdet.models import HEADS
@HEADS.register_module()
class MyHead(nn.Module):
def
*Здесь текст обрывается.* ```
__init__(self,
arg1,
arg2):
pass
def forward(self, inputs):
# implementation is ignored
pass
Вы можете либо добавить следующую строку в mmtrack/models/track_heads/__init__.py
,
from .my_head import MyHead
либо добавить
custom_imports = dict(
imports=['mmtrack.models.track_heads.my_head.py'],
allow_failed_imports=False)
в конфигурационный файл и избежать изменения исходного кода.
track_head=dict(
type='MyHead',
arg1=xxx,
arg2=xxx)
Предположим, вы хотите добавить новую потерю как MyLoss
для регрессии ограничивающего прямоугольника. Чтобы добавить новую функцию потерь, пользователи должны реализовать её в mmtrack/models/losses/my_loss.py
. Декоратор weighted_loss
позволяет взвешивать потери для каждого элемента.
import torch
import torch.nn as nn
from ..builder import LOSSES
from .utils import weighted_loss
@weighted_loss
def my_loss(pred, target):
assert pred.size() == target.size() and target.numel() > 0
loss = torch.abs(pred - target)
return loss
@LOSSES.register_module()
class MyLoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super(MyLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * my_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss_bbox
Затем пользователи должны добавить его в mmtrack/models/losses/__init__.py
.
from .my_loss import MyLoss, my_loss
В качестве альтернативы вы можете добавить
custom_imports=dict(
imports=['mmtrack.models.losses.my_loss'])
в конфигурационный файл для достижения той же цели.
Чтобы использовать его, измените поле loss_xxx
. Поскольку MyLoss
предназначен для регрессии, вам необходимо изменить поле loss_bbox
в заголовке.
loss_bbox=dict(type='MyLoss', loss_weight=1.0))
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )