Как было показано в последнем уроке по настройке моделей, потери рассматриваются как МОДУЛИ в MMGeneration. Настройка потерь аналогична настройке любых других моделей. Этот раздел в основном предназначен для разъяснения дизайна модулей потерь в нашем репозитории. Важно отметить, что при написании собственных модулей потерь вы должны следовать тому же дизайну, чтобы новый модуль потерь можно было без дополнительных усилий внедрить в нашу структуру.
В общем случае для реализации модуля потерь мы напишем реализацию функции, а затем обернём её классом. Однако в MMGeneration мы предоставляем ещё один унифицированный интерфейс data_info, который позволяет пользователям определять сопоставление между входным аргументом и элементами данных.
@weighted_loss
def disc_shift_loss(pred):
return pred**2
@MODULES.register_module()
class DiscShiftLoss(nn.Module):
def __init__(self, loss_weight=1.0, data_info=None):
super(DiscShiftLoss, self).__init__()
# коды можно найти в mmgen/models/losses/disc_auxiliary_loss.py
def forward(self, *args, **kwargs):
# коды можно найти в mmgen/models/losses/disc_auxiliary_loss.py
Цель этого дизайна для модулей потерь — позволить автоматически использовать их в генеративных моделях (MODELS) без других сложных кодов для определения сопоставления между данными и ключевыми аргументами. Таким образом, в отличие от других фреймворков в OpenMMLab, наши модули потерь содержат специальное ключевое слово data_info — это словарь, определяющий сопоставление между входными аргументами и данными из генеративных моделей. Возьмём в качестве примера DiscShiftLoss. При написании файла конфигурации пользователи могут использовать эту потерю следующим образом:
dict(type='DiscShiftLoss',
loss_weight=0.001 * 0.5,
data_info=dict(pred='disc_pred_real')
Информация в data_info сообщает модулю использовать данные disc_pred_real в качестве входного тензора для аргументов pred. Как только data_info не равно None, наш модуль потерь автоматически построит вычислительный граф.
@MODULES.register_module()
class DiscShiftLoss(nn.Module):
def __init__(self, loss_weight=1.0, data_info=None):
super(DiscShiftLoss, self).__init__()
self.loss_weight = loss_weight
self.data_info = data_info
def forward(self, *args, **kwargs):
# использовать data_info для построения вычислительного пути
if self.data_info is not None:
# разобрать args и kwargs
if len(args) == 1:
assert isinstance(args[0], dict), (
'Вы должны предоставить словарь, содержащий выходные данные сети '
'для построения вычислительного графа этого модуля потерь.')
outputs_dict = args[0]
elif 'outputs_dict' in kwargs:
assert len(args) == 0, (
'Если словарь выходных данных указан в аргументах с ключевыми словами, '
'дальнейшие аргументы без ключевых слов не должны предлагаться.')
outputs_dict = kwargs.pop('outputs_dict')
else:
raise NotImplementedError(
'Невозможно проанализировать ваши аргументы, переданные этому модулю потерь.'
' Пожалуйста, проверьте использование этого модуля')
# связать выходы с входными аргументами потерь согласно self.data_info
loss_input_dict = {
k: outputs_dict[v]
for k, v in self.data_info.items()
}
kwargs.update(loss_input_dict)
kwargs.update(dict(weight=self.loss_weight))
return disc_shift_loss(**kwargs)
else:
# если вы не определили, как построить вычислительный граф, этот
# модуль просто вернёт потери, как обычно.
return disc_shift_loss(*args, weight=self.loss_weight, **kwargs)
@staticmethod
def loss_name():
return 'loss_disc_shift'
Как показано в этой части кода, как только пользователи установят data_info, модуль потерь получит словарь, содержащий все необходимые данные.
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )