来源: TF Boys (TensorFlow Boys ) 养成记(二): TensorFlow 数据读取 – Charles-Wan – 博客园
TensorFlow 的 How-Tos,讲解了这么几点:
1. 变量:创建,初始化,保存,加载,共享;
2. TensorFlow 的可视化学习,(r0.12版本后,加入了Embedding Visualization)
3. 数据的读取;
4. 线程和队列;
5. 分布式的TensorFlow;
6. 增加新的Ops;
7. 自定义数据读取;
由于各种原因,本人只看了前5个部分,剩下的2个部分还没来得及看,时间紧任务重,所以匆匆发车了,以后如果有用到的地方,再回过头来研究。学习过程中深感官方文档的繁杂冗余极多多,特别是第三部分数据读取,又臭又长,花了我好久时间,所以我想把第三部分整理如下,方便乘客们。
TensorFlow 有三种方法读取数据:1)供给数据,用placeholder;2)从文件读取;3)用常量或者是变量来预加载数据,适用于数据规模比较小的情况。供给数据没什么好说的,前面已经见过了,不难理解,我们就简单的说一下从文件读取数据。
官方的文档里,从文件读取数据是一段很长的描述,链接层出不穷,看完这个链接还没看几个字,就出现了下一个链接。
自己花了很久才认识路,所以想把这部分总结一下,带带我的乘客们。
首先要知道你要读取的文件的格式,选择对应的文件读取器;
然后,定位到数据文件夹下,用
["file0", "file1"] # or [("file%d" % i) for i in range(2)]) # or tf.train.match_filenames_once
选择要读取的文件的名字,用 tf.train.string_input_producer 函数来生成文件名队列,这个函数可以设置shuffle = Ture,来打乱队列,可以设置epoch = 5,过5遍训练数据。
最后,选择的文件读取器,读取文件名队列并解码,输入 tf.train.shuffle_batch 函数中,生成 batch 队列,传递给下一层。
1)假如你要读取的文件是像 CSV 那样的文本文件,用的文件读取器和解码器就是 TextLineReader 和 decode_csv 。
2)假如你要读取的数据是像 cifar10 那样的 .bin 格式的二进制文件,就用 tf.FixedLengthRecordReader 和 tf.decode_raw 读取固定长度的文件读取器和解码器。如下列出了我的参考代码,后面会有详细的解释,这边先大致了解一下:
class cifar10_data(object): def __init__(self, filename_queue): self.height = 32 self.width = 32 self.depth = 3 self.label_bytes = 1 self.image_bytes = self.height * self.width * self.depth self.record_bytes = self.label_bytes + self.image_bytes self.label, self.image = self.read_cifar10(filename_queue) def read_cifar10(self, filename_queue): reader = tf.FixedLengthRecordReader(record_bytes = self.record_bytes) key, value = reader.read(filename_queue) record_bytes = tf.decode_raw(value, tf.uint8) label = tf.cast(tf.slice(record_bytes, [0], [self.label_bytes]), tf.int32) image_raw = tf.slice(record_bytes, [self.label_bytes], [self.image_bytes]) image_raw = tf.reshape(image_raw, [self.depth, self.height, self.width]) image = tf.transpose(image_raw, (1,2,0)) image = tf.cast(image, tf.float32) return label, image def inputs(data_dir, batch_size, train = True, name = 'input'): with tf.name_scope(name): if train: filenames = [os.path.join(data_dir,'data_batch_%d.bin' % ii) for ii in range(1,6)] for f in filenames: if not tf.gfile.Exists(f): raise ValueError('Failed to find file: ' + f) filename_queue = tf.train.string_input_producer(filenames) read_input = cifar10_data(filename_queue) images = read_input.image images = tf.image.per_image_whitening(images) labels = read_input.label num_preprocess_threads = 16 image, label = tf.train.shuffle_batch( [images,labels], batch_size = batch_size, num_threads = num_preprocess_threads, min_after_dequeue = 20000, capacity = 20192) return image, tf.reshape(label, [batch_size]) else: filenames = [os.path.join(data_dir,'test_batch.bin')] for f in filenames: if not tf.gfile.Exists(f): raise ValueError('Failed to find file: ' + f) filename_queue = tf.train.string_input_producer(filenames) read_input = cifar10_data(filename_queue) images = read_input.image images = tf.image.per_image_whitening(images) labels = read_input.label num_preprocess_threads = 16 image, label = tf.train.shuffle_batch( [images,labels], batch_size = batch_size, num_threads = num_preprocess_threads, min_after_dequeue = 20000, capacity = 20192) return image, tf.reshape(label, [batch_size])
3)如果你要读取的数据是图片,或者是其他类型的格式,那么可以先把数据转换成 TensorFlow 的标准支持格式 tfrecords ,它其实是一种二进制文件,通过修改 tf.train.Example 的Features,将 protocol buffer 序列化为一个字符串,再通过 tf.python_io.TFRecordWriter 将序列化的字符串写入 tfrecords,然后再用跟上面一样的方式读取tfrecords,只是读取器变成了tf.TFRecordReader,之后通过一个解析器tf.parse_single_example ,然后用解码器 tf.decode_raw 解码。
例如,对于生成式对抗网络GAN,我采用了这个形式进行输入,部分代码如下,后面会有详细解释,这边先大致了解一下:
def _int64_feature(value): return tf.train.Feature(int64_list = tf.train.Int64List(value = [value])) def _bytes_feature(value): return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value])) def convert_to(data_path, name): """ Converts s dataset to tfrecords """ rows = 64 cols = 64 depth = DEPTH for ii in range(12): writer = tf.python_io.TFRecordWriter(name + str(ii) + '.tfrecords') for img_name in os.listdir(data_path)[ii*16384 : (ii+1)*16384]: img_path = data_path + img_name img = Image.open(img_path) h, w = img.size[:2] j, k = (h - OUTPUT_SIZE) / 2, (w - OUTPUT_SIZE) / 2 box = (j, k, j + OUTPUT_SIZE, k+ OUTPUT_SIZE) img = img.crop(box = box) img = img.resize((rows,cols)) img_raw = img.tobytes() example = tf.train.Example(features = tf.train.Features(feature = { 'height': _int64_feature(rows), 'weight': _int64_feature(cols), 'depth': _int64_feature(depth), 'image_raw': _bytes_feature(img_raw)})) writer.write(example.SerializeToString()) writer.close() def read_and_decode(filename_queue): """ read and decode tfrecords """ # filename_queue = tf.train.string_input_producer([filename_queue]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized_example,features = { 'image_raw':tf.FixedLenFeature([], tf.string)}) image = tf.decode_raw(features['image_raw'], tf.uint8) return image
这里,我的data_path下面有16384*12张图,通过12次写入Example操作,把图片数据转化成了12个tfrecords,每个tfrecords里面有16384张图。
4)如果想定义自己的读取数据操作,请参考https://www.tensorflow.org/how_tos/new_data_formats/。
好了,今天的车到站了,请带好随身物品准备下车,明天老司机还有一趟车,请记得准时乘坐,车不等人。
参考文献:
1. https://www.tensorflow.org/how_tos/
2. 没了