Слияние кода завершено, страница обновится автоматически
from mobilenetv3 import mobilenetv3_large,mobilenetv3_small
import os
import torch
import torch.utils.data as Data
import torchvision.transforms as transforms
import torch.nn.functional as F
from PIL import Image
import numpy as np
from torch import nn, optim
import torchvision
from torchvision import models
DATA_PATH = '/home/phillip/datasets/incar_infrared_V2'
BATCH_SIZE = 32
#DATA_PATH = 'incar_infrared_V2_示例'
#BATCH_SIZE = 16
LR = 0.1
class xwzDataset(Data.Dataset):
def __init__(self, data_path, type):
txt_path = ''
if type == 'train':
txt_path = os.path.join(data_path, 'train','train_labels.txt')
dir_ = os.path.join(data_path, 'train','images')
if type == 'test':
txt_path = os.path.join(data_path, 'test','test_labels.txt')
dir_ = os.path.join(data_path, 'test','images')
fh = open(txt_path)
imgs = []
results = []
for line in fh:
if line is not None:
line = line.rstrip() # 去掉字符串的末尾字符
words = line.split() # 使用空格分隔
a = [words[1],words[2],words[3],words[4]]
a = [int(x) for x in a]
imgs.append(words[0])
results.append(a)
for i in range(len(imgs)):
imgs[i] = os.path.join(dir_, imgs[i]) #-> 为了方便调用,先构建每张图像的完整路径(这里用相对路径)
if type == 'train':
img_transform = transforms.Compose([
transforms.Resize((480,480)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
if type == 'test':
img_transform = transforms.Compose([
transforms.Resize((480,480)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.img_transform = img_transform
self.imgs = imgs
self.results = results
def __getitem__(self,index):
result = self.results[index]
result = torch.FloatTensor(result)
img_transform = self.img_transform
img_path = self.imgs[index]
img = Image.open(img_path).convert('L')
b_array = np.asarray(img)
rgb_array = np.zeros((b_array.shape[0], b_array.shape[1], 3), "uint8")
rgb_array[:, :, 0], rgb_array[:, :, 1], rgb_array[:, :, 2] = b_array, b_array, b_array
rgb_image = Image.fromarray(rgb_array)
img = img_transform(rgb_image)
return img, result
def __len__(self):
return len(self.imgs)
train_data = xwzDataset(DATA_PATH,'train')
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = xwzDataset(DATA_PATH,'test')
test_loader = Data.DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True)
pretrained_net = mobilenetv3_small()
pretrained_net_dict = pretrained_net.state_dict()
pretrained_small_dict = torch.load('mobilenetv3-small-c7eb32fe.pth')
pretrained_small_dict = {k: v for k, v in pretrained_small_dict.items() if k in pretrained_net_dict}
pretrained_net_dict.update(pretrained_small_dict)
pretrained_net.load_state_dict(pretrained_net_dict)
class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)
def forward(self, x):
return self.relu(x + 3) / 6
class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
self.sigmoid = h_sigmoid(inplace=inplace)
def forward(self, x):
return x * self.sigmoid(x)
pretrained_net.classifier = nn.Sequential(
nn.Linear(576, 1280),
nn.BatchNorm1d(1280),
h_swish(),
nn.Linear(1280, 4),
nn.BatchNorm1d(4),
h_swish()
)
output_params = list(map(id, pretrained_net.classifier.parameters()))
feature_params = filter(lambda p: id(p) not in output_params, pretrained_net.parameters())
'''
optimizer = optim.SGD([{'params': feature_params},
{'params': pretrained_net.classifier.parameters(), 'lr': LR * 10}],
lr=LR, weight_decay=0.001)
'''
optimizer = optim.SGD(pretrained_net.classifier.parameters(),lr=LR)
activate_func = nn.Sigmoid()
loss_func = nn.MSELoss()
activate_func = activate_func.cuda()
loss_func = loss_func.cuda()
pretrained_net = pretrained_net.cuda()
iteration_num = 0
loss_last,accuracy1_last,accuracy2_last = 0,0,0
for epoch in range(200):
for step, (b_x, b_y) in enumerate(train_loader):
b_x,b_y = b_x.cuda(),b_y.cuda()
output = pretrained_net(b_x)
output = activate_func(output)
loss = loss_func(output, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
correct_train,correct_0_train,correct_1_train,correct_2_train,correct_3_train = 0,0,0,0,0
pred_all_0_train,pred_all_1_train,pred_all_2_train,pred_all_3_train = 0,0,0,0
real_all_0_train,real_all_1_train,real_all_2_train,real_all_3_train = 0,0,0,0
TP_0_train,TP_1_train,TP_2_train,TP_3_train = 0,0,0,0
loss_all = 0
for step, (b_x1, b_y1) in enumerate(train_loader):
b_x1 = b_x1.cuda()
with torch.no_grad():
test_output1 = pretrained_net(b_x1)
test_output1 = activate_func(test_output1)
b_x1 = b_x1.cpu()
b_y1 = b_y1.cuda()
with torch.no_grad():
loss_test = loss_func(test_output1,b_y1)
b_y1 = b_y1.cpu()
loss_test = loss_test.cpu().detach().numpy()
loss_all += float(loss_test)
test_output1 = test_output1.cpu()
batch_length = b_y1.size(0) #用于解决最后一个batch长度小于batch_size的问题
for i in range(batch_length):
for n in range(4):
if test_output1[i][n] > 0.5:
test_output1[i][n] = 1
elif test_output1[i][n] <= 0.5:
test_output1[i][n] = 0
pred_y1 = test_output1.data.numpy()
b_y1 = b_y1.data.numpy()
correct,correct_0,correct_1,correct_2,correct_3 = 0,0,0,0,0
pred_all_0,pred_all_1,pred_all_2,pred_all_3 = 0,0,0,0
real_all_0,real_all_1,real_all_2,real_all_3 = 0,0,0,0
TP_0,TP_1,TP_2,TP_3 = 0,0,0,0
for i in range(batch_length):
if pred_y1[i][0] == b_y1[i][0]:
correct_0 += 1
if pred_y1[i][1] == b_y1[i][1]:
correct_1 += 1
if pred_y1[i][2] == b_y1[i][2]:
correct_2 += 1
if pred_y1[i][3] == b_y1[i][3]:
correct_3 += 1
if (pred_y1[i][0] == b_y1[i][0]) and (pred_y1[i][0] == 1):
TP_0 += 1
if (pred_y1[i][1] == b_y1[i][1]) and (pred_y1[i][1] == 1):
TP_1 += 1
if (pred_y1[i][2] == b_y1[i][2]) and (pred_y1[i][2] == 1):
TP_2 += 1
if (pred_y1[i][3] == b_y1[i][3]) and (pred_y1[i][3] == 1):
TP_3 += 1
if pred_y1[i][0] == 1:
pred_all_0 += 1
if pred_y1[i][1] == 1:
pred_all_1 += 1
if pred_y1[i][2] == 1:
pred_all_2 += 1
if pred_y1[i][3] == 1:
pred_all_3 += 1
if b_y1[i][0] == 1:
real_all_0 += 1
if b_y1[i][1] == 1:
real_all_1 += 1
if b_y1[i][2] == 1:
real_all_2 += 1
if b_y1[i][3] == 1:
real_all_3 += 1
correct = (correct_0+correct_1+correct_2+correct_3)/4
correct_train += correct
correct_0_train += correct_0
correct_1_train += correct_1
correct_2_train += correct_2
correct_3_train += correct_3
TP_0_train += TP_0
TP_1_train += TP_1
TP_2_train += TP_2
TP_3_train += TP_3
pred_all_0_train += pred_all_0
pred_all_1_train += pred_all_1
pred_all_2_train += pred_all_2
pred_all_3_train += pred_all_3
real_all_0_train += real_all_0
real_all_1_train += real_all_1
real_all_2_train += real_all_2
real_all_3_train += real_all_3
loss = loss_all/(len(train_loader))
print('epoch:',iteration_num,'| loss: %.5f' % loss)
accuracy1 = correct_train/(len(train_data))
print('| train average accuracy: %.2f' % accuracy1)
accuracy_0_train = correct_0_train/(len(train_data))
accuracy_1_train = correct_1_train/(len(train_data))
accuracy_2_train = correct_2_train/(len(train_data))
accuracy_3_train = correct_3_train/(len(train_data))
print('| train 未系安全带 accuracy: %.2f' % accuracy_0_train)
print('| train 抽烟 accuracy: %.2f' % accuracy_1_train)
print('| train 打电话 accuracy: %.2f' % accuracy_2_train)
print('| train 打哈欠 accuracy: %.2f' % accuracy_3_train)
#precision and recall
precision_0_train,precision_1_train,precision_2_train,precision_3_train = 0,0,0,0
recall_0_train,recall_1_train,recall_2_train,recall_3_train = 0,0,0,0
if pred_all_0_train != 0:
precision_0_train = TP_0_train/pred_all_0_train
if pred_all_1_train != 0:
precision_1_train = TP_1_train/pred_all_1_train
if pred_all_2_train != 0:
precision_2_train = TP_2_train/pred_all_2_train
if pred_all_3_train != 0:
precision_3_train = TP_3_train/pred_all_3_train
if real_all_0_train != 0:
recall_0_train = TP_0_train/real_all_0_train
if real_all_1_train != 0:
recall_1_train = TP_1_train/real_all_1_train
if real_all_2_train != 0:
recall_2_train = TP_2_train/real_all_2_train
if real_all_3_train != 0:
recall_3_train = TP_3_train/real_all_3_train
print('| train 未系安全带 precision: %.2f' % precision_0_train)
print('| train 抽烟 precision: %.2f' % precision_1_train)
print('| train 打电话 precision: %.2f' % precision_2_train)
print('| train 打哈欠 precision: %.2f' % precision_3_train)
print('| train 未系安全带 recall: %.2f' % recall_0_train)
print('| train 抽烟 recall: %.2f' % recall_1_train)
print('| train 打电话 recall: %.2f' % recall_2_train)
print('| train 打哈欠 recall: %.2f' % recall_3_train)
print('| ')
correct_test,correct_0_test,correct_1_test,correct_2_test,correct_3_test = 0,0,0,0,0
pred_all_0_test,pred_all_1_test,pred_all_2_test,pred_all_3_test = 0,0,0,0
real_all_0_test,real_all_1_test,real_all_2_test,real_all_3_test = 0,0,0,0
TP_0_test,TP_1_test,TP_2_test,TP_3_test = 0,0,0,0
for step, (b_x2, b_y2) in enumerate(test_loader):
b_x2 = b_x2.cuda()
with torch.no_grad():
test_output2 = pretrained_net(b_x2)
test_output2 = activate_func(test_output2)
b_x2 = b_x2.cpu()
test_output2 = test_output2.cpu()
batch_length = b_y2.size(0) #用于解决最后一个batch长度小于batch_size的问题
for i in range(batch_length):
for n in range(4):
if test_output2[i][n] > 0.5:
test_output2[i][n] = 1
elif test_output2[i][n] <= 0.5:
test_output2[i][n] = 0
pred_y2 = test_output2.data.numpy()
b_y2 = b_y2.data.numpy()
correct,correct_0,correct_1,correct_2,correct_3 = 0,0,0,0,0
pred_all_0,pred_all_1,pred_all_2,pred_all_3 = 0,0,0,0
real_all_0,real_all_1,real_all_2,real_all_3 = 0,0,0,0
TP_0,TP_1,TP_2,TP_3 = 0,0,0,0
for i in range(batch_length):
if pred_y2[i][0] == b_y2[i][0]:
correct_0 += 1
if pred_y2[i][1] == b_y2[i][1]:
correct_1 += 1
if pred_y2[i][2] == b_y2[i][2]:
correct_2 += 1
if pred_y2[i][3] == b_y2[i][3]:
correct_3 += 1
if (pred_y2[i][0] == b_y2[i][0]) and (pred_y2[i][0] == 1):
TP_0 += 1
if (pred_y2[i][1] == b_y2[i][1]) and (pred_y2[i][1] == 1):
TP_1 += 1
if (pred_y2[i][2] == b_y2[i][2]) and (pred_y2[i][2] == 1):
TP_2 += 1
if (pred_y2[i][3] == b_y2[i][3]) and (pred_y2[i][3] == 1):
TP_3 += 1
if pred_y2[i][0] == 1:
pred_all_0 += 1
if pred_y2[i][1] == 1:
pred_all_1 += 1
if pred_y2[i][2] == 1:
pred_all_2 += 1
if pred_y2[i][3] == 1:
pred_all_3 += 1
if b_y2[i][0] == 1:
real_all_0 += 1
if b_y2[i][1] == 1:
real_all_1 += 1
if b_y2[i][2] == 1:
real_all_2 += 1
if b_y2[i][3] == 1:
real_all_3 += 1
correct = (correct_0+correct_1+correct_2+correct_3)/4
correct_test += correct
correct_0_test += correct_0
correct_1_test += correct_1
correct_2_test += correct_2
correct_3_test += correct_3
TP_0_test += TP_0
TP_1_test += TP_1
TP_2_test += TP_2
TP_3_test += TP_3
pred_all_0_test += pred_all_0
pred_all_1_test += pred_all_1
pred_all_2_test += pred_all_2
pred_all_3_test += pred_all_3
real_all_0_test += real_all_0
real_all_1_test += real_all_1
real_all_2_test += real_all_2
real_all_3_test += real_all_3
accuracy2 = correct_test/(len(test_data))
print('| test average accuracy: %.2f' % accuracy2)
accuracy_0_test = correct_0_test/(len(test_data))
accuracy_1_test = correct_1_test/(len(test_data))
accuracy_2_test = correct_2_test/(len(test_data))
accuracy_3_test = correct_3_test/(len(test_data))
print('| test 未系安全带 accuracy: %.2f' % accuracy_0_test)
print('| test 抽烟 accuracy: %.2f' % accuracy_1_test)
print('| test 打电话 accuracy: %.2f' % accuracy_2_test)
print('| test 打哈欠 accuracy: %.2f' % accuracy_3_test)
#precision and recall
precision_0_test,precision_1_test,precision_2_test,precision_3_test = 0,0,0,0
recall_0_test,recall_1_test,recall_2_test,recall_3_test = 0,0,0,0
if pred_all_0_test != 0:
precision_0_test = TP_0_test/pred_all_0_test
if pred_all_1_test != 0:
precision_1_test = TP_1_test/pred_all_1_test
if pred_all_2_test != 0:
precision_2_test = TP_2_test/pred_all_2_test
if pred_all_3_test != 0:
precision_3_test = TP_3_test/pred_all_3_test
if real_all_0_test != 0:
recall_0_test = TP_0_test/real_all_0_test
if real_all_1_test != 0:
recall_1_test = TP_1_test/real_all_1_test
if real_all_2_test != 0:
recall_2_test = TP_2_test/real_all_2_test
if real_all_3_test != 0:
recall_3_test = TP_3_test/real_all_3_test
print('| test 未系安全带 precision: %.2f' % precision_0_test)
print('| test 抽烟 precision: %.2f' % precision_1_test)
print('| test 打电话 precision: %.2f' % precision_2_test)
print('| test 打哈欠 precision: %.2f' % precision_3_test)
print('| test 未系安全带 recall: %.2f' % recall_0_test)
print('| test 抽烟 recall: %.2f' % recall_1_test)
print('| test 打电话 recall: %.2f' % recall_2_test)
print('| test 打哈欠 recall: %.2f' % recall_3_test)
print(' ')
if loss < loss_last:
torch.save(pretrained_net.state_dict(),'paras/mobilenetv3_lowest_loss.pkl')
if accuracy1 > accuracy1_last:
torch.save(pretrained_net.state_dict(),'paras/mobilenetv3_most_accurate_train.pkl')
if accuracy2 > accuracy2_last:
torch.save(pretrained_net.state_dict(),'paras/mobilenetv3_most_accurate_test.pkl')
'''
if loss < loss_last:
torch.save(pretrained_net.state_dict(),'git/paras/mobilenetv3_lowest_loss.pkl')
if accuracy1 > accuracy1_last:
torch.save(pretrained_net.state_dict(),'git/paras/mobilenetv3_most_accurate_train.pkl')
if accuracy2 > accuracy2_last:
torch.save(pretrained_net.state_dict(),'git/paras/mobilenetv3_most_accurate_test.pkl')
'''
loss_last = loss
accuracy1_last = accuracy1
accuracy2_last = accuracy2
iteration_num += 1
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )