Слияние кода завершено, страница обновится автоматически
from stylegan import G_synthesis,G_mapping
from dataclasses import dataclass
from SphericalOptimizer import SphericalOptimizer
from pathlib import Path
import numpy as np
import time
import torch
from loss import LossBuilder
from functools import partial
from drive import open_url
class PULSE(torch.nn.Module):
def __init__(self, cache_dir, verbose=True):
super(PULSE, self).__init__()
self.synthesis = G_synthesis().cuda()
self.verbose = verbose
cache_dir = Path(cache_dir)
cache_dir.mkdir(parents=True, exist_ok = True)
if self.verbose: print("Loading Synthesis Network")
with open_url("https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8", cache_dir=cache_dir, verbose=verbose) as f:
self.synthesis.load_state_dict(torch.load(f))
for param in self.synthesis.parameters():
param.requires_grad = False
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2)
if Path("gaussian_fit.pt").exists():
self.gaussian_fit = torch.load("gaussian_fit.pt")
else:
if self.verbose: print("\tLoading Mapping Network")
mapping = G_mapping().cuda()
with open_url("https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k", cache_dir=cache_dir, verbose=verbose) as f:
mapping.load_state_dict(torch.load(f))
if self.verbose: print("\tRunning Mapping Network")
with torch.no_grad():
torch.manual_seed(0)
latent = torch.randn((1000000,512),dtype=torch.float32, device="cuda")
latent_out = torch.nn.LeakyReLU(5)(mapping(latent))
self.gaussian_fit = {"mean": latent_out.mean(0), "std": latent_out.std(0)}
torch.save(self.gaussian_fit,"gaussian_fit.pt")
if self.verbose: print("\tSaved \"gaussian_fit.pt\"")
def forward(self, ref_im,
seed,
loss_str,
eps,
noise_type,
num_trainable_noise_layers,
tile_latent,
bad_noise_layers,
opt_name,
learning_rate,
steps,
lr_schedule,
save_intermediate,
**kwargs):
if seed:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
batch_size = ref_im.shape[0]
# Generate latent tensor
if(tile_latent):
latent = torch.randn(
(batch_size, 1, 512), dtype=torch.float, requires_grad=True, device='cuda')
else:
latent = torch.randn(
(batch_size, 18, 512), dtype=torch.float, requires_grad=True, device='cuda')
# Generate list of noise tensors
noise = [] # stores all of the noise tensors
noise_vars = [] # stores the noise tensors that we want to optimize on
for i in range(18):
# dimension of the ith noise tensor
res = (batch_size, 1, 2**(i//2+2), 2**(i//2+2))
if(noise_type == 'zero' or i in [int(layer) for layer in bad_noise_layers.split('.')]):
new_noise = torch.zeros(res, dtype=torch.float, device='cuda')
new_noise.requires_grad = False
elif(noise_type == 'fixed'):
new_noise = torch.randn(res, dtype=torch.float, device='cuda')
new_noise.requires_grad = False
elif (noise_type == 'trainable'):
new_noise = torch.randn(res, dtype=torch.float, device='cuda')
if (i < num_trainable_noise_layers):
new_noise.requires_grad = True
noise_vars.append(new_noise)
else:
new_noise.requires_grad = False
else:
raise Exception("unknown noise type")
noise.append(new_noise)
var_list = [latent]+noise_vars
opt_dict = {
'sgd': torch.optim.SGD,
'adam': torch.optim.Adam,
'sgdm': partial(torch.optim.SGD, momentum=0.9),
'adamax': torch.optim.Adamax
}
opt_func = opt_dict[opt_name]
opt = SphericalOptimizer(opt_func, var_list, lr=learning_rate)
schedule_dict = {
'fixed': lambda x: 1,
'linear1cycle': lambda x: (9*(1-np.abs(x/steps-1/2)*2)+1)/10,
'linear1cycledrop': lambda x: (9*(1-np.abs(x/(0.9*steps)-1/2)*2)+1)/10 if x < 0.9*steps else 1/10 + (x-0.9*steps)/(0.1*steps)*(1/1000-1/10),
}
schedule_func = schedule_dict[lr_schedule]
scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func)
loss_builder = LossBuilder(ref_im, loss_str, eps).cuda()
min_loss = np.inf
min_l2 = np.inf
best_summary = ""
start_t = time.time()
gen_im = None
if self.verbose: print("Optimizing")
for j in range(steps):
opt.opt.zero_grad()
# Duplicate latent in case tile_latent = True
if (tile_latent):
latent_in = latent.expand(-1, 18, -1)
else:
latent_in = latent
# Apply learned linear mapping to match latent distribution to that of the mapping network
latent_in = self.lrelu(latent_in*self.gaussian_fit["std"] + self.gaussian_fit["mean"])
# Normalize image to [0,1] instead of [-1,1]
gen_im = (self.synthesis(latent_in, noise)+1)/2
# Calculate Losses
loss, loss_dict = loss_builder(latent_in, gen_im)
loss_dict['TOTAL'] = loss
# Save best summary for log
if(loss < min_loss):
min_loss = loss
best_summary = f'BEST ({j+1}) | '+' | '.join(
[f'{x}: {y:.4f}' for x, y in loss_dict.items()])
best_im = gen_im.clone()
loss_l2 = loss_dict['L2']
if(loss_l2 < min_l2):
min_l2 = loss_l2
# Save intermediate HR and LR images
if(save_intermediate):
yield (best_im.cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1))
loss.backward()
opt.step()
scheduler.step()
total_t = time.time()-start_t
current_info = f' | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}'
if self.verbose: print(best_summary+current_info)
if(min_l2 <= eps):
yield (gen_im.clone().cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1))
else:
print("Could not find a face that downscales correctly within epsilon")
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )