TensorFlow函數(shù):tf.nn.compute_accidental_hits

2019-01-31 11:30 更新

tf.nn.compute_accidental_hits函數(shù)

tf.nn.compute_accidental_hits(
    true_classes,
    sampled_candidates,
    num_true,
    seed=None,
    name=None
)

定義在:tensorflow/python/ops/candidate_sampling_ops.py.

請參閱指南:神經(jīng)網(wǎng)絡(luò)>候選抽樣

計算與true_classes匹配的sampled_candidate中的位置id.

在Candidate Sampling中,此操作實際上有助于刪除恰好與目標類匹配的抽樣類.這在Sampled Softmax和Sampled Logistic中完成.

我們預(yù)先假定sampled_candidates是獨一無二的.

當(dāng)其中一個目標類與其中一個抽樣類匹配時,我們將其稱為“意外命中”.此操作將意外命中報告為三元組(index, id, weight),其中index表示true_classes中的行號,id表示sampled_candidates中的位置,權(quán)重為-FLOAT_MAX.

此op的結(jié)果應(yīng)該通過一個sparse_to_dense操作來傳遞,然后添加到抽樣類的logits中.這消除了意外采樣真實目標類作為同一示例的噪聲類的矛盾效果.

參數(shù):

  • true_classes:一個Tensor,器類型為int64,并且形狀為[batch_size, num_true];是目標類.
  • sampled_candidates:一個Tensor,類型為int64,并且形狀為[num_sampled];CandidateSampler的sampled_candidates輸出.
  • num_true:int,每個訓(xùn)練示例的目標類數(shù).
  • seed:int,特定于操作的seed;默認值為0.
  • name:操作的名稱(可選).

返回:

  • indices:一個Tensor,其類型為int32,并且形狀為[num_accidental_hits];值表示true_classes中的行.
  • ids:一個Tensor,類型為int64,并且形狀為[num_accidental_hits];值表示sampled_candidates中的位置.
  • weights:一個Tensor,其類型為float,并且形狀為[num_accidental_hits];每個值都是-FLOAT_MAX.
以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號