Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 48 additions & 36 deletions pydeequ/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,81 +727,93 @@ def isPositive(self, column, assertion=None, hint=None):
self._Check = self._Check.isPositive(column, assertion_func, hint)
return self

def _column_comparison(self, columnA, columnB, operator, description, assertion, hint):
"""Builds a column comparison constraint via Deequ's ``satisfies``.

Deequ's native ``isLessThan``/``isGreaterThan`` family (since Deequ 2.0.x)
forwards ``columns = List(columnA, columnB)`` to ``satisfies``, which makes
Deequ require *both* operands to be existing columns. That breaks the long
supported column-vs-literal usage (e.g. ``isGreaterThanOrEqualTo("col", "1")``),
failing with ``Input data does not include column 1!`` (see issue #227).

We instead build the same Spark SQL predicate ourselves and call ``satisfies``
with an empty ``columns`` list (the pre-2.0 behaviour), so ``columnB`` may be
either a column name or a SQL literal/expression.

``columnA`` is always a real column, so it is backtick-quoted to stay valid
when the name contains spaces/special characters or is a SQL reserved word.
``columnB`` is left raw on purpose: it may be a column, a literal, or a SQL
expression, and quoting is the caller's responsibility (same as Deequ's
``satisfies``).
"""
# The comparator family now genuinely routes through ``satisfies``, so its
# default assertion (``satisfies$default$3`` == ``_ == 1.0``) is the correct
# one to use; it matches each comparator's own Deequ default (also ``_ == 1.0``).
assertion_func = (
ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
if assertion
Comment thread
nikolauspschuetz marked this conversation as resolved.
else getattr(self._Check, "satisfies$default$3")()
)
hint = self._jvm.scala.Option.apply(hint)
Comment thread
nikolauspschuetz marked this conversation as resolved.
column_condition = f"`{columnA}` {operator} {columnB}"
constraint_name = f"{columnA} is {description} {columnB}"
self._Check = self._Check.satisfies(
column_condition,
constraint_name,
assertion_func,
hint,
self._jvm.scala.collection.Seq.empty(),
self._jvm.scala.Option.apply(None),
)
return self

def isLessThan(self, columnA, columnB, assertion=None, hint=None):
"""
Asserts that, in each row, the value of columnA is less than the value of columnB

:param str columnA: Column in DataFrame to run the assertion on.
:param str columnB: Column in DataFrame to run the assertion on.
:param str columnB: Column in DataFrame to compare against, or a SQL literal/expression.
:param lambda assertion: A function that accepts an int or float parameter.
:param str hint: A hint that states why a constraint could have failed.
:return: isLessThan self : A Check object that checks the assertion on the columns.
"""
assertion_func = (
ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
if assertion
else getattr(self._Check, "isLessThan$default$3")()
)
hint = self._jvm.scala.Option.apply(hint)
self._Check = self._Check.isLessThan(columnA, columnB, assertion_func, hint)
return self
return self._column_comparison(columnA, columnB, "<", "less than", assertion, hint)

def isLessThanOrEqualTo(self, columnA, columnB, assertion=None, hint=None):
"""
Asserts that, in each row, the value of columnA is less than or equal to the value of columnB.

:param str columnA: Column in DataFrame to run the assertion on.
:param str columnB: Column in DataFrame to run the assertion on.
:param str columnB: Column in DataFrame to compare against, or a SQL literal/expression.
:param lambda assertion: A function that accepts an int or float parameter.
:param str hint: A hint that states why a constraint could have failed.
:return: isLessThanOrEqualTo self (isLessThanOrEqualTo): A Check object that checks the assertion on the columns.
"""
assertion_func = (
ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
if assertion
else getattr(self._Check, "isLessThanOrEqualTo$default$3")()
)
hint = self._jvm.scala.Option.apply(hint)
self._Check = self._Check.isLessThanOrEqualTo(columnA, columnB, assertion_func, hint)
return self
return self._column_comparison(columnA, columnB, "<=", "less than or equal to", assertion, hint)

