1 В избранное 0 Ответвления 0

OSCHINA-MIRROR/iangellove-omega-ai

Присоединиться к Gitlife
Откройте для себя и примите участие в публичных проектах с открытым исходным кодом с участием более 10 миллионов разработчиков. Приватные репозитории также полностью бесплатны :)
Присоединиться бесплатно
Клонировать/Скачать
README.md 26 КБ
Копировать Редактировать Web IDE Исходные данные Просмотреть построчно История
gitlife-traslator Отправлено 29.11.2024 20:33 40aaddd

В интернете есть много сайтов с информацией на эту тему. Посмотрите, что нашлось в поиске epoch:50 vgg16 测试数据集准确率86.45%

cifar_10 epoch:300 resnet18 [batchSize:128,初始learningRate:0.1,learnRateUpdate:GD_GECAY,optimizer:adamw] 数据预处理[randomCrop,randomHorizontalFilp,cutout,normalize] 测试数据集准确率91.23%

Примеры кода

bp iris demo

public void bpNetwork_iris() {
    // TODO Auto-generated method stub

    /**
     * 读取训练数据集
     */
    String iris_train = "/dataset/iris/iris.txt";
    
    String iris_test = "/dataset/iris/iris_test.txt";
    
    String[] labelSet = new String[] {"1","-1"};
    
    DataSet trainData = DataLoader.loalDataByTxt(iris_train, ",", 1, 1, 4, 2,labelSet);
    DataSet testData = DataLoader.loalDataByTxt(iris_test, ",", 1, 1, 4, 2,labelSet);
    
    System.out.println("train_data:"+JsonUtils.toJson(trainData));

    BPNetwork netWork = new BPNetwork(new SoftmaxWithCrossEntropyLoss());
    
    InputLayer inputLayer = new InputLayer(1,1,4);
    
    FullyLayer hidden1 = new FullyLayer(4, 40);
    
    ReluLayer active1 = new ReluLayer();
    
    FullyLayer hidden2 = new FullyLayer(40, 20);
    
    ReluLayer active2 = new ReluLayer();
    
    FullyLayer hidden3 = new FullyLayer(20, 2);

    SoftmaxWithCrossEntropyLayer hidden4 = new SoftmaxWithCrossEntropyLayer(2);
    
    netWork.addLayer(inputLayer);
    netWork.addLayer(hidden1);
    netWork.addLayer(active1);
    netWork.addLayer(hidden2);
    netWork.addLayer(active2);
    netWork.addLayer(hidden3);
    netWork.addLayer(hidden4);

    try {
        
        MBSGDOptimizer optimizer = new MBSGDOptimizer(netWork, 8, 0.00001d, 10, LearnRateUpdate.NONE);
        
        optimizer.train(trainData);
        
        optimizer.test(testData);
        
    } catch (Exception e) {
        // TODO Auto-generated catch block
        e.printStackTrace();
    }

}

cnn mnist demo

