В интернете есть много сайтов с информацией на эту тему. Посмотрите, что нашлось в поиске epoch:50 vgg16 测试数据集准确率86.45%
cifar_10 epoch:300 resnet18 [batchSize:128,初始learningRate:0.1,learnRateUpdate:GD_GECAY,optimizer:adamw] 数据预处理[randomCrop,randomHorizontalFilp,cutout,normalize] 测试数据集准确率91.23%
public void bpNetwork_iris() {
// TODO Auto-generated method stub
/**
* 读取训练数据集
*/
String iris_train = "/dataset/iris/iris.txt";
String iris_test = "/dataset/iris/iris_test.txt";
String[] labelSet = new String[] {"1","-1"};
DataSet trainData = DataLoader.loalDataByTxt(iris_train, ",", 1, 1, 4, 2,labelSet);
DataSet testData = DataLoader.loalDataByTxt(iris_test, ",", 1, 1, 4, 2,labelSet);
System.out.println("train_data:"+JsonUtils.toJson(trainData));
BPNetwork netWork = new BPNetwork(new SoftmaxWithCrossEntropyLoss());
InputLayer inputLayer = new InputLayer(1,1,4);
FullyLayer hidden1 = new FullyLayer(4, 40);
ReluLayer active1 = new ReluLayer();
FullyLayer hidden2 = new FullyLayer(40, 20);
ReluLayer active2 = new ReluLayer();
FullyLayer hidden3 = new FullyLayer(20, 2);
SoftmaxWithCrossEntropyLayer hidden4 = new SoftmaxWithCrossEntropyLayer(2);
netWork.addLayer(inputLayer);
netWork.addLayer(hidden1);
netWork.addLayer(active1);
netWork.addLayer(hidden2);
netWork.addLayer(active2);
netWork.addLayer(hidden3);
netWork.addLayer(hidden4);
try {
MBSGDOptimizer optimizer = new MBSGDOptimizer(netWork, 8, 0.00001d, 10, LearnRateUpdate.NONE);
optimizer.train(trainData);
optimizer.test(testData);
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
public void cnnNetwork_mnist() {
// TODO Auto-generated method stub
try {
/**
* 读取训练数据集
*/
String mnist_train_data = "/dataset/mnist/train-images.idx3-ubyte";
String mnist_train_label = "/dataset/mnist/train-labels.idx1-ubyte";
String mnist_test_data = "/dataset/mnist/t10k-images.idx3-ubyte";
String mnist_test_label = "/dataset/mnist/t10k-labels.idx1-ubyte";
String[] labelSet = new String[] {"0","1","2","3","4","5","6","7","8","9"};
Resource trainDataRes = new ClassPathResource(mnist_train_data);
Resource trainLabelRes = new ClassPathResource(mnist_train_label);
Resource testDataRes = new ClassPathResource(mnist_test_data);
Resource testLabelRes = new ClassPathResource(mnist_test_label);
DataSet trainData = DataLoader.loadDataByUByte(trainDataRes.getFile(), trainLabelRes.getFile(), labelSet, 1, 1 , 784, true);
DataSet testData = DataLoader.loadDataByUByte(testDataRes.getFile(), testLabelRes.getFile(), labelSet, 1, 1 , 784, true);
int channel = 1;
int height = 28;
int width = 28;
CNN netWork = new CNN(new SoftmaxWithCrossEntropyLoss(), UpdaterType.momentum);
netWork.learnRate = 0.001d;
InputLayer inputLayer = new InputLayer(channel, 1, 784);
ConvolutionLayer conv1 = new ConvolutionLayer(channel, 6, width, height, 5, 5, 2, 1, false);
BNLayer bn1 = new BNLayer();
LeakyReluLayer active1 = new LeakyReluLayer();
PoolingLayer pool1 = new PoolingLayer(conv1.oChannel, conv1.oWidth, conv1.oHeight, 2, 2, 2, PoolingType.MAX_POOLING);
ConvolutionLayer conv2 = new ConvolutionLayer(pool1.oChannel,
*Примечание: в запросе присутствуют фрагменты кода на языке Java, которые не были переведены.* ReluLayer active2 = new ReluLayer();
/**
* block2 64 * 32 * 32
*/
BasicBlockLayer bl2 = new BasicBlockLayer(bl1.oChannel, 64, bl1.oHeight, bl1.oWidth, 1, netWork);
ReluLayer active3 = new ReluLayer();
/**
* block3 128 * 16 * 16
* downSample 32 / 2 = 16
*/
BasicBlockLayer bl3 = new BasicBlockLayer(bl2.oChannel, 128, bl2.oHeight, bl2.oWidth, 2, netWork);
ReluLayer active4 = new ReluLayer();
/**
* block4 128 * 16 * 16
*/
BasicBlockLayer bl4 = new BasicBlockLayer(bl3.oChannel, 128, bl3.oHeight, bl3.oWidth, 1, netWork);
ReluLayer active5 = new ReluLayer();
/**
* block5 256 * 8 * 8
* downSample 16 / 2 = 8
*/
BasicBlockLayer bl5 = new BasicBlockLayer(bl4.oChannel, 256, bl4.oHeight, bl4.oWidth, 2, netWork);
ReluLayer active6 = new ReluLayer();
/**
* block6 256 * 8 * 8
*/
BasicBlockLayer bl6 = new BasicBlockLayer(bl5.oChannel, 256, bl5.oHeight, bl5.oWidth, 1, netWork);
ReluLayer active7 = new ReluLayer();
/**
* block7 512 * 4 * 4
* downSample 8 / 2 = 4
*/
BasicBlockLayer bl7 = new BasicBlockLayer(bl6.oChannel, 512, bl6.oHeight, bl6.oWidth, 2, netWork);
ReluLayer active8 = new ReluLayer();
/**
* block8 512 * 4 * 4
*/
BasicBlockLayer bl8 = new BasicBlockLayer(bl7.oChannel, 512, bl7.oHeight, bl7.oWidth, 1, netWork);
ReluLayer active9 = new ReluLayer();
AVGPoolingLayer pool2 = new AVGPoolingLayer(bl8.oChannel, bl8.oWidth, bl8.oHeight);
/**
* fully 512 * 1 * 1
*/
int fInputCount = pool2.oChannel * pool2.oWidth * pool2.oHeight;
FullyLayer full1 = new FullyLayer(fInputCount, 10);
netWork.addLayer(inputLayer);
netWork.addLayer(conv1);
netWork.addLayer(bn1);
netWork.addLayer(active1);
/**
* block1 64
*/
netWork.addLayer(bl1);
netWork.addLayer(active2);
netWork.addLayer(bl2);
netWork.addLayer(active3);
/**
* block2 128
*/
netWork.addLayer(bl3);
netWork.addLayer(active4);
netWork.addLayer(bl4);
netWork.addLayer(active5);
/**
* block3 256
*/
netWork.addLayer(bl5);
netWork.addLayer(active6);
netWork.addLayer(bl6);
netWork.addLayer(active7);
/**
* block4 512
*/
netWork.addLayer(bl7);
netWork.addLayer(active8);
netWork.addLayer(bl8);
netWork.addLayer(active9);
netWork.addLayer(pool2);
netWork.addLayer(full1);
MBSGDOptimizer optimizer = new MBSGDOptimizer(netWork, 250, 0.001f, 128, LearnRateUpdate.GD_GECAY, false);
long start = System.currentTimeMillis();
optimizer.train(trainData, testData, mean, std);
optimizer.test(testData);
System.out.println(((System.currentTimeMillis() - start) / 1000) + "s.");
} catch (Exception e) {
// TODO: handle exception
e.printStackTrace();
}finally {
try {
CUDAMemoryManager.freeAll();
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
} try {
String cfg_path = "H:/voc/train/yolov1-tiny.cfg";
String trainPath = "H:\\voc\\banana-detection\\bananas_train\\images";
String trainLabelPath = "H:\\voc\\banana-detection\\bananas_train\\label.csv";
String testPath = "H:\\voc\\banana-detection\\bananas_val\\images";
String testLabelPath = "H:\\voc\\banana-detection\\bananas_val\\label.csv";
YoloDataLoader trainData = new YoloDataLoader(trainPath, trainLabelPath, 1000, 3, 256, 256, 5, LabelType.csv, true);
YoloDataLoader vailData = new YoloDataLoader(testPath, testLabelPath, 100, 3, 256, 256, 5, LabelType.csv, true);
DataSet trainSet = formatToYolo(trainData.getDataSet());
DataSet vailSet = formatToYolo(vailData.getDataSet());
System.out.println("load data finish.");
CNN netWork = new CNN(LossType.yolo3, UpdaterType.adamw);
netWork.CUDNN = true;
netWork.learnRate = 0.001f;
ModelLoader.loadConfigToModel(netWork, cfg_path);
MBSGDOptimizer optimizer = new MBSGDOptimizer(netWork, 1000, 0.001f, 64, LearnRateUpdate.CONSTANT, false);
long start = System.currentTimeMillis();
optimizer.trainObjectRecognition(trainSet, vailSet);
/**
* Обработка тестовых предсказаний
*/
float[][][] draw_bbox = optimizer.showObjectRecognition(vailSet, 64);
YoloDataLoader testData = new YoloDataLoader(testPath, testLabelPath, 1000, 3, 256, 256, 5, LabelType.csv, false);
String outputPath = "H:\\voc\\banana-detection\\test\\";
showImg(outputPath, testData.getDataSet(), 1, draw_bbox, false);
System.out.println(((System.currentTimeMillis() - start) / 1000) + "s.");
} catch (Exception e) {
// TODO: handle exception
e.printStackTrace();
} finally {
try {
CUDAMemoryManager.freeAll();
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
public void yolov3_tiny_mask() {
int im_w = 416;
int im_h = 416;
int batchSize = 24;
int class_num = 2;
String[] labelset = new String[] {"unmask", "mask"};
try {
String cfg_path = "H:\\voc\\mask\\data\\\\dataset\\yolov3-tiny-mask.cfg";
String trainPath = "H:\\voc\\mask\\data\\resized\\train";
String trainLabelPath = "H:\\voc\\mask\\data\\resized\\train_label.txt";
String testPath = "H:\\voc\\mask\\data\\resized\\vail";
String testLabelPath = "H:\\voc\\mask\\data\\resized\\vail_label.txt";
String weightPath = "H:\\voc\\yolo-weights\\yolov3-tiny.conv.15";
/**
* Загрузка данных
*/
DetectionDataLoader trainData = new DetectionDataLoader(trainPath, trainLabelPath, LabelFileType.txt, im_w, im_h, class_num, batchSize, DataType.yolov3);
DetectionDataLoader vailData = new DetectionDataLoader(testPath, testLabelPath, LabelFileType.txt, im_w, im_h, class_num, batchSize, DataType.yolov3);
/**
* Создание модели yolo
*/
Yolo netWork = new Yolo(LossType.yolo3, UpdaterType.adamw);
netWork.CUDNN = true;
netWork.learnRate = 0.001f;
/**
* Загрузить конфигурацию модели
*/
ModelLoader.loadConfigToModel(netWork, cfg_path);
/**
* Загрузить предварительно обученные веса
*/
DarknetLoader.loadWeight(netWork, weightPath, 14, true);
/**
* Создать оптимизатор
*/
MBSGDOptimizer optimizer = new
``` **H:\\voc\\darknet_yolov7\\yolov7-tiny.conv.87**
try (FileInputStream fin = new FileInputStream(labelPath);
InputStreamReader reader = new InputStreamReader(fin);
BufferedReader buffReader = new BufferedReader(reader)) {
String strTmp = "";
int idx = 0;
while((strTmp = buffReader.readLine())!=null) {
labelset[idx] = strTmp;
idx++;
}
} catch (Exception e) {
// TODO: handle exception
e.printStackTrace();
}
DetectionDataLoader trainData = new DetectionDataLoader(trainPath, trainLabelPath, LabelFileType.txt, im_w, im_h, class_num, batchSize, DataType.yolov3);
DetectionDataLoader vailData = new DetectionDataLoader(testPath, testLabelPath, LabelFileType.txt, im_w, im_h, class_num, batchSize, DataType.yolov3);
Yolo netWork = new Yolo(LossType.yolov7, UpdaterType.adamw);
netWork.CUDNN = true;
netWork.learnRate = 0.001f;
ModelLoader.loadConfigToModel(netWork, cfg_path);
DarknetLoader.loadWeight(netWork, weightPath, 86, true);
MBSGDOptimizer optimizer = new MBSGDOptimizer(netWork, 1000, 0.001f, batchSize, LearnRateUpdate.SMART_HALF, false);
optimizer.trainObjectRecognitionOutputs(trainData, vailData);
/**
* Обработка тестовых предсказанных результатов
*/
List<YoloBox> draw_bbox = optimizer.showObjectRecognitionYoloV3(vailData, batchSize);
String outputPath = "H:\\voc\\sm\\test_yolov7\\";
showImg(outputPath, vailData, class_num, draw_bbox, batchSize, false, im_w, im_h, labelset);
} catch (Exception e) {
// TODO: обработать исключение
e.printStackTrace();
} finally {
try {
CUDAMemoryManager.freeAll();
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}}
**gan mnist demo — генерация рукописных цифр**
public static void gan_anime() {
int imgSize = 784;
int ngf = 784; // количество карт признаков генератора
int nz = 100; // размерность шума
int batchSize = 2048;
int d_every = 1;
int g_every = 1;
float[] mean = new float[] {0.5f};
float[] std = new float[] {0.5f};
try {
String mnist_train_data = "/dataset/mnist/train-images.idx3-ubyte";
String mnist_train_label = "/dataset/mnist/train-labels.idx1-ubyte";
String[] labelSet = new String[] {"0","1","2","3","4","5","6","7","8","9"};
Resource trainDataRes = new ClassPathResource(mnist_train_data);
Resource trainLabelRes = new ClassPathResource(mnist_train_label);
DataSet trainData = DataLoader.loadDataByUByte(trainDataRes.getFile(), trainLabelRes.getFile(), labelSet, 1, 1 , 784, true, mean, std);
BPNetwork netG = NetG(ngf, nz);
BPNetwork netD = NetD(imgSize);
GANOptimizer optimizer = new GANOptimizer(netG, netD, batchSize, 3500, d_every, g_every, 0.001f, LearnRateUpdate.CONSTANT, false);
optimizer.train(trainData);
} catch (Exception e) {
// TODO: обработать исключение
e.printStackTrace();
}}
**dcgan anime demo — генерация аниме-персонажей**
public static void dcgan_anime() {
int imw = 64;
int imh = 64;
int ngf = 64; //количество карт признаков генератора
int ndf = 64; //количество карт признаков дискриминатора
int nz = 100; //размерность шума
int batchSize = 64;
int d_every = 1;
int g_every = 5;
float[] mean = new float[] {0.5f,0.5f,0.5f};
float[] std = new float[] {0.5f,0.5f,0.5f};
try {
String imgDirPath = **Генеративно-состязательная сеть для генерации изображений в стиле аниме**
```H:\\voc\\gan_anime\\ml2021spring-hw6\\faces\\";
CNN netG = NetG(ngf, nz);
CNN netD = NetD(ndf, imw, imh);
ImageDataLoader dataLoader = new ImageDataLoader(imgDirPath, imw, imh, batchSize, true, mean, std);
GANOptimizer optimizer = new GANOptimizer(netG, netD, batchSize, 2000, d_every, g_every, 0.001f, LearnRateUpdate.POLY, false);
optimizer.train(dataLoader);
} catch (Exception e) {
// TODO: handle exception
e.printStackTrace();
}
}```
**RNN — генератор китайских романов**
```public void charRNN() {
try {
int time = 256;
int batchSize = 64;
int embedding_dim = 256;
int hiddenSize = 512;
String trainPath = "H:\\rnn_dataset\\dpcc.txt";
OneHotDataLoader trainData = new OneHotDataLoader(trainPath, time, batchSize);
RNN netWork = new RNN(LossType.softmax_with_cross_entropy, UpdaterType.adamw, time);
InputLayer inputLayer = new InputLayer(1, 1, trainData.characters);
EmbeddingLayer em = new EmbeddingLayer(trainData.characters, embedding_dim);
RNNLayer l1 = new RNNLayer(embedding_dim, hiddenSize, time, ActiveType.tanh, false, netWork);
RNNLayer l2 = new RNNLayer(hiddenSize, hiddenSize, time, ActiveType.tanh, false, netWork);
RNNLayer l3 = new RNNLayer(hiddenSize, hiddenSize, time, ActiveType.tanh, false, netWork);
FullyLayer f1 = new FullyLayer(hiddenSize, hiddenSize, false);
BNLayer bn = new BNLayer();
LeakyReluLayer a1 = new LeakyReluLayer();
FullyLayer f2 = new FullyLayer(hiddenSize, trainData.characters, true);
netWork.addLayer(inputLayer);
netWork.addLayer(em);
netWork.addLayer(l1);
netWork.addLayer(l2);
netWork.addLayer(l3);
netWork.addLayer(f1);
netWork.addLayer(bn);
netWork.addLayer(a1);
netWork.addLayer(f2);
netWork.CUDNN = true;
netWork.learnRate = 0.01f;
MBSGDOptimizer optimizer = new MBSGDOptimizer(netWork, 2, 0.001f, batchSize, LearnRateUpdate.POLY, false);
optimizer.trainRNN(trainData);
int gen_len = 1000;
int max_len = 256;
String pre_txt = "这个故事所造成的后果,便是造就了大批每天";
Tensor input = null;
Tensor output = null;
input = createTxtData(input, pre_txt, trainData.characters, trainData.dictionary, max_len);
netWork.RUN_MODEL = RunModel.TEST;
for(int i = 0;i<gen_len;i++) {
netWork.time = input.number;
String txt = genTxt(input, output, netWork, trainData, max_len);
if(netWork.time > 1) {
pre_txt += txt.substring(input.number - 1, input.number);
}else {
pre_txt += txt;
}
input = createTxtData(input, pre_txt, trainData.characters, trainData.dictionary, max_len);
}
System.out.println(pre_txt);
} catch (Exception e) {
// TODO: handle exception
e.printStackTrace();
}}```
**Seq2Seq — переводчик с английского на китайский**
```try {
int batchSize = 128;
int en_em = 64;
int de_em = 128;
int en_hidden = 256;
int de_hidden = 256;
String trainPath = "H:\\rnn_dataset\\translate1000.csv";
IndexDataLoader trainData = new IndexDataLoader(trainPath, batchSize);
Seq2Seq network = new Seq2Seq(LossType.softmax_with_cross_entropy, UpdaterType.adamw,
trainData.max_en, trainData.max_ch - 1, en_em, en_hidden, trainData.en_characters, de_em, de_hidden, trainData.ch_characters);
network.CUDNN = true;
network.learnRate =``` ```
0.01f;
EDOptimizer optimizer = new EDOptimizer(network, batchSize, 100, 0.001f, LearnRateUpdate.SMART_HALF, false);
optimizer.lr_step = new int[] {100,200};
optimizer.trainRNN(trainData);
Scanner scanner = new Scanner(System.in);
while (true) {
System.out.println("请输入英文:");
String input_txt = scanner.nextLine();
if(input_txt.equals("exit")){
break;
}
input_txt = input_txt.toLowerCase();
System.out.println(input_txt);
optimizer.predict(trainData, input_txt);
}
scanner.close();
} catch (Exception e) {
// TODO: handle exception
e.printStackTrace();
}
public static void gpt_dp() {
try {
boolean bias = false;
boolean dropout = true;
int batchSize = 32;
int max_len = 64;
int embedDim = 512;
int headNum = 8;
int decoderNum = 6;
String trainPath = "H:\\transformer_dataset\\gpt\\dpcc50.txt";
CNTokenizer trainData = new CNTokenizer(trainPath, max_len, batchSize);
NanoGPT network = new NanoGPT(LossType.softmax_with_cross_entropy, UpdaterType.adamw, headNum, decoderNum, trainData.characters, max_len, embedDim, bias, dropout);
network.learnRate = 0.001f;
EDOptimizer optimizer = new EDOptimizer(network, batchSize, 3, 0.001f, LearnRateUpdate.GD_GECAY, false);
optimizer.trainNanoGPT_GEN(trainData);
int gen_len = 1000;
network.RUN_MODEL = RunModel.TEST;
Tensor input = null;
Tensor output = null;
String pre_txt = "萧炎";
Tensor positions = CNChatTokenizer.getPositions(1, pre_txt.length());
Tensor mask = CNChatTokenizer.triu(1, network.headNum, pre_txt.length(), pre_txt.length(), 1);
input = createTxtData(input, pre_txt, trainData.characters, trainData.dictionary, max_len);
for(int i = 0;i<gen_len;i++) {
network.time = input.number;
String txt = genTxt(input, output, network, trainData, pre_txt.length(), mask, positions);
if(network.time > 1) {
pre_txt += txt.substring(input.number - 1, input.number);
}else {
pre_txt += txt;
}
input = createTxtData(input, pre_txt, trainData.characters, trainData.dictionary, max_len);
}
System.out.println(pre_txt);
} catch (Exception e) {
// TODO: handle exception
e.printStackTrace();
}
}
public static void ch_chat_gpt2() {
try {
boolean bias = false;
boolean dropout = true;
int batchSize = 32;
int max_len = 128;
int embedDim = 768;
int head_num = 12;
int decoderNum = 12;
String trainPath = "H:\\transformer_dataset\\gpt\\chatdata\\train-format20w.txt";
CNChatTokenizer trainData = new CNChatTokenizer(trainPath, max_len, batchSize);
NanoGPT network = new NanoGPT(LossType.softmax_with_cross_entropy, UpdaterType.adamw, head_num, decoderNum, trainData.vocab_size, max_len, embedDim, bias, dropout, false);
network.learnRate = 0.0001f;
EDOptimizer optimizer = new EDOptimizer(network, batchSize, 3, 0.0001f, LearnRateUpdate.SMART_HALF, false);
optimizer.lr_step = new int[] {1, 2};
optimizer.trainNanoGPT(trainData);
Scanner scanner = new Scanner(System.in);
String context = "";
while (true) {
System.out.println("请输入中文:");
String input_txt = scanner.nextLine();
if(input_txt.equals("clean")){
context = "";
``` Систем.аут.принтлн("чатбот:" + инпут_тхт.сплит(" ")[1]);
public static void llama2_chinese_chatglm_vocab() { try { boolean bias = false; boolean dropout = false; boolean flashAttention = false; int batchSize = 8; int max_len = 512; int embedDim = 512; int head_num = 8; int decoderNum = 8; String trainPath = "H:\transformer_dataset\wbm_idx_chatglm_vocab.txt"; String tokenizer_path = "H:\transformer_dataset\tokenizer.model"; SentencePieceTokenizer tokenizer = new SentencePieceTokenizer(tokenizer_path, 64793); CNWikiTokenizer4 trainData = new CNWikiTokenizer4(trainPath, max_len, batchSize, 6250865, tokenizer); Llama2 network = new Llama2(LossType.softmax_with_cross_entropy_idx, UpdaterType.adamw, head_num, decoderNum, trainData.vocab_size, max_len, embedDim, bias, dropout, flashAttention); network.learnRate = 3e-4f; EDOptimizer optimizer = new EDOptimizer(network, batchSize, 1, 0.0001f, LearnRateUpdate.COSINE, false); optimizer.lr_step = new int[] {1, 2}; optimizer.lr = 3e-4f; optimizer.min_lr = 1e-5f; optimizer.setWarmUp(true); optimizer.warmUpTime = 1000; optimizer.lrDecayIters = (int) (trainData.count_it * 0.96); optimizer.trainLlama2_chinese(trainData); String model_path = "H:\model\llama2-92m-chinese.model"; ModelUtils.saveModel(network, model_path); } catch (Exception e) { // TODO: handle exception e.printStackTrace(); } }
public static void llama3_monkey() {
try {
boolean bias = false;
boolean dropout = false;
boolean flashAttention = false;
int batchSize = 2;
int max_len = 512;
int embedDim = 512;
int head_num = 16;
int nKVHeadNum = 8;
int decoderNum = 8;
String trainPath = "H:\\model\\pretrain_data_6400.bin";
String vocabPath = "H:\\transformer_dataset\\6400\\vocab.json";
String mergesPath = "H:\\transformer_dataset\\6400\\merges.txt";
BPETokenizer3 tokenizer = new BPETokenizer3(vocabPath, mergesPath);
CNBpeTokenizer trainData = new CNBpeTokenizer(trainPath, max_len, batchSize, tokenizer, BinDataType.unint16);
Llama3 network = new Llama3(LossType.softmax_with_cross_entropy_idx, UpdaterType.adamw, head_num, nKVHeadNum, decoderNum, trainData.vocab_size, max_len, embedDim, bias, dropout, flashAttention);
network.learnRate = 1e-4f;
network.CLIP_GRAD_NORM = true;
initWeight(network, decoderNum);
EDOptimizer optimizer = new EDOptimizer(network, batchSize, 2, 0.0001f, LearnRateUpdate.CONSTANT, false);
optimizer.trainLlama3_chinese(trainData, 8, true);
String save_model_path = "H:\\model\\llama3-26m-chinese.model";
ModelUtils.saveModel(network, save_model_path);
} catch (Exception e) {
// TODO: handle exception
e.printStackTrace();
}
}
public static void duffsion_anime() { try { boolean bias = false; int batchSize = 8; int imw = 96; int imh = 96; int mChannel = 64; int resBlockNum = 2; int T = 1000; int[] channelMult = new int[] {1, 2, 3, 4}; String imgDirPath = "H:\voc\gan_anime\ml2021spring-hw6\faces\"; DiffusionImageDataLoader dataLoader = new DiffusionImageDataLoader(imgDirPath, imw, imh, batchSize, false); DiffusionUNet... ``` network = new DiffusionUNet(LossType.MSE, UpdaterType.adamw, T, 3, mChannel, channelMult, resBlockNum, imw, imh, bias); network.CUDNN = true; network.learnRate = 0.0002f; MBSGDOptimizer optimizer = new MBSGDOptimizer(network, 50, 0.00001f, batchSize, LearnRateUpdate.GD_GECAY, false); optimizer.trainGaussianDiffusion(dataLoader); } catch (Exception e) { // TODO: handle exception e.printStackTrace(); }
**Версия зависимостей пакета**
```xml
<!-- windows cuda 11.7 -->
<dependency>
<groupId>io.gitee.iangellove</groupId>
<artifactId>omega-engine-v4-gpu</artifactId>
<version>win-cu11.7-v1.0-beta</version>
</dependency>
<!-- windows cuda 11.8 -->
<dependency>
<groupId>io.gitee.iangellove</groupId>
<artifactId>omega-engine-v4-gpu</artifactId>
<version>win-cu11.8-v1.0-beta</version>
</dependency>
<!-- windows cuda 12.x -->
<dependency>
<groupId>io.gitee.iangellove</groupId>
<artifactId>omega-engine-v4-gpu</artifactId>
<version>win-cu12.x-v1.0-beta</version>
</dependency>
В будущем планируется:
— реализовать модели Llama2, UNet и диффузионные модели;
— обеспечить динамическую настройку параметров и визуализацию процесса обучения.
На закуску:
— реализация AI-игры «Гонки на автомобилях» с использованием нейронных сетей и генетических алгоритмов.
20 июня 2022 года:
— добавлена поддержка GPU с использованием jcuda для вызова CUBLAS Sgemm для матричного умножения;
— улучшена производительность свёрточных операций путём оптимизации их в im2col + gemm;
— оптимизирован процесс вычислений с использованием ForkJoin Framework для многопоточности;
— обновлена логика обновления скорости обучения, включая RANDOM, POLY, STEP, EXP и SIG методы.
Это лишь часть текста, переведённая с учётом ваших требований. Если вам нужен полный перевод, пожалуйста, уточните запрос.
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )