Create-TFRecord

TFRecord是一种高效的数据存储格式,尤其是在处理大数据集时,我们无法对数据进行一次读取,这时我们就可以将文件存储为TFRecord,然后再进行读取。这样可以可以提高数据移动、读取、处理等速度。
在对小数据集进行读取时,可以直接使用tf.data API来进行处理。

在TFRecord中是将每个样本example 以字典的方式进行存储。

主要的数据类型如下:

  • int64:tf.train.Feature(int64_list = tf.train.Int64List(value=输入))
  • float32:tf.train.Feature(float_list = tf.train.FloatList(value=输入))
  • string:tf.train.Feature(bytes_list=tf.train.BytesList(value=输入))
  • 注:输入必须是list(向量)

这里我们举一个NLP中常见例子。

  • 这里有10个句子sentence,每个句子有128个token_id
  • 每个句子对应的10个标签label
  • 每个句子中对应的token weight (mask)
  • 每个句子经过Embedding后的 句子matrixtensor (两者是同一个东西,只是为了后面介绍两种不同的存储方式。)

那么我们怎样将这些转换为TFRecord呢?

Create_TFRecord.py

大致可以分为以下几步:

  1. 由于TFRecord中是将每个样本当做一个example进行存储。所以我们先要取得每个样本对应的sentence, label, weight, matrix, tensor.
  2. 将每个样本属性转换为对应的feature字典类型。(注意,这里的value均为**list**类型)
    • int64:tf.train.Feature(int64_list = tf.train.Int64List(value=输入))
    • float32:tf.train.Feature(float_list = tf.train.FloatList(value=输入))
    • string:tf.train.Feature(bytes_list=tf.train.BytesList(value=输入))
  3. 将feature字典包装成features。
    features=tf.train.Features(feature=feature字典)
  4. 将features转换成example
    example = tf.train.Example(features=features)
  5. 通过example.SerializeToString() 将example 进行序列化,并通过 tfwriter.write()进行写入文件。

具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import tensorflow as tf
from pathlib import Path
import numpy as np

# 随机生成相应数据
senteces = np.random.randint(0,512,(10,128))
senteces_label = np.random.randint(0,2,(10))
senteces_weight = [[1.0]*128]*10
tensors = np.random.randn(10,128,512)
matrixs = np.random.randn(10,128,512)


tfrecord_save_path = 'data.tfrecord'

with tf.io.TFRecordWriter(tfrecord_save_path) as tfwriter:
for text,label ,weight, tensor, matrix in zip(senteces ,senteces_label,senteces_weight, tensors, matrixs):
example = tf.train.Example(features=tf.train.Features(
feature={
'text':tf.train.Feature(int64_list=tf.train.Int64List(value=text.tolist())),
'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
'weight':tf.train.Feature(float_list=tf.train.FloatList(value=weight)),

#当需要存入矩阵时,有两种方法 一种是将矩阵Flatten 然后在读取的时候进行 形状reshape
'matrix': tf.train.Feature(float_list=tf.train.FloatList(value=matrix.reshape(-1))),
'matrix_shape': tf.train.Feature(int64_list=tf.train.Int64List(value=matrix.shape)),

# 存入矩阵时,会使得矩阵形状丢失 因此 需要额外记录矩阵的形状,以便还原。

#另一种方法 是将矩阵转换为字符类型存储,随后在还原。
# 两种方法都会导致 形状丢失,都需要进行矩阵形状存储
'tensor': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tensor.tostring()])),
'tensor_shape': tf.train.Feature(int64_list=tf.train.Int64List(value=tensor.shape))

}
))

tfwriter.write(example.SerializeToString())

Read_TFRecord.py

在得到了TFRecord后,我们又改如何解析呢?

解析大致也可以分为几个步骤:

  1. 通过tf.data.TFRecordDataset对TFRecord进行读取。

  2. 在前面创建TFRecord时,我们需要创建feature字典,同样在解析时也需要定义一个feature_description字典,告诉程序,TFRecord中的数据类型。

    定长特征解析tf.FixedLenFeature(shape, dtype, default_value)

    • shape:可当reshape来用,如vector的shape从(3,)改动成了(1,3)。
    • 注:如果写入的feature使用了.tostring() 其shape就是()
    • dtype:必须tf.float32tf.int64tf.string中的一种。
    • default_value:feature值缺失时所指定的值。

    不定长特征解析tf.VarLenFeature(dtype)

    • 注:可以不明确指定shape,但得到的tensor是SparseTensor。
  3. 通过tf.io.parse_single_example对 1 中得到的raw_data进行解析。

  4. 对解析后的数据,对应的部分进一步进行还原。

在解析TFRecord时,需要注意:

  • tf.io.FixedLenFeature 中要明确传入数据的形状。
  • tf.io.VarLenFeature虽然不用传入数据形状,但需要通过tf.sparse.to_dense对对应数据进行解析
  • 其中由于tensor是前面是通过转换为字符类型进行存储的,因此需要进行解码。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import tensorflow as tf

def parse_function(example):
# 这里的dtype 类型只有 float32, int64, string
feature_description = {
'text': tf.io.FixedLenFeature(shape=(128,), dtype=tf.int64),
'label': tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
'weight': tf.io.FixedLenFeature(shape=(128,), dtype=tf.float32),
'matrix': tf.io.VarLenFeature(dtype=tf.float32),
'matrix_shape': tf.io.VarLenFeature(dtype=tf.int64),
'tensor': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
'tensor_shape': tf.io.FixedLenFeature(shape=(2,), dtype=tf.int64)
}

parse_example = tf.io.parse_single_example(example, feature_description)

# 对parse_example中的对应数据进一步解析
parse_example['matrix'] = tf.sparse.to_dense(parse_example['matrix'])
parse_example['matrix_shape'] = tf.sparse.to_dense(parse_example['matrix_shape'])
# 由于tensor是前面是通过转换为字符类型进行存储的,因此需要进行解码
parse_example['tensor'] = tf.io.decode_raw(parse_example['tensor'], tf.int64)
# 将相应矩阵进行reshape
parse_example['matrix'] = tf.reshape(parse_example['matrix'], parse_example['matrix_shape'])
parse_example['tensor'] = tf.reshape(parse_example['tensor'], parse_example['tensor_shape'])
return parse_example


if __name__ == '__main__':
tfrecord_save_path = 'data.tfrecord'

raw_dataset = tf.data.TFRecordDataset(tfrecord_save_path)

dataset = raw_dataset.map(parse_function)
# 在这里我们还可以对数据dataset进行shuffle和batch操作
# dataset = dataset.shuffle()
# dataset = dataset.batch()
for data in dataset:
print(data)

Reference

https://zhuanlan.zhihu.com/p/33223782