Слияние кода завершено, страница обновится автоматически
from __future__ import print_function, division
import _thread
import os
import cv2
import numpy as np
# from tensorflow.python.keras import Sequential, Input, Model
# from tensorflow.python.keras.layers import Conv2D, LeakyReLU, BatchNormalization, Dropout, Flatten, Dense
# from tensorflow.python.keras.optimizer_v2.adam import Adam
from keras import Input
from keras.engine.saving import load_model
from keras.layers import Dropout, BatchNormalization, LeakyReLU, Dense, Flatten
from keras.layers.convolutional import Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from util.dataloader import dataloader
import matplotlib.pyplot as plt
class CNN:
def __init__(self, using_model=True):
# 数据集图片大小信息
self.img_rows = 100
self.img_cols = 100
self.channels = 1 # 图像通道数,彩色图为3,黑白图为1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.epochs = 2000
self.batch_size = 4
self.save_interval = 10
# 优化器
self.optimizer = Adam(0.0002, 0.5)
self.load_all = True # 是否一次性将全部内容加载入内存进行训练
self.dataloader = dataloader(self.img_shape, load_all=self.load_all)
if using_model:
if os.path.exists("./model/model.h5"):
self.model = load_model("./model/model.h5")
else:
print("No model present there!")
self.model = self.build_model()
else:
self.model = self.build_model()
self.model.summary()
if self.dataloader.multi_category:
self.model.compile(loss='categorical_crossentropy',
optimizer=self.optimizer,
metrics=['accuracy'])
else:
self.model.compile(loss='binary_crossentropy',
optimizer=self.optimizer,
metrics=['accuracy'])
def build_model(self):
'''
构筑卷积神经网络
:return: CNN
'''
cnum = 16
model = Sequential()
model.add(Conv2D(cnum, kernel_size=4, strides=2, input_shape=self.img_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2D(cnum * 2, kernel_size=4, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dropout(0.25))
model.add(Conv2D(cnum * 4, kernel_size=4, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(1, kernel_size=4, strides=1, padding="same"))
model.add(Dropout(0.25))
model.add(Flatten())
if self.dataloader.multi_category:
model.add(Dense(self.dataloader.num_category, activation='sigmoid'))
else:
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
def train(self):
'''
训练模型
:return: None
'''
self.dataloader.load_all_data()
if self.dataloader.load_all:
for epoch in range(int(self.epochs / self.save_interval)):
self.model.fit(self.dataloader.x_data, self.dataloader.y_data, batch_size=self.batch_size,
epochs=self.save_interval)
self.save_model()
else:
self.dataloader.self_load_data(self.batch_size)
x_data = self.dataloader.x_data
y_data = self.dataloader.y_data
for epoch in range(self.epochs):
_thread.start_new_thread(self.dataloader.self_load_data, (self.batch_size,))
loss = self.model.train_on_batch(x_data, y_data)
x_data = self.dataloader.x_data
y_data = self.dataloader.y_data
print("%d [D loss: %2f acc: %.2f]" % (epoch, loss[0], loss[1]))
if epoch % self.save_interval == 0:
self.save_model()
def save_model(self):
'''
保存模型
:return: None
'''
print("save model...")
if not os.path.exists("./model"): # 如果路径不存在
os.makedirs("./model")
self.model.save("./model/model.h5")
def predict(self, img):
'''
预测结果
:param img: 经过dataloader处理的图片
:return: 返回label结果
'''
img = np.expand_dims(img, axis=0)
if self.dataloader.multi_category:
return self.dataloader.category[np.argmax(self.model.predict(img)[0])]
else:
state = self.model.predict(img)
if state >= 0.5:
return self.dataloader.category[1]
else:
return self.dataloader.category[0]
def predict_by_numpyImg(self, img, is_plot=True):
'''
使用Numpy图片作为输入进行预测
:param img: numpy图片
:param is_plot: 是否使用matplotlib显示结果
:return: 返回label结果
'''
img_processed = self.dataloader.process_img(img)
label = self.predict(img_processed)
if is_plot:
plt.imshow(img)
plt.title(label)
plt.show()
return label
def predict_by_path(self, path, is_plot=True):
'''
使用图片路径作为输入进行预测
:param img: 图片路径
:param is_plot: 是否使用matplotlib显示结果
:return: 返回label结果
'''
img = cv2.imread(path)
img_processed = self.dataloader.process_img(img)
label = self.predict(img_processed)
if is_plot:
plt.imshow(img)
plt.title(label)
plt.show()
return label
if __name__ == '__main__':
cnn = CNN(using_model=False) # 如果需要使用旧模型则using_model为True,反之要训练新模型则为False
cnn.dataloader.rename_all_file() # 对于一些含有中文的图片名称可能会导致程序错误,推荐可以用dataloader自动重命名。
cnn.train()
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )