TensorFlow(Keras)で大規模データを扱っていると、学習開始前のデータの読み出し(転送)でめっちゃ時間がかかることがあります。普通はGeneratorを利用して解決するのですが、AWSの環境などではそうもいかない場合があります。本記事は、学習におけるI/Oボトルネックを解消する際に用いられるTFRecordについて、自分なりに調べた内容をまとめます。
一応以下の記事と関りがあります。
TFRecordsの概要
Tensorflowの公式サイトによると「プロトコルバッファによりシリアライズされたバイナリデータを連続的に読み出せるファイルのセット(意訳)」ということらしいです(リンク)。
プロトコルバッファ(Protocol Buffers)とは、Googleが開発しているシリアライズフォーマットのことだそうです。スキーマ言語の一種ということなので、XMLやjsonなんかの仲間ということになります。TensorflowもGoogle謹製なので、自社開発の技術を使用しているということになりますね。
TFRecordは、以下のクラスが絡む「tf.Example」をシリアライズしたものになります。
- tf.Example は tf.train.Features をコンストラクタにとるクラス
- tf.train.Featuresは、{“string”: tf.train.Feature} の形式を持つ辞書をコンストラクタにとるクラス
- tf.train.Featureは、tf.train.BytesList・tf.train.FloatList・tf.train.Int64Listをコンストラクタにとるクラス
- tf.train.BytesList は string・byte のリストをコンストラクタにとるクラス
- tf.train.FloatList は float(float32)・double(float64) のリストをコンストラクタにとるクラス
- tf.train.Int64List は bool・enum・int32・uint32・int64・uint64 のリストをコンストラクタにとるクラス
いきなりこれを見ても「はぁ!?」感がすごいので、もう少し深く見ていきましょう。
tf.train.Feature / tf.train.Exampleとは?
公式サイトの説明文を見ていると、「.proto fileを参照」の項目があります。ご丁寧にリンクが張ってあるので、その先に飛ぶとtensorflowのgithubリポジトリに飛ばされます。更に言うと「tensorflow/tensorflow/core/example/feature.proto」のファイルが公開されているページに飛びます。中身を見てみると、54行目までコメントアウトされており、56行目に 「syntax =”proto3″」と書かれていることから protocol buffers の型定義ファイルと分かります。更に見ていくと67 ~ 75行目に、以下のような記述があることに気づきます
// Containers to hold repeated fundamental values.
message BytesList {
repeated bytes value = 1;
}
message FloatList {
repeated float value = 1 [packed = true];
}
message Int64List {
repeated int64 value = 1 [packed = true];
}
それぞれ、tf.train.BytesList・tf.train.FloatList・tf.train.Int64List に関わる部分だと分かります。それぞれ、bytes・float・int64のデータが連続で入るようなメッセージデータであることを表しています。可変長のリストをコンストラクタの引数にとることとも整合するので恐らく合っているはずです。
実際にこれらを作成してみると、以下のようになります(公式サイト改変)。
import tensorflow as tf
import numpy as np
raw_feature = np.random.randint(0, 5, 10)
int64_list = tf.train.Int64List(value=raw_feature.tolist())
print(int64_list)
# 出力結果:
# value: 3
# value: 0
# value: 2
# value: 0
# value: 1
# value: 1
# value: 4
# value: 1
# value: 0
# value: 2
これを見ると、ランダムな整数が入ったリストがfeature.protoで定義されていたデータ形式に変換されていることが分かります。
次に、tf.train.Featureについてですが、これは
import tensorflow as tf
import numpy as np
raw_feature = np.random.randint(0, 5, 10)
feature = tf.train.Feature(
int64_list=tf.train.Int64List(value=raw_feature.tolist()))
print(feature)
# 出力結果:
# int64_list {
# value: 3
# value: 1
# value: 2
# value: 4
# value: 4
# value: 0
# value: 2
# value: 4
# value: 0
# value: 1
# }
のように、上記で作成したbytes_list, float_list, int64_listを保持しているオブジェクトになります。tf.train.FeatureはSerializeToString()というメソッドを持っており、これを使うことでデータ列をバイナリ文字にシリアライズすることができます。
### 上記のコードに追加 ###
print(feature.SerializeToString())
# 出力結果:
# b'\x1a\x05\n\x03\x02\x02\x00'
先ほど説明したtf.train.Featureを複数集めると、tf.train.Featuresになります。tf.train.Featuresは、valuesにtf.train.Featureのインスタンスを持つ辞書をコンストラクタにとるオブジェクトです。
具体的には以下のようになります。
import tensorflow as tf
import numpy as np
raw_feature = np.random.randint(0, 5, 3)
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature1 = tf.train.Feature(
int64_list=tf.train.Int64List(value=raw_feature.tolist()))
feature2 = tf.train.Feature(
bytes_list=tf.train.BytesList(value=strings.tolist()))
feature_dict = {"elem1": feature1,
"elem2": feature2}
features = tf.train.Features(feature=feature_dict)
print(features)
# 出力結果:
# feature {
# key: "elem1"
# value {
# int64_list {
# value: 4
# value: 2
# value: 0
# }
# }
# }
# feature {
# key: "elem2"
# value {
# bytes_list {
# value: "cat"
# value: "dog"
# value: "chicken"
# value: "horse"
# value: "goat"
# }
# }
# }
keyと紐づいた形で、bytes_listなどのデータを保持するようになっている訳ですね。そして、これをtf.train.Exampleのコンストラクタに渡すことで、tf.Example メッセージを作成することができ、シリアライズしてTFRecords形式で保存できるようになるということです。コード例は以下です。
### 上記コードに以下を追加 ###
example_proto = tf.train.Example(features=features)
print(example_proto)
# 出力結果:
# features {
# feature {
# key: "elem1"
# value {
# int64_list {
# value: 3
# value: 3
# value: 0
# }
# }
# }
# feature {
# key: "elem2"
# value {
# bytes_list {
# value: "cat"
# value: "dog"
# value: "chicken"
# value: "horse"
# value: "goat"
# }
# }
# }
# }
これがtf.train.Example の実体になります。こうして見ると、事前に設計されているフォーマットに沿って、データの列をprotocol buffers の階層構造に詰めていっているだけと解釈できそうです。初見では色々な要素が出てくるので怯みますが、一つ一つ見ていくと何ということはないわけですね。
ちなみに、tf.train.ExampleもSerializeToString()メソッドでシリアライズ可能なようです。
### 上記コードに追加 ###
print(example_proto.SerializeToString()
)
# 出力結果:
# b'\n?\n\x10\n\x05elem1\x12\x07\x1a\x05\n\x03\x04\x01\x02\n+\n\x05elem2\x12"\n \n\x03cat\n\x03dog\n\x07chicken\n\x05horse\n\x04goat'
余談
調べてみると、protoによる定義や、SerializeToString() メソッドはprotocol buffers とPythonの連携で提供されている記法のようですね。Tensorflowとは関係なく使えるようなので覚えておくといいかもしれません。
TFRecordsへの書き出し・読み出し
さて、前章までで作成したtf.train.Exampleをシリアライズすると、もうTFRecords形式のファイルに書き出しが可能です。書き出しには tf.io.TFRecordWriter() メソッドを用います。普通のファイルオープンなどと同様に扱えるので、withブロックを用いて以下のように書くとよいかと思います。
### 上記コードに追加 ###
filepath = "./example.tfrecords"
with tf.io.TFRecordWriter(filepath) as writer:
writer.write(example_proto.SerializeToString())
# このコードを実行したディレクトリに、"example.tfrecords"というファイルが生成されていればOK
また、読み出しは以下のように行います。
import tensorflow as tf
import numpy as np
# TFRecordsファイルを保存したパス
filepath = "./example.tfrecords"
raw_dataset = tf.data.TFRecordDataset(filepath)
for raw_record in raw_dataset.take(1):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
print(example)
# 出力結果
# features {
# feature {
# key: "elem1"
# value {
# int64_list {
# value: 4
# value: 0
# value: 4
# }
# }
# }
# feature {
# key: "elem2"
# value {
# bytes_list {
# value: "cat"
# value: "dog"
# value: "chicken"
# value: "horse"
# value: "goat"
# }
# }
# }
# }
無事、元のExampleデータを読み込めていることが分かります。ちなみに、各feature毎に値を取得したければ、例えば以下のようにすればOKです。
import tensorflow as tf
import numpy as np
filepath = "./example.tfrecords"
raw_dataset = tf.data.TFRecordDataset(filepath)
feature_description = {
"elem1": tf.io.VarLenFeature(tf.int64),
"elem2": tf.io.VarLenFeature(tf.string)
}
# デシリアライズのためのコールバック関数
def _parse_dataset(example_proto):
return tf.io.parse_single_example(example_proto, feature_description)
parsed_example = raw_dataset.map(_parse_dataset)
for features in parsed_example:
elem1 = features["elem1"]
elem2 = features["elem2"]
print(elem1)
print(elem2)
# 出力結果:
# SparseTensor(indices=tf.Tensor(
# [[0]
# [1]
# [2]], shape=(3, 1), dtype=int64), values=tf.Tensor([4 0 4], shape=(3,), dtype=int64), dense_shape=tf.Tensor([3], shape=(1,), dtype=int64))
# SparseTensor(indices=tf.Tensor(
# [[0]
# [1]
# [2]
# [3]
# [4]], shape=(5, 1), dtype=int64), values=tf.Tensor([b'cat' b'dog' b'chicken' b'horse' b'goat'], shape=(5,), dtype=string), dense_shape=tf.Tensor([5], shape=(1,), dtype=int64))
これで、TFRecordsファイルからTensorflowのTensorへ、データを復元できたことになります。なお、デフォルトではどうやらSparseTensorとして返ってくるようなので、これをDenseTensorに変換する必要はあります。
まとめ
この記事では、TFRecordsの扱いについて、自分の勉強を兼ねてまとめてみました。ぱっと見とっつきづらいんですが、要はXMLなんかと似たような形式であるProtocol Buffers の形式にデータを変換して、シリアライズしているだけということでした。また、デシリアライズする際には、データ構造を指定する必要があるのがポイントということも分かりました。Protocol Buffers はデータ効率がよいようなので、通信を扱う際には何かとお世話になるかもしれませんね。
ちなみに本来はSparseTensorをTFRecordsとして扱うにはどうすればよいのかを書くつもりでしたが、基本からやった方がよいだろうということで今回はこのような内容になりました。
より実践的な内容・SparseTensorのシリアライズ等に関しては後日別の記事でまとめようと思います。では、今回は以上とします。今回も最後までお読みいただきありがとうございました!
コメント