Слияние кода завершено, страница обновится автоматически
Моя версия torch_musa — 2.0.0+git63348e6, и я обнаружил, что функции torch.cumsum
, torch.where
и torch.std
не поддерживают тензоры на устройстве musa. Вы можете запустить код в конце этого комментария для воспроизведения исключения.
В самом деле, эти операции могут быть заменены комбинацией существующих операторов, но такая комбинация приведёт к большему количеству вызовов ядра и может вызвать числовые ошибки.
Жду вашего обновления и ответа.
import torch
import torch_musa
import traceback
test = torch.rand((5, 10, 20), device='musa')
try:
torch.cumsum(test, dim=-1)
except Exception as e:
print(traceback.format_exc())
try:
torch.where(test > 0.5, test, test)
except Exception as e:
print(traceback.format_exc())
try:
torch.std(test, dim=-1)
except Exception as e:
print(traceback.format_exc())