Skip to content

Commit e51c1d9

Browse files
committed
add support for f16
1 parent d40cdc1 commit e51c1d9

File tree

9 files changed

+48
-21
lines changed

9 files changed

+48
-21
lines changed

arrow/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ serde_json = { version = "1.0", features = ["preserve_order"] }
4343
indexmap = "1.6"
4444
rand = { version = "0.8", optional = true }
4545
num = "0.4"
46+
half = "1.8"
4647
csv_crate = { version = "1.1", optional = true, package="csv" }
4748
regex = "1.3"
4849
lazy_static = "1.4"

arrow/src/alloc/types.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use crate::datatypes::DataType;
19+
use half::f16;
1920

2021
/// A type that Rust's custom allocator knows how to allocate and deallocate.
2122
/// This is implemented for all Arrow's physical types whose in-memory representation
@@ -67,5 +68,6 @@ create_native!(
6768
i64,
6869
DataType::Int64 | DataType::Date64 | DataType::Time64(_) | DataType::Timestamp(_, _)
6970
);
71+
create_native!(f16, DataType::Float16);
7072
create_native!(f32, DataType::Float32);
7173
create_native!(f64, DataType::Float64);

arrow/src/array/array.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef {
240240
DataType::UInt16 => Arc::new(UInt16Array::from(data)) as ArrayRef,
241241
DataType::UInt32 => Arc::new(UInt32Array::from(data)) as ArrayRef,
242242
DataType::UInt64 => Arc::new(UInt64Array::from(data)) as ArrayRef,
243-
DataType::Float16 => panic!("Float16 datatype not supported"),
243+
DataType::Float16 => Arc::new(Float16Array::from(data)) as ArrayRef,
244244
DataType::Float32 => Arc::new(Float32Array::from(data)) as ArrayRef,
245245
DataType::Float64 => Arc::new(Float64Array::from(data)) as ArrayRef,
246246
DataType::Date32 => Arc::new(Date32Array::from(data)) as ArrayRef,
@@ -393,7 +393,7 @@ pub fn new_null_array(data_type: &DataType, length: usize) -> ArrayRef {
393393
DataType::UInt8 => new_null_sized_array::<UInt8Type>(data_type, length),
394394
DataType::Int16 => new_null_sized_array::<Int16Type>(data_type, length),
395395
DataType::UInt16 => new_null_sized_array::<UInt16Type>(data_type, length),
396-
DataType::Float16 => unreachable!(),
396+
DataType::Float16 => new_null_sized_array::<Float16Type>(data_type, length),
397397
DataType::Int32 => new_null_sized_array::<Int32Type>(data_type, length),
398398
DataType::UInt32 => new_null_sized_array::<UInt32Type>(data_type, length),
399399
DataType::Float32 => new_null_sized_array::<Float32Type>(data_type, length),

arrow/src/array/data.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,16 @@
1818
//! Contains `ArrayData`, a generic representation of Arrow array data which encapsulates
1919
//! common attributes and operations for Arrow array.
2020
21-
use std::convert::TryInto;
22-
use std::mem;
23-
use std::sync::Arc;
24-
2521
use crate::datatypes::{DataType, IntervalUnit};
2622
use crate::error::{ArrowError, Result};
2723
use crate::{bitmap::Bitmap, datatypes::ArrowNativeType};
2824
use crate::{
2925
buffer::{Buffer, MutableBuffer},
3026
util::bit_util,
3127
};
28+
use std::convert::TryInto;
29+
use std::mem;
30+
use std::sync::Arc;
3231

3332
use super::equal::equal;
3433

@@ -89,6 +88,10 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff
8988
MutableBuffer::new(capacity * mem::size_of::<i64>()),
9089
empty_buffer,
9190
],
91+
DataType::Float16 => [
92+
MutableBuffer::new(capacity * mem::size_of::<u16>()),
93+
empty_buffer,
94+
],
9295
DataType::Float32 => [
9396
MutableBuffer::new(capacity * mem::size_of::<f32>()),
9497
empty_buffer,
@@ -178,7 +181,6 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff
178181
],
179182
_ => unreachable!(),
180183
},
181-
DataType::Float16 => unreachable!(),
182184
DataType::FixedSizeList(_, _) | DataType::Struct(_) => {
183185
[empty_buffer, MutableBuffer::new(0)]
184186
}
@@ -319,7 +321,7 @@ impl ArrayData {
319321
buffers: Vec<Buffer>,
320322
child_data: Vec<ArrayData>,
321323
) -> Result<Self> {
322-
// Safetly justification: `validate` is (will be) called below
324+
// Safety justification: `validate` is (will be) called below
323325
let new_self = unsafe {
324326
Self::new_unchecked(
325327
data_type,
@@ -519,6 +521,7 @@ impl ArrayData {
519521
| DataType::Int16
520522
| DataType::Int32
521523
| DataType::Int64
524+
| DataType::Float16
522525
| DataType::Float32
523526
| DataType::Float64
524527
| DataType::Date32
@@ -554,7 +557,6 @@ impl ArrayData {
554557
DataType::Dictionary(_, data_type) => {
555558
vec![Self::new_empty(data_type)]
556559
}
557-
DataType::Float16 => unreachable!(),
558560
};
559561

560562
// Data was constructed correctly above

arrow/src/array/equal/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,12 @@ fn equal_values(
251251
),
252252
_ => unreachable!(),
253253
},
254-
DataType::Float16 => unreachable!(),
254+
DataType::Float16 => {
255+
use half::f16;
256+
primitive_equal::<f16>(
257+
lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
258+
)
259+
}
255260
DataType::Map(_, _) => {
256261
list_equal::<i32>(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
257262
}

arrow/src/array/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,14 @@ pub type UInt64Array = PrimitiveArray<UInt64Type>;
192192
///
193193
/// # Example: Using `collect`
194194
/// ```
195+
/// # use arrow::array::Float16Array;
196+
/// use half::f16;
197+
/// let arr : Float16Array = [Some(f16::from_f64(1.0)), Some(f16::from_f64(2.0))].into_iter().collect();
198+
/// ```
199+
pub type Float16Array = PrimitiveArray<Float16Type>;
200+
///
201+
/// # Example: Using `collect`
202+
/// ```
195203
/// # use arrow::array::Float32Array;
196204
/// let arr : Float32Array = [Some(1.0), Some(2.0)].into_iter().collect();
197205
/// ```

arrow/src/array/transform/mod.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,20 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use super::{
19+
data::{into_buffers, new_buffers},
20+
ArrayData, ArrayDataBuilder,
21+
};
22+
use crate::array::StringOffsetSizeTrait;
1823
use crate::{
1924
buffer::MutableBuffer,
2025
datatypes::DataType,
2126
error::{ArrowError, Result},
2227
util::bit_util,
2328
};
29+
use half::f16;
2430
use std::mem;
2531

26-
use super::{
27-
data::{into_buffers, new_buffers},
28-
ArrayData, ArrayDataBuilder,
29-
};
30-
use crate::array::StringOffsetSizeTrait;
31-
3232
mod boolean;
3333
mod fixed_binary;
3434
mod list;
@@ -266,7 +266,7 @@ fn build_extend(array: &ArrayData) -> Extend {
266266
DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"),
267267
DataType::Struct(_) => structure::build_extend(array),
268268
DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
269-
DataType::Float16 => unreachable!(),
269+
DataType::Float16 => primitive::build_extend::<f16>(array),
270270
/*
271271
DataType::FixedSizeList(_, _) => {}
272272
DataType::Union(_) => {}
@@ -315,7 +315,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls {
315315
},
316316
DataType::Struct(_) => structure::extend_nulls,
317317
DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls,
318-
DataType::Float16 => unreachable!(),
318+
DataType::Float16 => primitive::extend_nulls::<f16>,
319319
/*
320320
DataType::FixedSizeList(_, _) => {}
321321
DataType::Union(_) => {}
@@ -429,6 +429,7 @@ impl<'a> MutableArrayData<'a> {
429429
| DataType::Int16
430430
| DataType::Int32
431431
| DataType::Int64
432+
| DataType::Float16
432433
| DataType::Float32
433434
| DataType::Float64
434435
| DataType::Date32
@@ -467,7 +468,6 @@ impl<'a> MutableArrayData<'a> {
467468
}
468469
// the dictionary type just appends keys and clones the values.
469470
DataType::Dictionary(_, _) => vec![],
470-
DataType::Float16 => unreachable!(),
471471
DataType::Struct(fields) => match capacities {
472472
Capacities::Struct(capacity, Some(ref child_capacities)) => {
473473
array_capacity = capacity;

arrow/src/datatypes/native.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use serde_json::{Number, Value};
19-
2018
use super::DataType;
19+
use half::f16;
20+
use serde_json::{Number, Value};
2121

2222
/// Trait declaring any type that is serializable to JSON. This includes all primitive types (bool, i32, etc.).
2323
pub trait JsonSerializable: 'static {
@@ -293,6 +293,12 @@ impl ArrowNativeType for u64 {
293293
}
294294
}
295295

296+
impl JsonSerializable for f16 {
297+
fn into_json_value(self) -> Option<Value> {
298+
Number::from_f64(f64::round(f64::from(self) * 1000.0) / 1000.0).map(Value::Number)
299+
}
300+
}
301+
296302
impl JsonSerializable for f32 {
297303
fn into_json_value(self) -> Option<Value> {
298304
Number::from_f64(f64::round(self as f64 * 1000.0) / 1000.0).map(Value::Number)
@@ -305,6 +311,7 @@ impl JsonSerializable for f64 {
305311
}
306312
}
307313

314+
impl ArrowNativeType for f16 {}
308315
impl ArrowNativeType for f32 {}
309316
impl ArrowNativeType for f64 {}
310317

arrow/src/datatypes/types.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use super::{ArrowPrimitiveType, DataType, IntervalUnit, TimeUnit};
19+
use half::f16;
1920

2021
// BooleanType is special: its bit-width is not the size of the primitive type, and its `index`
2122
// operation assumes bit-packing.
@@ -46,6 +47,7 @@ make_type!(UInt8Type, u8, DataType::UInt8);
4647
make_type!(UInt16Type, u16, DataType::UInt16);
4748
make_type!(UInt32Type, u32, DataType::UInt32);
4849
make_type!(UInt64Type, u64, DataType::UInt64);
50+
make_type!(Float16Type, f16, DataType::Float16);
4951
make_type!(Float32Type, f32, DataType::Float32);
5052
make_type!(Float64Type, f64, DataType::Float64);
5153
make_type!(

0 commit comments

Comments
 (0)