Слияние кода завершено, страница обновится автоматически
import shutil
import torch
import sys
import os
import json
import numpy as np
from config import config
from torch import nn
import torch.nn.functional as F
def save_checkpoint(state, is_best,fold):
filename = config.weights + config.model_name + os.sep +str(fold) + os.sep + "_checkpoint.pth.tar"
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, config.best_models + config.model_name+ os.sep +str(fold) + os.sep + 'model_best.pth.tar')
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 3 epochs"""
lr = config.lr * (0.1 ** (epoch // 3))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def schedule(current_epoch, current_lrs, **logs):
lrs = [1e-3, 1e-4, 0.5e-4, 1e-5, 0.5e-5]
epochs = [0, 1, 6, 8, 12]
for lr, epoch in zip(lrs, epochs):
if current_epoch >= epoch:
current_lrs[5] = lr
if current_epoch >= 2:
current_lrs[4] = lr * 1
current_lrs[3] = lr * 1
current_lrs[2] = lr * 1
current_lrs[1] = lr * 1
current_lrs[0] = lr * 0.1
return current_lrs
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class Logger(object):
def __init__(self):
self.terminal = sys.stdout #stdout
self.file = None
def open(self, file, mode=None):
if mode is None: mode ='w'
self.file = open(file, mode)
def write(self, message, is_terminal=1, is_file=1 ):
if '\r' in message: is_file=0
if is_terminal == 1:
self.terminal.write(message)
self.terminal.flush()
#time.sleep(1)
if is_file == 1:
self.file.write(message)
self.file.flush()
def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
pass
def get_learning_rate(optimizer):
lr=[]
for param_group in optimizer.param_groups:
lr +=[ param_group['lr'] ]
#assert(len(lr)==1) #we support only one param_group
lr = lr[0]
return lr
def time_to_str(t, mode='min'):
if mode=='min':
t = int(t)/60
hr = t//60
min = t%60
return '%2d hr %02d min'%(hr,min)
elif mode=='sec':
t = int(t)
min = t//60
sec = t%60
return '%2d min %02d sec'%(min,sec)
else:
raise NotImplementedError
class FocalLoss(nn.Module):
def __init__(self, focusing_param=2, balance_param=0.25):
super(FocalLoss, self).__init__()
self.focusing_param = focusing_param
self.balance_param = balance_param
def forward(self, output, target):
cross_entropy = F.cross_entropy(output, target)
cross_entropy_log = torch.log(cross_entropy)
logpt = - F.cross_entropy(output, target)
pt = torch.exp(logpt)
focal_loss = -((1 - pt) ** self.focusing_param) * logpt
balanced_focal_loss = self.balance_param * focal_loss
return balanced_focal_loss
class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(MyEncoder, self).default(obj)
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )