diff --git a/ggml.h b/ggml.h index bdbd128004332..4fa78a4273d82 100644 --- a/ggml.h +++ b/ggml.h @@ -255,8 +255,9 @@ extern "C" { #endif -#ifdef __ARM_NEON - // we use the built-in 16-bit float type +#if defined(__ARM_NEON) && defined(__CUDACC__) + typedef half ggml_fp16_t; +#elif defined(__ARM_NEON) typedef __fp16 ggml_fp16_t; #else typedef uint16_t ggml_fp16_t;