Skip to content

Commit 568f19f

Browse files
Jefffreyalamb
andauthored
Simplify Spark sha2 implementation (#19475)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Part of #12725 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Simplify implementation, also remove usage of user_defined. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> Refactor to be simpler. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Added tests. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> No. <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 4b31aaa commit 568f19f

6 files changed

Lines changed: 147 additions & 162 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ recursive = "0.1.1"
181181
regex = "1.12"
182182
rstest = "0.26.1"
183183
serde_json = "1"
184+
sha2 = "^0.10.9"
184185
sqlparser = { version = "0.60.0", default-features = false, features = ["std", "visitor"] }
185186
strum = "0.27.2"
186187
strum_macros = "0.27.2"

datafusion/functions/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ md-5 = { version = "^0.10.0", optional = true }
8585
num-traits = { workspace = true }
8686
rand = { workspace = true }
8787
regex = { workspace = true, optional = true }
88-
sha2 = { version = "^0.10.9", optional = true }
88+
sha2 = { workspace = true, optional = true }
8989
unicode-segmentation = { version = "^1.7.1", optional = true }
9090
uuid = { workspace = true, features = ["v4"], optional = true }
9191

datafusion/spark/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ log = { workspace = true }
6060
percent-encoding = "2.3.2"
6161
rand = { workspace = true }
6262
sha1 = "0.10"
63+
sha2 = { workspace = true }
6364
url = { workspace = true }
6465

6566
[dev-dependencies]

datafusion/spark/src/function/hash/sha2.rs

Lines changed: 88 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,30 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
extern crate datafusion_functions;
19-
20-
use crate::function::error_utils::{
21-
invalid_arg_count_exec_err, unsupported_data_type_exec_err,
22-
};
23-
use crate::function::math::hex::spark_sha2_hex;
24-
use arrow::array::{ArrayRef, AsArray, StringArray};
18+
use arrow::array::{ArrayRef, AsArray, BinaryArrayType, Int32Array, StringArray};
2519
use arrow::datatypes::{DataType, Int32Type};
26-
use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err};
27-
use datafusion_expr::Signature;
28-
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility};
29-
pub use datafusion_functions::crypto::basic::{sha224, sha256, sha384, sha512};
20+
use datafusion_common::types::{
21+
NativeType, logical_binary, logical_int32, logical_string,
22+
};
23+
use datafusion_common::utils::take_function_args;
24+
use datafusion_common::{Result, internal_err};
25+
use datafusion_expr::{
26+
Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature,
27+
TypeSignatureClass, Volatility,
28+
};
29+
use datafusion_functions::utils::make_scalar_function;
30+
use sha2::{self, Digest};
3031
use std::any::Any;
32+
use std::fmt::Write;
3133
use std::sync::Arc;
3234

35+
/// Differs from DataFusion version in allowing array input for bit lengths, and
36+
/// also hex encoding the output.
37+
///
3338
/// <https://spark.apache.org/docs/latest/api/sql/index.html#sha2>
3439
#[derive(Debug, PartialEq, Eq, Hash)]
3540
pub struct SparkSha2 {
3641
signature: Signature,
37-
aliases: Vec<String>,
3842
}
3943

4044
impl Default for SparkSha2 {
@@ -46,8 +50,21 @@ impl Default for SparkSha2 {
4650
impl SparkSha2 {
4751
pub fn new() -> Self {
4852
Self {
49-
signature: Signature::user_defined(Volatility::Immutable),
50-
aliases: vec![],
53+
signature: Signature::coercible(
54+
vec![
55+
Coercion::new_implicit(
56+
TypeSignatureClass::Native(logical_binary()),
57+
vec![TypeSignatureClass::Native(logical_string())],
58+
NativeType::Binary,
59+
),
60+
Coercion::new_implicit(
61+
TypeSignatureClass::Native(logical_int32()),
62+
vec![TypeSignatureClass::Integer],
63+
NativeType::Int32,
64+
),
65+
],
66+
Volatility::Immutable,
67+
),
5168
}
5269
}
5370
}
@@ -65,163 +82,73 @@ impl ScalarUDFImpl for SparkSha2 {
6582
&self.signature
6683
}
6784

68-
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
69-
if arg_types[1].is_null() {
70-
return Ok(DataType::Null);
71-
}
72-
Ok(match arg_types[0] {
73-
DataType::Utf8View
74-
| DataType::LargeUtf8
75-
| DataType::Utf8
76-
| DataType::Binary
77-
| DataType::BinaryView
78-
| DataType::LargeBinary => DataType::Utf8,
79-
DataType::Null => DataType::Null,
80-
_ => {
81-
return exec_err!(
82-
"{} function can only accept strings or binary arrays.",
83-
self.name()
84-
);
85-
}
86-
})
85+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
86+
Ok(DataType::Utf8)
8787
}
8888

8989
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
90-
let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| {
91-
internal_datafusion_err!("Expected 2 arguments for function sha2")
92-
})?;
93-
94-
sha2(args)
90+
make_scalar_function(sha2_impl, vec![])(&args.args)
9591
}
92+
}
9693

97-
fn aliases(&self) -> &[String] {
98-
&self.aliases
99-
}
94+
fn sha2_impl(args: &[ArrayRef]) -> Result<ArrayRef> {
95+
let [values, bit_lengths] = take_function_args("sha2", args)?;
10096

101-
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
102-
if arg_types.len() != 2 {
103-
return Err(invalid_arg_count_exec_err(
104-
self.name(),
105-
(2, 2),
106-
arg_types.len(),
107-
));
97+
let bit_lengths = bit_lengths.as_primitive::<Int32Type>();
98+
let output = match values.data_type() {
99+
DataType::Binary => sha2_binary_impl(&values.as_binary::<i32>(), bit_lengths),
100+
DataType::LargeBinary => {
101+
sha2_binary_impl(&values.as_binary::<i64>(), bit_lengths)
108102
}
109-
let expr_type = match &arg_types[0] {
110-
DataType::Utf8View
111-
| DataType::LargeUtf8
112-
| DataType::Utf8
113-
| DataType::Binary
114-
| DataType::BinaryView
115-
| DataType::LargeBinary
116-
| DataType::Null => Ok(arg_types[0].clone()),
117-
_ => Err(unsupported_data_type_exec_err(
118-
self.name(),
119-
"String, Binary",
120-
&arg_types[0],
121-
)),
122-
}?;
123-
let bit_length_type = if arg_types[1].is_numeric() {
124-
Ok(DataType::Int32)
125-
} else if arg_types[1].is_null() {
126-
Ok(DataType::Null)
127-
} else {
128-
Err(unsupported_data_type_exec_err(
129-
self.name(),
130-
"Numeric Type",
131-
&arg_types[1],
132-
))
133-
}?;
134-
135-
Ok(vec![expr_type, bit_length_type])
136-
}
103+
DataType::BinaryView => sha2_binary_impl(&values.as_binary_view(), bit_lengths),
104+
dt => return internal_err!("Unsupported datatype for sha2: {dt}"),
105+
};
106+
Ok(output)
137107
}
138108

139-
pub fn sha2(args: [ColumnarValue; 2]) -> Result<ColumnarValue> {
140-
match args {
141-
[
142-
ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)),
143-
ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg))),
144-
] => compute_sha2(
145-
bit_length_arg,
146-
&[ColumnarValue::from(ScalarValue::Utf8(expr_arg))],
147-
),
148-
[
149-
ColumnarValue::Array(expr_arg),
150-
ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg))),
151-
] => compute_sha2(bit_length_arg, &[ColumnarValue::from(expr_arg)]),
152-
[
153-
ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)),
154-
ColumnarValue::Array(bit_length_arg),
155-
] => {
156-
let arr: StringArray = bit_length_arg
157-
.as_primitive::<Int32Type>()
158-
.iter()
159-
.map(|bit_length| {
160-
match sha2([
161-
ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())),
162-
ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
163-
])
164-
.unwrap()
165-
{
166-
ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str,
167-
ColumnarValue::Array(arr) => arr
168-
.as_string::<i32>()
169-
.iter()
170-
.map(|str| str.unwrap().to_string())
171-
.next(), // first element
172-
_ => unreachable!(),
173-
}
174-
})
175-
.collect();
176-
Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef))
177-
}
178-
[
179-
ColumnarValue::Array(expr_arg),
180-
ColumnarValue::Array(bit_length_arg),
181-
] => {
182-
let expr_iter = expr_arg.as_string::<i32>().iter();
183-
let bit_length_iter = bit_length_arg.as_primitive::<Int32Type>().iter();
184-
let arr: StringArray = expr_iter
185-
.zip(bit_length_iter)
186-
.map(|(expr, bit_length)| {
187-
match sha2([
188-
ColumnarValue::Scalar(ScalarValue::Utf8(Some(
189-
expr.unwrap().to_string(),
190-
))),
191-
ColumnarValue::Scalar(ScalarValue::Int32(bit_length)),
192-
])
193-
.unwrap()
194-
{
195-
ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str,
196-
ColumnarValue::Array(arr) => arr
197-
.as_string::<i32>()
198-
.iter()
199-
.map(|str| str.unwrap().to_string())
200-
.next(), // first element
201-
_ => unreachable!(),
202-
}
203-
})
204-
.collect();
205-
Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef))
206-
}
207-
_ => exec_err!("Unsupported argument types for sha2 function"),
208-
}
109+
fn sha2_binary_impl<'a, BinaryArrType>(
110+
values: &BinaryArrType,
111+
bit_lengths: &Int32Array,
112+
) -> ArrayRef
113+
where
114+
BinaryArrType: BinaryArrayType<'a>,
115+
{
116+
let array = values
117+
.iter()
118+
.zip(bit_lengths.iter())
119+
.map(|(value, bit_length)| match (value, bit_length) {
120+
(Some(value), Some(224)) => {
121+
let mut digest = sha2::Sha224::default();
122+
digest.update(value);
123+
Some(hex_encode(digest.finalize()))
124+
}
125+
(Some(value), Some(0 | 256)) => {
126+
let mut digest = sha2::Sha256::default();
127+
digest.update(value);
128+
Some(hex_encode(digest.finalize()))
129+
}
130+
(Some(value), Some(384)) => {
131+
let mut digest = sha2::Sha384::default();
132+
digest.update(value);
133+
Some(hex_encode(digest.finalize()))
134+
}
135+
(Some(value), Some(512)) => {
136+
let mut digest = sha2::Sha512::default();
137+
digest.update(value);
138+
Some(hex_encode(digest.finalize()))
139+
}
140+
// Unknown bit-lengths go to null, same as in Spark
141+
_ => None,
142+
})
143+
.collect::<StringArray>();
144+
Arc::new(array)
209145
}
210146

