1515// specific language governing permissions and limitations
1616// under the License.
1717
18- extern crate datafusion_functions;
19-
20- use crate :: function:: error_utils:: {
21- invalid_arg_count_exec_err, unsupported_data_type_exec_err,
22- } ;
23- use crate :: function:: math:: hex:: spark_sha2_hex;
24- use arrow:: array:: { ArrayRef , AsArray , StringArray } ;
18+ use arrow:: array:: { ArrayRef , AsArray , BinaryArrayType , Int32Array , StringArray } ;
2519use arrow:: datatypes:: { DataType , Int32Type } ;
26- use datafusion_common:: { Result , ScalarValue , exec_err, internal_datafusion_err} ;
27- use datafusion_expr:: Signature ;
28- use datafusion_expr:: { ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Volatility } ;
29- pub use datafusion_functions:: crypto:: basic:: { sha224, sha256, sha384, sha512} ;
20+ use datafusion_common:: types:: {
21+ NativeType , logical_binary, logical_int32, logical_string,
22+ } ;
23+ use datafusion_common:: utils:: take_function_args;
24+ use datafusion_common:: { Result , internal_err} ;
25+ use datafusion_expr:: {
26+ Coercion , ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature ,
27+ TypeSignatureClass , Volatility ,
28+ } ;
29+ use datafusion_functions:: utils:: make_scalar_function;
30+ use sha2:: { self , Digest } ;
3031use std:: any:: Any ;
32+ use std:: fmt:: Write ;
3133use std:: sync:: Arc ;
3234
35+ /// Differs from DataFusion version in allowing array input for bit lengths, and
36+ /// also hex encoding the output.
37+ ///
3338/// <https://spark.apache.org/docs/latest/api/sql/index.html#sha2>
3439#[ derive( Debug , PartialEq , Eq , Hash ) ]
3540pub struct SparkSha2 {
3641 signature : Signature ,
37- aliases : Vec < String > ,
3842}
3943
4044impl Default for SparkSha2 {
@@ -46,8 +50,21 @@ impl Default for SparkSha2 {
4650impl SparkSha2 {
4751 pub fn new ( ) -> Self {
4852 Self {
49- signature : Signature :: user_defined ( Volatility :: Immutable ) ,
50- aliases : vec ! [ ] ,
53+ signature : Signature :: coercible (
54+ vec ! [
55+ Coercion :: new_implicit(
56+ TypeSignatureClass :: Native ( logical_binary( ) ) ,
57+ vec![ TypeSignatureClass :: Native ( logical_string( ) ) ] ,
58+ NativeType :: Binary ,
59+ ) ,
60+ Coercion :: new_implicit(
61+ TypeSignatureClass :: Native ( logical_int32( ) ) ,
62+ vec![ TypeSignatureClass :: Integer ] ,
63+ NativeType :: Int32 ,
64+ ) ,
65+ ] ,
66+ Volatility :: Immutable ,
67+ ) ,
5168 }
5269 }
5370}
@@ -65,163 +82,73 @@ impl ScalarUDFImpl for SparkSha2 {
6582 & self . signature
6683 }
6784
68- fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
69- if arg_types[ 1 ] . is_null ( ) {
70- return Ok ( DataType :: Null ) ;
71- }
72- Ok ( match arg_types[ 0 ] {
73- DataType :: Utf8View
74- | DataType :: LargeUtf8
75- | DataType :: Utf8
76- | DataType :: Binary
77- | DataType :: BinaryView
78- | DataType :: LargeBinary => DataType :: Utf8 ,
79- DataType :: Null => DataType :: Null ,
80- _ => {
81- return exec_err ! (
82- "{} function can only accept strings or binary arrays." ,
83- self . name( )
84- ) ;
85- }
86- } )
85+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
86+ Ok ( DataType :: Utf8 )
8787 }
8888
8989 fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
90- let args: [ ColumnarValue ; 2 ] = args. args . try_into ( ) . map_err ( |_| {
91- internal_datafusion_err ! ( "Expected 2 arguments for function sha2" )
92- } ) ?;
93-
94- sha2 ( args)
90+ make_scalar_function ( sha2_impl, vec ! [ ] ) ( & args. args )
9591 }
92+ }
9693
97- fn aliases ( & self ) -> & [ String ] {
98- & self . aliases
99- }
94+ fn sha2_impl ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
95+ let [ values, bit_lengths] = take_function_args ( "sha2" , args) ?;
10096
101- fn coerce_types ( & self , arg_types : & [ DataType ] ) -> Result < Vec < DataType > > {
102- if arg_types. len ( ) != 2 {
103- return Err ( invalid_arg_count_exec_err (
104- self . name ( ) ,
105- ( 2 , 2 ) ,
106- arg_types. len ( ) ,
107- ) ) ;
97+ let bit_lengths = bit_lengths. as_primitive :: < Int32Type > ( ) ;
98+ let output = match values. data_type ( ) {
99+ DataType :: Binary => sha2_binary_impl ( & values. as_binary :: < i32 > ( ) , bit_lengths) ,
100+ DataType :: LargeBinary => {
101+ sha2_binary_impl ( & values. as_binary :: < i64 > ( ) , bit_lengths)
108102 }
109- let expr_type = match & arg_types[ 0 ] {
110- DataType :: Utf8View
111- | DataType :: LargeUtf8
112- | DataType :: Utf8
113- | DataType :: Binary
114- | DataType :: BinaryView
115- | DataType :: LargeBinary
116- | DataType :: Null => Ok ( arg_types[ 0 ] . clone ( ) ) ,
117- _ => Err ( unsupported_data_type_exec_err (
118- self . name ( ) ,
119- "String, Binary" ,
120- & arg_types[ 0 ] ,
121- ) ) ,
122- } ?;
123- let bit_length_type = if arg_types[ 1 ] . is_numeric ( ) {
124- Ok ( DataType :: Int32 )
125- } else if arg_types[ 1 ] . is_null ( ) {
126- Ok ( DataType :: Null )
127- } else {
128- Err ( unsupported_data_type_exec_err (
129- self . name ( ) ,
130- "Numeric Type" ,
131- & arg_types[ 1 ] ,
132- ) )
133- } ?;
134-
135- Ok ( vec ! [ expr_type, bit_length_type] )
136- }
103+ DataType :: BinaryView => sha2_binary_impl ( & values. as_binary_view ( ) , bit_lengths) ,
104+ dt => return internal_err ! ( "Unsupported datatype for sha2: {dt}" ) ,
105+ } ;
106+ Ok ( output)
137107}
138108
139- pub fn sha2 ( args : [ ColumnarValue ; 2 ] ) -> Result < ColumnarValue > {
140- match args {
141- [
142- ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( expr_arg) ) ,
143- ColumnarValue :: Scalar ( ScalarValue :: Int32 ( Some ( bit_length_arg) ) ) ,
144- ] => compute_sha2 (
145- bit_length_arg,
146- & [ ColumnarValue :: from ( ScalarValue :: Utf8 ( expr_arg) ) ] ,
147- ) ,
148- [
149- ColumnarValue :: Array ( expr_arg) ,
150- ColumnarValue :: Scalar ( ScalarValue :: Int32 ( Some ( bit_length_arg) ) ) ,
151- ] => compute_sha2 ( bit_length_arg, & [ ColumnarValue :: from ( expr_arg) ] ) ,
152- [
153- ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( expr_arg) ) ,
154- ColumnarValue :: Array ( bit_length_arg) ,
155- ] => {
156- let arr: StringArray = bit_length_arg
157- . as_primitive :: < Int32Type > ( )
158- . iter ( )
159- . map ( |bit_length| {
160- match sha2 ( [
161- ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( expr_arg. clone ( ) ) ) ,
162- ColumnarValue :: Scalar ( ScalarValue :: Int32 ( bit_length) ) ,
163- ] )
164- . unwrap ( )
165- {
166- ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( str) ) => str,
167- ColumnarValue :: Array ( arr) => arr
168- . as_string :: < i32 > ( )
169- . iter ( )
170- . map ( |str| str. unwrap ( ) . to_string ( ) )
171- . next ( ) , // first element
172- _ => unreachable ! ( ) ,
173- }
174- } )
175- . collect ( ) ;
176- Ok ( ColumnarValue :: Array ( Arc :: new ( arr) as ArrayRef ) )
177- }
178- [
179- ColumnarValue :: Array ( expr_arg) ,
180- ColumnarValue :: Array ( bit_length_arg) ,
181- ] => {
182- let expr_iter = expr_arg. as_string :: < i32 > ( ) . iter ( ) ;
183- let bit_length_iter = bit_length_arg. as_primitive :: < Int32Type > ( ) . iter ( ) ;
184- let arr: StringArray = expr_iter
185- . zip ( bit_length_iter)
186- . map ( |( expr, bit_length) | {
187- match sha2 ( [
188- ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( Some (
189- expr. unwrap ( ) . to_string ( ) ,
190- ) ) ) ,
191- ColumnarValue :: Scalar ( ScalarValue :: Int32 ( bit_length) ) ,
192- ] )
193- . unwrap ( )
194- {
195- ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( str) ) => str,
196- ColumnarValue :: Array ( arr) => arr
197- . as_string :: < i32 > ( )
198- . iter ( )
199- . map ( |str| str. unwrap ( ) . to_string ( ) )
200- . next ( ) , // first element
201- _ => unreachable ! ( ) ,
202- }
203- } )
204- . collect ( ) ;
205- Ok ( ColumnarValue :: Array ( Arc :: new ( arr) as ArrayRef ) )
206- }
207- _ => exec_err ! ( "Unsupported argument types for sha2 function" ) ,
208- }
109+ fn sha2_binary_impl < ' a , BinaryArrType > (
110+ values : & BinaryArrType ,
111+ bit_lengths : & Int32Array ,
112+ ) -> ArrayRef
113+ where
114+ BinaryArrType : BinaryArrayType < ' a > ,
115+ {
116+ let array = values
117+ . iter ( )
118+ . zip ( bit_lengths. iter ( ) )
119+ . map ( |( value, bit_length) | match ( value, bit_length) {
120+ ( Some ( value) , Some ( 224 ) ) => {
121+ let mut digest = sha2:: Sha224 :: default ( ) ;
122+ digest. update ( value) ;
123+ Some ( hex_encode ( digest. finalize ( ) ) )
124+ }
125+ ( Some ( value) , Some ( 0 | 256 ) ) => {
126+ let mut digest = sha2:: Sha256 :: default ( ) ;
127+ digest. update ( value) ;
128+ Some ( hex_encode ( digest. finalize ( ) ) )
129+ }
130+ ( Some ( value) , Some ( 384 ) ) => {
131+ let mut digest = sha2:: Sha384 :: default ( ) ;
132+ digest. update ( value) ;
133+ Some ( hex_encode ( digest. finalize ( ) ) )
134+ }
135+ ( Some ( value) , Some ( 512 ) ) => {
136+ let mut digest = sha2:: Sha512 :: default ( ) ;
137+ digest. update ( value) ;
138+ Some ( hex_encode ( digest. finalize ( ) ) )
139+ }
140+ // Unknown bit-lengths go to null, same as in Spark
141+ _ => None ,
142+ } )
143+ . collect :: < StringArray > ( ) ;
144+ Arc :: new ( array)
209145}
210146
211- fn compute_sha2 (
212- bit_length_arg : i32 ,
213- expr_arg : & [ ColumnarValue ] ,
214- ) -> Result < ColumnarValue > {
215- match bit_length_arg {
216- 0 | 256 => sha256 ( expr_arg) ,
217- 224 => sha224 ( expr_arg) ,
218- 384 => sha384 ( expr_arg) ,
219- 512 => sha512 ( expr_arg) ,
220- _ => {
221- // Return null for unsupported bit lengths instead of error, because spark sha2 does not
222- // error out for this.
223- return Ok ( ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( None ) ) ) ;
224- }
147+ fn hex_encode < T : AsRef < [ u8 ] > > ( data : T ) -> String {
148+ let mut s = String :: with_capacity ( data. as_ref ( ) . len ( ) * 2 ) ;
149+ for b in data. as_ref ( ) {
150+ // Writing to a string never errors, so we can unwrap here.
151+ write ! ( & mut s, "{b:02x}" ) . unwrap ( ) ;
225152 }
226- . map ( |hashed| spark_sha2_hex ( & [ hashed ] ) . unwrap ( ) )
153+ s
227154}
0 commit comments