W3Cschool
恭喜您成為首批注冊(cè)用戶
獲得88經(jīng)驗(yàn)值獎(jiǎng)勵(lì)
tf.nn.ctc_loss(
labels,
inputs,
sequence_length,
preprocess_collapse_repeated=False,
ctc_merge_repeated=True,
ignore_longer_outputs_than_inputs=False,
time_major=True
)
定義在:tensorflow/python/ops/ctc_ops.py.
參見(jiàn)指南:神經(jīng)網(wǎng)絡(luò)>連接時(shí)間分類(lèi)(CTC)
計(jì)算CTC(連接時(shí)間分類(lèi))loss.
輸入要求:
sequence_length(b) <= time for all b
max(labels.indices(labels.indices[:, 1] == b, 2))
<= sequence_length(b) for all b.
筆記:
此類(lèi)為您執(zhí)行softmax操作,因此輸入應(yīng)該是例如LSTM對(duì)輸出的線性預(yù)測(cè).
該inputs
張量的最內(nèi)層的維度大小,num_classes
,代表num_labels + 1
類(lèi)別,其中num_labels是實(shí)際的標(biāo)簽的數(shù)量,而最大的值(num_classes - 1)
是為空白標(biāo)簽保留的.
例如,對(duì)于包含3個(gè)標(biāo)簽[a, b, c]
的詞匯表,num_classes = 4
,并且標(biāo)簽索引是{a: 0, b: 1, c: 2, blank: 3}
.
關(guān)于參數(shù)preprocess_collapse_repeated
和ctc_merge_repeated
:
如果preprocess_collapse_repeated
為T(mén)rue,則在loss計(jì)算之前運(yùn)行預(yù)處理步驟,其中傳遞給loss的重復(fù)標(biāo)簽會(huì)合并為單個(gè)標(biāo)簽.如果訓(xùn)練標(biāo)簽來(lái)自,例如強(qiáng)制對(duì)齊,并因此具有不必要的重復(fù),則這是有用的.
如果ctc_merge_repeated
設(shè)置為False,則在CTC計(jì)算的深處,重復(fù)的非空白標(biāo)簽將不會(huì)合并,并被解釋為單個(gè)標(biāo)簽.這是CTC的簡(jiǎn)化(非標(biāo)準(zhǔn))版本.
以下是(大致)預(yù)期的第一順序行為表:
preprocess_collapse_repeated=False
, ctc_merge_repeated=True
典型的CTC行為:輸出實(shí)際的重復(fù)類(lèi),其間有空白,還可以輸出中間沒(méi)有空白的重復(fù)類(lèi),這需要由解碼器折疊.
preprocess_collapse_repeated=True
, ctc_merge_repeated=False
不要得知輸出重復(fù)的類(lèi),因?yàn)樗鼈冊(cè)谟?xùn)練之前在輸入標(biāo)簽中折疊.
preprocess_collapse_repeated=False
, ctc_merge_repeated=False
輸出中間有空白的重復(fù)類(lèi),但通常不需要解碼器折疊/合并重復(fù)的類(lèi).
preprocess_collapse_repeated=True
, ctc_merge_repeated=True
未經(jīng)測(cè)試,很可能不會(huì)得知輸出重復(fù)的類(lèi).
該ignore_longer_outputs_than_inputs
選項(xiàng)允許在處理輸出長(zhǎng)于輸入的序列時(shí)指定CTCLoss的行為.如果為true,則CTCLoss將僅為這些項(xiàng)返回零梯度,否則返回InvalidArgument錯(cuò)誤,停止訓(xùn)練.
參數(shù):
labels
:一個(gè)int32
SparseTensor
;labels.indices[i, :] == [b, t]
表示labels.values[i]
存儲(chǔ)(batch b, time t)的id;labels.values[i]
必須采用[0, num_labels)
中的值.inputs
:3-D float
Tensor
;如果time_major == False,這將是一個(gè)Tensor
,形狀:[batch_size, max_time, num_classes]
;如果time_major == True(默認(rèn)值),這將是一個(gè)Tensor
,形狀:[max_time, batch_size, num_classes]
;是logits.sequence_length
:1-Dint32
向量,大小為[batch_size]
;序列長(zhǎng)度.preprocess_collapse_repeated
:Boolean,默認(rèn)值:False;如果為T(mén)rue,則在CTC計(jì)算之前折疊重復(fù)的標(biāo)簽.ctc_merge_repeated
:Boolean,默認(rèn)值:True.ignore_longer_outputs_than_inputs
:Boolean,默認(rèn)值:False;如果為T(mén)rue,則輸出比輸入長(zhǎng)的序列將被忽略.time_major
:inputs
張量的形狀格式;如果是True,那些Tensors
必須具有形狀[max_time, batch_size, num_classes]
;如果為False,則Tensors
必須具有形狀[batch_size, max_time, num_classes]
;使用time_major = True
(默認(rèn))更有效,因?yàn)樗苊饬嗽赾tc_loss計(jì)算開(kāi)始時(shí)的轉(zhuǎn)置.但是,大多數(shù)TensorFlow數(shù)據(jù)都是批處理為主的,因此通過(guò)此函數(shù)還可以接受以批處理為主的形式的輸入.返回:
1-Dfloat
Tensor
,大小為[batch]
包含負(fù)對(duì)數(shù)概率.
可能引發(fā)的異常:
TypeError
:如果標(biāo)簽不是SparseTensor
.Copyright©2021 w3cschool編程獅|閩ICP備15016281號(hào)-3|閩公網(wǎng)安備35020302033924號(hào)
違法和不良信息舉報(bào)電話:173-0602-2364|舉報(bào)郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號(hào)
聯(lián)系方式:
更多建議: