Слияние кода завершено, страница обновится автоматически
import logging
from functools import partial
import tensorflow as tf
from inputs import gpt2_pred_input
from models.gpt2 import encoder
# Takes in the user supplied text and generates output texts. Outputs to log/console and a file
def gpt2_predict(network, text, params):
logger = logging.getLogger('tensorflow')
enc = encoder.get_encoder(params["encoder_path"])
predictions = network.predict(input_fn=partial(gpt2_pred_input, text=text))
with tf.gfile.Open(params["predict_path"], "a") as f:
for i, p in enumerate(predictions):
p = p["tokens"]
text = enc.decode(p)
f.write("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
f.write(text)
f.write("\n" + "=" * 80 + "\n")
logger.info("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n")
logger.info(text)
logger.info("\n" + "=" * 80 + "\n")
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )