TensorFlow函數(shù)教程:tf.nn.ctc_loss

2019-01-31 13:45 更新

tf.nn.ctc_loss函數(shù)

tf.nn.ctc_loss(
    labels,
    inputs,
    sequence_length,
    preprocess_collapse_repeated=False,
    ctc_merge_repeated=True,
    ignore_longer_outputs_than_inputs=False,
    time_major=True
)

定義在:tensorflow/python/ops/ctc_ops.py.

參見(jiàn)指南:神經(jīng)網(wǎng)絡(luò)>連接時(shí)間分類(lèi)(CTC)

計(jì)算CTC(連接時(shí)間分類(lèi))loss.

輸入要求:

sequence_length(b) <= time for all b

max(labels.indices(labels.indices[:, 1] == b, 2))
  <= sequence_length(b) for all b.

筆記:

此類(lèi)為您執(zhí)行softmax操作,因此輸入應(yīng)該是例如LSTM對(duì)輸出的線性預(yù)測(cè).

inputs張量的最內(nèi)層的維度大小,num_classes,代表num_labels + 1類(lèi)別,其中num_labels是實(shí)際的標(biāo)簽的數(shù)量,而最大的值(num_classes - 1)是為空白標(biāo)簽保留的.

例如,對(duì)于包含3個(gè)標(biāo)簽[a, b, c]的詞匯表,num_classes = 4,并且標(biāo)簽索引是{a: 0, b: 1, c: 2, blank: 3}.

關(guān)于參數(shù)preprocess_collapse_repeatedctc_merge_repeated

如果preprocess_collapse_repeated為T(mén)rue,則在loss計(jì)算之前運(yùn)行預(yù)處理步驟,其中傳遞給loss的重復(fù)標(biāo)簽會(huì)合并為單個(gè)標(biāo)簽.如果訓(xùn)練標(biāo)簽來(lái)自,例如強(qiáng)制對(duì)齊,并因此具有不必要的重復(fù),則這是有用的.

如果ctc_merge_repeated設(shè)置為False,則在CTC計(jì)算的深處,重復(fù)的非空白標(biāo)簽將不會(huì)合并,并被解釋為單個(gè)標(biāo)簽.這是CTC的簡(jiǎn)化(非標(biāo)準(zhǔn))版本.

以下是(大致)預(yù)期的第一順序行為表:

  • preprocess_collapse_repeated=Falsectc_merge_repeated=True

典型的CTC行為:輸出實(shí)際的重復(fù)類(lèi),其間有空白,還可以輸出中間沒(méi)有空白的重復(fù)類(lèi),這需要由解碼器折疊.

  • preprocess_collapse_repeated=Truectc_merge_repeated=False

不要得知輸出重復(fù)的類(lèi),因?yàn)樗鼈冊(cè)谟?xùn)練之前在輸入標(biāo)簽中折疊.

  • preprocess_collapse_repeated=Falsectc_merge_repeated=False

輸出中間有空白的重復(fù)類(lèi),但通常不需要解碼器折疊/合并重復(fù)的類(lèi).

  • preprocess_collapse_repeated=Truectc_merge_repeated=True

未經(jīng)測(cè)試,很可能不會(huì)得知輸出重復(fù)的類(lèi).

ignore_longer_outputs_than_inputs選項(xiàng)允許在處理輸出長(zhǎng)于輸入的序列時(shí)指定CTCLoss的行為.如果為true,則CTCLoss將僅為這些項(xiàng)返回零梯度,否則返回InvalidArgument錯(cuò)誤,停止訓(xùn)練.

參數(shù):

  • labels:一個(gè)int32SparseTensorlabels.indices[i, :] == [b, t]表示labels.values[i]存儲(chǔ)(batch b, time t)的id;labels.values[i]必須采用[0, num_labels)中的值.
  • inputs:3-D float Tensor;如果time_major == False,這將是一個(gè)Tensor,形狀:[batch_size, max_time, num_classes]如果time_major == True(默認(rèn)值),這將是一個(gè)Tensor,形狀:[max_time, batch_size, num_classes];是logits.
  • sequence_length:1-Dint32向量,大小為[batch_size];序列長(zhǎng)度.
  • preprocess_collapse_repeatedBoolean,默認(rèn)值:False;如果為T(mén)rue,則在CTC計(jì)算之前折疊重復(fù)的標(biāo)簽.
  • ctc_merge_repeatedBoolean,默認(rèn)值:True.
  • ignore_longer_outputs_than_inputs:Boolean,默認(rèn)值:False;如果為T(mén)rue,則輸出比輸入長(zhǎng)的序列將被忽略.
  • time_majorinputs張量的形狀格式;如果是True,那些Tensors必須具有形狀[max_time, batch_size, num_classes];如果為False,則Tensors必須具有形狀[batch_size, max_time, num_classes]使用time_major = True(默認(rèn))更有效,因?yàn)樗苊饬嗽赾tc_loss計(jì)算開(kāi)始時(shí)的轉(zhuǎn)置.但是,大多數(shù)TensorFlow數(shù)據(jù)都是批處理為主的,因此通過(guò)此函數(shù)還可以接受以批處理為主的形式的輸入.

返回:

1-DfloatTensor,大小為[batch]包含負(fù)對(duì)數(shù)概率.

可能引發(fā)的異常:

  • TypeError:如果標(biāo)簽不是SparseTensor.
以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)