1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/niuyongjie-warped-ganweight

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Клонировать/Скачать
pinggu.py 3.3 КБ
Копировать Редактировать Web IDE Исходные данные Просмотреть построчно История
niuyongjie Отправлено 10.05.2022 10:11 34fe888
from matplotlib import image
from inference import GeneratorWithWeightDeformator
from loading import load_generator
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch_tools.visualization import to_image_grid
from PIL import Image
from inference import GeneratorWithFixedWeightDeformation
torch.manual_seed(2)
print("开始评估......")
with torch.no_grad():
zs = torch.randn([4, 512], device='cuda')
G = load_generator(
args={'resolution': 1024, 'gan_type': 'StyleGAN2'},
G_weights='models/pretrained/generators/StyleGAN2/stylegan2-ffhq-config-f.pt')
imgs_o=G(zs).cpu()
G = GeneratorWithWeightDeformator(
generator=G,
deformator_type='svd_rectification',
layer_ix=3,
checkpoint_path='results/checkpoint_y.pt',
)
shift = 3500.0
direction=26
G.deformate(direction, shift)
imgs_deformated=G(zs).cpu()
G.save_deformation("results/temp/deformator_17.pt",direction,shift)
G = load_generator(
args={'resolution': 1024, 'gan_type': 'StyleGAN2'},
G_weights='models/pretrained/generators/StyleGAN2/stylegan2-ffhq-config-f.pt')
G = GeneratorWithFixedWeightDeformation(generator=G, deformation_path='results/temp/deformator_17.pt')
G.deformate(0.5 * G.scale)
imgs_deformated1 = G(zs).cpu()
print("层数是:",G.layer_index,"方法系数为:",G.scale)
imgs_batch = []
imgs_batch.append(torch.cat([imgs_o,imgs_deformated,imgs_deformated1]))
imgs_grid = torch.cat([t for t in torch.stack(imgs_batch).transpose(0, 1)])
plt.figure(figsize=(7, len(zs)), dpi=200)
plt.axis('off')
plt.imshow(to_image_grid(torch.clamp(imgs_grid, -1, 1), nrow=4))
plt.show()
# for direction in range(64):
# G.deformate(direction, shift)
# G.save_deformation("results/temp/deformate{}.pt".format(direction),direction)
# G = load_generator(
# args={'resolution': 1024, 'gan_type': 'StyleGAN2'},
# G_weights='models/pretrained/generators/StyleGAN2/stylegan2-ffhq-config-f.pt'
# )
# source = 'results/temp/deformate7.pt' #eyes_size.pt
# G = GeneratorWithWeightDeformator(G, 'svd_rectification',3,"")
# print(source,"的层数是:",G.layer_index,"方法系数为:",G.scale)
# # Generate some samples
# zs = torch.randn((2, 512)).cuda()
# images=G(zs).cpu()
# G.deformate(1.0* G.scale)
# imgs_deformated = G(zs).cpu()
# imgs_batch = []
# imgs_batch.append(torch.cat([images,imgs_deformated]))
# imgs_grid = torch.cat([t for t in torch.stack(imgs_batch).transpose(0, 1)])
# plt.figure(figsize=(7, len(zs)), dpi=200)
# plt.axis('off')
# plt.imshow(to_image_grid(torch.clamp(imgs_grid, -1, 1), nrow=2))
# plt.show()
# zs = torch.randn([4, 512], device='cuda')
# imgs = []
# n_steps = 7
# with torch.no_grad():
# for i_scale in np.arange(n_steps):
# scale = 2.0 * float(i_scale) / (n_steps - 1) - 1.
# G.deformate(scale * G.scale)
# batch_size = 4
# imgs_batch = []
# for i in np.arange(0, len(zs), batch_size):
# imgs_batch.append(G(zs[i: i + batch_size]).cpu())
# imgs.append(torch.cat(imgs_batch))
# imgs_grid = torch.cat([t for t in torch.stack(imgs).transpose(0, 1)])
# plt.figure(figsize=(n_steps, len(zs)), dpi=150)
# plt.axis('off')
# plt.imshow(to_image_grid(torch.clamp(imgs_grid, -1, 1), nrow=n_steps))
# plt.show()

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://api.gitlife.ru/oschina-mirror/niuyongjie-warped-ganweight.git
git@api.gitlife.ru:oschina-mirror/niuyongjie-warped-ganweight.git
oschina-mirror
niuyongjie-warped-ganweight
niuyongjie-warped-ganweight
master