[软件设计/软件工程] TFX Trainer 组件未将模型导出到文件系统的问题

[复制链接]
发表于 2022-5-4 13:52:42
问题
首先,我使用的是 TFX 0.21.2 版和 Tensorflow 2.1 版。

我建立了一个管道,主要以芝加哥出租车为例。执行 Trainer 组件时,我可以在日志中看到以下内容:

信息培训完成。模型写入 /root/airflow/tfx/pipelines/fish/Trainer/Model/9/serving_model_dir

检查上面的目录时,它是空的。我错过了什么?

这是我的 DAG 定义文件(忽略 import 语句):
  1. _pipeline_name = 'fish'
  2. _airflow_config = AirflowPipelineConfig(airflow_dag_config = {
  3.     'schedule_interval': None,
  4.     'start_date': datetime.datetime(2019, 1, 1),
  5. })
  6. _project_root = os.path.join(os.environ['HOME'], 'airflow')
  7. _data_root = os.path.join(_project_root, 'data', 'fish_data')
  8. _module_file = os.path.join(_project_root, 'dags', 'fishUtils.py')
  9. _serving_model_dir = os.path.join(_project_root, 'serving_model', _pipeline_name)
  10. _tfx_root = os.path.join(_project_root, 'tfx')
  11. _pipeline_root = os.path.join(_tfx_root, 'pipelines', _pipeline_name)
  12. _metadata_path = os.path.join(_tfx_root, 'metadata', _pipeline_name,
  13.                               'metadata.db')


  14. def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
  15.                      module_file: Text, serving_model_dir: Text,
  16.                      metadata_path: Text,
  17.                      direct_num_workers: int) -> pipeline.Pipeline:

  18.     examples = external_input(data_root)
  19.     example_gen = CsvExampleGen(input=examples)

  20.     statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

  21.     infer_schema = SchemaGen(
  22.       statistics=statistics_gen.outputs['statistics'],
  23.       infer_feature_shape=False)

  24.     validate_stats = ExampleValidator(
  25.       statistics=statistics_gen.outputs['statistics'],
  26.       schema=infer_schema.outputs['schema'])

  27.     trainer = Trainer(
  28.     examples=example_gen.outputs['examples'], schema=infer_schema.outputs['schema'],
  29.     module_file=_module_file, train_args= trainer_pb2.TrainArgs(num_steps=10000),
  30.     eval_args= trainer_pb2.EvalArgs(num_steps=5000))

  31.     model_validator = ModelValidator(
  32.       examples=example_gen.outputs['examples'],
  33.       model=trainer.outputs['model'])

  34.     pusher = Pusher(
  35.       model=trainer.outputs['model'],
  36.       model_blessing=model_validator.outputs['blessing'],
  37.       push_destination=pusher_pb2.PushDestination(
  38.         filesystem=pusher_pb2.PushDestination.Filesystem(
  39.           base_directory=_serving_model_dir)))

  40.     return pipeline.Pipeline(
  41.       pipeline_name=_pipeline_name,
  42.       pipeline_root=_pipeline_root,
  43.       components=[
  44.           example_gen,
  45.           statistics_gen,
  46.           infer_schema,
  47.           validate_stats,
  48.           trainer,
  49.           model_validator,
  50.           pusher],
  51.       enable_cache=True,
  52.       metadata_connection_config=metadata.sqlite_metadata_connection_config(
  53.           metadata_path),
  54.       beam_pipeline_args=['--direct_num_workers=%d' % direct_num_workers]
  55.   )

  56. runner = AirflowDagRunner(config = _airflow_config)
  57. DAG = runner.run(
  58.     _create_pipeline(
  59.         pipeline_name=_pipeline_name,
  60.         pipeline_root=_pipeline_root,
  61.         data_root=_data_root,
  62.         module_file=_module_file,
  63.         serving_model_dir=_serving_model_dir,
  64.         metadata_path=_metadata_path,
  65.         # 0 means auto-detect based on on the number of CPUs available during
  66.         # execution time.
  67.         direct_num_workers=0))
