11use crate :: error:: ExTokenizersError ;
2+ use rustler:: resource:: ResourceArc ;
3+ use rustler:: { Binary , Env } ;
24use tokenizers:: utils:: padding:: PaddingDirection ;
35use tokenizers:: utils:: truncation:: TruncationDirection ;
46use tokenizers:: Encoding ;
@@ -8,7 +10,7 @@ pub struct ExTokenizersEncodingRef(pub Encoding);
810#[ derive( rustler:: NifStruct ) ]
911#[ module = "Tokenizers.Encoding" ]
1012pub struct ExTokenizersEncoding {
11- pub resource : rustler :: resource :: ResourceArc < ExTokenizersEncodingRef > ,
13+ pub resource : ResourceArc < ExTokenizersEncodingRef > ,
1214}
1315
1416impl ExTokenizersEncodingRef {
@@ -20,7 +22,7 @@ impl ExTokenizersEncodingRef {
2022impl ExTokenizersEncoding {
2123 pub fn new ( data : Encoding ) -> Self {
2224 Self {
23- resource : rustler :: resource :: ResourceArc :: new ( ExTokenizersEncodingRef :: new ( data) ) ,
25+ resource : ResourceArc :: new ( ExTokenizersEncodingRef :: new ( data) ) ,
2426 }
2527 }
2628}
@@ -35,23 +37,60 @@ pub fn get_ids(encoding: ExTokenizersEncoding) -> Result<Vec<u32>, ExTokenizersE
3537 Ok ( encoding. resource . 0 . get_ids ( ) . to_vec ( ) )
3638}
3739
40+ #[ rustler:: nif]
41+ pub fn get_u32_ids ( env : Env , encoding : ExTokenizersEncoding ) -> Result < Binary , ExTokenizersError > {
42+ Ok ( encoding
43+ . resource
44+ . make_binary ( env, |r| slice_u32_to_u8 ( r. 0 . get_ids ( ) ) ) )
45+ }
46+
3847#[ rustler:: nif]
3948pub fn get_attention_mask ( encoding : ExTokenizersEncoding ) -> Result < Vec < u32 > , ExTokenizersError > {
4049 Ok ( encoding. resource . 0 . get_attention_mask ( ) . to_vec ( ) )
4150}
4251
52+ #[ rustler:: nif]
53+ pub fn get_u32_attention_mask (
54+ env : Env ,
55+ encoding : ExTokenizersEncoding ,
56+ ) -> Result < Binary , ExTokenizersError > {
57+ Ok ( encoding
58+ . resource
59+ . make_binary ( env, |r| slice_u32_to_u8 ( r. 0 . get_attention_mask ( ) ) ) )
60+ }
61+
4362#[ rustler:: nif]
4463pub fn get_type_ids ( encoding : ExTokenizersEncoding ) -> Result < Vec < u32 > , ExTokenizersError > {
4564 Ok ( encoding. resource . 0 . get_type_ids ( ) . to_vec ( ) )
4665}
4766
67+ #[ rustler:: nif]
68+ pub fn get_u32_type_ids (
69+ env : Env ,
70+ encoding : ExTokenizersEncoding ,
71+ ) -> Result < Binary , ExTokenizersError > {
72+ Ok ( encoding
73+ . resource
74+ . make_binary ( env, |r| slice_u32_to_u8 ( r. 0 . get_type_ids ( ) ) ) )
75+ }
76+
4877#[ rustler:: nif]
4978pub fn get_special_tokens_mask (
5079 encoding : ExTokenizersEncoding ,
5180) -> Result < Vec < u32 > , ExTokenizersError > {
5281 Ok ( encoding. resource . 0 . get_special_tokens_mask ( ) . to_vec ( ) )
5382}
5483
84+ #[ rustler:: nif]
85+ pub fn get_u32_special_tokens_mask (
86+ env : Env ,
87+ encoding : ExTokenizersEncoding ,
88+ ) -> Result < Binary , ExTokenizersError > {
89+ Ok ( encoding
90+ . resource
91+ . make_binary ( env, |r| slice_u32_to_u8 ( r. 0 . get_special_tokens_mask ( ) ) ) )
92+ }
93+
5594#[ rustler:: nif]
5695pub fn get_offsets (
5796 encoding : ExTokenizersEncoding ,
@@ -99,3 +138,7 @@ pub fn pad(
99138 new_encoding. pad ( target_length, pad_id, pad_type_id, pad_token, direction) ;
100139 Ok ( ExTokenizersEncoding :: new ( new_encoding) )
101140}
141+
142+ fn slice_u32_to_u8 ( slice : & [ u32 ] ) -> & [ u8 ] {
143+ unsafe { std:: slice:: from_raw_parts ( slice. as_ptr ( ) as * const u8 , slice. len ( ) * 4 ) }
144+ }
0 commit comments