Слияние кода завершено, страница обновится автоматически
from matplotlib import image
from inference import GeneratorWithWeightDeformator
from loading import load_generator
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch_tools.visualization import to_image_grid
from PIL import Image
from inference import GeneratorWithFixedWeightDeformation
torch.manual_seed(2)
print("开始评估......")
with torch.no_grad():
zs = torch.randn([4, 512], device='cuda')
G = load_generator(
args={'resolution': 1024, 'gan_type': 'StyleGAN2'},
G_weights='models/pretrained/generators/StyleGAN2/stylegan2-ffhq-config-f.pt')
imgs_o=G(zs).cpu()
G = GeneratorWithWeightDeformator(
generator=G,
deformator_type='svd_rectification',
layer_ix=3,
checkpoint_path='results/checkpoint_y.pt',
)
shift = 3500.0
direction=26
G.deformate(direction, shift)
imgs_deformated=G(zs).cpu()
G.save_deformation("results/temp/deformator_17.pt",direction,shift)
G = load_generator(
args={'resolution': 1024, 'gan_type': 'StyleGAN2'},
G_weights='models/pretrained/generators/StyleGAN2/stylegan2-ffhq-config-f.pt')
G = GeneratorWithFixedWeightDeformation(generator=G, deformation_path='results/temp/deformator_17.pt')
G.deformate(0.5 * G.scale)
imgs_deformated1 = G(zs).cpu()
print("层数是:",G.layer_index,"方法系数为:",G.scale)
imgs_batch = []
imgs_batch.append(torch.cat([imgs_o,imgs_deformated,imgs_deformated1]))
imgs_grid = torch.cat([t for t in torch.stack(imgs_batch).transpose(0, 1)])
plt.figure(figsize=(7, len(zs)), dpi=200)
plt.axis('off')
plt.imshow(to_image_grid(torch.clamp(imgs_grid, -1, 1), nrow=4))
plt.show()
# for direction in range(64):
# G.deformate(direction, shift)
# G.save_deformation("results/temp/deformate{}.pt".format(direction),direction)
# G = load_generator(
# args={'resolution': 1024, 'gan_type': 'StyleGAN2'},
# G_weights='models/pretrained/generators/StyleGAN2/stylegan2-ffhq-config-f.pt'
# )
# source = 'results/temp/deformate7.pt' #eyes_size.pt
# G = GeneratorWithWeightDeformator(G, 'svd_rectification',3,"")
# print(source,"的层数是:",G.layer_index,"方法系数为:",G.scale)
# # Generate some samples
# zs = torch.randn((2, 512)).cuda()
# images=G(zs).cpu()
# G.deformate(1.0* G.scale)
# imgs_deformated = G(zs).cpu()
# imgs_batch = []
# imgs_batch.append(torch.cat([images,imgs_deformated]))
# imgs_grid = torch.cat([t for t in torch.stack(imgs_batch).transpose(0, 1)])
# plt.figure(figsize=(7, len(zs)), dpi=200)
# plt.axis('off')
# plt.imshow(to_image_grid(torch.clamp(imgs_grid, -1, 1), nrow=2))
# plt.show()
# zs = torch.randn([4, 512], device='cuda')
# imgs = []
# n_steps = 7
# with torch.no_grad():
# for i_scale in np.arange(n_steps):
# scale = 2.0 * float(i_scale) / (n_steps - 1) - 1.
# G.deformate(scale * G.scale)
# batch_size = 4
# imgs_batch = []
# for i in np.arange(0, len(zs), batch_size):
# imgs_batch.append(G(zs[i: i + batch_size]).cpu())
# imgs.append(torch.cat(imgs_batch))
# imgs_grid = torch.cat([t for t in torch.stack(imgs).transpose(0, 1)])
# plt.figure(figsize=(n_steps, len(zs)), dpi=150)
# plt.axis('off')
# plt.imshow(to_image_grid(torch.clamp(imgs_grid, -1, 1), nrow=n_steps))
# plt.show()
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )