1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/soldatjiang-robot-grasp-detection

Клонировать/Скачать
imagenet_classifier.py 5.8 КБ
Копировать Редактировать Web IDE Исходные данные Просмотреть построчно История
Tomi Nikolla Отправлено 24.10.2017 18:02 ed1e6d9
#!/usr/local/bin/python
''' Training a network on Imagenet.
'''
import sys
import argparse
import os.path
import glob
import tensorflow as tf
import image_processing
from inference import inference
import time
TRAIN_FILE = '/root/imagenet-data/train-00001-of-01024'
VALIDATION_FILE = '/root/imagenet-data/validation-00004-of-00128'
def data_files(dataset):
tf_record_pattern = os.path.join(FLAGS.data_dir, '%s-*' % dataset)
data_files = tf.gfile.Glob(tf_record_pattern)
return data_files
def run_training():
#data_files_ = TRAIN_FILE
#data_files_ = VALIDATION_FILE
data_files_ = data_files(FLAGS.train_or_validation)
images, labels = image_processing.distorted_inputs(
data_files_, FLAGS.num_epochs, batch_size=FLAGS.batch_size)
labels = tf.one_hot(labels, 1000)
logits = inference(images)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=labels))
tf.summary.scalar('loss', loss)
correct_pred = tf.equal(tf.arg_max(logits,1), tf.argmax(labels,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
tf.summary.scalar('accuracy', accuracy)
merged_summary_op = tf.summary.merge_all()
train_op = tf.train.AdamOptimizer(epsilon=0.1).minimize(loss)
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess = tf.Session()
sess.run(init_op)
summary_writer = tf.summary.FileWriter(FLAGS.log_dir)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
#save/restore model
d={}
l = ['w1', 'b1', 'w2', 'b2', 'w3', 'b3', 'w4', 'b4', 'w5', 'b5', 'w_fc1', 'b_fc1', 'w_fc2', 'b_fc2', 'w_output', 'b_output']
for i in l:
d[i] = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == i+':0'][0]
saver = tf.train.Saver(d)
saver.restore(sess, FLAGS.model_path)
try:
step = 0
start_time = time.time()
while not coord.should_stop():
start_batch = time.time()
#train
_, loss_value, pred, acc = sess.run(
[train_op, loss, correct_pred, accuracy])
duration = time.time() - start_batch
if step % 100 == 0:
print('Step %d | loss = %.2f | accuracy = %.2f (%.3f sec/batch)')%(
step, loss_value, acc, duration)
if step % 5000 == 0:
saver.save(sess, FLAGS.model_path)
step +=1
except tf.errors.OutOfRangeError:
print('Done training for %d epochs, %d steps, %.1f min.' % (FLAGS.num_epochs, step, (time.time()-start_time)/60))
finally:
coord.request_stop()
coord.join(threads)
sess.close()
def evaluation():
#data_files_ = TRAIN_FILE
data_files_ = data_files(FLAGS.train_or_validation)
images, labels = image_processing.inputs(
data_files_, FLAGS.num_epochs, batch_size=FLAGS.batch_size)
labels = tf.one_hot(labels, 1000)
logits = inference(images)
correct_pred = tf.equal(tf.arg_max(logits,1), tf.argmax(labels,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess = tf.Session()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
#save/restore model
d={}
l = ['w1', 'b1', 'w2', 'b2', 'w3', 'b3', 'w4', 'b4', 'w5', 'b5', 'w_fc1', 'b_fc1', 'w_fc2', 'b_fc2', 'w_output', 'b_output']
for i in l:
d[i] = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == i+':0'][0]
saver = tf.train.Saver(d)
saver.restore(sess, FLAGS.model_path)
try:
step = 0
start_time = time.time()
while not coord.should_stop():
start_batch = time.time()
acc = sess.run(accuracy)
duration = time.time() - start_batch
print('Step %d | accuracy = %.2f (%.3f sec/batch)')%(
step, acc, duration)
step +=1
except tf.errors.OutOfRangeError:
print('Done evaluating for %d epochs, %d steps, %.1f min.' % (FLAGS.num_epochs, step, (time.time()-start_time)/60))
finally:
coord.request_stop()
coord.join(threads)
sess.close()
def main(_):
if FLAGS.train_or_validation == 'train':
print ' *** run training.'
print FLAGS.train_or_validation
run_training()
else:
print ' *** run validation.'
print FLAGS.train_or_validation
evaluation()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--learning_rate',
type=float,
default=0.001,
help='Initial learning rate.'
)
parser.add_argument(
'--data_dir',
type=str,
default='/root/imagenet-data',
help='Directory with training data.'
)
parser.add_argument(
'--num_epochs',
type=int,
default=None,
help='Number of epochs to run trainer.'
)
parser.add_argument(
'--batch_size',
type=int,
default=64,
help='Batch size.'
)
parser.add_argument(
'--log_dir',
type=str,
default='/tmp/tf',
help='Tensorboard log_dir.'
)
parser.add_argument(
'--model_path',
type=str,
default='/tmp/tf/model.ckpt',
help='Variables for the model.'
)
parser.add_argument(
'--train_or_validation',
type=str,
default='train',
help='Train or evaluate the dataset'
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://api.gitlife.ru/oschina-mirror/soldatjiang-robot-grasp-detection.git
git@api.gitlife.ru:oschina-mirror/soldatjiang-robot-grasp-detection.git
oschina-mirror
soldatjiang-robot-grasp-detection
soldatjiang-robot-grasp-detection
master