1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/daib13-TwoStageVAE

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
В этом репозитории не указан файл с открытой лицензией (LICENSE). При использовании обратитесь к конкретному описанию проекта и его зависимостям в коде.
Клонировать/Скачать
demo_wae.py 7 КБ
Копировать Редактировать Web IDE Исходные данные Просмотреть построчно История
Bin Dai (FA Talent) Отправлено 10.06.2019 10:17 46888a6
import argparse
import os
from network.wae_model import WAE
import numpy as np
import tensorflow as tf
import math
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from scipy.misc import imsave, imresize
import pickle
from fid_score import evaluate_fid_score
from dataset import load_dataset, load_test_dataset
def main():
# exp info
if args.philly:
exp_folder = args.output_path
model_path = args.output_path
else:
exp_folder = os.path.join(args.output_path, args.dataset, args.exp_name)
if not os.path.exists(exp_folder):
os.makedirs(exp_folder)
model_path = exp_folder
# dataset
x, side_length, channels = load_dataset(args.dataset, args.root_folder)
input_x = tf.placeholder(tf.uint8, [args.batch_size, side_length, side_length, channels], 'x')
num_sample = np.shape(x)[0]
print('Num Sample = {}.'.format(num_sample))
# model
model = WAE(input_x, args.latent_dim, args.loss_type, args.wae_lambda, args.cost)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
if not args.val:
writer = tf.summary.FileWriter(exp_folder, sess.graph)
# train model
iteration_per_epoch = math.floor(float(num_sample) / float(args.batch_size))
model.pretrain_encoder(x, args.lr, sess)
for epoch in range(args.epochs):
lr = args.lr if args.lr_epochs <= 0 else args.lr * math.pow(args.lr_fac, math.floor(float(epoch) / float(args.lr_epochs)))
loss, summary = model.train(x, lr, sess)
print('Date: {date}\t'
'Epoch: [0][{0}/{1}]\t'
'Loss: {2:.4f}.'.format(epoch, args.epochs, loss, date=time.strftime('%Y-%m-%d %H:%M:%S')))
writer.add_summary(summary, epoch)
saver.save(sess, os.path.join(model_path, 'model'))
else:
saver.restore(sess, os.path.join(model_path, 'model'))
x, side_length, channels = load_test_dataset(args.dataset, args.root_folder)
np.random.shuffle(x)
x = x[0:10064]
# reconstruction
img_recons = []
for i in range(int(math.ceil(float(10000) / float(args.batch_size)))):
img_recon = sess.run(model.x_hat, feed_dict={input_x: x[i*args.batch_size:(i+1)*args.batch_size], model.is_training: False})
img_recons.append(img_recon)
img_recons = np.concatenate(img_recons)[0:10000]
# generate images of first VAE
img_gens1 = []
for i in range(int(math.ceil(float(10000) / float(args.batch_size)))):
img_gen = sess.run(model.x_hat, feed_dict={model.z: np.random.normal(0.0, 1.0, model.z.get_shape().as_list()), model.is_training: False})
img_gens1.append(img_gen)
img_gens1 = np.concatenate(img_gens1, 0)[0:10000]
img_recons_sample = stich_imgs(img_recons)
img_gens1_sample = stich_imgs(img_gens1)
plt.imsave(os.path.join(exp_folder, 'recon_sample.jpg'), img_recons_sample)
plt.imsave(os.path.join(exp_folder, 'gen_sample.jpg'), img_gens1_sample)
tf.reset_default_graph()
fid_recon = evaluate_fid_score(img_recons.copy(), args.dataset, args.root_folder, True)
fid_gen = evaluate_fid_score(img_gens1.copy(), args.dataset, args.root_folder, True)
print('Reconstruction Results:')
print('FID = {:.4F}\n'.format(fid_recon))
print('Generation Results:')
print('FID = {:.4f}\n'.format(fid_gen))
fid = open(os.path.join(exp_folder, 'report_fid.txt'), 'wt')
fid.write('Reconstruction FID = {:.4f}\n'.format(fid_recon))
fid.write('Generation FID = {:.4f}\n'.format(fid_gen))
fid.close()
def train(optimizer, iteration_per_epoch, sess, writer, lr, stage, train_x, input_x):
np.random.shuffle(train_x)
total_loss = 0
for i in range(int(iteration_per_epoch)):
x_batch = train_x[i*args.batch_size:(i+1)*args.batch_size]
if stage == 0:
loss, _, summary = sess.run([optimizer.model.loss, optimizer.optimizer, optimizer.model.summary], feed_dict={optimizer.lr: lr, optimizer.model.is_training: True, input_x: x_batch})
else:
loss, _, summary = sess.run([optimizer.model.loss2, optimizer.optimizer2, optimizer.model.summary2], feed_dict={optimizer.lr: lr, optimizer.model.is_training: True, input_x: x_batch})
global_step = optimizer.global_step.eval(sess)
if writer is not None and global_step % args.write_iteration == 0:
writer.add_summary(summary, global_step)
total_loss += loss
total_loss /= iteration_per_epoch
return total_loss
def stich_imgs(x, img_per_row=10, img_per_col=10):
x_shape = np.shape(x)
assert(len(x_shape) == 4)
output = np.zeros([img_per_row*x_shape[1], img_per_col*x_shape[2], x_shape[3]])
idx = 0
for r in range(img_per_row):
start_row = r * x_shape[1]
end_row = start_row + x_shape[1]
for c in range(img_per_col):
start_col = c * x_shape[2]
end_col = start_col + x_shape[2]
output[start_row:end_row, start_col:end_col] = x[idx]
idx += 1
if idx == x_shape[0]:
break
if idx == x_shape[0]:
break
if np.shape(output)[-1] == 1:
output = np.reshape(output, np.shape(output)[0:2])
return output
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--root-folder', type=str, default='.')
parser.add_argument('--output-path', type=str, default='./experiments')
parser.add_argument('--exp-name', type=str, default='debug')
parser.add_argument('--dataset', type=str, default='mnist')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--weight-decay', type=float, default=0.0000)
parser.add_argument('--opt', type=str, default='adam')
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--write-iteration', type=int, default=600)
parser.add_argument('--latent-dim', type=int, default=64)
parser.add_argument('--epochs', type=int, default=400)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--lr-epochs', type=int, default=150)
parser.add_argument('--lr-fac', type=float, default=0.5)
parser.add_argument('--val', default=False, action='store_true')
parser.add_argument('--fix-gamma', default=False, action='store_true')
parser.add_argument('--init-loggamma', type=float, default=0.0)
parser.add_argument('--loss-type', type=str, default='mmd')
parser.add_argument('--wae-lambda', type=float, default=100)
parser.add_argument('--cost', type=str, default='l2')
parser.add_argument('--num-trial', type=int, default=10)
parser.add_argument('--use-optimal', default=False, action='store_true')
parser.add_argument('--optimal-ratio', type=float, default=1.0)
parser.add_argument('--philly', default=False, action='store_true')
args = parser.parse_args()
print(args)
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
main()

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://api.gitlife.ru/oschina-mirror/daib13-TwoStageVAE.git
git@api.gitlife.ru:oschina-mirror/daib13-TwoStageVAE.git
oschina-mirror
daib13-TwoStageVAE
daib13-TwoStageVAE
master