diff --git a/sqlstream/__init__.py b/sqlstream/__init__.py index 040cac5..5f93ed1 100644 --- a/sqlstream/__init__.py +++ b/sqlstream/__init__.py @@ -5,7 +5,12 @@ Parquet, and JSON with lazy evaluation and intelligent optimizations. """ -__version__ = "0.1.0" +try: + from importlib.metadata import version + + __version__ = version("sqlstream") +except Exception: + __version__ = "0.6.3" # Main API from sqlstream.core.query import query diff --git a/sqlstream/core/duckdb_executor.py b/sqlstream/core/duckdb_executor.py index 7bf9ba6..bc4db5c 100644 --- a/sqlstream/core/duckdb_executor.py +++ b/sqlstream/core/duckdb_executor.py @@ -213,7 +213,7 @@ def _register_sources_as_dataframes(self, sources: dict[str, str]): # DuckDB can query pandas DataFrames directly! self.conn.register(table_name, df) - except Exception: + except (ValueError, FileNotFoundError, OSError, KeyError): self._register_source(table_name, file_path) def _register_source(self, table_name: str, file_path: str, read_only: bool = True): @@ -287,7 +287,7 @@ def _ensure_httpfs(self): try: self.conn.execute("INSTALL httpfs") self.conn.execute("LOAD httpfs") - except Exception: + except duckdb.Error: # httpfs might already be loaded or not needed pass diff --git a/sqlstream/core/executor.py b/sqlstream/core/executor.py index 5af0228..7f5912e 100644 --- a/sqlstream/core/executor.py +++ b/sqlstream/core/executor.py @@ -8,6 +8,7 @@ from collections.abc import Callable, Iterator from typing import Any +from sqlstream.operators.distinct import Distinct from sqlstream.operators.filter import Filter from sqlstream.operators.groupby import GroupByOperator from sqlstream.operators.join import HashJoinOperator @@ -140,6 +141,10 @@ def _build_plan( raise ValueError("GROUP BY requires aggregate functions in SELECT") plan = GroupByOperator(plan, ast.group_by, ast.aggregates, ast.columns) + # Add HAVING filter after GroupBy (filters on aggregated results) + if ast.having: + plan = Filter(plan, ast.having.conditions) + # Add OrderBy if ORDER BY clause exists if ast.order_by: plan = OrderByOperator(plan, ast.order_by) @@ -149,6 +154,10 @@ def _build_plan( if not ast.group_by: plan = Project(plan, ast.columns) + # Add Distinct if SELECT DISTINCT + if ast.distinct: + plan = Distinct(plan) + # Add Limit if LIMIT clause exists if ast.limit is not None: plan = Limit(plan, ast.limit) diff --git a/sqlstream/core/pandas_executor.py b/sqlstream/core/pandas_executor.py index 017a64c..3382c5a 100644 --- a/sqlstream/core/pandas_executor.py +++ b/sqlstream/core/pandas_executor.py @@ -72,6 +72,10 @@ def execute( if ast.group_by or ast.aggregates: df = self._apply_groupby(df, ast) + # Step 4b: Apply HAVING filter (after GROUP BY) + if ast.having: + df = self._apply_filter(df, ast.having.conditions) + # Step 5: Apply ORDER BY if ast.order_by: df = self._apply_orderby(df, ast) @@ -80,7 +84,11 @@ def execute( if not ast.group_by: # GroupBy already handled columns df = self._apply_projection(df, ast.columns) - # Step 7: Apply LIMIT + # Step 7: Apply DISTINCT + if ast.distinct: + df = df.drop_duplicates() + + # Step 8: Apply LIMIT if ast.limit is not None: df = df.head(ast.limit) @@ -247,6 +255,7 @@ def _apply_join( "INNER": "inner", "LEFT": "left", "RIGHT": "right", + "FULL": "outer", } how = join_type_map.get(ast.join.join_type, "inner") @@ -271,6 +280,20 @@ def _apply_filter(self, df: pd.DataFrame, conditions: list[Condition]) -> pd.Dat op = condition.operator value = condition.value + # Handle IS NULL / IS NOT NULL (column may or may not exist) + if op == "IS NULL": + if col not in df.columns: + # Missing column is treated as all NULL + continue # mask stays True (all rows match) + mask &= df[col].isna() + continue + elif op == "IS NOT NULL": + if col not in df.columns: + mask &= False + continue + mask &= df[col].notna() + continue + if col not in df.columns: # Column doesn't exist, skip this filter continue diff --git a/sqlstream/core/query.py b/sqlstream/core/query.py index fac9f54..b85e10d 100644 --- a/sqlstream/core/query.py +++ b/sqlstream/core/query.py @@ -56,8 +56,6 @@ def _can_parse_with_custom_parser(sql: str) -> bool: - Complex expressions (CASE, CAST, EXTRACT) - Subqueries - Set operations (UNION, INTERSECT, EXCEPT) - - HAVING clause - Args: sql: SQL query string to analyze @@ -72,7 +70,6 @@ def _can_parse_with_custom_parser(sql: str) -> bool: "OVER", # Window functions "PARTITION BY", # Window functions "WINDOW", # Window functions - "HAVING", # HAVING clause "UNION", # Set operations "INTERSECT", # Set operations "EXCEPT", # Set operations diff --git a/sqlstream/operators/distinct.py b/sqlstream/operators/distinct.py new file mode 100644 index 0000000..6858e5d --- /dev/null +++ b/sqlstream/operators/distinct.py @@ -0,0 +1,36 @@ +""" +Distinct operator - implements SELECT DISTINCT + +Deduplicates rows by keeping a set of seen row signatures. +""" + +from collections.abc import Iterator +from typing import Any + +from sqlstream.operators.base import Operator + + +class Distinct(Operator): + """ + DISTINCT operator - removes duplicate rows. + + Keeps a set of seen row tuples and only yields rows + that haven't been seen before. + """ + + def __init__(self, child: Operator): + super().__init__(child) + + def __iter__(self) -> Iterator[dict[str, Any]]: + seen: set[tuple] = set() + for row in self.child: + # Convert row to a hashable key + key = tuple( + sorted((k, str(v) if isinstance(v, (list, dict)) else v) for k, v in row.items()) + ) + if key not in seen: + seen.add(key) + yield row + + def __repr__(self) -> str: + return "Distinct()" diff --git a/sqlstream/operators/filter.py b/sqlstream/operators/filter.py index 522c915..ed8e96a 100644 --- a/sqlstream/operators/filter.py +++ b/sqlstream/operators/filter.py @@ -9,6 +9,8 @@ from sqlstream.operators.base import Operator from sqlstream.sql.ast_nodes import Condition +from sqlstream.utils.condition_eval import evaluate_condition +from sqlstream.utils.errors import ColumnNotFoundError class Filter(Operator): @@ -39,7 +41,12 @@ def __iter__(self) -> Iterator[dict[str, Any]]: 2. If all are True, yield the row 3. Otherwise, skip it """ + validated = False for row in self.child: + # Validate column names on first row for better error messages + if not validated: + self._validate_columns(row) + validated = True if self._matches(row): yield row @@ -58,53 +65,20 @@ def _matches(self, row: dict[str, Any]) -> bool: return False return True - def _evaluate_condition(self, row: dict[str, Any], condition: Condition) -> bool: - """ - Evaluate a single condition against a row - - Args: - row: Row to check - condition: Condition to evaluate + @staticmethod + def _evaluate_condition(row: dict[str, Any], condition: Condition) -> bool: + """Evaluate a single condition against a row (delegates to shared utility).""" + return evaluate_condition(row, condition) - Returns: - True if condition is satisfied - """ - # Get column value - if condition.column not in row: - return False - - value = row[condition.column] - - # Handle NULL values - if value is None: - return False - - # Get expected value - expected = condition.value - - # Evaluate operator - op = condition.operator - - try: - if op == "=": - return value == expected - elif op == ">": - return value > expected - elif op == "<": - return value < expected - elif op == ">=": - return value >= expected - elif op == "<=": - return value <= expected - elif op == "!=": - return value != expected - else: - # Unknown operator - default to True to avoid filtering - return True - - except TypeError: - # Type mismatch (e.g., comparing string to int) - return False + def _validate_columns(self, row: dict[str, Any]) -> None: + """Check that all condition columns exist in the row, raising helpful errors.""" + available = list(row.keys()) + for condition in self.conditions: + # Skip IS NULL/IS NOT NULL — missing column is valid (treated as NULL) + if condition.operator in ("IS NULL", "IS NOT NULL"): + continue + if condition.column not in row: + raise ColumnNotFoundError(condition.column, available) def __repr__(self) -> str: cond_str = " AND ".join(str(c) for c in self.conditions) diff --git a/sqlstream/operators/groupby.py b/sqlstream/operators/groupby.py index 2abda8d..5bee0de 100644 --- a/sqlstream/operators/groupby.py +++ b/sqlstream/operators/groupby.py @@ -73,6 +73,13 @@ def __iter__(self) -> Iterator[dict[str, Any]]: value = row.get(agg_func.column) if agg_func.column != "*" else None aggregators[i].update(value) + # Handle empty input with no GROUP BY columns (SQL standard): + # SELECT COUNT(*) FROM empty_table should return one row with COUNT=0 + if not groups and not self.group_by_columns: + aggregators = self._create_aggregators() + yield self._build_output_row((), aggregators) + return + # Yield one row per group for group_key, aggregators in groups.items(): row = self._build_output_row(group_key, aggregators) diff --git a/sqlstream/operators/join.py b/sqlstream/operators/join.py index c2db616..6cd13c9 100644 --- a/sqlstream/operators/join.py +++ b/sqlstream/operators/join.py @@ -60,7 +60,7 @@ def __init__( self.right_key = right_key # Validate join type - if self.join_type not in ("INNER", "LEFT", "RIGHT"): + if self.join_type not in ("INNER", "LEFT", "RIGHT", "FULL"): raise ValueError(f"Unsupported join type: {join_type}") def __iter__(self) -> Iterator[dict[str, Any]]: @@ -76,6 +76,8 @@ def __iter__(self) -> Iterator[dict[str, Any]]: yield from self._left_join() elif self.join_type == "RIGHT": yield from self._right_join() + elif self.join_type == "FULL": + yield from self._full_outer_join() def _inner_join(self) -> Iterator[dict[str, Any]]: """ @@ -152,6 +154,38 @@ def _right_join(self) -> Iterator[dict[str, Any]]: if (join_key, idx) not in matched_right_rows: yield self._merge_rows(None, right_row) + def _full_outer_join(self) -> Iterator[dict[str, Any]]: + """ + Execute FULL OUTER JOIN + + Returns all rows from both tables. Matched rows are joined, + unmatched rows from either side have NULL for the other side's columns. + """ + hash_table = self._build_hash_table() + matched_right_rows: set[tuple] = set() + left_columns: list[str] | None = None + + # Probe phase: Scan left table + for left_row in self.left: + if left_columns is None: + left_columns = list(left_row.keys()) + + join_key = left_row.get(self.left_key) + + if join_key is not None and join_key in hash_table: + for idx, right_row in enumerate(hash_table[join_key]): + yield self._merge_rows(left_row, right_row) + matched_right_rows.add((join_key, idx)) + else: + # No match - output left row with NULL for right columns + yield self._merge_rows(left_row, None) + + # Output unmatched right rows with NULL for left columns + for join_key, right_rows in hash_table.items(): + for idx, right_row in enumerate(right_rows): + if (join_key, idx) not in matched_right_rows: + yield self._merge_rows(None, right_row, left_columns=left_columns) + def _build_hash_table(self) -> dict[Any, list[dict[str, Any]]]: """ Build hash table from right table @@ -180,7 +214,10 @@ def _build_hash_table(self) -> dict[Any, list[dict[str, Any]]]: return hash_table def _merge_rows( - self, left_row: dict[str, Any] | None, right_row: dict[str, Any] | None + self, + left_row: dict[str, Any] | None, + right_row: dict[str, Any] | None, + left_columns: list[str] | None = None, ) -> dict[str, Any]: """ Merge left and right rows into a single output row @@ -188,8 +225,9 @@ def _merge_rows( Handles column name conflicts by prefixing with table names if needed. Args: - left_row: Row from left table (None for RIGHT JOIN with no match) - right_row: Row from right table (None for LEFT JOIN with no match) + left_row: Row from left table (None for RIGHT/FULL JOIN with no match) + right_row: Row from right table (None for LEFT/FULL JOIN with no match) + left_columns: Known left column names (for producing NULLs when left_row is None) Returns: Merged row dictionary @@ -199,11 +237,10 @@ def _merge_rows( # Add left columns if left_row is not None: result.update(left_row) - elif right_row is not None: - # For RIGHT JOIN with no match, add NULL for all left columns - # We don't know the left schema, so we just don't add anything - # The columns will be added on first matched row - pass + elif left_columns is not None: + # For RIGHT/FULL JOIN with no match, add NULL for known left columns + for col in left_columns: + result[col] = None # Add right columns if right_row is not None: diff --git a/sqlstream/optimizers/condition_reorder.py b/sqlstream/optimizers/condition_reorder.py new file mode 100644 index 0000000..ff28e44 --- /dev/null +++ b/sqlstream/optimizers/condition_reorder.py @@ -0,0 +1,49 @@ +""" +Condition Reordering Optimizer + +Reorders WHERE conditions to evaluate cheapest/most selective conditions first. +This is safe because AND conditions are commutative. +""" + +from sqlstream.optimizers.base import Optimizer +from sqlstream.readers.base import BaseReader +from sqlstream.sql.ast_nodes import Condition, SelectStatement + +# Operator selectivity heuristic: equality is typically more selective +_OPERATOR_PRIORITY = { + "=": 0, + "!=": 1, + "IS NULL": 1, + "IS NOT NULL": 1, + "<": 2, + ">": 2, + "<=": 2, + ">=": 2, +} + + +class ConditionReorderingOptimizer(Optimizer): + """ + Reorder WHERE conditions by estimated cost and selectivity. + + Heuristic: equality conditions first, then range conditions. + This enables short-circuit evaluation to skip expensive checks. + """ + + def get_name(self) -> str: + return "Condition reordering" + + def can_optimize(self, ast: SelectStatement, reader: BaseReader) -> bool: + return ast.where is not None and len(ast.where.conditions) > 1 + + def optimize(self, ast: SelectStatement, reader: BaseReader) -> None: + original_order = list(ast.where.conditions) + ast.where.conditions.sort(key=self._condition_cost) + if ast.where.conditions != original_order: + self.applied = True + self.description = f"reordered {len(ast.where.conditions)} conditions" + + @staticmethod + def _condition_cost(condition: Condition) -> int: + """Lower cost = evaluated first.""" + return _OPERATOR_PRIORITY.get(condition.operator, 3) diff --git a/sqlstream/optimizers/planner.py b/sqlstream/optimizers/planner.py index b5d6f5a..6f22fee 100644 --- a/sqlstream/optimizers/planner.py +++ b/sqlstream/optimizers/planner.py @@ -7,6 +7,7 @@ from sqlstream.optimizers.base import Optimizer, OptimizerPipeline from sqlstream.optimizers.column_pruning import ColumnPruningOptimizer +from sqlstream.optimizers.condition_reorder import ConditionReorderingOptimizer from sqlstream.optimizers.join_reordering import JoinReorderingOptimizer from sqlstream.optimizers.limit_pushdown import LimitPushdownOptimizer from sqlstream.optimizers.partition_pruning import PartitionPruningOptimizer @@ -54,6 +55,7 @@ def __init__(self): [ JoinReorderingOptimizer(), PartitionPruningOptimizer(), + ConditionReorderingOptimizer(), PredicatePushdownOptimizer(), ColumnPruningOptimizer(), LimitPushdownOptimizer(), diff --git a/sqlstream/optimizers/predicate_pushdown.py b/sqlstream/optimizers/predicate_pushdown.py index 85d1f3d..4d6a36b 100644 --- a/sqlstream/optimizers/predicate_pushdown.py +++ b/sqlstream/optimizers/predicate_pushdown.py @@ -39,7 +39,9 @@ def can_optimize(self, ast: SelectStatement, reader: BaseReader) -> bool: Conditions: 1. Query has WHERE clause 2. Reader supports pushdown - 3. Not a JOIN query (complex - needs smarter analysis) + + For JOIN queries, only conditions referencing left table columns + are pushed down (conservative approach). Args: ast: Parsed SQL statement @@ -56,11 +58,6 @@ def can_optimize(self, ast: SelectStatement, reader: BaseReader) -> bool: if not reader.supports_pushdown(): return False - # Skip JOINs for now - WHERE conditions may reference either table - # TODO: Make this smarter by analyzing which conditions apply to which table - if ast.join: - return False - return True def optimize(self, ast: SelectStatement, reader: BaseReader) -> None: @@ -71,8 +68,19 @@ def optimize(self, ast: SelectStatement, reader: BaseReader) -> None: ast: Parsed SQL statement reader: Data source reader """ - # Extract conditions that can be pushed down - pushable = self._extract_pushable_conditions(ast.where.conditions) + conditions = ast.where.conditions + + if ast.join: + # For JOINs, only push conditions that reference left table columns + schema = reader.get_schema() + if schema is None: + return + left_columns = set(schema.columns.keys()) + pushable = [ + c for c in self._extract_pushable_conditions(conditions) if c.column in left_columns + ] + else: + pushable = self._extract_pushable_conditions(conditions) if pushable: reader.set_filter(pushable) diff --git a/sqlstream/readers/csv_reader.py b/sqlstream/readers/csv_reader.py index 8b676b0..86b3c41 100644 --- a/sqlstream/readers/csv_reader.py +++ b/sqlstream/readers/csv_reader.py @@ -13,6 +13,7 @@ from sqlstream.core.types import Schema from sqlstream.readers.base import BaseReader from sqlstream.sql.ast_nodes import Condition +from sqlstream.utils.condition_eval import matches_filter class CSVReader(BaseReader): @@ -178,67 +179,8 @@ def _infer_value_type(self, value: str) -> Any: return infer_type_from_string(value) def _matches_filter(self, row: dict[str, Any]) -> bool: - """ - Check if row matches all filter conditions - - Args: - row: Row to check - - Returns: - True if row matches all conditions (AND logic) - """ - for condition in self.filter_conditions: - if not self._evaluate_condition(row, condition): - return False - return True - - def _evaluate_condition(self, row: dict[str, Any], condition: Condition) -> bool: - """ - Evaluate a single condition against a row - - Args: - row: Row to check - condition: Condition to evaluate - - Returns: - True if condition is satisfied - """ - # Get column value - if condition.column not in row: - return False - - value = row[condition.column] - - # Handle NULL values - if value is None: - return False - - # Evaluate operator - op = condition.operator - expected = condition.value - - try: - if op == "=": - return value == expected - elif op == ">": - return value > expected - elif op == "<": - return value < expected - elif op == ">=": - return value >= expected - elif op == "<=": - return value <= expected - elif op == "!=": - return value != expected - else: - # Unknown operator, skip this condition - warnings.warn(f"Unknown operator: {op}", UserWarning, stacklevel=2) - return True - - except TypeError: - # Type mismatch (e.g., comparing string to int) - # This is fine - row just doesn't match - return False + """Check if row matches all filter conditions (AND logic).""" + return matches_filter(row, self.filter_conditions) def get_schema(self, sample_size: int = 100) -> Schema | None: """ diff --git a/sqlstream/readers/json_reader.py b/sqlstream/readers/json_reader.py index 1839f46..919c10b 100644 --- a/sqlstream/readers/json_reader.py +++ b/sqlstream/readers/json_reader.py @@ -10,6 +10,7 @@ from sqlstream.core.types import Schema from sqlstream.readers.base import BaseReader from sqlstream.sql.ast_nodes import Condition +from sqlstream.utils.condition_eval import matches_filter class JSONReader(BaseReader): @@ -300,41 +301,8 @@ def _navigate_simple(self, data: Any, path: str) -> Any: return current def _matches_filter(self, row: dict[str, Any]) -> bool: - """Check if row matches filter conditions""" - for condition in self.filter_conditions: - if not self._evaluate_condition(row, condition): - return False - return True - - def _evaluate_condition(self, row: dict[str, Any], condition: Condition) -> bool: - """Evaluate single condition""" - if condition.column not in row: - return False - - value = row[condition.column] - if value is None: - return False - - op = condition.operator - expected = condition.value - - try: - if op == "=": - return value == expected - elif op == ">": - return value > expected - elif op == "<": - return value < expected - elif op == ">=": - return value >= expected - elif op == "<=": - return value <= expected - elif op == "!=": - return value != expected - else: - return True - except TypeError: - return False + """Check if row matches filter conditions.""" + return matches_filter(row, self.filter_conditions) def get_schema(self) -> Schema | None: """Infer schema from data""" @@ -353,5 +321,5 @@ def get_schema(self) -> Schema | None: return None return Schema.from_rows(rows) - except Exception: + except (ValueError, KeyError, TypeError, IndexError): return None diff --git a/sqlstream/readers/jsonl_reader.py b/sqlstream/readers/jsonl_reader.py index 360914b..0b2bcf7 100644 --- a/sqlstream/readers/jsonl_reader.py +++ b/sqlstream/readers/jsonl_reader.py @@ -11,6 +11,7 @@ from sqlstream.core.types import Schema from sqlstream.readers.base import BaseReader from sqlstream.sql.ast_nodes import Condition +from sqlstream.utils.condition_eval import matches_filter class JSONLReader(BaseReader): @@ -129,41 +130,8 @@ def read_lazy(self) -> Iterator[dict[str, Any]]: continue def _matches_filter(self, row: dict[str, Any]) -> bool: - """Check if row matches filter conditions""" - for condition in self.filter_conditions: - if not self._evaluate_condition(row, condition): - return False - return True - - def _evaluate_condition(self, row: dict[str, Any], condition: Condition) -> bool: - """Evaluate single condition""" - if condition.column not in row: - return False - - value = row[condition.column] - if value is None: - return False - - op = condition.operator - expected = condition.value - - try: - if op == "=": - return value == expected - elif op == ">": - return value > expected - elif op == "<": - return value < expected - elif op == ">=": - return value >= expected - elif op == "<=": - return value <= expected - elif op == "!=": - return value != expected - else: - return True - except TypeError: - return False + """Check if row matches filter conditions.""" + return matches_filter(row, self.filter_conditions) def get_schema(self, sample_size: int = 100) -> Schema | None: """Infer schema by sampling first N lines""" @@ -176,7 +144,7 @@ def get_schema(self, sample_size: int = 100) -> Schema | None: sample_rows.append(next(iterator)) except StopIteration: break - except Exception: + except (json.JSONDecodeError, ValueError, OSError): pass if not sample_rows: diff --git a/sqlstream/readers/parquet_reader.py b/sqlstream/readers/parquet_reader.py index c03967a..816cef4 100644 --- a/sqlstream/readers/parquet_reader.py +++ b/sqlstream/readers/parquet_reader.py @@ -16,6 +16,7 @@ from sqlstream.core.types import DataType, Schema from sqlstream.readers.base import BaseReader from sqlstream.sql.ast_nodes import Condition +from sqlstream.utils.condition_eval import matches_filter class ParquetReader(BaseReader): @@ -369,66 +370,8 @@ def _read_row_group(self, rg_idx: int) -> Iterator[dict[str, Any]]: yield row def _matches_filter(self, row: dict[str, Any]) -> bool: - """ - Check if row matches all filter conditions - - Args: - row: Row to check - - Returns: - True if row matches all conditions (AND logic) - """ - for condition in self.filter_conditions: - if not self._evaluate_condition(row, condition): - return False - return True - - def _evaluate_condition(self, row: dict[str, Any], condition: Condition) -> bool: - """ - Evaluate a single condition against a row - - Args: - row: Row to check - condition: Condition to evaluate - - Returns: - True if condition is satisfied - """ - # Get column value - if condition.column not in row: - return False - - value = row[condition.column] - - # Handle NULL values - if value is None: - return False - - # Evaluate operator - op = condition.operator - expected = condition.value - - try: - if op == "=": - return value == expected - elif op == ">": - return value > expected - elif op == "<": - return value < expected - elif op == ">=": - return value >= expected - elif op == "<=": - return value <= expected - elif op == "!=": - return value != expected - else: - # Unknown operator, conservatively keep row - return True - - except TypeError: - # Type mismatch (e.g., comparing string to int) - # This is fine - row just doesn't match - return False + """Check if row matches all filter conditions (AND logic).""" + return matches_filter(row, self.filter_conditions) def get_schema(self) -> Schema: """ diff --git a/sqlstream/sql/ast_nodes.py b/sqlstream/sql/ast_nodes.py index b4aaa99..634b35e 100644 --- a/sqlstream/sql/ast_nodes.py +++ b/sqlstream/sql/ast_nodes.py @@ -104,10 +104,12 @@ class SelectStatement: source: str # Table/file name (FROM clause) where: WhereClause | None = None group_by: list[str] | None = None + having: WhereClause | None = None order_by: list[OrderByColumn] | None = None limit: int | None = None aggregates: list[AggregateFunction] | None = None # Aggregate functions in SELECT join: JoinClause | None = None # JOIN clause + distinct: bool = False # SELECT DISTINCT def __repr__(self) -> str: parts = [f"SELECT {', '.join(self.columns)}"] diff --git a/sqlstream/sql/parser.py b/sqlstream/sql/parser.py index f75b664..187668c 100644 --- a/sqlstream/sql/parser.py +++ b/sqlstream/sql/parser.py @@ -110,6 +110,12 @@ def _parse_select(self) -> SelectStatement: """Parse SELECT statement""" self.consume("SELECT") + # Check for DISTINCT keyword + distinct = False + if self.current() and self.current().upper() == "DISTINCT": + self.consume("DISTINCT") + distinct = True + # Parse columns (may include aggregates) columns, aggregates = self._parse_columns() @@ -124,19 +130,21 @@ def _parse_select(self) -> SelectStatement: elif self.current() and self.current().upper() not in ( "WHERE", "GROUP", + "HAVING", "ORDER", "LIMIT", "JOIN", "INNER", "LEFT", "RIGHT", + "FULL", ): # If next token is not a keyword, it's likely an alias self.consume() # Skip alias name # Optional JOIN clause join = None - if self.current() and self.current().upper() in ("INNER", "LEFT", "RIGHT", "JOIN"): + if self.current() and self.current().upper() in ("INNER", "LEFT", "RIGHT", "FULL", "JOIN"): join = self._parse_join() # Optional WHERE clause @@ -149,6 +157,11 @@ def _parse_select(self) -> SelectStatement: if self.current() and self.current().upper() == "GROUP": group_by = self._parse_group_by() + # Optional HAVING clause (must come after GROUP BY) + having = None + if self.current() and self.current().upper() == "HAVING": + having = self._parse_having() + # Optional ORDER BY clause order_by = None if self.current() and self.current().upper() == "ORDER": @@ -164,10 +177,12 @@ def _parse_select(self) -> SelectStatement: source=source, where=where, group_by=group_by, + having=having, order_by=order_by, limit=limit, aggregates=aggregates, join=join, + distinct=distinct, ) def _parse_columns(self): @@ -278,8 +293,22 @@ def _parse_condition(self) -> Condition: age > 25 name = 'Alice' city != 'NYC' + name IS NULL + name IS NOT NULL """ column = self.consume() + + # Check for IS NULL / IS NOT NULL + if self.current() and self.current().upper() == "IS": + self.consume("IS") + if self.current() and self.current().upper() == "NOT": + self.consume("NOT") + self.consume("NULL") + return Condition(column=column, operator="IS NOT NULL", value=None) + else: + self.consume("NULL") + return Condition(column=column, operator="IS NULL", value=None) + operator = self.consume() # Parse value (could be number, string, or identifier) @@ -372,6 +401,23 @@ def _parse_group_by(self) -> list[str]: return columns + def _parse_having(self) -> WhereClause: + """ + Parse HAVING clause (same syntax as WHERE, applied after GROUP BY) + + Example: HAVING COUNT(*) > 5 AND SUM(amount) >= 100 + """ + self.consume("HAVING") + + conditions = [] + conditions.append(self._parse_condition()) + + while self.current() and self.current().upper() == "AND": + self.consume("AND") + conditions.append(self._parse_condition()) + + return WhereClause(conditions=conditions) + def _parse_order_by(self) -> list[OrderByColumn]: """ Parse ORDER BY clause @@ -431,8 +477,11 @@ def _parse_join(self) -> JoinClause: # Parse join type (INNER/LEFT/RIGHT or just JOIN) current = self.current().upper() - if current in ("INNER", "LEFT", "RIGHT"): + if current in ("INNER", "LEFT", "RIGHT", "FULL"): join_type = self.consume().upper() + # Skip optional OUTER keyword (e.g., FULL OUTER JOIN, LEFT OUTER JOIN) + if self.current() and self.current().upper() == "OUTER": + self.consume("OUTER") self.consume("JOIN") elif current == "JOIN": self.consume("JOIN") diff --git a/sqlstream/utils/condition_eval.py b/sqlstream/utils/condition_eval.py new file mode 100644 index 0000000..4c59167 --- /dev/null +++ b/sqlstream/utils/condition_eval.py @@ -0,0 +1,104 @@ +""" +Shared condition evaluation logic for filters and reader pushdown. + +Used by both the Filter operator and readers (CSV, JSON, JSONL, Parquet) +to evaluate WHERE conditions against rows. +""" + +import warnings +from typing import Any + +from sqlstream.sql.ast_nodes import Condition + +# Sentinel for SQL three-valued logic (TRUE/FALSE/UNKNOWN) +UNKNOWN = object() + + +def evaluate_condition(row: dict[str, Any], condition: Condition) -> bool: + """ + Evaluate a single condition against a row. + + Implements SQL three-valued logic internally: NULL comparisons yield UNKNOWN, + which is treated as False in WHERE context (standard SQL behavior). + + Args: + row: Row to check + condition: Condition to evaluate + + Returns: + True if condition is satisfied, False otherwise (UNKNOWN treated as False) + """ + result = _evaluate_condition_3vl(row, condition) + # In WHERE context, UNKNOWN is treated as False (SQL standard) + return result is True + + +def _evaluate_condition_3vl(row: dict[str, Any], condition: Condition) -> bool | object: + """ + Evaluate a condition with three-valued logic support. + + Returns: + True, False, or UNKNOWN + """ + op = condition.operator + + # Handle IS NULL / IS NOT NULL operators + if op == "IS NULL": + if condition.column not in row: + return True + return row[condition.column] is None + elif op == "IS NOT NULL": + if condition.column not in row: + return False + return row[condition.column] is not None + + # Get column value + if condition.column not in row: + return False + + value = row[condition.column] + + # NULL comparisons yield UNKNOWN (SQL three-valued logic) + if value is None: + return UNKNOWN + + # Evaluate operator + expected = condition.value + + try: + if op == "=": + return value == expected + elif op == ">": + return value > expected + elif op == "<": + return value < expected + elif op == ">=": + return value >= expected + elif op == "<=": + return value <= expected + elif op == "!=": + return value != expected + else: + warnings.warn(f"Unknown operator: {op}", UserWarning, stacklevel=4) + return True + + except TypeError: + # Type mismatch (e.g., comparing string to int) + return False + + +def matches_filter(row: dict[str, Any], conditions: list[Condition]) -> bool: + """ + Check if a row matches all filter conditions (AND logic). + + Args: + row: Row to check + conditions: List of conditions to evaluate + + Returns: + True if row matches all conditions + """ + for condition in conditions: + if not evaluate_condition(row, condition): + return False + return True diff --git a/sqlstream/utils/errors.py b/sqlstream/utils/errors.py new file mode 100644 index 0000000..9255cdf --- /dev/null +++ b/sqlstream/utils/errors.py @@ -0,0 +1,18 @@ +""" +Custom error types with helpful suggestions. +""" + +import difflib + + +class ColumnNotFoundError(KeyError): + """Raised when a referenced column does not exist in the data.""" + + def __init__(self, column: str, available: list[str]): + self.column = column + self.available = available + suggestions = difflib.get_close_matches(column, available, n=3, cutoff=0.5) + msg = f"Column '{column}' not found. Available columns: {', '.join(available)}" + if suggestions: + msg += f". Did you mean: {', '.join(repr(s) for s in suggestions)}?" + super().__init__(msg) diff --git a/tests/test_basic.py b/tests/test_basic.py index c4fb526..e2fe1b5 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -8,7 +8,7 @@ def test_version(): """Test that version is defined""" assert hasattr(sqlstream, "__version__") - assert sqlstream.__version__ == "0.1.0" + assert sqlstream.__version__ == "0.6.3" def test_import(): diff --git a/tests/test_new_features.py b/tests/test_new_features.py new file mode 100644 index 0000000..893b8fe --- /dev/null +++ b/tests/test_new_features.py @@ -0,0 +1,583 @@ +""" +Tests for new SQL features: DISTINCT, FULL OUTER JOIN, HAVING, +IS NULL/IS NOT NULL, empty GROUP BY defaults, condition reordering, +predicate pushdown for JOINs, and column-not-found fuzzy matching. +""" + +import pytest + +from sqlstream.core.query import query +from sqlstream.operators.distinct import Distinct +from sqlstream.operators.filter import Filter +from sqlstream.operators.groupby import GroupByOperator +from sqlstream.operators.join import HashJoinOperator +from sqlstream.operators.scan import Scan +from sqlstream.sql.ast_nodes import AggregateFunction, Condition +from sqlstream.sql.parser import parse +from sqlstream.utils.condition_eval import ( + UNKNOWN, + _evaluate_condition_3vl, + evaluate_condition, +) +from sqlstream.utils.errors import ColumnNotFoundError + +# ── Fixtures ────────────────────────────────────────────────────── + + +@pytest.fixture +def employees_csv(tmp_path): + csv_file = tmp_path / "employees.csv" + csv_file.write_text( + "name,age,city,department,salary\n" + "Alice,30,NYC,Engineering,90000\n" + "Bob,25,LA,Marketing,60000\n" + "Charlie,35,SF,Engineering,95000\n" + "Diana,28,NYC,Marketing,65000\n" + "Eve,32,LA,Engineering,85000\n" + "Frank,40,SF,Sales,70000\n" + ) + return csv_file + + +@pytest.fixture +def employees_with_nulls_csv(tmp_path): + csv_file = tmp_path / "employees_nulls.csv" + csv_file.write_text("name,age,city\nAlice,30,NYC\nBob,,LA\nCharlie,35,\nDiana,,\n") + return csv_file + + +@pytest.fixture +def empty_csv(tmp_path): + csv_file = tmp_path / "empty.csv" + csv_file.write_text("name,age,city\n") + return csv_file + + +@pytest.fixture +def departments_csv(tmp_path): + csv_file = tmp_path / "departments.csv" + csv_file.write_text("dept_name,budget\nEngineering,500000\nMarketing,200000\nHR,150000\n") + return csv_file + + +# ── SELECT DISTINCT ────────────────────────────────────────────── + + +class TestSelectDistinct: + """Test SELECT DISTINCT functionality""" + + def test_parser_recognizes_distinct(self): + ast = parse("SELECT DISTINCT city FROM data") + assert ast.distinct is True + assert ast.columns == ["city"] + + def test_parser_without_distinct(self): + ast = parse("SELECT city FROM data") + assert ast.distinct is False + + def test_distinct_operator_deduplicates(self): + data = [ + {"city": "NYC"}, + {"city": "LA"}, + {"city": "NYC"}, + {"city": "SF"}, + {"city": "LA"}, + ] + reader = type("R", (), {"read_lazy": lambda self: iter(data)})() + scan = Scan(reader) + distinct = Distinct(scan) + + rows = list(distinct) + cities = [r["city"] for r in rows] + + assert len(rows) == 3 + assert set(cities) == {"NYC", "LA", "SF"} + + def test_distinct_preserves_order(self): + data = [ + {"city": "NYC"}, + {"city": "LA"}, + {"city": "NYC"}, + {"city": "SF"}, + ] + reader = type("R", (), {"read_lazy": lambda self: iter(data)})() + scan = Scan(reader) + distinct = Distinct(scan) + + rows = list(distinct) + assert [r["city"] for r in rows] == ["NYC", "LA", "SF"] + + def test_distinct_end_to_end(self, employees_csv): + results = ( + query(str(employees_csv)).sql(f"SELECT DISTINCT city FROM '{employees_csv}'").to_list() + ) + + cities = [r["city"] for r in results] + assert len(cities) == 3 + assert set(cities) == {"NYC", "LA", "SF"} + + def test_distinct_with_multiple_columns(self, employees_csv): + results = ( + query(str(employees_csv)) + .sql(f"SELECT DISTINCT city, department FROM '{employees_csv}'") + .to_list() + ) + + # NYC-Engineering, NYC-Marketing, LA-Marketing, LA-Engineering, + # SF-Engineering, SF-Sales = 6 unique combinations + assert len(results) == 6 + + def test_distinct_all_same(self): + data = [{"x": 1}, {"x": 1}, {"x": 1}] + reader = type("R", (), {"read_lazy": lambda self: iter(data)})() + distinct = Distinct(Scan(reader)) + + assert len(list(distinct)) == 1 + + def test_distinct_empty(self): + reader = type("R", (), {"read_lazy": lambda self: iter([])})() + distinct = Distinct(Scan(reader)) + + assert list(distinct) == [] + + +# ── FULL OUTER JOIN ────────────────────────────────────────────── + + +class TestFullOuterJoin: + """Test FULL OUTER JOIN functionality""" + + @pytest.fixture + def left_csv(self, tmp_path): + f = tmp_path / "left.csv" + f.write_text("id,name\n1,Alice\n2,Bob\n3,Charlie\n") + return f + + @pytest.fixture + def right_csv(self, tmp_path): + f = tmp_path / "right.csv" + f.write_text("id,dept\n2,Engineering\n3,Marketing\n4,Sales\n") + return f + + def test_parser_full_outer_join(self): + ast = parse("SELECT * FROM left FULL OUTER JOIN right ON id = id") + assert ast.join is not None + assert ast.join.join_type == "FULL" + + def test_parser_full_join_without_outer(self): + ast = parse("SELECT * FROM left FULL JOIN right ON id = id") + assert ast.join.join_type == "FULL" + + def test_full_outer_join_operator(self): + left_data = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}] + right_data = [{"id": 2, "dept": "Eng"}, {"id": 3, "dept": "Sales"}] + + left_reader = type("R", (), {"read_lazy": lambda self: iter(left_data)})() + right_reader = type("R", (), {"read_lazy": lambda self: iter(right_data)})() + + join = HashJoinOperator( + Scan(left_reader), + Scan(right_reader), + join_type="FULL", + left_key="id", + right_key="id", + ) + rows = list(join) + + # id=1 (left only), id=2 (match), id=3 (right only) = 3 rows + assert len(rows) == 3 + + # Left-only: Alice with no dept + alice = next(r for r in rows if r.get("name") == "Alice") + assert alice.get("dept") is None + + # Matched: Bob with Eng + bob = next(r for r in rows if r.get("name") == "Bob") + assert bob["dept"] == "Eng" + + # Right-only: id=3 with NULL name + right_only = next(r for r in rows if r.get("dept") == "Sales") + assert right_only["name"] is None + + def test_full_outer_join_end_to_end(self, left_csv, right_csv): + results = ( + query(str(left_csv)) + .sql(f"SELECT * FROM '{left_csv}' FULL OUTER JOIN '{right_csv}' ON id = id") + .to_list() + ) + + # id=1 left only, id=2 match, id=3 match, id=4 right only = 4 rows + assert len(results) == 4 + + def test_full_outer_join_both_empty(self): + left_reader = type("R", (), {"read_lazy": lambda self: iter([])})() + right_reader = type("R", (), {"read_lazy": lambda self: iter([])})() + + join = HashJoinOperator( + Scan(left_reader), + Scan(right_reader), + join_type="FULL", + left_key="id", + right_key="id", + ) + assert list(join) == [] + + +# ── HAVING clause ──────────────────────────────────────────────── + + +class TestHavingClause: + """Test HAVING clause functionality""" + + def test_parser_having(self): + ast = parse("SELECT city, COUNT(*) FROM data GROUP BY city HAVING count_* > 2") + assert ast.having is not None + assert len(ast.having.conditions) == 1 + assert ast.having.conditions[0].column == "count_*" + assert ast.having.conditions[0].operator == ">" + assert ast.having.conditions[0].value == 2 + + def test_parser_having_with_alias(self): + ast = parse("SELECT city, COUNT(*) AS cnt FROM data GROUP BY city HAVING cnt > 5") + assert ast.having is not None + assert ast.having.conditions[0].column == "cnt" + + def test_having_end_to_end(self, employees_csv): + results = ( + query(str(employees_csv)) + .sql( + f"SELECT department, COUNT(*) AS cnt " + f"FROM '{employees_csv}' " + f"GROUP BY department HAVING cnt > 1" + ) + .to_list() + ) + + # Engineering has 3, Marketing has 2, Sales has 1 + # Only Engineering and Marketing should pass HAVING cnt > 1 + assert len(results) == 2 + depts = {r["department"] for r in results} + assert depts == {"Engineering", "Marketing"} + + def test_having_filters_all(self, employees_csv): + results = ( + query(str(employees_csv)) + .sql( + f"SELECT department, COUNT(*) AS cnt " + f"FROM '{employees_csv}' " + f"GROUP BY department HAVING cnt > 100" + ) + .to_list() + ) + + assert len(results) == 0 + + def test_having_with_sum(self, employees_csv): + results = ( + query(str(employees_csv)) + .sql( + f"SELECT department, SUM(salary) AS total " + f"FROM '{employees_csv}' " + f"GROUP BY department HAVING total > 100000" + ) + .to_list() + ) + + # Engineering total = 270000, Marketing = 125000, Sales = 70000 + # Engineering and Marketing pass + assert len(results) == 2 + + +# ── IS NULL / IS NOT NULL ──────────────────────────────────────── + + +class TestIsNull: + """Test IS NULL and IS NOT NULL""" + + def test_parser_is_null(self): + ast = parse("SELECT * FROM data WHERE name IS NULL") + cond = ast.where.conditions[0] + assert cond.column == "name" + assert cond.operator == "IS NULL" + assert cond.value is None + + def test_parser_is_not_null(self): + ast = parse("SELECT * FROM data WHERE name IS NOT NULL") + cond = ast.where.conditions[0] + assert cond.operator == "IS NOT NULL" + + def test_evaluate_is_null_with_none(self): + row = {"name": None, "age": 30} + cond = Condition(column="name", operator="IS NULL", value=None) + assert evaluate_condition(row, cond) is True + + def test_evaluate_is_null_with_value(self): + row = {"name": "Alice", "age": 30} + cond = Condition(column="name", operator="IS NULL", value=None) + assert evaluate_condition(row, cond) is False + + def test_evaluate_is_null_missing_column(self): + row = {"age": 30} + cond = Condition(column="name", operator="IS NULL", value=None) + assert evaluate_condition(row, cond) is True + + def test_evaluate_is_not_null_with_value(self): + row = {"name": "Alice"} + cond = Condition(column="name", operator="IS NOT NULL", value=None) + assert evaluate_condition(row, cond) is True + + def test_evaluate_is_not_null_with_none(self): + row = {"name": None} + cond = Condition(column="name", operator="IS NOT NULL", value=None) + assert evaluate_condition(row, cond) is False + + def test_is_null_end_to_end(self, employees_with_nulls_csv): + results = ( + query(str(employees_with_nulls_csv)) + .sql(f"SELECT * FROM '{employees_with_nulls_csv}' WHERE age IS NULL") + .to_list() + ) + + # Bob and Diana have empty age (NULL) + names = {r["name"] for r in results} + assert names == {"Bob", "Diana"} + + def test_is_not_null_end_to_end(self, employees_with_nulls_csv): + results = ( + query(str(employees_with_nulls_csv)) + .sql(f"SELECT * FROM '{employees_with_nulls_csv}' WHERE age IS NOT NULL") + .to_list() + ) + + names = {r["name"] for r in results} + assert names == {"Alice", "Charlie"} + + +# ── NULL three-valued logic ────────────────────────────────────── + + +class TestNullThreeValuedLogic: + """Test SQL three-valued logic for NULL comparisons""" + + def test_null_equals_returns_unknown(self): + row = {"age": None} + cond = Condition(column="age", operator="=", value=25) + result = _evaluate_condition_3vl(row, cond) + assert result is UNKNOWN + + def test_null_greater_than_returns_unknown(self): + row = {"age": None} + cond = Condition(column="age", operator=">", value=25) + result = _evaluate_condition_3vl(row, cond) + assert result is UNKNOWN + + def test_unknown_treated_as_false_in_where(self): + row = {"age": None} + cond = Condition(column="age", operator="=", value=25) + # evaluate_condition treats UNKNOWN as False (WHERE semantics) + assert evaluate_condition(row, cond) is False + + def test_non_null_comparison_works(self): + row = {"age": 30} + cond = Condition(column="age", operator=">", value=25) + result = _evaluate_condition_3vl(row, cond) + assert result is True + + +# ── Empty GROUP BY defaults ────────────────────────────────────── + + +class TestEmptyGroupByDefaults: + """Test that GROUP BY on empty data returns aggregate defaults""" + + def test_count_star_empty_table(self, empty_csv): + results = query(str(empty_csv)).sql(f"SELECT COUNT(*) AS cnt FROM '{empty_csv}'").to_list() + + assert len(results) == 1 + assert results[0]["cnt"] == 0 + + def test_count_star_empty_operator(self): + reader = type("R", (), {"read_lazy": lambda self: iter([])})() + scan = Scan(reader) + agg = [AggregateFunction("COUNT", "*", "cnt")] + groupby = GroupByOperator(scan, [], agg, ["cnt"]) + + rows = list(groupby) + assert len(rows) == 1 + assert rows[0]["cnt"] == 0 + + def test_sum_empty_table_returns_none(self): + reader = type("R", (), {"read_lazy": lambda self: iter([])})() + scan = Scan(reader) + agg = [AggregateFunction("SUM", "amount", "total")] + groupby = GroupByOperator(scan, [], agg, ["total"]) + + rows = list(groupby) + assert len(rows) == 1 + assert rows[0]["total"] is None + + def test_avg_empty_table_returns_none(self): + reader = type("R", (), {"read_lazy": lambda self: iter([])})() + scan = Scan(reader) + agg = [AggregateFunction("AVG", "amount", "avg_amount")] + groupby = GroupByOperator(scan, [], agg, ["avg_amount"]) + + rows = list(groupby) + assert len(rows) == 1 + assert rows[0]["avg_amount"] is None + + def test_group_by_columns_empty_returns_no_rows(self): + """With GROUP BY columns and no data, return zero rows (SQL standard)""" + reader = type("R", (), {"read_lazy": lambda self: iter([])})() + scan = Scan(reader) + agg = [AggregateFunction("COUNT", "*", "cnt")] + groupby = GroupByOperator(scan, ["city"], agg, ["city", "cnt"]) + + rows = list(groupby) + assert len(rows) == 0 + + +# ── Column-not-found fuzzy matching ────────────────────────────── + + +class TestColumnNotFoundError: + """Test helpful error messages for typos in column names""" + + def test_error_includes_available_columns(self): + with pytest.raises(ColumnNotFoundError, match="Available columns"): + raise ColumnNotFoundError("nme", ["name", "age", "city"]) + + def test_error_suggests_close_match(self): + with pytest.raises(ColumnNotFoundError, match="Did you mean"): + raise ColumnNotFoundError("nme", ["name", "age", "city"]) + + def test_error_suggests_name_for_nme(self): + with pytest.raises(ColumnNotFoundError, match="'name'"): + raise ColumnNotFoundError("nme", ["name", "age", "city"]) + + def test_filter_raises_on_bad_column(self): + data = [{"name": "Alice", "age": 30, "city": "NYC"}] + reader = type("R", (), {"read_lazy": lambda self: iter(data)})() + scan = Scan(reader) + cond = Condition(column="nme", operator="=", value="Alice") + filt = Filter(scan, [cond]) + + with pytest.raises(ColumnNotFoundError, match="'name'"): + list(filt) + + def test_filter_does_not_raise_for_valid_column(self): + data = [{"name": "Alice", "age": 30}] + reader = type("R", (), {"read_lazy": lambda self: iter(data)})() + scan = Scan(reader) + cond = Condition(column="name", operator="=", value="Alice") + filt = Filter(scan, [cond]) + + rows = list(filt) + assert len(rows) == 1 + + def test_filter_does_not_raise_for_is_null_missing_column(self): + """IS NULL on a missing column should not raise — missing means NULL""" + data = [{"name": "Alice"}] + reader = type("R", (), {"read_lazy": lambda self: iter(data)})() + scan = Scan(reader) + cond = Condition(column="age", operator="IS NULL", value=None) + filt = Filter(scan, [cond]) + + rows = list(filt) + assert len(rows) == 1 + + def test_fuzzy_match_on_join_where(self, employees_csv, departments_csv): + """When predicate pushdown can't apply (JOIN query), Filter validates columns.""" + with pytest.raises(ColumnNotFoundError, match="nme"): + query(str(employees_csv)).sql( + f"SELECT * FROM '{employees_csv}' " + f"INNER JOIN '{departments_csv}' ON department = dept_name " + f"WHERE nme = 'Alice'", + backend="python", + ).to_list() + + +# ── Condition reordering optimizer ─────────────────────────────── + + +class TestConditionReordering: + """Test that conditions are reordered by selectivity heuristic""" + + def test_equality_before_range(self): + from sqlstream.optimizers.condition_reorder import ConditionReorderingOptimizer + + ast = parse("SELECT * FROM data WHERE age > 25 AND name = 'Alice'") + reader = type( + "R", + (), + { + "read_lazy": lambda self: iter([]), + "supports_pushdown": lambda self: False, + "supports_column_selection": lambda self: False, + "supports_limit": lambda self: False, + "get_schema": lambda self, *a, **kw: None, + }, + )() + + opt = ConditionReorderingOptimizer() + assert opt.can_optimize(ast, reader) + opt.optimize(ast, reader) + + # Equality (=) should come before range (>) + assert ast.where.conditions[0].operator == "=" + assert ast.where.conditions[1].operator == ">" + + def test_no_reorder_single_condition(self): + from sqlstream.optimizers.condition_reorder import ConditionReorderingOptimizer + + ast = parse("SELECT * FROM data WHERE age > 25") + reader = type( + "R", + (), + { + "read_lazy": lambda self: iter([]), + "supports_pushdown": lambda self: False, + "supports_column_selection": lambda self: False, + "supports_limit": lambda self: False, + "get_schema": lambda self, *a, **kw: None, + }, + )() + + opt = ConditionReorderingOptimizer() + assert not opt.can_optimize(ast, reader) + + +# ── Parser edge cases for new syntax ──────────────────────────── + + +class TestParserNewSyntax: + """Test parser handles new syntax correctly""" + + def test_is_null_with_and(self): + ast = parse("SELECT * FROM data WHERE name IS NULL AND age > 25") + assert len(ast.where.conditions) == 2 + assert ast.where.conditions[0].operator == "IS NULL" + assert ast.where.conditions[1].operator == ">" + + def test_having_multiple_conditions(self): + ast = parse( + "SELECT city, COUNT(*) AS cnt FROM data GROUP BY city HAVING cnt > 1 AND cnt < 100" + ) + assert ast.having is not None + assert len(ast.having.conditions) == 2 + + def test_distinct_with_where_and_limit(self): + ast = parse("SELECT DISTINCT city FROM data WHERE age > 25 LIMIT 5") + assert ast.distinct is True + assert ast.where is not None + assert ast.limit == 5 + + def test_full_outer_join_with_where(self): + ast = parse("SELECT * FROM left FULL OUTER JOIN right ON id = id WHERE age > 25") + assert ast.join.join_type == "FULL" + assert ast.where is not None + + def test_left_outer_join(self): + """LEFT OUTER JOIN should also work (OUTER is optional)""" + ast = parse("SELECT * FROM left LEFT OUTER JOIN right ON id = id") + assert ast.join.join_type == "LEFT"