public void cnnNetwork_mnist() {
    // TODO Auto-generated method stub
    
    try {

        /**
         * 读取训练数据集
         */
        String mnist_train_data = "/dataset/mnist/train-images.idx3-ubyte";
        
        String mnist_train_label = "/dataset/mnist/train-labels.idx1-ubyte";
        
        String mnist_test_data = "/dataset/mnist/t10k-images.idx3-ubyte";
        
        String mnist_test_label = "/dataset/mnist/t10k-labels.idx1-ubyte";
        
        String[] labelSet = new String[] {"0","1","2","3","4","5","6","7","8","9"};
        
        Resource trainDataRes = new ClassPathResource(mnist_train_data);

        Resource trainLabelRes = new ClassPathResource(mnist_train_label);
        
        Resource testDataRes = new ClassPathResource(mnist_test_data);
        
        Resource testLabelRes = new ClassPathResource(mnist_test_label);
        
        DataSet trainData = DataLoader.loadDataByUByte(trainDataRes.getFile(), trainLabelRes.getFile(), labelSet, 1, 1 , 784, true);
        
        DataSet testData = DataLoader.loadDataByUByte(testDataRes.getFile(), testLabelRes.getFile(), labelSet, 1, 1 , 784, true);

        int channel = 1;
        
        int height = 28;
        
        int width = 28;
        
        CNN netWork = new CNN(new SoftmaxWithCrossEntropyLoss(), UpdaterType.momentum);
        
        netWork.learnRate = 0.001d;
        
        InputLayer inputLayer = new InputLayer(channel, 1, 784);
        
        ConvolutionLayer conv1 = new ConvolutionLayer(channel, 6, width, height, 5, 5, 2, 1, false);
        
        BNLayer bn1 = new BNLayer();
        
        LeakyReluLayer active1 = new LeakyReluLayer();
        
        PoolingLayer pool1 = new PoolingLayer(conv1.oChannel, conv1.oWidth, conv1.oHeight, 2, 2, 2, PoolingType.MAX_POOLING);
        
        ConvolutionLayer conv2 = new ConvolutionLayer(pool1.oChannel,

*Примечание: в запросе присутствуют фрагменты кода на языке Java, которые не были переведены.* ReluLayer active2 = new ReluLayer();

/**
 * block2  64 * 32 * 32
 */
BasicBlockLayer bl2 = new BasicBlockLayer(bl1.oChannel, 64, bl1.oHeight, bl1.oWidth, 1, netWork);
ReluLayer active3 = new ReluLayer();

/**
 * block3  128 * 16 * 16
 * downSample 32 / 2 = 16
 */
BasicBlockLayer bl3 = new BasicBlockLayer(bl2.oChannel, 128, bl2.oHeight, bl2.oWidth, 2, netWork);
ReluLayer active4 = new ReluLayer();

/**
 * block4  128 * 16 * 16
 */
BasicBlockLayer bl4 = new BasicBlockLayer(bl3.oChannel, 128, bl3.oHeight, bl3.oWidth, 1, netWork);
ReluLayer active5 = new ReluLayer();

/**
 * block5  256 * 8 * 8
 * downSample 16 / 2 = 8
 */
BasicBlockLayer bl5 = new BasicBlockLayer(bl4.oChannel, 256, bl4.oHeight, bl4.oWidth, 2, netWork);
ReluLayer active6 = new ReluLayer();

/**
 * block6  256 * 8 * 8
 */
BasicBlockLayer bl6 = new BasicBlockLayer(bl5.oChannel, 256, bl5.oHeight, bl5.oWidth, 1, netWork);
ReluLayer active7 = new ReluLayer();

/**
 * block7  512 * 4 * 4
 * downSample 8 / 2 = 4
 */
BasicBlockLayer bl7 = new BasicBlockLayer(bl6.oChannel, 512, bl6.oHeight, bl6.oWidth, 2, netWork);
ReluLayer active8 = new ReluLayer();


/**
 * block8  512 * 4 * 4
 */
BasicBlockLayer bl8 = new BasicBlockLayer(bl7.oChannel, 512, bl7.oHeight, bl7.oWidth, 1, netWork);
ReluLayer active9 = new ReluLayer();

AVGPoolingLayer pool2 = new AVGPoolingLayer(bl8.oChannel, bl8.oWidth, bl8.oHeight);

/**
 * fully  512 * 1 * 1
 */
int fInputCount = pool2.oChannel * pool2.oWidth * pool2.oHeight;

FullyLayer full1 = new FullyLayer(fInputCount, 10);

netWork.addLayer(inputLayer);
netWork.addLayer(conv1);
netWork.addLayer(bn1);
netWork.addLayer(active1);

/**
 * block1  64
 */
netWork.addLayer(bl1);
netWork.addLayer(active2);
netWork.addLayer(bl2);
netWork.addLayer(active3);

/**
 * block2  128
 */
netWork.addLayer(bl3);
netWork.addLayer(active4);
netWork.addLayer(bl4);
netWork.addLayer(active5);

/**
 * block3  256
 */
netWork.addLayer(bl5);
netWork.addLayer(active6);
netWork.addLayer(bl6);
netWork.addLayer(active7);

/**
 * block4  512
 */
netWork.addLayer(bl7);
netWork.addLayer(active8);
netWork.addLayer(bl8);
netWork.addLayer(active9);

netWork.addLayer(pool2);
netWork.addLayer(full1);

MBSGDOptimizer optimizer = new MBSGDOptimizer(netWork, 250, 0.001f, 128, LearnRateUpdate.GD_GECAY, false);

long start = System.currentTimeMillis();

optimizer.train(trainData, testData, mean, std);

optimizer.test(testData);

System.out.println(((System.currentTimeMillis() - start) / 1000) + "s.");

} catch (Exception e) {
    // TODO: handle exception
    e.printStackTrace();
}finally {

    try {
        CUDAMemoryManager.freeAll();
    } catch (Exception e) {
        // TODO Auto-generated catch block
        e.printStackTrace();
    }

} try {

    String cfg_path = "H:/voc/train/yolov1-tiny.cfg";

    String trainPath = "H:\\voc\\banana-detection\\bananas_train\\images";
    String trainLabelPath = "H:\\voc\\banana-detection\\bananas_train\\label.csv";

    String testPath = "H:\\voc\\banana-detection\\bananas_val\\images";
    String testLabelPath = "H:\\voc\\banana-detection\\bananas_val\\label.csv";

    YoloDataLoader trainData = new YoloDataLoader(trainPath, trainLabelPath, 1000, 3, 256, 256, 5, LabelType.csv, true);

    YoloDataLoader vailData = new YoloDataLoader(testPath, testLabelPath, 100, 3, 256, 256, 5, LabelType.csv, true);

    DataSet trainSet = formatToYolo(trainData.getDataSet());

    DataSet vailSet = formatToYolo(vailData.getDataSet());

    System.out.println("load data finish.");

    CNN netWork = new CNN(LossType.yolo3, UpdaterType.adamw);

    netWork.CUDNN = true;

    netWork.learnRate = 0.001f;

    ModelLoader.loadConfigToModel(netWork, cfg_path);

    MBSGDOptimizer optimizer = new MBSGDOptimizer(netWork, 1000, 0.001f, 64, LearnRateUpdate.CONSTANT, false);

    long start = System.currentTimeMillis();

    optimizer.trainObjectRecognition(trainSet, vailSet);

    /**
     * Обработка тестовых предсказаний
     */
    float[][][] draw_bbox = optimizer.showObjectRecognition(vailSet, 64);

    YoloDataLoader testData = new YoloDataLoader(testPath, testLabelPath, 1000, 3, 256, 256, 5, LabelType.csv, false);

    String outputPath = "H:\\voc\\banana-detection\\test\\";

    showImg(outputPath, testData.getDataSet(), 1, draw_bbox, false);

    System.out.println(((System.currentTimeMillis() - start) / 1000) + "s.");
} catch (Exception e) {
    // TODO: handle exception
    e.printStackTrace();
} finally {
    try {
        CUDAMemoryManager.freeAll();
    } catch (Exception e) {
        // TODO Auto-generated catch block
        e.printStackTrace();
    }
}

}

yolov3 mask demo (распознавание ношения маски)

public void yolov3_tiny_mask() {

int im_w = 416;
int im_h = 416;
int batchSize = 24;
int class_num = 2;
String[] labelset = new String[] {"unmask", "mask"};
try {
    String cfg_path = "H:\\voc\\mask\\data\\\\dataset\\yolov3-tiny-mask.cfg";
    String trainPath = "H:\\voc\\mask\\data\\resized\\train";
    String trainLabelPath = "H:\\voc\\mask\\data\\resized\\train_label.txt";
    String testPath = "H:\\voc\\mask\\data\\resized\\vail";
    String testLabelPath = "H:\\voc\\mask\\data\\resized\\vail_label.txt";
    String weightPath = "H:\\voc\\yolo-weights\\yolov3-tiny.conv.15";
    /**
    * Загрузка данных
    */
    DetectionDataLoader trainData = new DetectionDataLoader(trainPath, trainLabelPath, LabelFileType.txt, im_w, im_h, class_num, batchSize, DataType.yolov3);
    DetectionDataLoader vailData = new DetectionDataLoader(testPath, testLabelPath, LabelFileType.txt, im_w, im_h, class_num, batchSize, DataType.yolov3);
    /**
    * Создание модели yolo
    */
    Yolo netWork = new Yolo(LossType.yolo3, UpdaterType.adamw);
    netWork.CUDNN = true;
    netWork.learnRate = 0.001f;
    /**
    * Загрузить конфигурацию модели
    */
    ModelLoader.loadConfigToModel(netWork, cfg_path);
    /**
    * Загрузить предварительно обученные веса
    */
    DarknetLoader.loadWeight(netWork, weightPath, 14, true);
    /**
    * Создать оптимизатор
    */
    MBSGDOptimizer optimizer = new

``` **H:\\voc\\darknet_yolov7\\yolov7-tiny.conv.87**

try (FileInputStream fin = new FileInputStream(labelPath);
     InputStreamReader reader = new InputStreamReader(fin);  
     BufferedReader buffReader = new BufferedReader(reader)) {
    String strTmp = "";
    int idx = 0;
    while((strTmp = buffReader.readLine())!=null) {
        labelset[idx] = strTmp;
        idx++;
    }   
} catch (Exception e) {
    // TODO: handle exception
    e.printStackTrace();
}

DetectionDataLoader trainData = new DetectionDataLoader(trainPath, trainLabelPath, LabelFileType.txt, im_w, im_h, class_num, batchSize, DataType.yolov3);
DetectionDataLoader vailData = new DetectionDataLoader(testPath, testLabelPath, LabelFileType.txt, im_w, im_h, class_num, batchSize, DataType.yolov3);
Yolo netWork = new Yolo(LossType.yolov7, UpdaterType.adamw);
netWork.CUDNN = true;
netWork.learnRate = 0.001f;
ModelLoader.loadConfigToModel(netWork, cfg_path);
DarknetLoader.loadWeight(netWork, weightPath, 86, true);
MBSGDOptimizer optimizer = new MBSGDOptimizer(netWork, 1000, 0.001f, batchSize, LearnRateUpdate.SMART_HALF, false);
optimizer.trainObjectRecognitionOutputs(trainData, vailData);

/**
 * Обработка тестовых предсказанных результатов
 */
List<YoloBox> draw_bbox = optimizer.showObjectRecognitionYoloV3(vailData, batchSize);
String outputPath = "H:\\voc\\sm\\test_yolov7\\";
showImg(outputPath, vailData, class_num, draw_bbox, batchSize, false, im_w, im_h, labelset);
} catch (Exception e) {
    // TODO: обработать исключение
    e.printStackTrace();
} finally {
    try {
        CUDAMemoryManager.freeAll();
    } catch (Exception e) {
        // TODO Auto-generated catch block
        e.printStackTrace();
    }}

**gan mnist demo  генерация рукописных цифр**

public static void gan_anime() {

int imgSize = 784;
int ngf = 784; // количество карт признаков генератора
int nz = 100; // размерность шума
int batchSize = 2048;

int d_every = 1;
int g_every = 1;

float[] mean = new float[] {0.5f};
float[] std = new float[] {0.5f};

try {

String mnist_train_data = "/dataset/mnist/train-images.idx3-ubyte";

String mnist_train_label = "/dataset/mnist/train-labels.idx1-ubyte";

String[] labelSet = new String[] {"0","1","2","3","4","5","6","7","8","9"};

Resource trainDataRes = new ClassPathResource(mnist_train_data);

Resource trainLabelRes = new ClassPathResource(mnist_train_label);

DataSet trainData = DataLoader.loadDataByUByte(trainDataRes.getFile(), trainLabelRes.getFile(), labelSet, 1, 1 , 784, true, mean, std);

BPNetwork netG = NetG(ngf, nz);

BPNetwork netD = NetD(imgSize);

GANOptimizer optimizer = new GANOptimizer(netG, netD, batchSize, 3500, d_every, g_every, 0.001f, LearnRateUpdate.CONSTANT, false);

optimizer.train(trainData);

} catch (Exception e) {
// TODO: обработать исключение
e.printStackTrace();
}}

**dcgan anime demo  генерация аниме-персонажей**

public static void dcgan_anime() {

int imw = 64;
int imh = 64;
int ngf = 64; //количество карт признаков генератора
int ndf = 64; //количество карт признаков дискриминатора
int nz = 100; //размерность шума
int batchSize = 64;

int d_every = 1;
int g_every = 5;

float[] mean = new float[] {0.5f,0.5f,0.5f};
float[] std = new float[] {0.5f,0.5f,0.5f};

try {

String imgDirPath = **Генеративно-состязательная сеть для генерации изображений в стиле аниме**

```H:\\voc\\gan_anime\\ml2021spring-hw6\\faces\\";

CNN netG = NetG(ngf, nz);

CNN netD = NetD(ndf, imw, imh);

ImageDataLoader dataLoader = new ImageDataLoader(imgDirPath, imw, imh, batchSize, true, mean, std);

GANOptimizer optimizer = new GANOptimizer(netG, netD, batchSize, 2000, d_every, g_every, 0.001f, LearnRateUpdate.POLY, false);

optimizer.train(dataLoader);
} catch (Exception e) {
    // TODO: handle exception
    e.printStackTrace();
}

}```

**RNN — генератор китайских романов**

```public void charRNN() {
try {
int time = 256;
int batchSize = 64;
int embedding_dim = 256;
int hiddenSize = 512;

String trainPath = "H:\\rnn_dataset\\dpcc.txt";
OneHotDataLoader trainData = new OneHotDataLoader(trainPath, time, batchSize);

RNN netWork = new RNN(LossType.softmax_with_cross_entropy, UpdaterType.adamw, time);

InputLayer inputLayer = new InputLayer(1, 1, trainData.characters);
EmbeddingLayer em = new EmbeddingLayer(trainData.characters, embedding_dim);
RNNLayer l1 = new RNNLayer(embedding_dim, hiddenSize, time, ActiveType.tanh, false, netWork);
RNNLayer l2 = new RNNLayer(hiddenSize, hiddenSize, time, ActiveType.tanh, false, netWork);
RNNLayer l3 = new RNNLayer(hiddenSize, hiddenSize, time, ActiveType.tanh, false, netWork);
FullyLayer f1 = new FullyLayer(hiddenSize, hiddenSize, false);
BNLayer bn = new BNLayer();
LeakyReluLayer a1 = new LeakyReluLayer();
FullyLayer f2 = new FullyLayer(hiddenSize, trainData.characters, true);
netWork.addLayer(inputLayer);
netWork.addLayer(em);
netWork.addLayer(l1);
netWork.addLayer(l2);
netWork.addLayer(l3);
netWork.addLayer(f1);
netWork.addLayer(bn);
netWork.addLayer(a1);
netWork.addLayer(f2);

netWork.CUDNN = true;
netWork.learnRate = 0.01f;

MBSGDOptimizer optimizer = new MBSGDOptimizer(netWork, 2, 0.001f, batchSize, LearnRateUpdate.POLY, false);
optimizer.trainRNN(trainData);

int gen_len = 1000;
int max_len = 256;
String pre_txt = "这个故事所造成的后果便是造就了大批每天";
Tensor input = null;
Tensor output = null;
input = createTxtData(input, pre_txt, trainData.characters, trainData.dictionary, max_len);
netWork.RUN_MODEL = RunModel.TEST;
for(int i = 0;i<gen_len;i++) {
netWork.time = input.number;
String txt = genTxt(input, output, netWork, trainData, max_len);
if(netWork.time > 1) {
pre_txt += txt.substring(input.number - 1, input.number);
}else {
pre_txt += txt;
}
input = createTxtData(input, pre_txt, trainData.characters, trainData.dictionary, max_len);
}
System.out.println(pre_txt);
} catch (Exception e) {
// TODO: handle exception
e.printStackTrace();
}}```

**Seq2Seq — переводчик с английского на китайский**

```try {
int batchSize = 128;
int en_em = 64;
int de_em = 128;
int en_hidden = 256;
int de_hidden = 256;

String trainPath = "H:\\rnn_dataset\\translate1000.csv";
IndexDataLoader trainData = new IndexDataLoader(trainPath, batchSize);

Seq2Seq network = new Seq2Seq(LossType.softmax_with_cross_entropy, UpdaterType.adamw,
trainData.max_en, trainData.max_ch - 1, en_em, en_hidden, trainData.en_characters, de_em, de_hidden, trainData.ch_characters);
network.CUDNN = true;
network.learnRate =``` ```
0.01f;

EDOptimizer optimizer = new EDOptimizer(network, batchSize, 100, 0.001f, LearnRateUpdate.SMART_HALF, false);
optimizer.lr_step = new int[] {100,200};
optimizer.trainRNN(trainData);

Scanner scanner = new Scanner(System.in);
while (true) {
    System.out.println("请输入英文:");
    String input_txt = scanner.nextLine();
    if(input_txt.equals("exit")){
        break;
    }
    input_txt = input_txt.toLowerCase();
    System.out.println(input_txt);
    optimizer.predict(trainData, input_txt);    
}
scanner.close();
} catch (Exception e) {
    // TODO: handle exception
    e.printStackTrace();
}

gpt-китайский генератор рассказов

public static void gpt_dp() {
try {
boolean bias = false;
boolean dropout = true;
int batchSize = 32;
int max_len = 64;
int embedDim = 512;
int headNum = 8;
int decoderNum = 6;
String trainPath = "H:\\transformer_dataset\\gpt\\dpcc50.txt";
CNTokenizer trainData = new CNTokenizer(trainPath, max_len, batchSize);
NanoGPT network = new NanoGPT(LossType.softmax_with_cross_entropy, UpdaterType.adamw, headNum, decoderNum, trainData.characters, max_len, embedDim, bias, dropout);
network.learnRate = 0.001f;
EDOptimizer optimizer = new EDOptimizer(network, batchSize, 3, 0.001f, LearnRateUpdate.GD_GECAY, false);
optimizer.trainNanoGPT_GEN(trainData);
int gen_len = 1000;
network.RUN_MODEL = RunModel.TEST;
Tensor input = null;
Tensor output = null;
String pre_txt = "萧炎";
Tensor positions = CNChatTokenizer.getPositions(1, pre_txt.length());
Tensor mask = CNChatTokenizer.triu(1, network.headNum, pre_txt.length(), pre_txt.length(), 1);
input = createTxtData(input, pre_txt, trainData.characters, trainData.dictionary, max_len);
for(int i = 0;i<gen_len;i++) {
network.time = input.number;
String txt = genTxt(input, output, network, trainData, pre_txt.length(), mask, positions);
if(network.time > 1) {
pre_txt += txt.substring(input.number - 1, input.number);
}else {
pre_txt += txt;
}
input = createTxtData(input, pre_txt, trainData.characters, trainData.dictionary, max_len);
}
System.out.println(pre_txt);
} catch (Exception e) {
// TODO: handle exception
e.printStackTrace();
}
}

gpt — китайский чат-бот

public static void ch_chat_gpt2() {
try {
boolean bias = false;
boolean dropout = true;
int batchSize = 32;
int max_len = 128;
int embedDim = 768;
int head_num = 12;
int decoderNum = 12;
String trainPath = "H:\\transformer_dataset\\gpt\\chatdata\\train-format20w.txt";
CNChatTokenizer trainData = new CNChatTokenizer(trainPath, max_len, batchSize);
NanoGPT network = new NanoGPT(LossType.softmax_with_cross_entropy, UpdaterType.adamw, head_num, decoderNum, trainData.vocab_size, max_len, embedDim, bias, dropout, false);
network.learnRate = 0.0001f;
EDOptimizer optimizer = new EDOptimizer(network, batchSize, 3, 0.0001f, LearnRateUpdate.SMART_HALF, false);
optimizer.lr_step = new int[] {1, 2};
optimizer.trainNanoGPT(trainData);
Scanner scanner = new Scanner(System.in);
String context = "";
while (true) {
System.out.println("请输入中文:");
String input_txt = scanner.nextLine();
if(input_txt.equals("clean")){
context = "";
``` Систем.аут.принтлн("чатбот:" + инпут_тхт.сплит(" ")[1]);

public static void llama2_chinese_chatglm_vocab() { try { boolean bias = false; boolean dropout = false; boolean flashAttention = false; int batchSize = 8; int max_len = 512; int embedDim = 512; int head_num = 8; int decoderNum = 8; String trainPath = "H:\transformer_dataset\wbm_idx_chatglm_vocab.txt"; String tokenizer_path = "H:\transformer_dataset\tokenizer.model"; SentencePieceTokenizer tokenizer = new SentencePieceTokenizer(tokenizer_path, 64793); CNWikiTokenizer4 trainData = new CNWikiTokenizer4(trainPath, max_len, batchSize, 6250865, tokenizer); Llama2 network = new Llama2(LossType.softmax_with_cross_entropy_idx, UpdaterType.adamw, head_num, decoderNum, trainData.vocab_size, max_len, embedDim, bias, dropout, flashAttention); network.learnRate = 3e-4f; EDOptimizer optimizer = new EDOptimizer(network, batchSize, 1, 0.0001f, LearnRateUpdate.COSINE, false); optimizer.lr_step = new int[] {1, 2}; optimizer.lr = 3e-4f; optimizer.min_lr = 1e-5f; optimizer.setWarmUp(true); optimizer.warmUpTime = 1000; optimizer.lrDecayIters = (int) (trainData.count_it * 0.96); optimizer.trainLlama2_chinese(trainData); String model_path = "H:\model\llama2-92m-chinese.model"; ModelUtils.saveModel(network, model_path); } catch (Exception e) { // TODO: handle exception e.printStackTrace(); } }


public static void llama3_monkey() {
    try {
        boolean bias = false;
        boolean dropout = false;
        boolean flashAttention = false;
        int batchSize = 2;
        int max_len = 512;
        int embedDim = 512;
        int head_num = 16;
        int nKVHeadNum = 8;
        int decoderNum = 8;

        String trainPath = "H:\\model\\pretrain_data_6400.bin";
        String vocabPath = "H:\\transformer_dataset\\6400\\vocab.json";
        String mergesPath = "H:\\transformer_dataset\\6400\\merges.txt";

        BPETokenizer3 tokenizer = new BPETokenizer3(vocabPath, mergesPath);
        CNBpeTokenizer trainData = new CNBpeTokenizer(trainPath, max_len, batchSize, tokenizer, BinDataType.unint16);
        Llama3 network = new Llama3(LossType.softmax_with_cross_entropy_idx, UpdaterType.adamw, head_num, nKVHeadNum, decoderNum, trainData.vocab_size, max_len, embedDim, bias, dropout, flashAttention);
        network.learnRate = 1e-4f;
        network.CLIP_GRAD_NORM = true;
        initWeight(network, decoderNum);
        EDOptimizer optimizer = new EDOptimizer(network, batchSize, 2, 0.0001f, LearnRateUpdate.CONSTANT, false);
        optimizer.trainLlama3_chinese(trainData, 8, true);
        String save_model_path = "H:\\model\\llama3-26m-chinese.model";
        ModelUtils.saveModel(network, save_model_path);
    } catch (Exception e) {
        // TODO: handle exception
        e.printStackTrace();
    }
}

public static void duffsion_anime() { try { boolean bias = false; int batchSize = 8; int imw = 96; int imh = 96; int mChannel = 64; int resBlockNum = 2; int T = 1000; int[] channelMult = new int[] {1, 2, 3, 4}; String imgDirPath = "H:\voc\gan_anime\ml2021spring-hw6\faces\"; DiffusionImageDataLoader dataLoader = new DiffusionImageDataLoader(imgDirPath, imw, imh, batchSize, false); DiffusionUNet... ``` network = new DiffusionUNet(LossType.MSE, UpdaterType.adamw, T, 3, mChannel, channelMult, resBlockNum, imw, imh, bias); network.CUDNN = true; network.learnRate = 0.0002f; MBSGDOptimizer optimizer = new MBSGDOptimizer(network, 50, 0.00001f, batchSize, LearnRateUpdate.GD_GECAY, false); optimizer.trainGaussianDiffusion(dataLoader); } catch (Exception e) { // TODO: handle exception e.printStackTrace(); }


**Версия зависимостей пакета**

```xml
<!-- windows cuda 11.7 -->
<dependency>
    <groupId>io.gitee.iangellove</groupId>
    <artifactId>omega-engine-v4-gpu</artifactId>
    <version>win-cu11.7-v1.0-beta</version>
</dependency>
<!-- windows cuda 11.8 -->
<dependency>
    <groupId>io.gitee.iangellove</groupId>
    <artifactId>omega-engine-v4-gpu</artifactId>
    <version>win-cu11.8-v1.0-beta</version>
</dependency>
<!-- windows cuda 12.x -->
<dependency>
    <groupId>io.gitee.iangellove</groupId>
    <artifactId>omega-engine-v4-gpu</artifactId>
    <version>win-cu12.x-v1.0-beta</version>
</dependency>

В будущем планируется:

— реализовать модели Llama2, UNet и диффузионные модели;

— обеспечить динамическую настройку параметров и визуализацию процесса обучения.

На закуску:

— реализация AI-игры «Гонки на автомобилях» с использованием нейронных сетей и генетических алгоритмов.

  • Обновление версии Omega-Engine-V3

20 июня 2022 года:

— добавлена поддержка GPU с использованием jcuda для вызова CUBLAS Sgemm для матричного умножения;

— улучшена производительность свёрточных операций путём оптимизации их в im2col + gemm;

— оптимизирован процесс вычислений с использованием ForkJoin Framework для многопоточности;

— обновлена логика обновления скорости обучения, включая RANDOM, POLY, STEP, EXP и SIG методы.

Это лишь часть текста, переведённая с учётом ваших требований. Если вам нужен полный перевод, пожалуйста, уточните запрос.

Опубликовать ( 0 )

Вы можете оставить комментарий после Вход в систему

1
https://api.gitlife.ru/oschina-mirror/iangellove-omega-ai.git
git@api.gitlife.ru:oschina-mirror/iangellove-omega-ai.git
oschina-mirror
iangellove-omega-ai
iangellove-omega-ai
master