Слияние кода завершено, страница обновится автоматически
from typing import Dict
import torch
from torch import nn, Tensor
from torch.nn import CTCLoss as TorchCTCLoss
class CTCLoss(nn.Module):
def __init__(self, blank_idx: int, reduction: str = 'sum'):
super().__init__()
self.loss_func = TorchCTCLoss(
blank=blank_idx, reduction=reduction, zero_infinity=True)
def forward(self,
pred: Tensor,
label: Tensor,
label_length: Tensor) -> Dict[str, Tensor]:
pred = pred.permute(1, 0, 2)
batch_size = pred.size(1)
pred = pred.log_softmax(2)
preds_lengths = torch.tensor(
[pred.size(0)] * batch_size, dtype=torch.long)
loss = self.loss_func(pred, label, preds_lengths, label_length)
return dict(loss=loss)
Вы можете оставить комментарий после Вход в систему
Неприемлемый контент может быть отображен здесь и не будет показан на странице. Вы можете проверить и изменить его с помощью соответствующей функции редактирования.
Если вы подтверждаете, что содержание не содержит непристойной лексики/перенаправления на рекламу/насилия/вульгарной порнографии/нарушений/пиратства/ложного/незначительного или незаконного контента, связанного с национальными законами и предписаниями, вы можете нажать «Отправить» для подачи апелляции, и мы обработаем ее как можно скорее.
Опубликовать ( 0 )