Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 173 additions & 37 deletions src/iceberg/expression/literal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@

#include <cmath>
#include <concepts>
#include <cstdint>
#include <string>

#include "iceberg/exception.h"
#include "iceberg/type_fwd.h"
#include "iceberg/util/checked_cast.h"
#include "iceberg/util/conversions.h"
#include "iceberg/util/macros.h"

namespace iceberg {

Expand Down Expand Up @@ -54,6 +56,30 @@ class LiteralCaster {
/// Cast from Float type to target type.
static Result<Literal> CastFromFloat(const Literal& literal,
const std::shared_ptr<PrimitiveType>& target_type);

/// Cast from Double type to target type.
static Result<Literal> CastFromDouble(
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type);

/// Cast from String type to target type.
static Result<Literal> CastFromString(
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type);

/// Cast from Timestamp type to target type.
static Result<Literal> CastFromTimestamp(
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type);

/// Cast from TimestampTz type to target type.
static Result<Literal> CastFromTimestampTz(
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type);

/// Cast from Binary type to target type.
static Result<Literal> CastFromBinary(
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type);

/// Cast from Fixed type to target type.
static Result<Literal> CastFromFixed(const Literal& literal,
const std::shared_ptr<PrimitiveType>& target_type);
};

Literal LiteralCaster::BelowMinLiteral(std::shared_ptr<PrimitiveType> type) {
Expand All @@ -76,6 +102,8 @@ Result<Literal> LiteralCaster::CastFromInt(
return Literal::Float(static_cast<float>(int_val));
case TypeId::kDouble:
return Literal::Double(static_cast<double>(int_val));
case TypeId::kDate:
return Literal::Date(int_val);
default:
return NotSupported("Cast from Int to {} is not implemented",
target_type->ToString());
Expand All @@ -85,15 +113,14 @@ Result<Literal> LiteralCaster::CastFromInt(
Result<Literal> LiteralCaster::CastFromLong(
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) {
auto long_val = std::get<int64_t>(literal.value_);
auto target_type_id = target_type->type_id();

switch (target_type_id) {
switch (target_type->type_id()) {
case TypeId::kInt: {
// Check for overflow
if (long_val >= std::numeric_limits<int32_t>::max()) {
if (long_val > std::numeric_limits<int32_t>::max()) {
return AboveMaxLiteral(target_type);
}
if (long_val <= std::numeric_limits<int32_t>::min()) {
if (long_val < std::numeric_limits<int32_t>::min()) {
return BelowMinLiteral(target_type);
}
return Literal::Int(static_cast<int32_t>(long_val));
Expand All @@ -102,6 +129,21 @@ Result<Literal> LiteralCaster::CastFromLong(
return Literal::Float(static_cast<float>(long_val));
case TypeId::kDouble:
return Literal::Double(static_cast<double>(long_val));
case TypeId::kDate: {
if (long_val > std::numeric_limits<int32_t>::max()) {
return AboveMaxLiteral(target_type);
}
if (long_val < std::numeric_limits<int32_t>::min()) {
return BelowMinLiteral(target_type);
}
return Literal::Date(static_cast<int32_t>(long_val));
}
case TypeId::kTime:
return Literal::Time(long_val);
case TypeId::kTimestamp:
return Literal::Timestamp(long_val);
case TypeId::kTimestampTz:
return Literal::TimestampTz(long_val);
default:
return NotSupported("Cast from Long to {} is not supported",
target_type->ToString());
Expand All @@ -111,9 +153,8 @@ Result<Literal> LiteralCaster::CastFromLong(
Result<Literal> LiteralCaster::CastFromFloat(
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) {
auto float_val = std::get<float>(literal.value_);
auto target_type_id = target_type->type_id();

switch (target_type_id) {
switch (target_type->type_id()) {
case TypeId::kDouble:
return Literal::Double(static_cast<double>(float_val));
default:
Expand All @@ -122,6 +163,103 @@ Result<Literal> LiteralCaster::CastFromFloat(
}
}

Result<Literal> LiteralCaster::CastFromDouble(
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) {
auto double_val = std::get<double>(literal.value_);

switch (target_type->type_id()) {
case TypeId::kFloat: {
if (double_val > static_cast<double>(std::numeric_limits<float>::max())) {
return AboveMaxLiteral(target_type);
}
if (double_val < static_cast<double>(std::numeric_limits<float>::lowest())) {
return BelowMinLiteral(target_type);
}
return Literal::Float(static_cast<float>(double_val));
}
default:
return NotSupported("Cast from Double to {} is not supported",
target_type->ToString());
}
}

Result<Literal> LiteralCaster::CastFromString(
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) {
const auto& str_val = std::get<std::string>(literal.value_);

switch (target_type->type_id()) {
case TypeId::kDate:
case TypeId::kTime:
case TypeId::kTimestamp:
case TypeId::kTimestampTz:
case TypeId::kUuid:
return NotImplemented("Cast from String to {} is not implemented yet",
target_type->ToString());
default:
return NotSupported("Cast from String to {} is not supported",
target_type->ToString());
}
}

Result<Literal> LiteralCaster::CastFromTimestamp(
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) {
auto timestamp_val = std::get<int64_t>(literal.value_);

switch (target_type->type_id()) {
case TypeId::kDate:
return NotImplemented("Cast from Timestamp to Date is not implemented yet");
case TypeId::kTimestampTz:
return Literal::TimestampTz(timestamp_val);
default:
return NotSupported("Cast from Timestamp to {} is not supported",
target_type->ToString());
}
}

Result<Literal> LiteralCaster::CastFromTimestampTz(
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) {
auto micros = std::get<int64_t>(literal.value_);

switch (target_type->type_id()) {
case TypeId::kDate:
return NotImplemented("Cast from TimestampTz to Date is not implemented yet");
case TypeId::kTimestamp:
return Literal::Timestamp(micros);
default:
return NotSupported("Cast from TimestampTz to {} is not supported",
target_type->ToString());
}
}

Result<Literal> LiteralCaster::CastFromBinary(
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) {
auto binary_val = std::get<std::vector<uint8_t>>(literal.value_);
switch (target_type->type_id()) {
case TypeId::kFixed: {
auto target_fixed_type = internal::checked_pointer_cast<FixedType>(target_type);
if (binary_val.size() == target_fixed_type->length()) {
return Literal::Fixed(std::move(binary_val));
}
return InvalidArgument("Failed to cast Binary with length {} to Fixed({})",
binary_val.size(), target_fixed_type->length());
}
default:
return NotSupported("Cast from Binary to {} is not supported",
target_type->ToString());
}
}

Result<Literal> LiteralCaster::CastFromFixed(
const Literal& literal, const std::shared_ptr<PrimitiveType>& target_type) {
switch (target_type->type_id()) {
case TypeId::kBinary:
return Literal::Binary(std::get<std::vector<uint8_t>>(literal.value_));
default:
return NotSupported("Cast from Fixed to {} is not supported",
target_type->ToString());
}
}

// Constructor
Literal::Literal(Value value, std::shared_ptr<PrimitiveType> type)
: value_(std::move(value)), type_(std::move(type)) {}
Expand Down Expand Up @@ -152,8 +290,8 @@ Literal Literal::Binary(std::vector<uint8_t> value) {
}

Literal Literal::Fixed(std::vector<uint8_t> value) {
auto length = static_cast<int32_t>(value.size());
return {Value{std::move(value)}, fixed(length)};
const auto size = value.size();
return {Value{std::move(value)}, fixed(size)};
}

Result<Literal> Literal::Deserialize(std::span<const uint8_t> data,
Expand Down Expand Up @@ -251,12 +389,7 @@ std::partial_ordering Literal::operator<=>(const Literal& other) const {
return this_val <=> other_val;
}

case TypeId::kBinary: {
auto& this_val = std::get<std::vector<uint8_t>>(value_);
auto& other_val = std::get<std::vector<uint8_t>>(other.value_);
return this_val <=> other_val;
}

case TypeId::kBinary:
case TypeId::kFixed: {
auto& this_val = std::get<std::vector<uint8_t>>(value_);
auto& other_val = std::get<std::vector<uint8_t>>(other.value_);
Expand Down Expand Up @@ -297,36 +430,29 @@ std::string Literal::ToString() const {
return std::to_string(std::get<double>(value_));
}
case TypeId::kString: {
return std::get<std::string>(value_);
return "\"" + std::get<std::string>(value_) + "\"";
}
case TypeId::kBinary: {
case TypeId::kBinary:
case TypeId::kFixed: {
const auto& binary_data = std::get<std::vector<uint8_t>>(value_);
std::string result;
result.reserve(binary_data.size() * 2); // 2 chars per byte
std::string result = "X'";
result.reserve(/*prefix*/ 2 + /*suffix*/ 1 + /*data*/ binary_data.size() * 2);
for (const auto& byte : binary_data) {
std::format_to(std::back_inserter(result), "{:02X}", byte);
}
result.push_back('\'');
return result;
}
case TypeId::kFixed: {
const auto& fixed_data = std::get<std::vector<uint8_t>>(value_);
std::string result;
result.reserve(fixed_data.size() * 2); // 2 chars per byte
for (const auto& byte : fixed_data) {
std::format_to(std::back_inserter(result), "{:02X}", byte);
}
return result;
}
case TypeId::kDecimal:
case TypeId::kUuid:
case TypeId::kDate:
case TypeId::kTime:
case TypeId::kTimestamp:
case TypeId::kTimestampTz: {
throw IcebergError("Not implemented: ToString for " + type_->ToString());
return std::to_string(std::get<int64_t>(value_));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these switch cases can be easily rewritten by return std::to_string(std::get<typename LiteralTraits<type_id>::ValueType>(value_)); once #185 is merged.

}
case TypeId::kDate: {
return std::to_string(std::get<int32_t>(value_));
}
default: {
throw IcebergError("Unknown type: " + type_->ToString());
return std::format("invalid literal of type {}", type_->ToString());
}
}
}
Expand Down Expand Up @@ -358,22 +484,32 @@ Result<Literal> LiteralCaster::CastTo(const Literal& literal,

// Delegate to specific cast functions based on source type
switch (source_type_id) {
case TypeId::kBoolean:
// No casts defined for Boolean, other than to itself.
break;
case TypeId::kInt:
return CastFromInt(literal, target_type);
case TypeId::kLong:
return CastFromLong(literal, target_type);
case TypeId::kFloat:
return CastFromFloat(literal, target_type);
case TypeId::kDouble:
case TypeId::kBoolean:
return CastFromDouble(literal, target_type);
case TypeId::kString:
return CastFromString(literal, target_type);
case TypeId::kBinary:
break;
return CastFromBinary(literal, target_type);
case TypeId::kFixed:
return CastFromFixed(literal, target_type);
case TypeId::kTimestamp:
return CastFromTimestamp(literal, target_type);
case TypeId::kTimestampTz:
return CastFromTimestampTz(literal, target_type);
default:
break;
}

return NotSupported("Cast from {} to {} is not implemented", literal.type_->ToString(),
return NotSupported("Cast from {} to {} is not supported", literal.type_->ToString(),
target_type->ToString());
}

Expand Down
7 changes: 3 additions & 4 deletions src/iceberg/expression/predicate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,11 @@ std::string UnboundPredicate<B>::ToString() const {
return values_.size() == 1 ? std::format("{} != {}", term, values_[0])
: invalid_predicate_string(op);
case Expression::Operation::kStartsWith:
return values_.size() == 1 ? std::format("{} startsWith \"{}\"", term, values_[0])
return values_.size() == 1 ? std::format("{} startsWith {}", term, values_[0])
: invalid_predicate_string(op);
case Expression::Operation::kNotStartsWith:
return values_.size() == 1
? std::format("{} notStartsWith \"{}\"", term, values_[0])
: invalid_predicate_string(op);
return values_.size() == 1 ? std::format("{} notStartsWith {}", term, values_[0])
: invalid_predicate_string(op);
case Expression::Operation::kIn:
return std::format("{} in {}", term, values_);
case Expression::Operation::kNotIn:
Expand Down
Loading
Loading