Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hybridbackend/cpp/tensorflow/ops/parquet_dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class ParquetTabularDatasetOp::Dataset : public DatasetBase {
}

Status Open() {
VLOG(0) << "Starting to read " << filename_ << " ...";
VLOG(1) << "Starting to read " << filename_ << " ...";
return reader_->Open();
}

Expand Down
2 changes: 1 addition & 1 deletion hybridbackend/cpp/tensorflow/ops/rebatch_dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class RebatchTabularDatasetOp::Dataset::Iterator
return 0;
}
if (field_ranks_[0] > 0) {
return input_tensors[1].dim_size(0);
return input_tensors[1].dim_size(0) - 1;
}
return input_tensors[0].dim_size(0);
}
Expand Down
36 changes: 36 additions & 0 deletions tests/tensorflow/data/parquet_dataset_rebatch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import os
import random
from six.moves import xrange # pylint: disable=redefined-builtin
import tempfile
import tensorflow as tf
Expand Down Expand Up @@ -234,6 +237,39 @@ def test_thread_pool(self):
actual.nested_row_splits, expected.nested_row_splits)


class ParquetDatasetSequenceRebatchTest(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name
self._workspace = tempfile.mkdtemp()
self._filename = os.path.join(self._workspace, 'seqtest.parquet')
self._nrows = 1000
self._ncols = 10
self._data = {
'clicks': [
[random.randint(0, 100) for col in range(self._ncols)]
for row in range(self._nrows)]}
pq.write_table(pa.Table.from_pydict(self._data), self._filename)

def tearDown(self): # pylint: disable=invalid-name
os.remove(self._filename)

def test_ragged(self):
batch_size = 8
with tf.Graph().as_default() as graph:
ds = ParquetDataset(self._filename, batch_size=batch_size)
ds = ds.apply(rebatch(batch_size))
batch = make_one_shot_iterator(ds).get_next()

clicks = self._data['clicks']
with tf.Session(graph=graph) as sess:
for i in xrange(3):
actual = sess.run(batch['clicks'])
start_row = i * batch_size
end_row = (i + 1) * batch_size
expected = clicks[start_row:end_row]
expected_values = [v for sublist in expected for v in sublist]
np.testing.assert_equal(actual.values, expected_values)


if __name__ == '__main__':
register(['cpu', 'data'])
os.environ['CUDA_VISIBLE_DEVICES'] = ''
Expand Down