页面树结构

2017-11-09 ApacheCN 开源组织,第二期邀请成员活动,一起走的更远 : http://www.apachecn.org/member/209.html


MachineLearning 优酷地址 : http://i.youku.com/apachecn

转至元数据结尾
转至元数据起始

先决条件:

我们把支持文件格式的任务分成两部分:

  • 文件格式:我们使用Reader Op 从文件读取记录(可以是任何字符串)。
  • 记录格式:我们使用解码器或解析操作将字符串记录转换为TensorFlow可用的张量。

例如,要阅读 CSV文件,我们使用 Reader作为文本文件, 然后 使用Op从一行文本中分析CSV数据

 

为文件格式写入读卡器

Reader是从文件读取记录的东西。有一些已经在TensorFlow中构建的Reader Ops的例子:

你可以看到这些都暴露了同一个接口,唯一的区别在于它们的构造函数。最重要的方法是read。它需要一个队列参数,它是从哪里获取文件名从哪里需要一个(例如,当readop第一次运行,或以前read从文件读取最后一个记录时)。它产生两个标量张量:一个字符串键和一个字符串值。

要创建一个新的阅读器SomeReader,您需要:

  1. 在C ++中,定义一个tensorflow::ReaderBase 被调用的子类 SomeReader
  2. 在C ++中,使用名称注册一个新的读者op和内核"SomeReader"
  3. 在Python中,定义一个tf.ReaderBase被调用的子类SomeReader

您可以将所有C ++代码放入文件中 tensorflow/core/user_ops/some_reader_op.cc。读取文件的代码将生活在C ++ ReaderBase类的后代,该类定义在 tensorflow/core/kernels/reader_base.h。您将需要实现以下方法:

  • OnWorkStartedLocked:打开下一个文件
  • ReadLocked:读取记录或报告EOF /错误
  • OnWorkFinishedLocked:关闭当前文件,和
  • ResetLocked得到一个干净的石板后,例如,一个错误

这些方法的名称以“锁定”结尾,因为ReaderBase确保在调用任何这些方法之前获取互斥体,因此您通常不必担心线程安全性(尽管只保护类的成员,而不是全局状态) 。

因为OnWorkStartedLocked,要打开的文件的名称是该current_work()方法返回的值。 ReadLocked有这个签名:

Status ReadLocked(string* key, string* value, bool* produced, bool* at_end)
如果ReadLocked从文件中成功读取记录,则应填写:
  • *key:具有记录的标识符,人可以用来再次找到该记录。您可以包括文件名current_work(),并附加记录号或其他。
  • *value:记录内容。
  • *produced:设置为true

如果您点击文件(EOF)的末尾,请设置*at_endtrue。在任一情况下,返回Status::OK()。如果有错误,只需使用其中一个帮助函数返回, tensorflow/core/lib/core/errors.h 而不修改任何参数。

接下来,您将创建实际的读者操作。如果您熟悉添加操作方法,这将有所帮助。主要步骤是:

  • 注册
  • 定义并注册OpKernel

要注册op,您将使用REGISTER_OP定义的调用 tensorflow/core/framework/op.h。读卡器操作从不采取任何输入,并始终具有单一输出类型 resource。他们应该有字符串containershared_nameattrs。您可以选择定义附加的attr进行配置,或者在a中包含文档Doc。例如,参见tensorflow/core/ops/io_ops.cc,例如:

#include "tensorflow/core/framework/op.h"

REGISTER_OP("TextLineReader")
 .Output("reader_handle: resource")
 .Attr("skip_header_lines: int = 0")
 .Attr("container: string = ''")
 .Attr("shared_name: string = ''")
 .SetIsStateful()
 .SetShapeFn(shape_inference::ScalarShape)
 .Doc(R"doc(
A Reader that outputs the lines of a file delimited by '\n'.
)doc");
要定义一个OpKernel,读者可以使用从中下调的快捷方式 ReaderOpKernel,定义tensorflow/core/framework/reader_op_kernel.h和实现一个调用的构造函数SetReaderFactory。定义你的课后,你需要注册REGISTER_KERNEL_BUILDER(...)。没有attrs的例子:
#include "tensorflow/core/framework/reader_op_kernel.h"

class TFRecordReaderOp : public ReaderOpKernel {
 public:
 explicit TFRecordReaderOp(OpKernelConstruction* context)
 : ReaderOpKernel(context) {
 Env* env = context->env();
 SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); });
 }
};

REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU),
 TFRecordReaderOp);
attrs的例子:
#include "tensorflow/core/framework/reader_op_kernel.h"

class TextLineReaderOp : public ReaderOpKernel {
 public:
 explicit TextLineReaderOp(OpKernelConstruction* context)
 : ReaderOpKernel(context) {
 int skip_header_lines = -1;
 OP_REQUIRES_OK(context,
 context->GetAttr("skip_header_lines", &skip_header_lines));
 OP_REQUIRES(context, skip_header_lines >= 0,
 errors::InvalidArgument("skip_header_lines must be >= 0 not ",
 skip_header_lines));
 Env* env = context->env();
 SetReaderFactory([this, skip_header_lines, env]() {
 return new TextLineReader(name(), skip_header_lines, env);
 });
 }
};

REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
 TextLineReaderOp);
最后一步是添加Python包装器。您可以通过编译动态库来执行此操作, 或者如果要从源代码构建TensorFlow,则添加到user_ops.py。对于后者,您将导入tensorflow.python.ops.io_opstensorflow/python/user_ops/user_ops.py 添加的后裔io_ops.ReaderBase
 
from tensorflow.python.framework import ops
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import io_ops

class SomeReader(io_ops.ReaderBase):

 def __init__(self, name=None):
 rr = gen_user_ops.some_reader(name=name)
 super(SomeReader, self).__init__(rr)

ops.NotDifferentiable("SomeReader")
你可以看到一些例子 tensorflow/python/ops/io_ops.py

为录制格式编写作品

一般来说,这是一个普通的op,它将标量字符串记录作为输入,因此按照说明添加一个Op。您可以选择使用标量字符串键作为输入,并将其包含在报告格式不正确的数据的错误消息中。这样用户可以更轻松地跟踪坏数据来自哪里。

对解码记录有用的操作示例:

请注意,使用多个Ops来解码特定的记录格式可能很有用。例如,可能必须保存为一个字符串的图像 一个tf.train.Example协议缓冲器。根据该图像的格式,你可能会采取相应的输出从tf.parse_single_example OP和呼叫tf.image.decode_jpeg, tf.image.decode_pngtf.decode_raw。通常要输出tf.decode_raw并使用 tf.slice和 tf.reshape提取片段。

  • 无标签