def isGreaterThan(self, columnA, columnB, assertion=None, hint=None):
"""
Asserts that, in each row, the value of columnA is greater than the value of columnB

:param str columnA: Column in DataFrame to run the assertion on.
:param str columnB: Column in DataFrame to run the assertion on.
:param str columnB: Column in DataFrame to compare against, or a SQL literal/expression.
:param lambda assertion: A function that accepts an int or float parameter.
:param str hint: A hint that states why a constraint could have failed.
:return: isGreaterThan self: A Check object that runs the assertion on the columns.
"""
assertion_func = (
ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
if assertion
else getattr(self._Check, "isGreaterThan$default$3")()
)
hint = self._jvm.scala.Option.apply(hint)
self._Check = self._Check.isGreaterThan(columnA, columnB, assertion_func, hint)
return self
return self._column_comparison(columnA, columnB, ">", "greater than", assertion, hint)

def isGreaterThanOrEqualTo(self, columnA, columnB, assertion=None, hint=None):
"""
Asserts that, in each row, the value of columnA is greather than or equal to the value of columnB

:param str columnA: Column in DataFrame to run the assertion on.
:param str columnB: Column in DataFrame to run the assertion on.
:param str columnB: Column in DataFrame to compare against, or a SQL literal/expression.
:param lambda assertion: A function that accepts an int or float parameter.
:param str hint: A hint that states why a constraint could have failed.
:return: isGreaterThanOrEqualTo self: A Check object that runs the assertion on the columns.
"""
assertion_func = (
ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
if assertion
else getattr(self._Check, "isGreaterThanOrEqualTo$default$3")()
)
hint = self._jvm.scala.Option.apply(hint)
self._Check = self._Check.isGreaterThanOrEqualTo(columnA, columnB, assertion_func, hint)
return self
return self._column_comparison(columnA, columnB, ">=", "greater than or equal to", assertion, hint)

def isContainedIn(self, column, allowed_values, assertion=None, hint=None):
"""
Expand Down
43 changes: 43 additions & 0 deletions tests/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,49 @@ def test_fail_isGreaterThanOrEqualTo(self):
self.isGreaterThanOrEqualTo("h", "f", lambda x: x == 1), [Row(constraint_status="Failure")]
)

def test_comparator_against_literal(self):
# Regression test for issue #227: comparing a column to a SQL literal
# (rather than another column) must not fail with
# "Input data does not include column <literal>!".
# Column "b" holds values 1, 2, 3 -> all are >= 1 and all are <= 10.
self.assertEqual(
self.isGreaterThanOrEqualTo("b", "1", hint="Cluster should have at least one element"),
[Row(constraint_status="Success")],
)
self.assertEqual(
self.isLessThanOrEqualTo("b", "10", hint="b never exceeds 10"),
[Row(constraint_status="Success")],
)
self.assertEqual(
self.isGreaterThan("b", "0"),
[Row(constraint_status="Success")],
)
self.assertEqual(
self.isLessThan("b", "100"),
[Row(constraint_status="Success")],
)

def test_fail_comparator_against_literal(self):
# Column "b" holds values 1, 2, 3 -> not all are >= 3.
self.assertEqual(
self.isGreaterThanOrEqualTo("b", "3"), [Row(constraint_status="Failure")]
)

def test_comparator_column_name_with_space(self):
# Regression test for issue #227 review: columnA is always a real column
# and must be backtick-quoted so a name containing spaces/special chars
# (or a SQL reserved word) produces valid SQL.
df = self.df.withColumnRenamed("b", "my col")
check = Check(self.spark, CheckLevel.Warning, "test spaced column").isGreaterThanOrEqualTo(
"my col", "1", hint="values in 'my col' are at least 1"
)
result = VerificationSuite(self.spark).onData(df).addCheck(check).run()
result_df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
self.assertEqual(
result_df.select("constraint_status").collect(),
[Row(constraint_status="Success")],
)

def test_where(self):
self.assertEqual(
self.where(lambda x: x == 2.0, "boolean='true'", "column 'boolean' has two values true"),
Expand Down
Loading