211-
fn compute_sha2(
212-
bit_length_arg: i32,
213-
expr_arg: &[ColumnarValue],
214-
) -> Result<ColumnarValue> {
215-
match bit_length_arg {
216-
0 | 256 => sha256(expr_arg),
217-
224 => sha224(expr_arg),
218-
384 => sha384(expr_arg),
219-
512 => sha512(expr_arg),
220-
_ => {
221-
// Return null for unsupported bit lengths instead of error, because spark sha2 does not
222-
// error out for this.
223-
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
224-
}
147+
fn hex_encode<T: AsRef<[u8]>>(data: T) -> String {
148+
let mut s = String::with_capacity(data.as_ref().len() * 2);
149+
for b in data.as_ref() {
150+
// Writing to a string never errors, so we can unwrap here.
151+
write!(&mut s, "{b:02x}").unwrap();
225152
}
226-
.map(|hashed| spark_sha2_hex(&[hashed]).unwrap())
153+
s
227154
}

datafusion/sqllogictest/test_files/spark/hash/sha2.slt

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,58 @@ SELECT sha2(expr, bit_length) FROM VALUES ('foo',0::INT), ('bar',224::INT), ('ba
7575
967004d25de4abc1bd6a7c9a216254a5ac0733e8ad96dc9f1ea0fad9619da7c32d654ec8ad8ba2f9b5728fed6633bd91
7676
8c6be9ed448a34883a13a13f4ead4aefa036b67dcda59020c01e57ea075ea8a4792d428f2c6fd0c09d1c49994d6c22789336e062188df29572ed07e7f9779c52
7777
NULL
78+
79+
# All string types
80+
query T
81+
SELECT sha2(arrow_cast('foo', 'Utf8'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length);
82+
----
83+
0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db
84+
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
85+
86+
query T
87+
SELECT sha2(arrow_cast('foo', 'LargeUtf8'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length);
88+
----
89+
0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db
90+
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
91+
92+
query T
93+
SELECT sha2(arrow_cast('foo', 'Utf8View'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length);
94+
----
95+
0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db
96+
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
97+
98+
# All binary types
99+
query T
100+
SELECT sha2(arrow_cast('foo', 'Binary'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length);
101+
----
102+
0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db
103+
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
104+
105+
query T
106+
SELECT sha2(arrow_cast('foo', 'LargeBinary'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length);
107+
----
108+
0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db
109+
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
110+
111+
query T
112+
SELECT sha2(arrow_cast('foo', 'BinaryView'), bit_length) FROM VALUES (224::INT), (256::INT) AS t(bit_length);
113+
----
114+
0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db
115+
2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae
116+
117+
118+
# Null cases
119+
query T
120+
select sha2(null, 0);
121+
----
122+
NULL
123+
124+
query T
125+
select sha2('a', null);
126+
----
127+
NULL
128+
129+
query T
130+
select sha2('a', null::int);
131+
----
132+
NULL

0 commit comments

Comments
 (0)