Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorflow_io/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ cc_library(
"kernels/bigtable/bigtable_row_set.h",
"kernels/bigtable/bigtable_version_filters.cc",
"kernels/bigtable/bigtable_version_filters.h",
"kernels/bigtable/serialization.cc",
"kernels/bigtable/serialization.h",
"ops/bigtable_ops.cc",
],
copts = tf_io_copts(),
Expand Down
20 changes: 14 additions & 6 deletions tensorflow_io/core/kernels/bigtable/bigtable_dataset_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow_io/core/kernels/bigtable/serialization.h"
#include "tensorflow_io/core/kernels/bigtable/bigtable_row_set.h"
#include "tensorflow_io/core/kernels/bigtable/bigtable_version_filters.h"

Expand Down Expand Up @@ -169,8 +170,8 @@ class Iterator : public DatasetIterator<Dataset> {

VLOG(1) << "alocating tensor";
const std::size_t kNumCols = column_to_idx_.size();
Tensor res(ctx->allocator({}), DT_STRING, {(long)kNumCols});
auto res_data = res.tensor<tstring, 1>();
const DataType dtype = this->dataset()->output_type();
Tensor res(ctx->allocator({}), dtype, {(long)kNumCols});

VLOG(1) << "getting row";
const auto& row = *it_;
Expand All @@ -184,7 +185,8 @@ class Iterator : public DatasetIterator<Dataset> {
const auto column_idx = column_to_idx_.find(key);
if (column_idx != column_to_idx_.end()) {
VLOG(1) << "getting column:" << column_idx->second;
res_data(column_idx->second) = std::move(cell.value());
TF_RETURN_IF_ERROR(
io::PutCellValueInTensor(res, column_idx->second, dtype, cell));
} else {
LOG(ERROR) << "column " << cell.family_name() << ":"
<< cell.column_qualifier()
Expand Down Expand Up @@ -280,14 +282,15 @@ class Dataset : public DatasetBase {
Dataset(OpKernelContext* ctx,
const std::shared_ptr<cbt::DataClient>& data_client,
cbt::RowSet row_set, cbt::Filter filter, std::string table_id,
std::vector<std::string> columns)
std::vector<std::string> columns, DataType output_type)
: DatasetBase(DatasetContext(ctx)),
data_client_(data_client),
row_set_(std::move(row_set)),
filter_(std::move(filter)),
output_type_(std::move(output_type)),
table_id_(table_id),
columns_(columns) {
dtypes_.push_back(DT_STRING);
dtypes_.push_back({output_type_});
output_shapes_.push_back({});
}

Expand All @@ -306,6 +309,8 @@ class Dataset : public DatasetBase {
return output_shapes_;
}

const DataType output_type() const { return output_type_; }

std::string DebugString() const override {
return "BigtableDatasetOp::Dataset";
}
Expand Down Expand Up @@ -338,6 +343,7 @@ class Dataset : public DatasetBase {
std::shared_ptr<cbt::DataClient> const& data_client_;
const cbt::RowSet row_set_;
cbt::Filter filter_;
DataType output_type_;
const std::string table_id_;
const std::vector<std::string> columns_;
DataTypeVector dtypes_;
Expand All @@ -349,6 +355,7 @@ class BigtableDatasetOp : public DatasetOpKernel {
explicit BigtableDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("table_id", &table_id_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("columns", &columns_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_type", &output_type_));
}

void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
Expand All @@ -370,12 +377,13 @@ class BigtableDatasetOp : public DatasetOpKernel {

*output = new Dataset(ctx, client_resource->data_client(),
row_set_resource->row_set(),
filter_resource->filter(), table_id_, columns_);
filter_resource->filter(), table_id_, columns_, output_type_);
}

private:
std::string table_id_;
std::vector<std::string> columns_;
DataType output_type_;
};

REGISTER_KERNEL_BUILDER(Name("BigtableDataset").Device(DEVICE_CPU),
Expand Down
130 changes: 130 additions & 0 deletions tensorflow_io/core/kernels/bigtable/serialization.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/

#include "tensorflow_io/core/kernels/bigtable/serialization.h"
#include "rpc/xdr.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/statusor.h"

namespace tensorflow {
namespace io {

inline StatusOr<float> BytesToFloat(std::string const& s) {
float v;
XDR xdrs;
xdrmem_create(&xdrs, const_cast<char*>(s.data()), sizeof(v), XDR_DECODE);
if (!xdr_float(&xdrs, &v)) {
return errors::InvalidArgument("Error reading float from byte array.");
}
return v;
}

inline StatusOr<double> BytesToDouble(std::string const& s) {
double v;
XDR xdrs;
xdrmem_create(&xdrs, const_cast<char*>(s.data()), sizeof(v), XDR_DECODE);
if (!xdr_double(&xdrs, &v)) {
return errors::InvalidArgument("Error reading double from byte array.");
}
return v;
}

inline StatusOr<int64_t> BytesToInt64(std::string const& s) {
int64_t v;
XDR xdrs;
xdrmem_create(&xdrs, const_cast<char*>(s.data()), sizeof(v), XDR_DECODE);
if (!xdr_int64_t(&xdrs, &v)) {
return errors::InvalidArgument("Error reading int64 from byte array.");
}
return v;
}

inline StatusOr<int32_t> BytesToInt32(std::string const& s) {
int32_t v;
XDR xdrs;
xdrmem_create(&xdrs, const_cast<char*>(s.data()), sizeof(v), XDR_DECODE);
if (!xdr_int32_t(&xdrs, &v)) {
return errors::InvalidArgument("Error reading int32 from byte array.");
}
return v;
}

inline StatusOr<bool_t> BytesToBool(std::string const& s) {
bool_t v;
XDR xdrs;
xdrmem_create(&xdrs, const_cast<char*>(s.data()), sizeof(v), XDR_DECODE);
if (!xdr_bool(&xdrs, &v)) {
return errors::InvalidArgument("Error reading bool from byte array.");
}
return v;
}

Status PutCellValueInTensor(Tensor& tensor, size_t index,
DataType cell_type,
google::cloud::bigtable::Cell const& cell) {
switch (cell_type) {
case DT_STRING: {
auto tensor_data = tensor.tensor<tstring, 1>();
tensor_data(index) = std::string(cell.value());
} break;
case DT_BOOL: {
auto tensor_data = tensor.tensor<bool, 1>();
auto maybe_parsed_data = BytesToBool(cell.value());
if (!maybe_parsed_data.ok()) {
return maybe_parsed_data.status();
}
tensor_data(index) = maybe_parsed_data.ValueOrDie();
} break;
case DT_INT32: {
auto tensor_data = tensor.tensor<int32_t, 1>();
auto maybe_parsed_data = BytesToInt32(cell.value());
if (!maybe_parsed_data.ok()) {
return maybe_parsed_data.status();
}
tensor_data(index) = maybe_parsed_data.ValueOrDie();
} break;
case DT_INT64: {
auto tensor_data = tensor.tensor<int64_t, 1>();
auto maybe_parsed_data = BytesToInt64(cell.value());
if (!maybe_parsed_data.ok()) {
return maybe_parsed_data.status();
}
tensor_data(index) = maybe_parsed_data.ValueOrDie();
} break;
case DT_FLOAT: {
auto tensor_data = tensor.tensor<float, 1>();
auto maybe_parsed_data = BytesToFloat(cell.value());
if (!maybe_parsed_data.ok()) {
return maybe_parsed_data.status();
}
tensor_data(index) = maybe_parsed_data.ValueOrDie();
} break;
case DT_DOUBLE: {
auto tensor_data = tensor.tensor<double, 1>();
auto maybe_parsed_data = BytesToDouble(cell.value());
if (!maybe_parsed_data.ok()) {
return maybe_parsed_data.status();
}
tensor_data(index) = maybe_parsed_data.ValueOrDie();
} break;
default:
return errors::Unimplemented("Data type not supported.");
}
return Status::OK();
}

} // namespace io
} // namespace tensorflow
36 changes: 36 additions & 0 deletions tensorflow_io/core/kernels/bigtable/serialization.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef SERIALIZATION_H
#define SERIALIZATION_H

#include "google/cloud/bigtable/table.h"
#include "tensorflow/core/framework/tensor.h"

namespace tensorflow {
namespace io {

// Bigtable only stores values as byte buffers - except for int64 the server
// side does not have any notion of types. Tensorflow, needs to store shorter
// integers, floats, doubles, so we needed to decide on how. We chose to follow
// what HBase does, since there is a path for migrating from HBase to Bigtable.
// XDR seems to match what HBase does.
Status PutCellValueInTensor(Tensor& tensor, size_t index, DataType cell_type,
google::cloud::bigtable::Cell const& cell);

} // namespace io
} // namespace tensorflow

#endif /* SERIALIZATION_H */
1 change: 1 addition & 0 deletions tensorflow_io/core/ops/bigtable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ REGISTER_OP("BigtableDataset")
.Input("filter: resource")
.Attr("table_id: string")
.Attr("columns: list(string) >= 1")
.Attr("output_type: type")
.Output("handle: variant")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
Expand Down
50 changes: 35 additions & 15 deletions tensorflow_io/python/ops/bigtable/bigtable_dataset_ops.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import tensor_spec
from tensorflow_io.python.ops import core_ops
import tensorflow_io.python.ops.bigtable.bigtable_version_filters as filters
import tensorflow_io.python.ops.bigtable.bigtable_row_set as bigtable_row_set
import tensorflow_io.python.ops.bigtable.bigtable_row_range as bigtable_row_range
from tensorflow.python.framework import dtypes
import tensorflow as tf
from tensorflow.python.data.ops import dataset_ops

from tensorflow_io.python.ops.bigtable.bigtable_row_set import (
from_rows_or_ranges,
RowSet,
intersect,
)
from tensorflow_io.python.ops.bigtable.bigtable_row_range import infinite


class BigtableClient:
"""BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF.
Expand All @@ -38,28 +47,38 @@ def __init__(self, client_resource, table_id: str):
def read_rows(
self,
columns: List[str],
row_set: RowSet,
row_set: bigtable_row_set.RowSet,
filter: filters.BigtableFilter = filters.latest(),
output_type=tf.string,
):
return _BigtableDataset(
self._client_resource, self._table_id, columns, row_set, filter
self._client_resource,
self._table_id,
columns,
row_set,
filter,
output_type,
)

def parallel_read_rows(
self,
columns: List[str],
num_parallel_calls=tf.data.AUTOTUNE,
row_set: RowSet = from_rows_or_ranges(infinite()),
row_set: bigtable_row_set.RowSet = bigtable_row_set.from_rows_or_ranges(
bigtable_row_range.infinite()
),
filter: filters.BigtableFilter = filters.latest(),
output_type=tf.string,
):

print("calling parallel read_rows with row_set:", row_set)
samples = core_ops.bigtable_split_row_set_evenly(
self._client_resource, row_set._impl, self._table_id, num_parallel_calls,
)

def map_func(idx):
return self.read_rows(columns, RowSet(samples[idx]), filter)
return self.read_rows(
columns, bigtable_row_set.RowSet(samples[idx]), filter, output_type
)

# We interleave a dataset of sample's indexes instead of a dataset of
# samples, because Dataset.from_tensor_slices attempts to copy the
Expand All @@ -82,16 +101,17 @@ def __init__(
client_resource,
table_id: str,
columns: List[str],
row_set: RowSet,
row_set: bigtable_row_set.RowSet,
filter,
output_type,
):
self._table_id = table_id
self._columns = columns
self._filter = filter
self._element_spec = tf.TensorSpec(shape=[len(columns)], dtype=dtypes.string)
self._element_spec = tf.TensorSpec(shape=[len(columns)], dtype=output_type)

variant_tensor = core_ops.bigtable_dataset(
client_resource, row_set._impl, filter._impl, table_id, columns
client_resource, row_set._impl, filter._impl, table_id, columns, output_type
)
super().__init__(variant_tensor)

Expand Down