torch_musa — это расширенный пакет Python, основанный на PyTorch. Разработка torch_musa в виде плагина позволяет отвязать его от PyTorch, что удобно для поддержки кода. В сочетании с PyTorch пользователи могут использовать мощность графических карт Moore Threads через torch_musa. Кроме того, у torch_musa есть два значительных преимущества:
torch_musa также предоставляет набор инструментов для пользователей для проведения портирования CUDA, сборки расширений MUSA и отладки. Пожалуйста, обратитесь к README.md пакета torch_musa.utils.
(WIP) Мы загружаем наш torch_musa на PyPi, чтобы пользователи могли установить его с помощью pip
.
ВАЖНО: Поскольку некоторые зависимые библиотеки находятся в стадии бета-тестирования и еще не были официально выпущены, мы рекомендуем использовать предоставленный ниже Docker образ для разработчика для сборки torch_musa. Если вы действительно хотите собрать torch_musa в своем окружении, пожалуйста, свяжитесь с нами для получения дополнительных зависимостей.
apt-get install ccache
apt-get install libomp-11-dev
pip install -r requirements.txt
export MUSA_HOME=путь/к/musa_библиотекам(включая_mudnn_и_musa_toolkits) # значение по умолчанию /usr/local/musa/
export LD_LIBRARY_PATH=$MUSA_HOME/lib:$LD_LIBRARY_PATH
# если PYTORCH_REPO_PATH не установлен, PyTorch-v2.0.0 будет скачан вне этого каталога при сборке с помощью build.sh
export PYTORCH_REPO_PATH=путь/к/исходному_коду_PyTorch
bash build.sh # соберите оригинальный PyTorch и torch_musa с нуля
# Некоторые важные параметры:
bash build.sh --torch # соберите только оригинальный PyTorch
bash build.sh --musa # соберите только torch_musa
bash build.sh --fp64 # скомпилируйте fp64 в ядрах с помощью mcc в torch_musa
bash build.sh --debug # соберите в режиме отладки
bash build.sh --asan # соберите в режиме asan
bash build.sh --clean # очистите все собранное и соберите заново
bash build.sh --patch
cd pytorch
pip install -r requirements.txt
python setup.py install
# режим отладки: DEBUG=1 python setup.py install
# режим asan: USE_ASAN=1 python setup.py install
cd torch_musa
pip install -r requirements.txt
python setup.py install
# режим отладки: DEBUG=1 python setup.py install
# режим asan: USE_ASAN=1 python setup.py install
ВАЖНО: Если вы хотите использовать torch_musa в контейнере Docker, пожалуйста, установите mt-container-toolkit и используйте '--env MTHREADS_VISIBLE_DEVICES=all' при запуске контейнера. При первом запуске Docker выполняет самотестирование. Результаты юнит-тестов и интеграционных тестов torch_musa в разработческом Docker образе находятся по адресам /home/integration_test_output.txt и /home/ut_output.txt. Разработческий Docker уже установлен torch и torch_musa, а исходный код находится по пути /home.
# Для запуска Docker для s3000/s80 просто замените 's3000/s80' на 's4000' в следующей команде.
# Для запуска Docker с различными версиями Python, просто замените 'py38', 'py39' на 'py310'.
# Python 3.10
docker run -it --privileged --pull always --network=host --name=torch_musa_dev --env MTHREADS_VISIBLE_DEVICES=all --shm-size=80g registry.mthreads.com/mcconline/musa-pytorch-dev-public:rc3.1.0-v1.3.0-S4000-py310 /bin/bash
Docker тег | Описание |
---|---|
rc3.1.0-v1.3.0-S80/rc3.1.0-v1.3.0-S3000/rc3.1.0-v1.3.0-S4000 Python 3.8 Python 3.9 Python 3.10 |
musatoolkits rc3.1.0 mudnn rc2.7.0 mccl rc1.7.0 MUSA SDK rc3.1.0 |
rc2.1.0-v1.1.0-qy1/rc2.1.0-v1.1.0-qy2 Python 3.8 Python 3.9 Python 3.10 |
musatoolkits rc2.1.0 mudnn rc2.5.0 mccl rc2.0.0 muAlg_dev-0.3.0 muSPARSE_dev0.1.0 muThrust_dev-0.3.0 torch_musa ветка v1.1.0-rc1 |
# Для запуска Docker для s3000/s80 просто замените 's3000/s80' на 's4000' в следующей команде.
# Для запуска Docker с различными версиями Python, просто замените 'py38', 'py39' на 'py310'.
# python 3.10
docker run -it --privileged --pull always --network=host --name=torch_musa_release --env MTHREADS_VISIBLE_DEVICES=all --shm-size=80g registry.mthreads.com/mcconline/musa-pytorch-release-public:rc3.1.0-v1.3.0-S4000-py310 /bin/bash
Docker тег | Описание |
---|---|
rc3.1.0-v1.3.0-S80/rc3.1.0-v1.3.0-S3000/rc3.1.0-v1.3.0-S4000 Python 3.8 Python 3.9 Python 3.10 |
musatoolkits rc3.1.0 mudnn rc2.7.0 mccl rc1.7.0 MUSA SDK rc3.1.0 |
rc2.1.0-v1.1.0-qy1/rc2.1.0-v1.1.0-qy2 Python 3.8 Python 3.9 Python 3.10 |
musatoolkits rc2.1.0 mudnn rc2.5.0 mccl rc2.0.0 muAlg_dev-0.3.0 muSPARSE_dev0.1.0 muThrust_dev-0.3.0 torch_musa ветка v1.1.0-rc1 |
torch_musa главным образом следует стилю Google C++ (Google C++ style) и кастомизированному PEP8 Python стилю.
Вы можете использовать инструменты проверки под линтигом в папке tools/lint
для проверки соблюдения стилей кодирования.
# Проверка ошибок Python линтера
bash tools/lint/pylint.sh --rev main
# Проверка C++ линтера
bash tools/lint/git-clang-format.sh --rev main
Вы можете использовать следующую команду для исправления ошибок C++ линтера с помощью clang-format-11 и выше.
bash tools/lint/git-clang-format.sh -i --rev main
Ошибки Python немного отличаются. tools/lint/git-black.sh
можно использовать для форматирования кода Python, но другие ошибки линтера, например названия, все еще нужно исправлять вручную согласно предложенным ошибкам.
Два основных изменения необходимы при использовании torch_musa:
Импортировать пакет torch_musa
import torch
import torch_musa
Изменить устройство на musa
import torch
import torch_musa
a = torch.tensor([1.2, 2.3], dtype=torch.float32, device='musa')
b = torch.tensor([1.2, 2.3], dtype=torch.float32, device='cpu').to('musa')
c = torch.tensor([1.2, 2.3], dtype=torch.float32).musa()
torch_musa интегрировал операции torchvision в бэкенд musa. Пожалуйста, выполните следующие шаги, если torchvision не установлен:
# убедитесь, что torchvision не установлен
pip uninstall torchvision
git clone https://github.com/pytorch/vision.git
cd vision
python setup.py install
import torch
import torch_musa
import torchvision
def get_forge_data(num_boxes):
boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2
assert max(boxes[:, 1]) < min(boxes[:, 3]) # y1 < y2
scores = torch.rand(num_boxes)
return boxes, scores
num_boxes = 10
boxes, scores = get_forge_data(num_boxes)
iou_threshold = 0.5
print(torchvision.ops.nms(boxes=boxes.to("musa"), scores=scores.to("musa"), iou_threshold=iou_threshold))
import torch
import torch_musa
torch.musa.is_available()
torch.musa.device_count()
torch.musa.synchronize()
with torch.musa.device(0):
assert torch.musa.current_device() == 0
if torch.musa.device_count() > 1:
torch.musa.set_device(1)
assert torch.musa.current_device() == 1
torch.musa.synchronize("musa:1")
a = torch.tensor([1.2, 2.3], dtype=torch.float32, device='musa')
b = torch.tensor([1.8, 1.2], dtype=torch.float32, device='musa')
c = a + b
import torch
import torch_musa
import torchvision.models as models
model = models.resnet50().eval()
x = torch.rand((1, 3, 224, 224), device="musa")
model = model.to("musa")
# Выполнить инференс
y = model(x)
import torch
import torch_musa
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# 1. Подготовка данных
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 4
train_set = torchvision.datasets.CIFAR10(root='./data',
train=True,
download=True,
transform=transform)
train_loader = torch.utils.data.DataLoader(train_set,
batch_size=batch_size,
shuffle=True,
num_workers=2)
test_set = torchvision.datasets.CIFAR10(root='./data',
train=False,
download=True,
transform=transform)
test_loader = torch.utils.data.DataLoader(test_set,
batch_size=batch_size,
shuffle=False,
num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
device = torch.device("musa")
# 2. Построение сети
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # развернуть все измерения, кроме батча
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net().to(device)
# 3. Определение функции потерь и оптимизатора
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 4. Обучение
for epoch in range(2):
running_loss = 0.0
for i, data in enumerate(train_loader, 0):
inputs, labels = data
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs.to(device))
loss = criterion(outputs, labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print('Обучение завершено')
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
net.load_state_dict(torch.load(PATH))
# 5. Тестирование
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images, labels = data
outputs = net(images.to(device))
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels.to(device)).sum().item()
print(f'Точность сети на 10000 тестовых изображениях: {100 * correct // total} %')
В torch_musa мы предоставляем модуль codegen для реализации связей и регистрации пользовательских ядер MUSA, см. ссылку.
Пожалуйста, обратитесь к README.md в директории docker/common.
Пожалуйста, обратитесь к файлам в папке docs.
Пожалуйста, обратитесь к файлу op_list.md
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Комментарии ( 0 )