1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/open-mmlab-mmgeneration

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Клонировать/Скачать
customize_losses.md 6 КБ
Копировать Редактировать Web IDE Исходные данные Просмотреть построчно История
gitlife-traslator Отправлено 30.11.2024 07:34 ccb7da8

Tutorial 4: Дизайн наших модулей потерь

Как было показано в последнем уроке по настройке моделей, потери рассматриваются как МОДУЛИ в MMGeneration. Настройка потерь аналогична настройке любых других моделей. Этот раздел в основном предназначен для разъяснения дизайна модулей потерь в нашем репозитории. Важно отметить, что при написании собственных модулей потерь вы должны следовать тому же дизайну, чтобы новый модуль потерь можно было без дополнительных усилий внедрить в нашу структуру.

Дизайн модулей потерь

В общем случае для реализации модуля потерь мы напишем реализацию функции, а затем обернём её классом. Однако в MMGeneration мы предоставляем ещё один унифицированный интерфейс data_info, который позволяет пользователям определять сопоставление между входным аргументом и элементами данных.

@weighted_loss
def disc_shift_loss(pred):
    return pred**2

@MODULES.register_module()
class DiscShiftLoss(nn.Module):

    def __init__(self, loss_weight=1.0, data_info=None):
        super(DiscShiftLoss, self).__init__()
        # коды можно найти в mmgen/models/losses/disc_auxiliary_loss.py

    def forward(self, *args, **kwargs):
        # коды можно найти в mmgen/models/losses/disc_auxiliary_loss.py

Цель этого дизайна для модулей потерь — позволить автоматически использовать их в генеративных моделях (MODELS) без других сложных кодов для определения сопоставления между данными и ключевыми аргументами. Таким образом, в отличие от других фреймворков в OpenMMLab, наши модули потерь содержат специальное ключевое слово data_info — это словарь, определяющий сопоставление между входными аргументами и данными из генеративных моделей. Возьмём в качестве примера DiscShiftLoss. При написании файла конфигурации пользователи могут использовать эту потерю следующим образом:

dict(type='DiscShiftLoss',
    loss_weight=0.001 * 0.5,
    data_info=dict(pred='disc_pred_real')

Информация в data_info сообщает модулю использовать данные disc_pred_real в качестве входного тензора для аргументов pred. Как только data_info не равно None, наш модуль потерь автоматически построит вычислительный граф.

@MODULES.register_module()
class DiscShiftLoss(nn.Module):

    def __init__(self, loss_weight=1.0, data_info=None):
        super(DiscShiftLoss, self).__init__()
        self.loss_weight = loss_weight
        self.data_info = data_info

    def forward(self, *args, **kwargs):
        # использовать data_info для построения вычислительного пути
        if self.data_info is not None:
            # разобрать args и kwargs
            if len(args) == 1:
                assert isinstance(args[0], dict), (
                    'Вы должны предоставить словарь, содержащий выходные данные сети '
                    'для построения вычислительного графа этого модуля потерь.')
                outputs_dict = args[0]
            elif 'outputs_dict' in kwargs:
                assert len(args) == 0, (
                    'Если словарь выходных данных указан в аргументах с ключевыми словами, '
                    'дальнейшие аргументы без ключевых слов не должны предлагаться.')
                outputs_dict = kwargs.pop('outputs_dict')
            else:
                raise NotImplementedError(
                    'Невозможно проанализировать ваши аргументы, переданные этому модулю потерь.'
                    ' Пожалуйста, проверьте использование этого модуля')
            # связать выходы с входными аргументами потерь согласно self.data_info
            loss_input_dict = {
                k: outputs_dict[v]
                for k, v in self.data_info.items()
            }
            kwargs.update(loss_input_dict)
            kwargs.update(dict(weight=self.loss_weight))
            return disc_shift_loss(**kwargs)
        else:
            # если вы не определили, как построить вычислительный граф, этот
            # модуль просто вернёт потери, как обычно.
            return disc_shift_loss(*args, weight=self.loss_weight, **kwargs)

    @staticmethod
    def loss_name():
        return 'loss_disc_shift'

Как показано в этой части кода, как только пользователи установят data_info, модуль потерь получит словарь, содержащий все необходимые данные.

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://api.gitlife.ru/oschina-mirror/open-mmlab-mmgeneration.git
git@api.gitlife.ru:oschina-mirror/open-mmlab-mmgeneration.git
oschina-mirror
open-mmlab-mmgeneration
open-mmlab-mmgeneration
master