TensorFlow的estimator類函數(shù):tf.estimator.Estimator

2018-04-27 09:55 更新

tf.estimator.Estimator函數(shù)

Estimator類

定義在:tensorflow/python/estimator/estimator.py

estimator類對(duì)TensorFlow模型進(jìn)行訓(xùn)練和計(jì)算.

Estimator對(duì)象包裝由model_fn指定的模型,其中,給定輸入和其他一些參數(shù),返回需要進(jìn)行訓(xùn)練、計(jì)算,或預(yù)測(cè)的操作.

所有輸出(檢查點(diǎn),事件文件等)都被寫入model_dir或其子目錄.如果model_dir未設(shè)置,則使用臨時(shí)目錄.

可以通過RunConfig對(duì)象(包含了有關(guān)執(zhí)行環(huán)境的信息)傳遞config參數(shù).它被傳遞給model_fn,如果model_fn有一個(gè)名為“config”的參數(shù)(和輸入函數(shù)以相同的方式).如果該config參數(shù)未被傳遞,則由Estimator進(jìn)行實(shí)例化.不傳遞配置意味著使用對(duì)本地執(zhí)行有用的默認(rèn)值.Estimator使配置對(duì)模型可用(例如,允許根據(jù)可用的工作人員數(shù)量進(jìn)行專業(yè)化),并且還使用其一些字段來控制內(nèi)部,特別是關(guān)于檢查點(diǎn).

該params參數(shù)包含hyperparameter,如果model_fn有一個(gè)名為“PARAMS”的參數(shù),并且以相同的方式傳遞給輸入函數(shù),則將它傳遞給 model_fn.Estimator只是沿著參數(shù)傳遞,并不檢查它.因此,params的結(jié)構(gòu)完全取決于開發(fā)人員.

不能在子類中重寫任何Estimator方法(其構(gòu)造函數(shù)強(qiáng)制執(zhí)行此操作).子類應(yīng)使用model_fn來配置基類,并且可以添加實(shí)現(xiàn)專門功能的方法.

Eager兼容性

estimator與eager執(zhí)行不兼容.

屬性

  • config
  • model_dir
  • model_fn
    返回綁定到self.params的model_fn.
    返回:返回具有以下簽名的model_fn: def model_fn(features, labels, mode, config)
  • params

方法

__init__

__init__(
    model_fn,
    model_dir=None,
    config=None,
    params=None,
    warm_start_from=None
)

構(gòu)造一個(gè)Estimator實(shí)例.

請(qǐng)參閱Estimator了解更多信息.啟動(dòng)一個(gè)Estimator的方法如下所示:

estimator = tf.estimator.DNNClassifier(
    feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
    hidden_units=[1024, 512, 256],
    warm_start_from="/path/to/checkpoint/dir")

有關(guān)warm-start啟動(dòng)配置的更多詳細(xì)信息,請(qǐng)參閱WarmStartSettings.

參數(shù):

  • model_fn:模型函數(shù),具有以下簽名:
      ARGS.
    • features:這是從input_fn傳遞給train、evaluate和predict返回的第一個(gè)項(xiàng)目.這應(yīng)該是一個(gè)相同的單一的Tensor或dict.
    • labels:這是從input_fn傳遞給train、evaluate和predict返回的第二個(gè)項(xiàng)目.這應(yīng)該是相同的單個(gè)Tensor或dict(對(duì)于multi-head模型).如果模式是ModeKeys.PREDICT,則將傳遞labels=None.如果model_fn簽名不接受mode,model_fn必須仍然能夠處理labels=None.
    • mode:可選的.指定train、evaluate和predict.參考ModeKeys.
    • params:hyperparameters的可選字典.將在params參數(shù)中接收傳遞給Estimator的內(nèi)容.這允許從hyperparameters調(diào)整來配置Estimator.
    • config:可選配置對(duì)象.將收到傳遞給Estimator的config參數(shù)或默認(rèn)值config.允許根據(jù)配置(如num_ps_replicas或model_dir)更新您的model_fn中的內(nèi)容.
    • 返回:EstimatorSpec
  • model_dir:保存模型參數(shù)、圖形等的目錄.這也可用于將目錄中的檢查點(diǎn)加載到Estimator中,以繼續(xù)訓(xùn)練以前保存的模型.如果為PathLike對(duì)象,則路徑將被解析.如果為None,則將使用config中的model_dir(如果設(shè)置的話).如果兩者都設(shè)置,則它們必須相同.如果兩者都是None,則會(huì)使用臨時(shí)目錄.
  • config:配置對(duì)象.
  • params:dict將傳遞到model_fn中的hyperparameters.key是參數(shù)的名稱,value是基本的Python類型.
  • warm_start_from:可選的字符串文件路徑,用于從warm-start的檢查點(diǎn);或tf.estimator.WarmStartSettings對(duì)象,用于完全配置warm-start.如果提供字符串文件路徑而不是WarmStartSettings,則所有變量都是warm-start的,并且假定詞匯表和張量名稱未更改.

可能引發(fā)的異常:

  • RuntimeError:如果eager執(zhí)行已啟用.
  • ValueError:參數(shù)model_fn不匹配params.
  • ValueError:如果這是通過子類調(diào)用的,并且該類重寫了Estimator的一個(gè)成員.

evaluate

evaluate(
    input_fn,
    steps=None,
    hooks=None,
    checkpoint_path=None,
    name=None
)

計(jì)算給定計(jì)算數(shù)據(jù)input_fn的模型.

對(duì)于每個(gè)步驟來說,調(diào)用input_fn返回一批數(shù)據(jù).計(jì)算直到: -steps批處理被處理,或-input_fn引發(fā)輸入結(jié)束異常(OutOfRangeError或StopIteration).

參數(shù):

  • input_fn:構(gòu)造用于計(jì)算的輸入數(shù)據(jù)的函數(shù).有關(guān)更多信息,請(qǐng)參閱TensorFlow入門.該函數(shù)應(yīng)該構(gòu)造并返回下列選項(xiàng)之一:
    • tf.data.Dataset對(duì)象:Dataset對(duì)象的輸出必須是一個(gè)具有相同約束的元組(特征(features),標(biāo)簽(labels)),其約束條件與下面相同.
    • tuple (features, labels):其中features是Tensor或者名為Tensor的字符串特征的字典,而labels是Tensor或者名為Tensor的字符串標(biāo)簽的字典.這兩個(gè)特征和標(biāo)簽都由model_fn消耗.他們應(yīng)該滿足model_fn對(duì)輸入的期望.
  • steps:計(jì)算模型所需的步驟數(shù).如果為None,則計(jì)算直到input_fn引發(fā)輸入異常時(shí)結(jié)束.
  • hooks:SessionRunHook子類實(shí)例列表.用于計(jì)算調(diào)用中的回調(diào).
  • checkpoint_path:計(jì)算特定檢查點(diǎn)的路徑.如果為None,則使用model_dir中的最新檢查點(diǎn).
  • name:需要使用的計(jì)算的名稱,如果用戶需要在不同的數(shù)據(jù)集上運(yùn)行多個(gè)計(jì)算(如培訓(xùn)數(shù)據(jù)和測(cè)試數(shù)據(jù)).不同計(jì)算的度量標(biāo)準(zhǔn)保存在單獨(dú)的文件夾中,并單獨(dú)出現(xiàn)在tensorboard中.

返回值:

返回一個(gè)包含按name為鍵的model_fn中指定的計(jì)算指標(biāo)的詞典,以及包含執(zhí)行此技術(shù)的全局步驟的值的條目global_step.

可能引發(fā)的異常:

  • ValueError:如果steps <= 0.
  • ValueError:如果沒有模型被訓(xùn)練,名為model_dir,或者給定checkpoint_path是空的.

export_savedmodel

export_savedmodel(
    export_dir_base,
    serving_input_receiver_fn,
    assets_extra=None,
    as_text=False,
    checkpoint_path=None,
    strip_default_attrs=False
)

將推理圖作為SavedModel導(dǎo)出到給定的目錄中.

