# Streaming large training and test files into Tensorflow's DNNClassifier

Check out the [`tf.data.Dataset`](https://www.tensorflow.org/versions/master/api_docs/python/tf/data) API. There are a number of ways to create a dataset. I'll outline three - but you'll only have to implement one.

I assume each row of your `csv` files is `n_features` float values followed by a single `int` value.

## Creating a `tf.data.Dataset`

#### Wrap a python generator with `Dataset.from_generator`

The easiest way is to wrap a native python generator. I'd encourage you to try this first and only change if you see serious performance issues.

```
def read_csv(filename):
    with open(filename, 'r') as f:
        for line in f.readlines():
            record = line.rstrip().split(',')
            features = [float(n) for n in record[:-1]]
            label = int(record[-1])
            yield features, label

def get_dataset():
    filename = 'my_train_dataset.csv'
    generator = lambda: read_csv(filename)
    return tf.data.Dataset.from_generator(
        generator, (tf.float32, tf.int32), ((n_features,), ()))
```

This approach is highly versatile and allows you to test your generator function (`read_csv`) independently of TensorFlow.

#### Wrap an index-based python function

One of the downsides of the above is shuffling the resulting dataset with a shuffle buffer of size `n`requires `n` examples to be loaded. This will either create periodic pauses in your pipeline (large `n`) or result in potentially poor shuffling (small `n`).

```
def get_record(i):
    # load the ith record using standard python, return numpy arrays
    return features, labels

def get_inputs(batch_size, is_training):

    def tf_map_fn(index):
        features, labels = tf.py_func(
            get_record, (index,), (tf.float32, tf.int32), stateful=False)
        features.set_shape((n_features,))
        labels.set_shape(())
        # do data augmentation here
        return features, labels

    epoch_size = get_epoch_size()
    dataset = tf.data.Dataset.from_tensor_slices((tf.range(epoch_size,))
    if is_training:
        dataset = dataset.repeat().shuffle(epoch_size)
    dataset = dataset.map(tf_map_fn, (tf.float32, tf.int32), num_parallel_calls=8)
    dataset = dataset.batch(batch_size)
    # prefetch data to CPU while GPU processes previous batch
    dataset = dataset.prefetch(1)
    # Also possible
    # dataset = dataset.apply(
    #     tf.contrib.data.prefetch_to_device('/gpu:0'))
    features, labels = dataset.make_one_shot_iterator().get_next()
    return features, labels
```

In short, we create a dataset just of the record indices (or any small record ID which we can load entirely into memory). We then do shuffling/repeating operations on this minimal dataset, then `map`the index to the actual data via `tf.data.Dataset.map` and `tf.py_func`. See the `Using with Estimators` and `Testing in isolation` sections below for usage. Note this requires your data to be accessible by row, so you may need to convert from `csv` to some other format.

#### TextLineDataset

You can also read the `csv` file directly using a `tf.data.TextLineDataset`.

```
def get_record_defaults():
  zf = tf.zeros(shape=(1,), dtype=tf.float32)
  zi = tf.ones(shape=(1,), dtype=tf.int32)
  return [zf]*n_features + [zi]

def parse_row(tf_string):
    data = tf.decode_csv(
        tf.expand_dims(tf_string, axis=0), get_record_defaults())
    features = data[:-1]
    features = tf.stack(features, axis=-1)
    label = data[-1]
    features = tf.squeeze(features, axis=0)
    label = tf.squeeze(label, axis=0)
    return features, label

def get_dataset():
    dataset = tf.data.TextLineDataset(['data.csv'])
    return dataset.map(parse_row, num_parallel_calls=8)
```

The `parse_row` function is a little convoluted since `tf.decode_csv` expects a batch. You can make it slightly simpler if you batch the dataset before parsing.

```
def parse_batch(tf_string):
    data = tf.decode_csv(tf_string, get_record_defaults())
    features = data[:-1]
    labels = data[-1]
    features = tf.stack(features, axis=-1)
    return features, labels

def get_batched_dataset(batch_size):
    dataset = tf.data.TextLineDataset(['data.csv'])
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(parse_batch)
    return dataset
```

#### TFRecordDataset

Alternatively you can convert the `csv` files to TFRecord files and use a [TFRecordDataset](https://www.tensorflow.org/versions/master/api_docs/python/tf/data/TFRecordDataset). There's a thorough tutorial [here](https://www.tensorflow.org/versions/master/programmers_guide/datasets).

Step 1: Convert the `csv` data to TFRecords data. Example code below (see `read_csv` from `from_generator` example above).

```
with tf.python_io.TFRecordWriter("my_train_dataset.tfrecords") as writer:
    for features, labels in read_csv('my_train_dataset.csv'):
        example = tf.train.Example()
        example.features.feature[
            "features"].float_list.value.extend(features)
        example.features.feature[
            "label"].int64_list.value.append(label)
        writer.write(example.SerializeToString())
```

This only needs to be run once.

Step 2: Write a dataset that decodes these record files.

```
def parse_function(example_proto):
    features = {
        'features': tf.FixedLenFeature((n_features,), tf.float32),
        'label': tf.FixedLenFeature((), tf.int64)
    }
    parsed_features = tf.parse_single_example(example_proto, features)
    return parsed_features['features'], parsed_features['label']

def get_dataset():
    dataset = tf.data.TFRecordDataset(['data.tfrecords'])
    dataset = dataset.map(parse_function)
    return dataset
```

## Using the dataset with estimators

```
def get_inputs(batch_size, shuffle_size):
    dataset = get_dataset()  # one of the above implementations
    dataset = dataset.shuffle(shuffle_size)
    dataset = dataset.repeat()  # repeat indefinitely
    dataset = dataset.batch(batch_size)
            # prefetch data to CPU while GPU processes previous batch
    dataset = dataset.prefetch(1)
    # Also possible
    # dataset = dataset.apply(
    #     tf.contrib.data.prefetch_to_device('/gpu:0'))
    features, label = dataset.make_one_shot_iterator().get_next()

estimator.train(lambda: get_inputs(32, 1000), max_steps=1e7)
```

## Testing the dataset in isolation

I'd strongly encourage you to test your dataset independently of your estimator. Using the above `get_inputs`, it should be as simple as

```
batch_size = 4
shuffle_size = 100
features, labels = get_inputs(batch_size, shuffle_size)
with tf.Session() as sess:
    f_data, l_data = sess.run([features, labels])
print(f_data, l_data)  # or some better visualization function
```

## Performance

Assuming your using a GPU to run your network, unless each row of your `csv` file is enormous and your network is tiny you probably won't notice a difference in performance. This is because the `Estimator` implementation forces data loading/preprocessing to be performed on the CPU, and `prefetch` means the next batch can be prepared on the CPU as the current batch is training on the GPU. The only exception to this is if you have a massive shuffle size on a dataset with a large amount of data per record, which will take some time to load in a number of examples initially before running anything through the GPU.


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://stephanosterburg.gitbook.io/scrapbook/career/learn.co/capstone-project-notes/streaming-large-training-and-test-files-into-tensorflows-dnnclassifier.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
