tf.estimator.Estimator函數(shù)
Estimator類
定義在:tensorflow/python/estimator/estimator.py
estimator類對(duì)TensorFlow模型進(jìn)行訓(xùn)練和計(jì)算.
Estimator對(duì)象包裝由model_fn指定的模型,其中,給定輸入和其他一些參數(shù),返回需要進(jìn)行訓(xùn)練、計(jì)算,或預(yù)測(cè)的操作.
所有輸出(檢查點(diǎn),事件文件等)都被寫入model_dir或其子目錄.如果model_dir未設(shè)置,則使用臨時(shí)目錄.
可以通過RunConfig對(duì)象(包含了有關(guān)執(zhí)行環(huán)境的信息)傳遞config參數(shù).它被傳遞給model_fn,如果model_fn有一個(gè)名為“config”的參數(shù)(和輸入函數(shù)以相同的方式).如果該config參數(shù)未被傳遞,則由Estimator進(jìn)行實(shí)例化.不傳遞配置意味著使用對(duì)本地執(zhí)行有用的默認(rèn)值.Estimator使配置對(duì)模型可用(例如,允許根據(jù)可用的工作人員數(shù)量進(jìn)行專業(yè)化),并且還使用其一些字段來控制內(nèi)部,特別是關(guān)于檢查點(diǎn).
該params參數(shù)包含hyperparameter,如果model_fn有一個(gè)名為“PARAMS”的參數(shù),并且以相同的方式傳遞給輸入函數(shù),則將它傳遞給 model_fn.Estimator只是沿著參數(shù)傳遞,并不檢查它.因此,params的結(jié)構(gòu)完全取決于開發(fā)人員.
不能在子類中重寫任何Estimator方法(其構(gòu)造函數(shù)強(qiáng)制執(zhí)行此操作).子類應(yīng)使用model_fn來配置基類,并且可以添加實(shí)現(xiàn)專門功能的方法.
Eager兼容性
estimator與eager執(zhí)行不兼容.
屬性
- config
- model_dir
- model_fn
返回綁定到self.params的model_fn.
返回:返回具有以下簽名的model_fn: def model_fn(features, labels, mode, config) - params
方法
__init__
__init__(
model_fn,
model_dir=None,
config=None,
params=None,
warm_start_from=None
)
構(gòu)造一個(gè)Estimator實(shí)例.
請(qǐng)參閱Estimator了解更多信息.啟動(dòng)一個(gè)Estimator的方法如下所示:
estimator = tf.estimator.DNNClassifier(
feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
hidden_units=[1024, 512, 256],
warm_start_from="/path/to/checkpoint/dir")
有關(guān)warm-start啟動(dòng)配置的更多詳細(xì)信息,請(qǐng)參閱WarmStartSettings.
參數(shù):
- model_fn:模型函數(shù),具有以下簽名:
ARGS.- features:這是從input_fn傳遞給train、evaluate和predict返回的第一個(gè)項(xiàng)目.這應(yīng)該是一個(gè)相同的單一的Tensor或dict.
- labels:這是從input_fn傳遞給train、evaluate和predict返回的第二個(gè)項(xiàng)目.這應(yīng)該是相同的單個(gè)Tensor或dict(對(duì)于multi-head模型).如果模式是ModeKeys.PREDICT,則將傳遞labels=None.如果model_fn簽名不接受mode,model_fn必須仍然能夠處理labels=None.
- mode:可選的.指定train、evaluate和predict.參考ModeKeys.
- params:hyperparameters的可選字典.將在params參數(shù)中接收傳遞給Estimator的內(nèi)容.這允許從hyperparameters調(diào)整來配置Estimator.
- config:可選配置對(duì)象.將收到傳遞給Estimator的config參數(shù)或默認(rèn)值config.允許根據(jù)配置(如num_ps_replicas或model_dir)更新您的model_fn中的內(nèi)容.
返回:EstimatorSpec
- model_dir:保存模型參數(shù)、圖形等的目錄.這也可用于將目錄中的檢查點(diǎn)加載到Estimator中,以繼續(xù)訓(xùn)練以前保存的模型.如果為PathLike對(duì)象,則路徑將被解析.如果為None,則將使用config中的model_dir(如果設(shè)置的話).如果兩者都設(shè)置,則它們必須相同.如果兩者都是None,則會(huì)使用臨時(shí)目錄.
- config:配置對(duì)象.
- params:dict將傳遞到model_fn中的hyperparameters.key是參數(shù)的名稱,value是基本的Python類型.
- warm_start_from:可選的字符串文件路徑,用于從warm-start的檢查點(diǎn);或tf.estimator.WarmStartSettings對(duì)象,用于完全配置warm-start.如果提供字符串文件路徑而不是WarmStartSettings,則所有變量都是warm-start的,并且假定詞匯表和張量名稱未更改.
可能引發(fā)的異常:
- RuntimeError:如果eager執(zhí)行已啟用.
- ValueError:參數(shù)model_fn不匹配params.
- ValueError:如果這是通過子類調(diào)用的,并且該類重寫了Estimator的一個(gè)成員.
evaluate
evaluate(
input_fn,
steps=None,
hooks=None,
checkpoint_path=None,
name=None
)
計(jì)算給定計(jì)算數(shù)據(jù)input_fn的模型.
對(duì)于每個(gè)步驟來說,調(diào)用input_fn返回一批數(shù)據(jù).計(jì)算直到: -steps批處理被處理,或-input_fn引發(fā)輸入結(jié)束異常(OutOfRangeError或StopIteration).
參數(shù):
- input_fn:構(gòu)造用于計(jì)算的輸入數(shù)據(jù)的函數(shù).有關(guān)更多信息,請(qǐng)參閱TensorFlow入門.該函數(shù)應(yīng)該構(gòu)造并返回下列選項(xiàng)之一:
- tf.data.Dataset對(duì)象:Dataset對(duì)象的輸出必須是一個(gè)具有相同約束的元組(特征(features),標(biāo)簽(labels)),其約束條件與下面相同.
- tuple
(features, labels):其中features是Tensor或者名為Tensor的字符串特征的字典,而labels是Tensor或者名為Tensor的字符串標(biāo)簽的字典.這兩個(gè)特征和標(biāo)簽都由model_fn消耗.他們應(yīng)該滿足model_fn對(duì)輸入的期望.
- steps:計(jì)算模型所需的步驟數(shù).如果為None,則計(jì)算直到input_fn引發(fā)輸入異常時(shí)結(jié)束.
- hooks:SessionRunHook子類實(shí)例列表.用于計(jì)算調(diào)用中的回調(diào).
- checkpoint_path:計(jì)算特定檢查點(diǎn)的路徑.如果為None,則使用model_dir中的最新檢查點(diǎn).
- name:需要使用的計(jì)算的名稱,如果用戶需要在不同的數(shù)據(jù)集上運(yùn)行多個(gè)計(jì)算(如培訓(xùn)數(shù)據(jù)和測(cè)試數(shù)據(jù)).不同計(jì)算的度量標(biāo)準(zhǔn)保存在單獨(dú)的文件夾中,并單獨(dú)出現(xiàn)在tensorboard中.
返回值:
返回一個(gè)包含按name為鍵的model_fn中指定的計(jì)算指標(biāo)的詞典,以及包含執(zhí)行此技術(shù)的全局步驟的值的條目global_step.
可能引發(fā)的異常:
- ValueError:如果steps <= 0.
- ValueError:如果沒有模型被訓(xùn)練,名為model_dir,或者給定checkpoint_path是空的.
export_savedmodel
export_savedmodel(
export_dir_base,
serving_input_receiver_fn,
assets_extra=None,
as_text=False,
checkpoint_path=None,
strip_default_attrs=False
)
將推理圖作為SavedModel導(dǎo)出到給定的目錄中.
該方法通過首先調(diào)用serving_input_receiver_fn來獲取特征Tensors來構(gòu)建一個(gè)新圖,然后調(diào)用這個(gè)Estimator的model_fn來基于這些特征生成模型圖.它在新的會(huì)話中將給定的檢查點(diǎn)恢復(fù)到該圖中.最后它會(huì)在給定的export_dir_base下面創(chuàng)建一個(gè)時(shí)間戳導(dǎo)出目錄,并在其中寫入一個(gè)SavedModel,其中包含從此會(huì)話保存的單個(gè)MetaGraphDef.
導(dǎo)出的MetaGraphDef將為從model_fn返回的export_outputs字典的每個(gè)元素提供一個(gè)SignatureDef,該字典使用相同的key命名.其中一個(gè)key始終為signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,指示在服務(wù)請(qǐng)求未指定簽名時(shí)將提供哪個(gè)簽名.對(duì)于每個(gè)簽名,輸出由相應(yīng)的ExportOutputs提供,并且輸入始終是由serving_input_receiver_fn提供的輸入接收器.
額外的資產(chǎn)可以通過assets_extra參數(shù)寫入SavedModel.這應(yīng)該是一個(gè)字典,其中每個(gè)key給出與assets.extra目錄相關(guān)的目標(biāo)路徑(包括文件名).相應(yīng)的值給出了要復(fù)制的源文件的完整路徑.例如,在不重命名的情況下復(fù)制單個(gè)文件的簡單情況被指定為{'my_asset_file.txt': '/path/to/my_asset_file.txt'}.
參數(shù):
- export_dir_base:包含一個(gè)目錄的字符串,在該目錄中創(chuàng)建包含導(dǎo)出的SavedModels的時(shí)間戳子目錄.
- serving_input_receiver_fn:一個(gè)不帶參數(shù)并返回一個(gè)ServingInputReceiver的函數(shù).
- assets_extra:指定如何在導(dǎo)出的SavedModel中填充assets.extra目錄的字典,如果不需要額外的資產(chǎn),則為 None.
- as_text:是否以文本格式編寫SavedModel原型.
- checkpoint_path:要導(dǎo)出的檢查點(diǎn)路徑.如果None(默認(rèn)),則選擇在模型目錄中找到的最近檢查點(diǎn).
- strip_default_attrs:布爾值.如果True,則將從NodeDefs中刪除默認(rèn)值屬性.
返回值:
導(dǎo)出目錄的字符串路徑.
可能引發(fā)的異常:
- ValueError:如果未提供serving_input_receiver_fn,則不提供export_outputs,或者找不到檢查點(diǎn).
get_variable_names
get_variable_names()
返回此模型中所有變量名稱的列表.
返回值:
返回名字列表.
可能引發(fā)的異常:
- ValueError:如果Estimator尚未產(chǎn)生檢查點(diǎn).
get_variable_value
get_variable_value(name)
返回由名稱給出的變量的值.
參數(shù):
返回值:
Numpy數(shù)組 - 張量的值.
可能引發(fā)的異常:
- ValueError:如果Estimator尚未產(chǎn)生檢查點(diǎn).
latest_checkpoint
latest_checkpoint()
查找model_dir中最新保存的檢查點(diǎn)文件的文件名.
返回值:
返回最新檢查點(diǎn)的完整路徑或None(未找到檢查點(diǎn)).
predict
predict(
input_fn,
predict_keys=None,
hooks=None,
checkpoint_path=None
)
對(duì)給定的features產(chǎn)生預(yù)測(cè).
參數(shù):
- input_fn:構(gòu)造特征的函數(shù).預(yù)測(cè)繼續(xù),直到input_fn引發(fā)輸入端異常(OutOfRangeError或StopIteration).有關(guān)更多信息,請(qǐng)參閱TensorFlow入門.該函數(shù)應(yīng)該構(gòu)造并返回下列之一:
- tf.data.Dataset對(duì)象:Dataset對(duì)象的輸出必須具有與下面相同的約束.
- features:一個(gè)Tensor或者名為Tensor的字符串特征的字典.feature被model_fn消耗.他們應(yīng)該滿足model_fn對(duì)輸入的期望.
- 一個(gè)元組,在這種情況下,第一個(gè)項(xiàng)被提取為feature.
- predict_keys:str列表,要預(yù)測(cè)的鍵名稱.如果EstimatorSpec.predictions是字典,則使用該方法.如果使用predict_keys,則剩余的預(yù)測(cè)將從字典中過濾.如果None,則返回全部.
- hooks:SessionRunHook子類實(shí)例列表.用于預(yù)測(cè)調(diào)用中的回調(diào).
- checkpoint_path:要預(yù)測(cè)的特定檢查點(diǎn)的路徑.如果為None,則使用model_dir中的最新的檢查點(diǎn).
返回值:
predictions張量的計(jì)算值.
可能引發(fā)的異常:
- ValueError:在model_dir中找不到訓(xùn)練有素的模型.
- ValueError:如果批次的預(yù)測(cè)長度不相同.
- ValueError:如果predict_keys和predictions之間有沖突.例如,如果predict_keys不是None,但EstimatorSpec.predictions不是一個(gè)dict.
train
train(
input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None
)
訓(xùn)練給定訓(xùn)練數(shù)據(jù)input_fn的模型.
參數(shù):
- input_fn:提供作為minibatches培訓(xùn)的輸入數(shù)據(jù)的函數(shù).有關(guān)更多信息,請(qǐng)參閱TensorFlow入門.該函數(shù)應(yīng)該構(gòu)造并返回下列之一:
- tf.data.Dataset對(duì)象:Dataset對(duì)象的輸出必須是一個(gè)具有相同約束的元組(特征,標(biāo)簽)((features, labels)),其約束條件與下面相同.
- tuple
(features, labels):其中features是一個(gè)Tensor或者名為Tensor的字符串特征的字典,labels是一個(gè)Tensor或者名為Tensor的字符串標(biāo)簽的字典.這兩個(gè)特征和標(biāo)簽都由model_fn消耗.他們應(yīng)該滿足model_fn對(duì)輸入的期望.
- hooks:SessionRunHook子類實(shí)例列表.用于訓(xùn)練循環(huán)內(nèi)的回調(diào).
- steps:訓(xùn)練模型的步驟數(shù).如果為None,則永遠(yuǎn)訓(xùn)練或訓(xùn)練直到input_fn產(chǎn)生OutOfRange錯(cuò)誤或StopIteration異常.“steps”逐步運(yùn)作.如果您調(diào)用兩次train(steps=10),則訓(xùn)練總共發(fā)生20個(gè)步驟.如果OutOfRange或StopIteration發(fā)生在中間,訓(xùn)練在20步之前停止.如果你不想有增量行為,請(qǐng)改為設(shè)置.如果設(shè)置max_steps,max_steps必須None.
- max_steps:訓(xùn)練模型的總步驟數(shù).如果為None,則永遠(yuǎn)訓(xùn)練或訓(xùn)練直到input_fn產(chǎn)生OutOfRange錯(cuò)誤或StopIteration異常.如果設(shè)置,steps必須None.如果OutOfRange或StopIteration發(fā)生在中間,訓(xùn)練在max_steps步驟之前停止.兩次調(diào)用train(steps=100)意味著200次訓(xùn)練迭代.另一方面,兩次調(diào)用train(max_steps=100)意味著第二次調(diào)用將不會(huì)做任何迭代,因?yàn)榈谝淮握{(diào)用完成了所有100個(gè)步驟.
- saving_listeners:CheckpointSaverListener對(duì)象列表.用于在檢查點(diǎn)節(jié)省之前或之后立即執(zhí)行的回調(diào).
可能引發(fā)的異常:
- ValueError:如果steps和max_steps都不是None.
- ValueError:如果steps或max_steps其中之一小于等于0.
更多建議: