Слияние кода завершено, страница обновится автоматически
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import Variable
from sklearn.decomposition import PCA
import scipy.stats as sps
import numpy as np
from copy import deepcopy
from models.StyleGAN2.model import ModulatedConv2dPatchedFixedBasisDelta
from models.StyleGAN2.model import ModulatedConv2dPatchedSVDBasisDelta
def get_conv_from_generator(generator, conv_ix):
if 'StyleGAN2Wrapper' in generator.__class__.__name__:
return generator.style_gan2.convs[conv_ix].conv
else:
assert NotImplementedError
class WeightDeformatorFixedBasis(nn.Module):
def __init__(self, generator, conv_layer_ix, directions_count,
basis_vectors=None, basis_vectors_path=None):
super(WeightDeformatorFixedBasis, self).__init__()
assert (basis_vectors is not None) or (basis_vectors_path is not None),\
'either basis tensor or basis tensor path must be provided'
# Get conv layer to be hooked
# List is used for this layer not to show up in .parameters()
if basis_vectors is None:
basis_vectors = torch.load(basis_vectors_path)
generator.style_gan2.convs[conv_layer_ix].conv = ModulatedConv2dPatchedFixedBasisDelta(
basis_vectors=basis_vectors.cuda(),
conv_to_patch=generator.style_gan2.convs[conv_layer_ix].conv,
directions_count=directions_count
)
self.hooked_conv_layer = [get_conv_from_generator(generator, conv_layer_ix)]
self.disable_deformation()
def deformate(self, batch_directions, batch_shifts):
self.hooked_conv_layer[0].is_deformated = True
self.hooked_conv_layer[0].batch_directions = batch_directions
self.hooked_conv_layer[0].batch_shifts = batch_shifts
def disable_deformation(self):
self.hooked_conv_layer[0].is_deformated = False
def parameters(self):
return [self.hooked_conv_layer[0].direction_to_basis_coefs]
class WeightDeformatorSVDBasis(nn.Module):
def __init__(self, generator, conv_layer_ix, directions_count):
super(WeightDeformatorSVDBasis, self).__init__()
# Get conv layer to be hooked
#向层里添加了direction_to_eigenvalues_delta属性,这个属性就是方向控制矩阵A
generator.style_gan2.convs[conv_layer_ix].conv = ModulatedConv2dPatchedSVDBasisDelta(
conv_to_patch=generator.style_gan2.convs[conv_layer_ix].conv,
directions_count=directions_count
)
# 列表的作用是在使用.parameters()时,只显示方向控制矩阵A
self.hooked_conv_layer = [get_conv_from_generator(generator, conv_layer_ix)]
self.disable_deformation()
def deformate(self, batch_directions, batch_shifts):
self.hooked_conv_layer[0].is_deformated = True
self.hooked_conv_layer[0].batch_directions = batch_directions
self.hooked_conv_layer[0].batch_shifts = batch_shifts
def disable_deformation(self):
self.hooked_conv_layer[0].is_deformated = False
def parameters(self):
'''
方向矩阵A,维度为[方向,维数],如(64,512)
'''
return [self.hooked_conv_layer[0].direction_to_eigenvalues_delta]
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )