Слияние кода завершено, страница обновится автоматически
import argparse
import os
from network.wae_model import WAE
import numpy as np
import tensorflow as tf
import math
import time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from scipy.misc import imsave, imresize
import pickle
from fid_score import evaluate_fid_score
from dataset import load_dataset, load_test_dataset
def main():
# exp info
if args.philly:
exp_folder = args.output_path
model_path = args.output_path
else:
exp_folder = os.path.join(args.output_path, args.dataset, args.exp_name)
if not os.path.exists(exp_folder):
os.makedirs(exp_folder)
model_path = exp_folder
# dataset
x, side_length, channels = load_dataset(args.dataset, args.root_folder)
input_x = tf.placeholder(tf.uint8, [args.batch_size, side_length, side_length, channels], 'x')
num_sample = np.shape(x)[0]
print('Num Sample = {}.'.format(num_sample))
# model
model = WAE(input_x, args.latent_dim, args.loss_type, args.wae_lambda, args.cost)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
if not args.val:
writer = tf.summary.FileWriter(exp_folder, sess.graph)
# train model
iteration_per_epoch = math.floor(float(num_sample) / float(args.batch_size))
model.pretrain_encoder(x, args.lr, sess)
for epoch in range(args.epochs):
lr = args.lr if args.lr_epochs <= 0 else args.lr * math.pow(args.lr_fac, math.floor(float(epoch) / float(args.lr_epochs)))
loss, summary = model.train(x, lr, sess)
print('Date: {date}\t'
'Epoch: [0][{0}/{1}]\t'
'Loss: {2:.4f}.'.format(epoch, args.epochs, loss, date=time.strftime('%Y-%m-%d %H:%M:%S')))
writer.add_summary(summary, epoch)
saver.save(sess, os.path.join(model_path, 'model'))
else:
saver.restore(sess, os.path.join(model_path, 'model'))
x, side_length, channels = load_test_dataset(args.dataset, args.root_folder)
np.random.shuffle(x)
x = x[0:10064]
# reconstruction
img_recons = []
for i in range(int(math.ceil(float(10000) / float(args.batch_size)))):
img_recon = sess.run(model.x_hat, feed_dict={input_x: x[i*args.batch_size:(i+1)*args.batch_size], model.is_training: False})
img_recons.append(img_recon)
img_recons = np.concatenate(img_recons)[0:10000]
# generate images of first VAE
img_gens1 = []
for i in range(int(math.ceil(float(10000) / float(args.batch_size)))):
img_gen = sess.run(model.x_hat, feed_dict={model.z: np.random.normal(0.0, 1.0, model.z.get_shape().as_list()), model.is_training: False})
img_gens1.append(img_gen)
img_gens1 = np.concatenate(img_gens1, 0)[0:10000]
img_recons_sample = stich_imgs(img_recons)
img_gens1_sample = stich_imgs(img_gens1)
plt.imsave(os.path.join(exp_folder, 'recon_sample.jpg'), img_recons_sample)
plt.imsave(os.path.join(exp_folder, 'gen_sample.jpg'), img_gens1_sample)
tf.reset_default_graph()
fid_recon = evaluate_fid_score(img_recons.copy(), args.dataset, args.root_folder, True)
fid_gen = evaluate_fid_score(img_gens1.copy(), args.dataset, args.root_folder, True)
print('Reconstruction Results:')
print('FID = {:.4F}\n'.format(fid_recon))
print('Generation Results:')
print('FID = {:.4f}\n'.format(fid_gen))
fid = open(os.path.join(exp_folder, 'report_fid.txt'), 'wt')
fid.write('Reconstruction FID = {:.4f}\n'.format(fid_recon))
fid.write('Generation FID = {:.4f}\n'.format(fid_gen))
fid.close()
def train(optimizer, iteration_per_epoch, sess, writer, lr, stage, train_x, input_x):
np.random.shuffle(train_x)
total_loss = 0
for i in range(int(iteration_per_epoch)):
x_batch = train_x[i*args.batch_size:(i+1)*args.batch_size]
if stage == 0:
loss, _, summary = sess.run([optimizer.model.loss, optimizer.optimizer, optimizer.model.summary], feed_dict={optimizer.lr: lr, optimizer.model.is_training: True, input_x: x_batch})
else:
loss, _, summary = sess.run([optimizer.model.loss2, optimizer.optimizer2, optimizer.model.summary2], feed_dict={optimizer.lr: lr, optimizer.model.is_training: True, input_x: x_batch})
global_step = optimizer.global_step.eval(sess)
if writer is not None and global_step % args.write_iteration == 0:
writer.add_summary(summary, global_step)
total_loss += loss
total_loss /= iteration_per_epoch
return total_loss
def stich_imgs(x, img_per_row=10, img_per_col=10):
x_shape = np.shape(x)
assert(len(x_shape) == 4)
output = np.zeros([img_per_row*x_shape[1], img_per_col*x_shape[2], x_shape[3]])
idx = 0
for r in range(img_per_row):
start_row = r * x_shape[1]
end_row = start_row + x_shape[1]
for c in range(img_per_col):
start_col = c * x_shape[2]
end_col = start_col + x_shape[2]
output[start_row:end_row, start_col:end_col] = x[idx]
idx += 1
if idx == x_shape[0]:
break
if idx == x_shape[0]:
break
if np.shape(output)[-1] == 1:
output = np.reshape(output, np.shape(output)[0:2])
return output
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--root-folder', type=str, default='.')
parser.add_argument('--output-path', type=str, default='./experiments')
parser.add_argument('--exp-name', type=str, default='debug')
parser.add_argument('--dataset', type=str, default='mnist')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--weight-decay', type=float, default=0.0000)
parser.add_argument('--opt', type=str, default='adam')
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--write-iteration', type=int, default=600)
parser.add_argument('--latent-dim', type=int, default=64)
parser.add_argument('--epochs', type=int, default=400)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--lr-epochs', type=int, default=150)
parser.add_argument('--lr-fac', type=float, default=0.5)
parser.add_argument('--val', default=False, action='store_true')
parser.add_argument('--fix-gamma', default=False, action='store_true')
parser.add_argument('--init-loggamma', type=float, default=0.0)
parser.add_argument('--loss-type', type=str, default='mmd')
parser.add_argument('--wae-lambda', type=float, default=100)
parser.add_argument('--cost', type=str, default='l2')
parser.add_argument('--num-trial', type=int, default=10)
parser.add_argument('--use-optimal', default=False, action='store_true')
parser.add_argument('--optimal-ratio', type=float, default=1.0)
parser.add_argument('--philly', default=False, action='store_true')
args = parser.parse_args()
print(args)
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
main()
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )