1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/zhaohuxing-bert_seq2seq

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
В этом репозитории не указан файл с открытой лицензией (LICENSE). При использовании обратитесь к конкретному описанию проекта и его зависимостям в коде.
Клонировать/Скачать
predict_math.py 6.9 КБ
Копировать Редактировать Web IDE Исходные данные Просмотреть построчно История
xingzhaohu Отправлено 31.10.2020 06:45 95562ae
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import csv
from bert_seq2seq.utils import load_bert
from bert_seq2seq.tokenizer import Tokenizer, load_chinese_base_vocab
from sympy import Integer
import re
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device is " + str(device))
word2idx = load_chinese_base_vocab("./state_dict/roberta_wwm_vocab.txt")
model_path = "./state_dict/bert_math_ques_model.bin"
model_name = "roberta"
if __name__ == "__main__":
data = pd.read_csv("./test.csv", header=None)
# save_data = []
bert_model = load_bert(word2idx, model_name=model_name)
bert_model.load_state_dict(torch.load(model_path, map_location=device))
tokenizer = Tokenizer(word2idx)
bert_model.to(device)
bert_model.eval()
err_num = 0
with open("submit.csv", "w") as f_err :
writer_err = csv.writer(f_err)
# for i, question in tqdm(raw_data.values):
for i, row in data.iterrows():
print(i)
question = re.sub('(\d+)_(\d+/\d+)', '(\\1+\\2)', row[1])
pred_equation = bert_model.generate(question, beam_size=3, device=device)
pred_equation = pred_equation.replace(" ", "")
if '.' not in pred_equation:
pred_equation = re.sub('([\d]+)', 'Integer(\\1)', pred_equation)
try:
pred_answer = eval(pred_equation)
except:
pred_answer = np.random.choice(21) + 1
if '.' in pred_equation:
if u'百分之几' in question:
pred_answer = pred_answer * 100
pred_answer = round(pred_answer, 2)
if int(pred_answer) == pred_answer:
pred_answer = int(pred_answer)
if (
re.findall(u'多少[辆|人|个|只|箱|包本|束|头|盒|张]', question) or
re.findall(u'几[辆|人|个|只|箱|包|本|束|头|盒|张]', question)
):
if re.findall(u'至少|最少', question):
pred_answer = np.ceil(pred_answer)
elif re.findall(u'至多|最多', question):
pred_answer = np.floor(pred_answer)
else:
pred_answer = np.ceil(pred_answer)
pred_answer = int(pred_answer)
pred_answer = str(pred_answer)
if u'百分之几' in question:
pred_answer = pred_answer + '%'
else:
pred_answer = str(pred_answer)
if '/' in pred_answer:
# if re.findall('\d+/\d+', question):
if u"几分之几" in question:
pass
# a, b = pred_answer.split('/')
# print("a is " + str(a))
# print("b is " + str(b))
# a, b = int(a), int(b)
# if a > b:
# pred_answer = '%s_%s/%s' % (a // b, a % b, b)
elif "百分之几" in question or u"出米率" in question or "出糖率" in question or "利润率" in question or "出粉率" in question \
or "出勤率" in question or "缺勤率" in question or "过标率" in question or "错误率" in question \
or ("成活率" in question and "多少棵" not in question) or ("合格率" in question and "多少个" not in question) or "出席率" in question \
or "发芽率" in question or "近视率" in question or "含盐率" in question or "命中率" in question :
a, b = pred_answer.split('/')
a, b = int(a), int(b)
pred_answer = round(a / b, 5)
pred_answer = str(round(pred_answer * 100)) + "%"
elif re.findall('\d+/\d+', question) or ":" in question :
a, b = pred_answer.split('/')
a, b = int(a), int(b)
pred_answer = round(a / b, 5)
else:
if re.findall(u'至少|最少', question):
pred_answer = np.ceil(eval(pred_answer))
elif re.findall(u'至多|最多', question):
pred_answer = np.floor(eval(pred_answer))
else:
pred_answer = np.ceil(eval(pred_answer))
pred_answer = str(int(pred_answer))
writer_err.writerow([row[0], pred_answer])
# print("equation is " + str(pred_equation) + "pred out is " + str(pred_answer) + "true res is " + str(row[2]) + "question is " + str(row[1]))
# if str(pred_answer) != row[2]:
# # 说明答案错了
# err_num += 1
# print("错误个数为:" + str(err_num))
# writer_err.writerow([row[0], pred_answer, row[2], row[1], pred_equation, "0"])
# else :
# writer_err.writerow([row[0], pred_answer, row[2], row[1], pred_equation, "1"])
# for i, row in data.iterrows():
# print(i)
# out = bert_model.generate(row[1], beam_size=3, device=device)
# out = out.replace(" ", "")
# try:
# if "几分之几" in row[1] :
# new_equation = re.sub("(\d+)", "Integer(\\1)", out)
# out_v = eval(new_equation)
# elif "百分之几" in row[1]:
# out_v = float(eval(out)) * 100
# out_v = str(out_v) + "%"
# elif "得数保留整数" in row[1] or "几条船" in row[1]:
# out_v = round(eval(out))
# out_v = abs(out_v)# 避免出现负数
# else :
# out_v = float(eval(out))
# out_v = abs(out_v)# 避免出现负数
# if abs(out_v - round(out_v, 5)) < 0.0001:
# out_v = round(out_v, 5)
# if str(out_v)[-1] == "0":
# out_v = str(out_v)[:-2]
# print("equation is " + str(out) + "pred out is " + str(out_v) + "true res is " + str(row[2]) + "question is " + str(row[1]))
# except Exception as e :
# print(e)
# print("解析错误out为 " + str(out) + "true res is " + str(row[2]) + "question is " + str(row[1]))
# out_v = -10000
# if str(out_v) != row[2]:
# # 说明答案错了
# err_num += 1
# print("错误个数为:" + str(err_num))
# writer_err.writerow([row[0], out_v, row[2], row[1], out, "0"])
# else :
# writer_err.writerow([row[0], out_v, row[2], row[1], out, "1"])

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://api.gitlife.ru/oschina-mirror/zhaohuxing-bert_seq2seq.git
git@api.gitlife.ru:oschina-mirror/zhaohuxing-bert_seq2seq.git
oschina-mirror
zhaohuxing-bert_seq2seq
zhaohuxing-bert_seq2seq
master