該方法通過首先調(diào)用serving_input_receiver_fn來獲取特征Tensors來構(gòu)建一個(gè)新圖,然后調(diào)用這個(gè)Estimator的model_fn來基于這些特征生成模型圖.它在新的會(huì)話中將給定的檢查點(diǎn)恢復(fù)到該圖中.最后它會(huì)在給定的export_dir_base下面創(chuàng)建一個(gè)時(shí)間戳導(dǎo)出目錄,并在其中寫入一個(gè)SavedModel,其中包含從此會(huì)話保存的單個(gè)MetaGraphDef.

導(dǎo)出的MetaGraphDef將為從model_fn返回的export_outputs字典的每個(gè)元素提供一個(gè)SignatureDef,該字典使用相同的key命名.其中一個(gè)key始終為signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,指示在服務(wù)請(qǐng)求未指定簽名時(shí)將提供哪個(gè)簽名.對(duì)于每個(gè)簽名,輸出由相應(yīng)的ExportOutputs提供,并且輸入始終是由serving_input_receiver_fn提供的輸入接收器.

額外的資產(chǎn)可以通過assets_extra參數(shù)寫入SavedModel.這應(yīng)該是一個(gè)字典,其中每個(gè)key給出與assets.extra目錄相關(guān)的目標(biāo)路徑(包括文件名).相應(yīng)的值給出了要復(fù)制的源文件的完整路徑.例如,在不重命名的情況下復(fù)制單個(gè)文件的簡單情況被指定為{'my_asset_file.txt': '/path/to/my_asset_file.txt'}.

參數(shù):

  • export_dir_base:包含一個(gè)目錄的字符串,在該目錄中創(chuàng)建包含導(dǎo)出的SavedModels的時(shí)間戳子目錄.
  • serving_input_receiver_fn:一個(gè)不帶參數(shù)并返回一個(gè)ServingInputReceiver的函數(shù).
  • assets_extra:指定如何在導(dǎo)出的SavedModel中填充assets.extra目錄的字典,如果不需要額外的資產(chǎn),則為 None.
  • as_text:是否以文本格式編寫SavedModel原型.
  • checkpoint_path:要導(dǎo)出的檢查點(diǎn)路徑.如果None(默認(rèn)),則選擇在模型目錄中找到的最近檢查點(diǎn).
  • strip_default_attrs:布爾值.如果True,則將從NodeDefs中刪除默認(rèn)值屬性.

返回值:

導(dǎo)出目錄的字符串路徑.

可能引發(fā)的異常:

  • ValueError:如果未提供serving_input_receiver_fn,則不提供export_outputs,或者找不到檢查點(diǎn).

get_variable_names

get_variable_names()

返回此模型中所有變量名稱的列表.

返回值:

返回名字列表.

可能引發(fā)的異常:

  • ValueError:如果Estimator尚未產(chǎn)生檢查點(diǎn).

get_variable_value

get_variable_value(name)

返回由名稱給出的變量的值.

參數(shù):

  • name:字符串或字符串列表,張量的名稱.

返回值:

Numpy數(shù)組 - 張量的值.

可能引發(fā)的異常:

  • ValueError:如果Estimator尚未產(chǎn)生檢查點(diǎn).

latest_checkpoint

latest_checkpoint()

查找model_dir中最新保存的檢查點(diǎn)文件的文件名.

返回值:

返回最新檢查點(diǎn)的完整路徑或None(未找到檢查點(diǎn)).

predict

predict(
    input_fn,
    predict_keys=None,
    hooks=None,
    checkpoint_path=None
)

對(duì)給定的features產(chǎn)生預(yù)測(cè).

