1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/monkeycc-mmsegmentation

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Клонировать/Скачать
cc_head.py 1.3 КБ
Копировать Редактировать Web IDE Исходные данные Просмотреть построчно История
谢昕辰 Отправлено 19.09.2022 09:06 230246f
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.registry import MODELS
from .fcn_head import FCNHead
try:
from mmcv.ops import CrissCrossAttention
except ModuleNotFoundError:
CrissCrossAttention = None
@MODELS.register_module()
class CCHead(FCNHead):
"""CCNet: Criss-Cross Attention for Semantic Segmentation.
This head is the implementation of `CCNet
<https://arxiv.org/abs/1811.11721>`_.
Args:
recurrence (int): Number of recurrence of Criss Cross Attention
module. Default: 2.
"""
def __init__(self, recurrence=2, **kwargs):
if CrissCrossAttention is None:
raise RuntimeError('Please install mmcv-full for '
'CrissCrossAttention ops')
super().__init__(num_convs=2, **kwargs)
self.recurrence = recurrence
self.cca = CrissCrossAttention(self.channels)
def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
output = self.convs[0](x)
for _ in range(self.recurrence):
output = self.cca(output)
output = self.convs[1](output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://api.gitlife.ru/oschina-mirror/monkeycc-mmsegmentation.git
git@api.gitlife.ru:oschina-mirror/monkeycc-mmsegmentation.git
oschina-mirror
monkeycc-mmsegmentation
monkeycc-mmsegmentation
main