留意:本文运用tensorflow1.x版本进行演示
运用本地Jupyter Notebook搭载TensorFlow相关库进行操作
1. 读取TFRecords文件
其实读取TFRecords文件大体思路与常规文件读取思路(结构行列、读取、解码、批处理行列)比较一致。但是,仍是有一点不相同,在解码操作之前,需求解析Example操作(由于TFRecords比其他文件多了个Example结构),TFRecords文件读取步骤如下所示:
- 结构文件名行列
- 读取
- 解析Example
- tf.parse_single_example()
- tf.FixedLenFeature(shape, dtype)
- tf.parse_single_example()
- 解码
- 结构批处理行列
接下来,咱们将对TFRecords文件读取中用到的函数进行详细说明:
-
tf.parse_single_example(serialized, features=None, name=None)
- 用来解析一个单一的Example原型
- serialized:标量字符串Tensor,一个序列化的Example
- features:dict字典数据,键为读取的姓名,值为FixedLenFeature
- return:回来一个键值对组成的字典,键为读取的姓名。想拿到解析后的example数据,需求经过字典形式访问。
-
tf.FixedLenFeature(shape, dtype)
- 这个函数和上一个函数其实是嵌套运用的,上一个函数中的features参数中的一部分(字典中值的部分)需求用本函数填充
- shape:输入数据的形状,一般不指定即为空列表
- dtype:输入数据类型,与存储进文件的类型要相同
- 类型只能是float32,int64,string
2. 代码演示
导入所需模块,由于本地下载的是Tensorflow2.x版本,想运转Tensorflow1的语法,需求敞开兼容模型,以支撑Tensorflow1语法正常运转。
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import os
从读取TFRecords文件的视点,进行函数定义,对已保存到本地的TFRecords文件进行读取。详细代码如下所示:
-
首先在函数中需求结构文件名行列,经过变量file_queue接纳
-
然后,运用tf.TFRecordReader()读取器,运用该读取器下面的read办法进行文件读取,运用变量key与value承受元组
-
接下来以上述介绍的API进行Example解析,能够将中间结果image与label打印出来看看
-
别忘记敞开会话tf.Session()才干看到详细的值
- 会话中tf.train.Coordinator()敞开线程
- sess.run()运转一下用以检查详细的值
- 收回资源,收回线程
-
然后是解码操作,咱们能够将其解码成uint8
-
打印出的是一维数组,咱们需求进行图画调整将其调整成32323
-
终究,将其放入批处理行列。
class Cifar():
def __init__(self):
# 设置图画巨细
self.height = 32
self.width = 32
self.channel = 3
# 设置图画字节数
self.image = self.height * self.width * self.channel
self.label = 1
self.sample = self.image + self.label
def read_tfrecords(self):
"""
读取tfrecords文件
"""
# 1. 结构文件名行列
file_queue = tf.train.string_input_producer(["cifar10.tfrecords"])
# 2. 读取与解码
# 2.1 读取
reader = tf.TFRecordReader()
key, value = reader.read(file_queue)
# 2.2 解析example
feature = tf.parse_single_example(value, features={
"image": tf.FixedLenFeature([], tf.string),
"label": tf.FixedLenFeature([], tf.int64)
})
image = feature["image"]
label = feature["label"]
print("read_tf_image:\n", image)
print("read_tf_label:\n", label)
# 2.3 解码
image_decoded = tf.decode_raw(image, tf.uint8)
print("image_decoded:\n", image_decoded)
# 图画形状调整
image_reshaped = tf.reshape(image_decoded, [self.height, self.width, self.channel])
# 3. 结构批处理行列
image_batch, label_batch = tf.train.batch([image_reshaped, label], batch_size=100, num_threads=5, capacity=100)
print("image_batch:\n", image_batch)
print("label_batch:\n", label_batch)
# 敞开会话
with tf.Session() as sess:
# 敞开线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
image_value, label_value, image_decoded, image_batch, label_batch = sess.run([image, label, image_decoded, image_batch, label_batch])
print("image_value:\n", image_value)
print("label_value:\n", label_value)
# 收回资源
coord.request_stop()
coord.join(threads)
return None
cifar = Cifar()
cifar.read_tfrecords()
部分读取结果如下图所示:
本文正在参加「金石方案 . 瓜分6万现金大奖」