Фейджянь: ядро фреймворка
Высокоуровневый API
Фейджянь представляет новое поколение высокоуровневого API, которое является дальнейшим развитием и улучшением API Фейджяня. Оно предоставляет более простой и удобный в использовании API, что повышает лёгкость изучения и использования Фейджяня, а также расширяет его функциональность.
Высокоуровневый API предназначен для всех пользователей, от начинающих до опытных разработчиков. Для новичков в области глубокого обучения он позволяет быстро создавать проекты глубокого обучения, а для опытных разработчиков — быстро выполнять итерации алгоритмов.
— Простота изучения и использования: высокоуровневый API представляет собой дальнейшее развитие и оптимизацию обычного динамического графического API, сохраняя при этом совместимость с обычным API. Высокоуровневый API проще в изучении и использовании, позволяя сократить объём кода при реализации тех же функций. — Низкоуровневое программирование: использование высокоуровневого API характеризуется значительным сокращением объёма кода, который должен написать пользователь. — Динамическая и статическая конверсия: высокоуровневый API поддерживает динамическую и статическую конверсию, позволяя пользователям изменять режим работы всего одной строкой кода. Это удобно для отладки моделей в динамическом режиме и повышения эффективности их обучения в статическом режиме.
В плане улучшения функциональности и удобства использования высокоуровневый API предлагает следующие обновления:
— Метод обучения модели: в высокоуровневом API реализован класс Model, который наследуется от нейронных сетей. С помощью нескольких строк кода можно обучить модель. — Новый модуль обработки изображений transform: Фейджянь добавил новый модуль обработки изображений, включающий десятки функций обработки данных, охватывающих основные методы обработки и расширения данных. — Предоставлены часто используемые модели нейронных сетей: высокоуровневый API включает в себя модели компьютерного зрения и обработки естественного языка, такие как mobilenet, resnet, yolov3, cyclegan, bert, transformer, seq2seq и другие. Также были выпущены предварительно обученные модели этих архитектур, которые могут быть использованы напрямую или служить основой для дальнейшего развития.
Характеристики
Высокоуровневый API основан на реализации динамических графиков Фейджяня и совместим со всеми функциями динамических графиков. Он сохраняет простоту изучения, удобство использования и возможность отладки динамических графиков, а также дополнительно оптимизирует их.
По сравнению с реализацией алгоритмов с использованием динамических графических API, использование высокоуровневых API требует значительно меньшего объёма программирования. Например, для написания кода для распознавания рукописных символов с использованием обычного динамического API требуется более 20 строк кода, тогда как с высокоуровневым API достаточно всего 8 строк.
На рисунке ниже показано сравнение написания кода для распознавания рукописного текста с использованием обычного и высокоуровневого API. Видно, что код с использованием высокоуровневого API значительно короче.
Высокоуровневые API поддерживают динамические и статические графики, позволяя пользователям переключаться между ними одной строкой кода. Динамические графики удобны для отладки моделей, а статические обеспечивают более высокую эффективность обучения.
По умолчанию высокоуровневые API используют динамический режим обучения, аналогичный основному фреймворку. Чтобы переключиться на статический режим, можно использовать paddle.disable_static()
, а чтобы вернуться к динамическому режиму — paddle.enable_static()
.
Пример кода для переключения между режимами обучения:
# Переключение в динамический режим обучения одной строкой
paddle.disable_static()
# Настройка среды выполнения на GPU
paddle.set_device('gpu')
# Создание структуры сети
model = paddle.Model(Mnist())
# Определение оптимизатора
optimizer = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())
# Подготовка модели с помощью prepare()
model.prepare(optimizer, CrossEntropy(), Accuracy())
# Запуск обучения модели с помощью fit()
model.fit(train_dataset, val_dataset, batch_size=100, epochs=1, log_freq=100, save_dir="./output/")
Рассмотрим пример использования высокоуровневого API для задачи распознавания рукописных цифр MNIST.
Создание структуры сети с использованием высокоуровневых API аналогично созданию с использованием обычных динамических API. Необходимо унаследовать от класса paddle.nn.Layer
и определить структуру сети в методе forward
.
Пример создания структуры сети:
import paddle
# Установка среды выполнения на GPU
paddle.set_device('gpu')
# Использование динамического режима обучения
paddle.disable_static()
class Mnist(paddle.nn.Layer):
def __init__(self):
super(Mnist, self).__init__()
self.fc = paddle.nn.Linear(input_dim=784, output_dim=10)
# Определение процесса прямого распространения сети
def forward(self, inputs):
outputs = self.fc(inputs)
return outputs
Перед началом обучения необходимо определить оптимизатор, функцию потерь, метрику и подготовить данные. Эти шаги выполняются в функции prepare
класса Model
.
Пример подготовки к обучению:
# Определение формата входных данных
inputs = [Input([None, 784], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
# Объявление структуры сети
model = paddle.Model(Mnist())
optimizer = paddle.optimizer.SGD(learning_rate=0.001,
parameters=model.parameters())
# Подготовка к обучению с помощью prepare()
model.prepare(optimizer,
paddle.nn.CrossEntropy(),
paddle.metricAccuracy())
С помощью высокоуровневых API процесс обучения можно запустить одной строкой кода, создав цикл обучения с контролем количества эпох и процессом чтения данных.
Пример запуска обучения:
from paddle.vision.datasets import MNIST as MnistDataset
# Определение считывателя данных
train_dataset = MnistDataset(mode='train')
val_dataset = MnistDataset(mode='test')
# Запуск обучения
model.fit(train_dataset, val_dataset, batch_size=100, epochs=10, log_freq=100, save_dir="./output/")
Функция fit
в высокоуровневых API выполняет цикл обучения, требуя только указать считыватель данных, размер пакета, количество эпох, частоту вывода журнала и путь сохранения модели. paddle.vision.transforms
Модуль transforms в области обработки изображений включает в себя ряд реализаций для улучшения и обработки изображений, что может быть полезно при решении задач, связанных с компьютерным зрением.
В таблице ниже представлены данные об API для обработки и улучшения данных в Transforms:
Функция обработки данных | Описание функции |
---|---|
Compose | Объединение нескольких преобразований данных |
BatchCompose | Интерфейс предварительной обработки для пакетных данных |
Resize | Преобразование изображения в фиксированный размер |
RandomResizedCrop | Случайное обрезание изображения с последующим изменением размера до указанного размера |
CenterCrop | Обрезание изображения по центру |
CenterCropResize | Добавление отступов к изображению, обрезание по центру и изменение размера до указанного |
RandomHorizontalFlip | Горизонтальное случайное переворачивание изображения |
RandomVerticalFlip | Вертикальное случайное переворачивание изображения |
RandomCrop | Вырезание случайного участка из входного изображения |
RandomErasing | Случайный выбор прямоугольной области на изображении и удаление пикселей в этой области |
RandomRotate | Поворот изображения на заданный угол |
Permute | Перестановка измерений данных |
Normalize | Нормализация данных с использованием указанных среднего значения и стандартного отклонения |
GaussianNoise | Добавление гауссовского шума к данным |
BrightnessTransform | Регулировка яркости входного изображения |
SaturationTransform | Регулировка насыщенности входного изображения |
ContrastTransform | Регулировка контрастности входного изображения |
HueTransform | Регулировка оттенка входного изображения |
ColorJitter | Случайная регулировка яркости, насыщенности, контраста и оттенка изображения |
Grayscale | Превращение изображения в оттенки серого |
Pad | Заполнение входного изображения с использованием определенного режима и значений заполнения |
Использование:
from paddle.vision import transforms
import cv2
img_path = "./output/sample.jpg"
img = cv2.imread(img_path)
# Использование Compose для объединения нескольких функций улучшения данных
trans_funcs = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.BrightnessTransform(0.2)])
label = None
img_processed, label = trans_funcs(img, label)
Эффект кода выше представлен на рисунке:
paddle.vision.models содержит высокоуровневые API для популярных моделей, таких как ResNet, VGG, MobileNet и LeNet. Эти модели позволяют быстро выполнять задачи, связанные с обучением нейронных сетей, такими как тренировка, finetune и т. д.
Используя модели в paddle.vision, можно легко и быстро создать задачу глубокого обучения, например, тренировку resnet на наборе данных Cifar10, используя всего 13 строк кода:
from paddle.vision.models import resnet50
from paddle.vision.datasets import Cifar10
from paddle.optimizer import Momentum
from paddle.regularizer import L2Decay
from paddle.nn import CrossEntropy
from paddle.metirc import Accuracy
# Вызов модели resnet50
model = paddle.Model(resnet50(pretrained=False, num_classes=10))
# Использование набора данных Cifar10
train_dataset = Cifar10(mode='train')
val_dataset = Cifar10(mode='test')
# Определение оптимизатора
optimizer = Momentum(learning_rate=0.01,
momentum=0.9,
weight_decay=L2Decay(1e-4),
parameters=model.parameters())
# Подготовка перед тренировкой
model.prepare(optimizer, CrossEntropy(), Accuracy(topk=(1, 5)))
# Запуск тренировки
model.fit(train_dataset,
val_dataset,
epochs=50,
batch_size=64,
save_dir="./output",
num_workers=8)
Для получения дополнительной информации о примерах использования высокоуровневых API обратитесь к следующим ресурсам:
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Комментарии ( 0 )