复制代码

这是我的模块文件:
  1. _DENSE_FLOAT_FEATURE_KEYS = ['length']

  2. real_valued_columns = [tf.feature_column.numeric_column('length')]

  3. def _eval_input_receiver_fn():

  4.   serialized_tf_example = tf.compat.v1.placeholder(
  5.       dtype=tf.string, shape=[None], name='input_example_tensor')

  6.   features = tf.io.parse_example(
  7.       serialized=serialized_tf_example,
  8.       features={
  9.           'length': tf.io.FixedLenFeature([], tf.float32),
  10.           'label': tf.io.FixedLenFeature([], tf.int64),
  11.       })

  12.   receiver_tensors = {'examples': serialized_tf_example}

  13.   return tfma.export.EvalInputReceiver(
  14.       features={'length' : features['length']},
  15.       receiver_tensors=receiver_tensors,
  16.       labels= features['label'],
  17.       )

  18. def parser(serialized_example):

  19.   features = tf.io.parse_single_example(
  20.       serialized_example,
  21.       features={
  22.           'length': tf.io.FixedLenFeature([], tf.float32),
  23.           'label': tf.io.FixedLenFeature([], tf.int64),
  24.       })
  25.   return ({'length' : features['length']}, features['label'])

  26. def _input_fn(filenames):
  27.   # TFRecordDataset doesn't directly accept paths with wildcards
  28.   filenames = tf.data.Dataset.list_files(filenames)
  29.   dataset = tf.data.TFRecordDataset(filenames, 'GZIP')
  30.   dataset = dataset.map(parser)
  31.   dataset = dataset.shuffle(2000)
  32.   dataset = dataset.batch(40)
  33.   dataset = dataset.repeat(10)

  34.   return dataset

  35. def trainer_fn(trainer_fn_args, schema):

  36.     estimator = tf.estimator.LinearClassifier(feature_columns=real_valued_columns)

  37.     train_input_fn = lambda: _input_fn(trainer_fn_args.train_files)

  38.     train_spec = tf.estimator.TrainSpec(
  39.       train_input_fn,
  40.       max_steps=trainer_fn_args.train_steps)

  41.     eval_input_fn = lambda: _input_fn(trainer_fn_args.eval_files)

  42.     eval_spec = tf.estimator.EvalSpec(
  43.       eval_input_fn,
  44.       steps=trainer_fn_args.eval_steps,
  45.       name='fish-eval')

  46.     receiver_fn = lambda: _eval_input_receiver_fn()

  47.     return {
  48.       'estimator': estimator,
  49.       'train_spec': train_spec,
  50.       'eval_spec': eval_spec,
  51.       'eval_input_receiver_fn': receiver_fn
  52.   }
复制代码

在此先感谢您的帮助!

回答
为遇到与我相同问题的任何人发布解决方案。

模型没有写入文件系统的原因是估计器需要一个配置参数来知道在哪里写入模型。

以下对 trainer_fn 函数的修改应该可以解决问题。
  1. run_config = tf.estimator.RunConfig(save_checkpoints_steps=999, keep_checkpoint_max=1)  

  2. run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir)

  3. estimator=tf.estimator.LinearClassifier(feature_columns=real_valued_columns,config=run_config)
复制代码






上一篇:使用底部选项卡导航器时不显示标题
下一篇:更改任务触发器但不反映 OIM 流程表单上的字段值

使用道具 举报

Archiver|手机版|小黑屋|吾爱开源 |网站地图

Copyright 2011 - 2012 Lnqq.NET.All Rights Reserved( ICP备案粤ICP备14042591号-1粤ICP14042591号 )

关于本站 - 版权申明 - 侵删联系 - Ln Studio! - 广告联系

本站资源来自互联网,仅供用户测试使用,相关版权归原作者所有

快速回复 返回顶部 返回列表