1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/angzhao-TextBrewerNer

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Клонировать/Скачать
main.train.dist.py 12 КБ
Копировать Редактировать Web IDE Исходные данные Просмотреть построчно История
zhaoang Отправлено 25.09.2020 09:19 6e58f66
import logging
import os,random
import numpy as np
import torch
from utils_ner import read_features, label2id_dict, Tokenize
from utils import divide_parameters
from transformers import ElectraConfig, AdamW, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup, BertTokenizer, get_constant_schedule
import config
from modeling import ElectraForTokenClassification, ElectraForTokenClassificationAdaptorTraining
from textbrewer import DistillationConfig, TrainingConfig,BasicTrainer
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from functools import partial
from train_eval import predict, ddp_predict
import time
def args_check(logger, args):
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
logger.warning("Output directory () already exists and is not empty.")
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
if not args.do_train and not args.do_predict and not args.do_dir_predict :
raise ValueError("At least one of `do_train` or `do_predict` must be True.")
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count() if not args.no_cuda else 0
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
n_gpu = 1
torch.distributed.init_process_group(backend='nccl')
logger.info("rank %d device %s n_gpu %d distributed training %r", torch.distributed.get_rank(), device, n_gpu, bool(args.local_rank != -1))
args.n_gpu = n_gpu
args.device = device
return device, n_gpu
def main():
#parse arguments
config.parse()
args = config.args
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S',
level=logging.INFO
)
logger = logging.getLogger("Train_eval")
#arguments check
device, n_gpu = args_check(logger, args)
if args.local_rank in [-1, 0]:
os.makedirs(args.output_dir, exist_ok=True)
if args.local_rank != -1:
logger.warning(f"Process rank: {torch.distributed.get_rank()}, device : {args.device}, n_gpu : {args.n_gpu}, distributed training : {bool(args.local_rank!=-1)}")
for k,v in vars(args).items():
logger.info(f"{k}:{v}")
#set seeds
torch.manual_seed(args.random_seed)
torch.cuda.manual_seed_all(args.random_seed)
np.random.seed(args.random_seed)
random.seed(args.random_seed)
forward_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
args.forward_batch_size = forward_batch_size
#load bert config
bert_config_S = ElectraConfig.from_json_file(args.bert_config_file_S)
bert_config_S.output_hidden_states = (args.output_encoded_layers=='true')
bert_config_S.num_labels = len(label2id_dict)
assert args.max_seq_length <= bert_config_S.max_position_embeddings
#read data
train_examples = None
train_dataset = None
eval_examples = None
eval_dataset = None
num_train_steps = None
#tokenizer = BertTokenizer(vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
tokenizer = Tokenize(dict_path=args.vocab_file)
if args.local_rank not in [-1, 0]:
torch.distributed.barrier()
if args.do_train:
train_examples,train_dataset = read_features(args.train_file, tokenizer=tokenizer, max_seq_length=args.max_seq_length)
logger.info(f'len train_examples {len(train_examples)}, len train_dataset {len(train_dataset)}')
if args.do_predict:
print(f'txt mode {args.txtmode}')
eval_examples, eval_dataset = read_features(args.predict_file,tokenizer=tokenizer, max_seq_length=args.max_seq_length, txtmode=args.txtmode)
logger.info(f'len eval_examples {len(eval_examples)}, len eval_dataset {len(eval_dataset)}')
if args.local_rank == 0:
torch.distributed.barrier()
#Build Model and load checkpoint
bert_config_S.usecrf = True
model_S = ElectraForTokenClassification(bert_config_S)
#Load student
if args.load_model_type=='bert':
assert args.init_checkpoint_S is not None
state_dict_S = torch.load(args.init_checkpoint_S, map_location='cpu')
#state_weight = {k[5:]:v for k,v in state_dict_S.items() if k.startswith('bert.')}
#missing_keys,_ = model_S.bert.load_state_dict(state_weight,strict=False)
missing_keys, unexpected_keys = model_S.load_state_dict(state_dict_S,strict=False)
logger.info(f"missing keys:{missing_keys}")
logger.info(f"unexpected keys:{unexpected_keys}")
elif args.load_model_type=='all':
assert args.tuned_checkpoint_S is not None
state_dict_S = torch.load(args.tuned_checkpoint_S,map_location='cpu')
model_S.load_state_dict(state_dict_S)
else:
logger.info("Model is randomly initialized.")
logger.info(f'device {device}')
model_S.to(device)
if args.do_train:
#parameters
if args.lr_decay is not None:
outputs_params = list(model_S.classifier.named_parameters()) + list(model_S.crf.named_parameters())
outputs_params = divide_parameters(outputs_params, lr = args.learning_rate)
electra_params = []
n_layers = len(model_S.electra.encoder.layer)
assert n_layers==12
for i,n in enumerate(reversed(range(n_layers))):
encoder_params = list(model_S.electra.encoder.layer[n].named_parameters())
lr = args.learning_rate * args.lr_decay**(i+1)
electra_params += divide_parameters(encoder_params, lr = lr)
logger.info(f"{i},{n},{lr}")
embed_params = [(name,value) for name,value in model_S.electra.named_parameters() if 'embedding' in name]
logger.info(f"{[name for name,value in embed_params]}")
lr = args.learning_rate * args.lr_decay**(n_layers+1)
electra_params += divide_parameters( embed_params, lr = lr)
logger.info(f"embed lr:{lr}")
all_trainable_params = outputs_params + electra_params
assert sum(map(lambda x:len(x['params']), all_trainable_params))==len(list(model_S.parameters())),\
(sum(map(lambda x:len(x['params']), all_trainable_params)), len(list(model_S.parameters())))
else:
params = list(model_S.named_parameters())
all_trainable_params = divide_parameters(params, lr=args.learning_rate)
logger.info("Length of all_trainable_params: %d", len(all_trainable_params))
if args.local_rank == -1:
train_sampler = RandomSampler(train_dataset)
else:
train_sampler = DistributedSampler(train_dataset)
logger.info(f"Length of train_sampler: {len(train_sampler)}")
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.forward_batch_size,drop_last=True)
num_train_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
logger.info(f'num_train_steps {num_train_steps}')
optimizer = AdamW(all_trainable_params, lr=args.learning_rate, correct_bias = False)
if args.official_schedule == 'const':
scheduler_class = get_constant_schedule_with_warmup
scheduler_args = {'num_warmup_steps':int(args.warmup_proportion*num_train_steps)}
#scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=int(args.warmup_proportion*num_train_steps))
elif args.official_schedule == 'linear':
scheduler_class = get_linear_schedule_with_warmup
scheduler_args = {'num_warmup_steps':int(args.warmup_proportion*num_train_steps), 'num_training_steps': num_train_steps}
#scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=int(args.warmup_proportion*num_train_steps), num_training_steps = num_train_steps)
elif args.official_schedule == 'const_nowarmup':
scheduler_class = get_constant_schedule
scheduler_args = {}
else:
raise NotImplementedError
logger.warning("***** Running training *****")
logger.warning("local_rank %d Num orig examples = %d", args.local_rank, len(train_examples))
logger.warning("local_rank %d Num split examples = %d", args.local_rank, len(train_dataset))
logger.warning("local_rank %d Forward batch size = %d", args.local_rank, forward_batch_size)
logger.warning("local_rank %d Num backward steps = %d", args.local_rank, num_train_steps)
########### TRAINING ###########
train_config = TrainingConfig(
gradient_accumulation_steps = args.gradient_accumulation_steps,
ckpt_frequency = args.ckpt_frequency,
log_dir = args.output_dir,
output_dir = args.output_dir,
device = args.device,
fp16=args.fp16,
local_rank = args.local_rank)
logger.info(f"{train_config}")
distiller = BasicTrainer(train_config = train_config,
model = model_S,
adaptor = ElectraForTokenClassificationAdaptorTraining)
# evluate the model in a single process in ddp_predict
callback_func = partial(ddp_predict,
filename = None,
eval_examples=eval_examples,
eval_dataset=eval_dataset,
args=args)
with distiller:
distiller.train(optimizer, scheduler_class=scheduler_class,
scheduler_args=scheduler_args,
max_grad_norm = 1.0,
dataloader=train_dataloader,
num_epochs=args.num_train_epochs, callback=callback_func)
if not args.do_train and args.do_predict:
res = ddp_predict(None, model_S, eval_examples, eval_dataset, step=0, args=args)
print (res)
if args.do_dir_predict:
#目录预测
assert args.txtmode == 1
new_words = {}
t0 = time.time()
files_fullpath = []
for t in os.walk(args.file[0]):
root, dirs, files = t
for f in files:
fullpath = os.path.join(root, f)
output_prediction_file = os.path.join(args.file[1], f)
if os.path.exists(output_prediction_file):
continue
files_fullpath.append(fullpath)
logger.info(f'left files {len(files_fullpath)}')
for i, predict_file in enumerate(files_fullpath[args.bpos:]):
pwd_file = predict_file[len(args.file[0]) + 1:]
output_prediction_file = os.path.join(args.file[1], pwd_file)
if os.path.exists(output_prediction_file):
continue
logger.info(f'predict file {predict_file}')
t00 = time.time()
eval_examples, eval_dataset = read_features(predict_file, tokenizer=tokenizer,
max_seq_length=args.max_seq_length, txtmode=args.txtmode, dump=False)
res = ddp_predict(output_prediction_file, model_S, eval_examples, eval_dataset, step=0, args=args, newwords=new_words)
t01 = time.time()
char_cnt = new_words.get('__chars__', 0)
logger.info(f'pred file {predict_file} {char_cnt} chars, costs {t01 - t00}s, avg {char_cnt / (t01 - t00)}')
t1 = time.time()
logger.info(f'all files cost {t1 -t0}s')
if __name__ == "__main__":
main()

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://api.gitlife.ru/oschina-mirror/angzhao-TextBrewerNer.git
git@api.gitlife.ru:oschina-mirror/angzhao-TextBrewerNer.git
oschina-mirror
angzhao-TextBrewerNer
angzhao-TextBrewerNer
master