1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/saterr-pulse

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
В этом репозитории не указан файл с открытой лицензией (LICENSE). При использовании обратитесь к конкретному описанию проекта и его зависимостям в коде.
Клонировать/Скачать
PULSE.py 6.7 КБ
Копировать Редактировать Web IDE Исходные данные Просмотреть построчно История
Alex Damian Отправлено 21.06.2020 04:48 ece9211
from stylegan import G_synthesis,G_mapping
from dataclasses import dataclass
from SphericalOptimizer import SphericalOptimizer
from pathlib import Path
import numpy as np
import time
import torch
from loss import LossBuilder
from functools import partial
from drive import open_url
class PULSE(torch.nn.Module):
def __init__(self, cache_dir, verbose=True):
super(PULSE, self).__init__()
self.synthesis = G_synthesis().cuda()
self.verbose = verbose
cache_dir = Path(cache_dir)
cache_dir.mkdir(parents=True, exist_ok = True)
if self.verbose: print("Loading Synthesis Network")
with open_url("https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8", cache_dir=cache_dir, verbose=verbose) as f:
self.synthesis.load_state_dict(torch.load(f))
for param in self.synthesis.parameters():
param.requires_grad = False
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2)
if Path("gaussian_fit.pt").exists():
self.gaussian_fit = torch.load("gaussian_fit.pt")
else:
if self.verbose: print("\tLoading Mapping Network")
mapping = G_mapping().cuda()
with open_url("https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k", cache_dir=cache_dir, verbose=verbose) as f:
mapping.load_state_dict(torch.load(f))
if self.verbose: print("\tRunning Mapping Network")
with torch.no_grad():
torch.manual_seed(0)
latent = torch.randn((1000000,512),dtype=torch.float32, device="cuda")
latent_out = torch.nn.LeakyReLU(5)(mapping(latent))
self.gaussian_fit = {"mean": latent_out.mean(0), "std": latent_out.std(0)}
torch.save(self.gaussian_fit,"gaussian_fit.pt")
if self.verbose: print("\tSaved \"gaussian_fit.pt\"")
def forward(self, ref_im,
seed,
loss_str,
eps,
noise_type,
num_trainable_noise_layers,
tile_latent,
bad_noise_layers,
opt_name,
learning_rate,
steps,
lr_schedule,
save_intermediate,
**kwargs):
if seed:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
batch_size = ref_im.shape[0]
# Generate latent tensor
if(tile_latent):
latent = torch.randn(
(batch_size, 1, 512), dtype=torch.float, requires_grad=True, device='cuda')
else:
latent = torch.randn(
(batch_size, 18, 512), dtype=torch.float, requires_grad=True, device='cuda')
# Generate list of noise tensors
noise = [] # stores all of the noise tensors
noise_vars = [] # stores the noise tensors that we want to optimize on
for i in range(18):
# dimension of the ith noise tensor
res = (batch_size, 1, 2**(i//2+2), 2**(i//2+2))
if(noise_type == 'zero' or i in [int(layer) for layer in bad_noise_layers.split('.')]):
new_noise = torch.zeros(res, dtype=torch.float, device='cuda')
new_noise.requires_grad = False
elif(noise_type == 'fixed'):
new_noise = torch.randn(res, dtype=torch.float, device='cuda')
new_noise.requires_grad = False
elif (noise_type == 'trainable'):
new_noise = torch.randn(res, dtype=torch.float, device='cuda')
if (i < num_trainable_noise_layers):
new_noise.requires_grad = True
noise_vars.append(new_noise)
else:
new_noise.requires_grad = False
else:
raise Exception("unknown noise type")
noise.append(new_noise)
var_list = [latent]+noise_vars
opt_dict = {
'sgd': torch.optim.SGD,
'adam': torch.optim.Adam,
'sgdm': partial(torch.optim.SGD, momentum=0.9),
'adamax': torch.optim.Adamax
}
opt_func = opt_dict[opt_name]
opt = SphericalOptimizer(opt_func, var_list, lr=learning_rate)
schedule_dict = {
'fixed': lambda x: 1,
'linear1cycle': lambda x: (9*(1-np.abs(x/steps-1/2)*2)+1)/10,
'linear1cycledrop': lambda x: (9*(1-np.abs(x/(0.9*steps)-1/2)*2)+1)/10 if x < 0.9*steps else 1/10 + (x-0.9*steps)/(0.1*steps)*(1/1000-1/10),
}
schedule_func = schedule_dict[lr_schedule]
scheduler = torch.optim.lr_scheduler.LambdaLR(opt.opt, schedule_func)
loss_builder = LossBuilder(ref_im, loss_str, eps).cuda()
min_loss = np.inf
min_l2 = np.inf
best_summary = ""
start_t = time.time()
gen_im = None
if self.verbose: print("Optimizing")
for j in range(steps):
opt.opt.zero_grad()
# Duplicate latent in case tile_latent = True
if (tile_latent):
latent_in = latent.expand(-1, 18, -1)
else:
latent_in = latent
# Apply learned linear mapping to match latent distribution to that of the mapping network
latent_in = self.lrelu(latent_in*self.gaussian_fit["std"] + self.gaussian_fit["mean"])
# Normalize image to [0,1] instead of [-1,1]
gen_im = (self.synthesis(latent_in, noise)+1)/2
# Calculate Losses
loss, loss_dict = loss_builder(latent_in, gen_im)
loss_dict['TOTAL'] = loss
# Save best summary for log
if(loss < min_loss):
min_loss = loss
best_summary = f'BEST ({j+1}) | '+' | '.join(
[f'{x}: {y:.4f}' for x, y in loss_dict.items()])
best_im = gen_im.clone()
loss_l2 = loss_dict['L2']
if(loss_l2 < min_l2):
min_l2 = loss_l2
# Save intermediate HR and LR images
if(save_intermediate):
yield (best_im.cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1))
loss.backward()
opt.step()
scheduler.step()
total_t = time.time()-start_t
current_info = f' | time: {total_t:.1f} | it/s: {(j+1)/total_t:.2f} | batchsize: {batch_size}'
if self.verbose: print(best_summary+current_info)
if(min_l2 <= eps):
yield (gen_im.clone().cpu().detach().clamp(0, 1),loss_builder.D(best_im).cpu().detach().clamp(0, 1))
else:
print("Could not find a face that downscales correctly within epsilon")

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://api.gitlife.ru/oschina-mirror/saterr-pulse.git
git@api.gitlife.ru:oschina-mirror/saterr-pulse.git
oschina-mirror
saterr-pulse
saterr-pulse
master