Streaming large training and test files into Tensorflow's DNNClassifier
https://stackoverflow.com/questions/45828616/streaming-large-training-and-test-files-into-tensorflows-dnnclassifier
Check out the tf.data.Dataset
API. There are a number of ways to create a dataset. I'll outline three - but you'll only have to implement one.
I assume each row of your csv
files is n_features
float values followed by a single int
value.
Creating a tf.data.Dataset
tf.data.Dataset
Wrap a python generator with Dataset.from_generator
Dataset.from_generator
The easiest way is to wrap a native python generator. I'd encourage you to try this first and only change if you see serious performance issues.
This approach is highly versatile and allows you to test your generator function (read_csv
) independently of TensorFlow.
Wrap an index-based python function
One of the downsides of the above is shuffling the resulting dataset with a shuffle buffer of size n
requires n
examples to be loaded. This will either create periodic pauses in your pipeline (large n
) or result in potentially poor shuffling (small n
).
In short, we create a dataset just of the record indices (or any small record ID which we can load entirely into memory). We then do shuffling/repeating operations on this minimal dataset, then map
the index to the actual data via tf.data.Dataset.map
and tf.py_func
. See the Using with Estimators
and Testing in isolation
sections below for usage. Note this requires your data to be accessible by row, so you may need to convert from csv
to some other format.
TextLineDataset
You can also read the csv
file directly using a tf.data.TextLineDataset
.
The parse_row
function is a little convoluted since tf.decode_csv
expects a batch. You can make it slightly simpler if you batch the dataset before parsing.
TFRecordDataset
Alternatively you can convert the csv
files to TFRecord files and use a TFRecordDataset. There's a thorough tutorial here.
Step 1: Convert the csv
data to TFRecords data. Example code below (see read_csv
from from_generator
example above).
This only needs to be run once.
Step 2: Write a dataset that decodes these record files.
Using the dataset with estimators
Testing the dataset in isolation
I'd strongly encourage you to test your dataset independently of your estimator. Using the above get_inputs
, it should be as simple as
Performance
Assuming your using a GPU to run your network, unless each row of your csv
file is enormous and your network is tiny you probably won't notice a difference in performance. This is because the Estimator
implementation forces data loading/preprocessing to be performed on the CPU, and prefetch
means the next batch can be prepared on the CPU as the current batch is training on the GPU. The only exception to this is if you have a massive shuffle size on a dataset with a large amount of data per record, which will take some time to load in a number of examples initially before running anything through the GPU.
Last updated