參數(shù):

  • input_fn:構(gòu)造特征的函數(shù).預(yù)測(cè)繼續(xù),直到input_fn引發(fā)輸入端異常(OutOfRangeError或StopIteration).有關(guān)更多信息,請(qǐng)參閱TensorFlow入門.該函數(shù)應(yīng)該構(gòu)造并返回下列之一:
    • tf.data.Dataset對(duì)象:Dataset對(duì)象的輸出必須具有與下面相同的約束.
    • features:一個(gè)Tensor或者名為Tensor的字符串特征的字典.feature被model_fn消耗.他們應(yīng)該滿足model_fn對(duì)輸入的期望.
    • 一個(gè)元組,在這種情況下,第一個(gè)項(xiàng)被提取為feature.
  • predict_keys:str列表,要預(yù)測(cè)的鍵名稱.如果EstimatorSpec.predictions是字典,則使用該方法.如果使用predict_keys,則剩余的預(yù)測(cè)將從字典中過濾.如果None,則返回全部.
  • hooks:SessionRunHook子類實(shí)例列表.用于預(yù)測(cè)調(diào)用中的回調(diào).
  • checkpoint_path:要預(yù)測(cè)的特定檢查點(diǎn)的路徑.如果為None,則使用model_dir中的最新的檢查點(diǎn).

返回值:

predictions張量的計(jì)算值.

可能引發(fā)的異常:

  • ValueError:在model_dir中找不到訓(xùn)練有素的模型.
  • ValueError:如果批次的預(yù)測(cè)長度不相同.
  • ValueError:如果predict_keys和predictions之間有沖突.例如,如果predict_keys不是None,但EstimatorSpec.predictions不是一個(gè)dict.

train

train(
    input_fn,
    hooks=None,
    steps=None,
    max_steps=None,
    saving_listeners=None
)

訓(xùn)練給定訓(xùn)練數(shù)據(jù)input_fn的模型.

參數(shù):

  • input_fn:提供作為minibatches培訓(xùn)的輸入數(shù)據(jù)的函數(shù).有關(guān)更多信息,請(qǐng)參閱TensorFlow入門.該函數(shù)應(yīng)該構(gòu)造并返回下列之一:
    • tf.data.Dataset對(duì)象:Dataset對(duì)象的輸出必須是一個(gè)具有相同約束的元組(特征,標(biāo)簽)((features, labels)),其約束條件與下面相同.
    • tuple (features, labels):其中features是一個(gè)Tensor或者名為Tensor的字符串特征的字典,labels是一個(gè)Tensor或者名為Tensor的字符串標(biāo)簽的字典.這兩個(gè)特征和標(biāo)簽都由model_fn消耗.他們應(yīng)該滿足model_fn對(duì)輸入的期望.
  • hooks:SessionRunHook子類實(shí)例列表.用于訓(xùn)練循環(huán)內(nèi)的回調(diào).
  • steps:訓(xùn)練模型的步驟數(shù).如果為None,則永遠(yuǎn)訓(xùn)練或訓(xùn)練直到input_fn產(chǎn)生OutOfRange錯(cuò)誤或StopIteration異常.“steps”逐步運(yùn)作.如果您調(diào)用兩次train(steps=10),則訓(xùn)練總共發(fā)生20個(gè)步驟.如果OutOfRange或StopIteration發(fā)生在中間,訓(xùn)練在20步之前停止.如果你不想有增量行為,請(qǐng)改為設(shè)置.如果設(shè)置max_steps,max_steps必須None.
  • max_steps:訓(xùn)練模型的總步驟數(shù).如果為None,則永遠(yuǎn)訓(xùn)練或訓(xùn)練直到input_fn產(chǎn)生OutOfRange錯(cuò)誤或StopIteration異常.如果設(shè)置,steps必須None.如果OutOfRange或StopIteration發(fā)生在中間,訓(xùn)練在max_steps步驟之前停止.兩次調(diào)用train(steps=100)意味著200次訓(xùn)練迭代.另一方面,兩次調(diào)用train(max_steps=100)意味著第二次調(diào)用將不會(huì)做任何迭代,因?yàn)榈谝淮握{(diào)用完成了所有100個(gè)步驟.
  • saving_listeners:CheckpointSaverListener對(duì)象列表.用于在檢查點(diǎn)節(jié)省之前或之后立即執(zhí)行的回調(diào).

可能引發(fā)的異常:

  • ValueError:如果steps和max_steps都不是None.
  • ValueError:如果steps或max_steps其中之一小于等于0.
以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)