1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/spytensor-plants_disease_detection

В этом репозитории не указан файл с открытой лицензией (LICENSE). При использовании обратитесь к конкретному описанию проекта и его зависимостям в коде.
Клонировать/Скачать
utils.py 4.4 КБ
Копировать Редактировать Web IDE Исходные данные Просмотреть построчно История
spytensor Отправлено 27.10.2018 04:22 97bbed7
import shutil
import torch
import sys
import os
import json
import numpy as np
from config import config
from torch import nn
import torch.nn.functional as F
def save_checkpoint(state, is_best,fold):
filename = config.weights + config.model_name + os.sep +str(fold) + os.sep + "_checkpoint.pth.tar"
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, config.best_models + config.model_name+ os.sep +str(fold) + os.sep + 'model_best.pth.tar')
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def adjust_learning_rate(optimizer, epoch):
"""Sets the learning rate to the initial LR decayed by 10 every 3 epochs"""
lr = config.lr * (0.1 ** (epoch // 3))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def schedule(current_epoch, current_lrs, **logs):
lrs = [1e-3, 1e-4, 0.5e-4, 1e-5, 0.5e-5]
epochs = [0, 1, 6, 8, 12]
for lr, epoch in zip(lrs, epochs):
if current_epoch >= epoch:
current_lrs[5] = lr
if current_epoch >= 2:
current_lrs[4] = lr * 1
current_lrs[3] = lr * 1
current_lrs[2] = lr * 1
current_lrs[1] = lr * 1
current_lrs[0] = lr * 0.1
return current_lrs
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class Logger(object):
def __init__(self):
self.terminal = sys.stdout #stdout
self.file = None
def open(self, file, mode=None):
if mode is None: mode ='w'
self.file = open(file, mode)
def write(self, message, is_terminal=1, is_file=1 ):
if '\r' in message: is_file=0
if is_terminal == 1:
self.terminal.write(message)
self.terminal.flush()
#time.sleep(1)
if is_file == 1:
self.file.write(message)
self.file.flush()
def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
pass
def get_learning_rate(optimizer):
lr=[]
for param_group in optimizer.param_groups:
lr +=[ param_group['lr'] ]
#assert(len(lr)==1) #we support only one param_group
lr = lr[0]
return lr
def time_to_str(t, mode='min'):
if mode=='min':
t = int(t)/60
hr = t//60
min = t%60
return '%2d hr %02d min'%(hr,min)
elif mode=='sec':
t = int(t)
min = t//60
sec = t%60
return '%2d min %02d sec'%(min,sec)
else:
raise NotImplementedError
class FocalLoss(nn.Module):
def __init__(self, focusing_param=2, balance_param=0.25):
super(FocalLoss, self).__init__()
self.focusing_param = focusing_param
self.balance_param = balance_param
def forward(self, output, target):
cross_entropy = F.cross_entropy(output, target)
cross_entropy_log = torch.log(cross_entropy)
logpt = - F.cross_entropy(output, target)
pt = torch.exp(logpt)
focal_loss = -((1 - pt) ** self.focusing_param) * logpt
balanced_focal_loss = self.balance_param * focal_loss
return balanced_focal_loss
class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(MyEncoder, self).default(obj)

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://api.gitlife.ru/oschina-mirror/spytensor-plants_disease_detection.git
git@api.gitlife.ru:oschina-mirror/spytensor-plants_disease_detection.git
oschina-mirror
spytensor-plants_disease_detection
spytensor-plants_disease_detection
master