Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion sqlstream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
Parquet, and JSON with lazy evaluation and intelligent optimizations.
"""

__version__ = "0.1.0"
try:
from importlib.metadata import version

__version__ = version("sqlstream")
except Exception:
__version__ = "0.6.3"

# Main API
from sqlstream.core.query import query
Expand Down
4 changes: 2 additions & 2 deletions sqlstream/core/duckdb_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def _register_sources_as_dataframes(self, sources: dict[str, str]):
# DuckDB can query pandas DataFrames directly!
self.conn.register(table_name, df)

except Exception:
except (ValueError, FileNotFoundError, OSError, KeyError):
self._register_source(table_name, file_path)

def _register_source(self, table_name: str, file_path: str, read_only: bool = True):
Expand Down Expand Up @@ -287,7 +287,7 @@ def _ensure_httpfs(self):
try:
self.conn.execute("INSTALL httpfs")
self.conn.execute("LOAD httpfs")
except Exception:
except duckdb.Error:
# httpfs might already be loaded or not needed
pass

Expand Down
9 changes: 9 additions & 0 deletions sqlstream/core/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections.abc import Callable, Iterator
from typing import Any

from sqlstream.operators.distinct import Distinct
from sqlstream.operators.filter import Filter
from sqlstream.operators.groupby import GroupByOperator
from sqlstream.operators.join import HashJoinOperator
Expand Down Expand Up @@ -140,6 +141,10 @@ def _build_plan(
raise ValueError("GROUP BY requires aggregate functions in SELECT")
plan = GroupByOperator(plan, ast.group_by, ast.aggregates, ast.columns)

# Add HAVING filter after GroupBy (filters on aggregated results)
if ast.having:
plan = Filter(plan, ast.having.conditions)

# Add OrderBy if ORDER BY clause exists
if ast.order_by:
plan = OrderByOperator(plan, ast.order_by)
Expand All @@ -149,6 +154,10 @@ def _build_plan(
if not ast.group_by:
plan = Project(plan, ast.columns)

# Add Distinct if SELECT DISTINCT
if ast.distinct:
plan = Distinct(plan)

# Add Limit if LIMIT clause exists
if ast.limit is not None:
plan = Limit(plan, ast.limit)
Expand Down
25 changes: 24 additions & 1 deletion sqlstream/core/pandas_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def execute(
if ast.group_by or ast.aggregates:
df = self._apply_groupby(df, ast)

# Step 4b: Apply HAVING filter (after GROUP BY)
if ast.having:
df = self._apply_filter(df, ast.having.conditions)

# Step 5: Apply ORDER BY
if ast.order_by:
df = self._apply_orderby(df, ast)
Expand All @@ -80,7 +84,11 @@ def execute(
if not ast.group_by: # GroupBy already handled columns
df = self._apply_projection(df, ast.columns)

# Step 7: Apply LIMIT
# Step 7: Apply DISTINCT
if ast.distinct:
df = df.drop_duplicates()

# Step 8: Apply LIMIT
if ast.limit is not None:
df = df.head(ast.limit)

Expand Down Expand Up @@ -247,6 +255,7 @@ def _apply_join(
"INNER": "inner",
"LEFT": "left",
"RIGHT": "right",
"FULL": "outer",
}

how = join_type_map.get(ast.join.join_type, "inner")
Expand All @@ -271,6 +280,20 @@ def _apply_filter(self, df: pd.DataFrame, conditions: list[Condition]) -> pd.Dat
op = condition.operator
value = condition.value

# Handle IS NULL / IS NOT NULL (column may or may not exist)
if op == "IS NULL":
if col not in df.columns:
# Missing column is treated as all NULL
continue # mask stays True (all rows match)
mask &= df[col].isna()
continue
elif op == "IS NOT NULL":
if col not in df.columns:
mask &= False
continue
mask &= df[col].notna()
continue

if col not in df.columns:
# Column doesn't exist, skip this filter
continue
Expand Down
3 changes: 0 additions & 3 deletions sqlstream/core/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ def _can_parse_with_custom_parser(sql: str) -> bool:
- Complex expressions (CASE, CAST, EXTRACT)
- Subqueries
- Set operations (UNION, INTERSECT, EXCEPT)
- HAVING clause

Args:
sql: SQL query string to analyze

Expand All @@ -72,7 +70,6 @@ def _can_parse_with_custom_parser(sql: str) -> bool:
"OVER", # Window functions
"PARTITION BY", # Window functions
"WINDOW", # Window functions
"HAVING", # HAVING clause
"UNION", # Set operations
"INTERSECT", # Set operations
"EXCEPT", # Set operations
Expand Down
36 changes: 36 additions & 0 deletions sqlstream/operators/distinct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Distinct operator - implements SELECT DISTINCT

Deduplicates rows by keeping a set of seen row signatures.
"""

from collections.abc import Iterator
from typing import Any

from sqlstream.operators.base import Operator


class Distinct(Operator):
"""
DISTINCT operator - removes duplicate rows.

Keeps a set of seen row tuples and only yields rows
that haven't been seen before.
"""

def __init__(self, child: Operator):
super().__init__(child)

def __iter__(self) -> Iterator[dict[str, Any]]:
seen: set[tuple] = set()
for row in self.child:
# Convert row to a hashable key
key = tuple(
sorted((k, str(v) if isinstance(v, (list, dict)) else v) for k, v in row.items())
)
if key not in seen:
seen.add(key)
yield row

def __repr__(self) -> str:
return "Distinct()"
66 changes: 20 additions & 46 deletions sqlstream/operators/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from sqlstream.operators.base import Operator
from sqlstream.sql.ast_nodes import Condition
from sqlstream.utils.condition_eval import evaluate_condition
from sqlstream.utils.errors import ColumnNotFoundError


class Filter(Operator):
Expand Down Expand Up @@ -39,7 +41,12 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
2. If all are True, yield the row
3. Otherwise, skip it
"""
validated = False
for row in self.child:
# Validate column names on first row for better error messages
if not validated:
self._validate_columns(row)
validated = True
if self._matches(row):
yield row

Expand All @@ -58,53 +65,20 @@ def _matches(self, row: dict[str, Any]) -> bool:
return False
return True

def _evaluate_condition(self, row: dict[str, Any], condition: Condition) -> bool:
"""
Evaluate a single condition against a row

Args:
row: Row to check
condition: Condition to evaluate
@staticmethod
def _evaluate_condition(row: dict[str, Any], condition: Condition) -> bool:
"""Evaluate a single condition against a row (delegates to shared utility)."""
return evaluate_condition(row, condition)

Returns:
True if condition is satisfied
"""
# Get column value
if condition.column not in row:
return False

value = row[condition.column]

# Handle NULL values
if value is None:
return False

# Get expected value
expected = condition.value

# Evaluate operator
op = condition.operator

try:
if op == "=":
return value == expected
elif op == ">":
return value > expected
elif op == "<":
return value < expected
elif op == ">=":
return value >= expected
elif op == "<=":
return value <= expected
elif op == "!=":
return value != expected
else:
# Unknown operator - default to True to avoid filtering
return True

except TypeError:
# Type mismatch (e.g., comparing string to int)
return False
def _validate_columns(self, row: dict[str, Any]) -> None:
"""Check that all condition columns exist in the row, raising helpful errors."""
available = list(row.keys())
for condition in self.conditions:
# Skip IS NULL/IS NOT NULL — missing column is valid (treated as NULL)
if condition.operator in ("IS NULL", "IS NOT NULL"):
continue
if condition.column not in row:
raise ColumnNotFoundError(condition.column, available)

def __repr__(self) -> str:
cond_str = " AND ".join(str(c) for c in self.conditions)
Expand Down
7 changes: 7 additions & 0 deletions sqlstream/operators/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
value = row.get(agg_func.column) if agg_func.column != "*" else None
aggregators[i].update(value)

# Handle empty input with no GROUP BY columns (SQL standard):
# SELECT COUNT(*) FROM empty_table should return one row with COUNT=0
if not groups and not self.group_by_columns:
aggregators = self._create_aggregators()
yield self._build_output_row((), aggregators)
return

# Yield one row per group
for group_key, aggregators in groups.items():
row = self._build_output_row(group_key, aggregators)
Expand Down
55 changes: 46 additions & 9 deletions sqlstream/operators/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
self.right_key = right_key

# Validate join type
if self.join_type not in ("INNER", "LEFT", "RIGHT"):
if self.join_type not in ("INNER", "LEFT", "RIGHT", "FULL"):
raise ValueError(f"Unsupported join type: {join_type}")

def __iter__(self) -> Iterator[dict[str, Any]]:
Expand All @@ -76,6 +76,8 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
yield from self._left_join()
elif self.join_type == "RIGHT":
yield from self._right_join()
elif self.join_type == "FULL":
yield from self._full_outer_join()

def _inner_join(self) -> Iterator[dict[str, Any]]:
"""
Expand Down Expand Up @@ -152,6 +154,38 @@ def _right_join(self) -> Iterator[dict[str, Any]]:
if (join_key, idx) not in matched_right_rows:
yield self._merge_rows(None, right_row)

def _full_outer_join(self) -> Iterator[dict[str, Any]]:
"""
Execute FULL OUTER JOIN

Returns all rows from both tables. Matched rows are joined,
unmatched rows from either side have NULL for the other side's columns.
"""
hash_table = self._build_hash_table()
matched_right_rows: set[tuple] = set()
left_columns: list[str] | None = None

# Probe phase: Scan left table
for left_row in self.left:
if left_columns is None:
left_columns = list(left_row.keys())

join_key = left_row.get(self.left_key)

if join_key is not None and join_key in hash_table:
for idx, right_row in enumerate(hash_table[join_key]):
yield self._merge_rows(left_row, right_row)
matched_right_rows.add((join_key, idx))
else:
# No match - output left row with NULL for right columns
yield self._merge_rows(left_row, None)

# Output unmatched right rows with NULL for left columns
for join_key, right_rows in hash_table.items():
for idx, right_row in enumerate(right_rows):
if (join_key, idx) not in matched_right_rows:
yield self._merge_rows(None, right_row, left_columns=left_columns)

def _build_hash_table(self) -> dict[Any, list[dict[str, Any]]]:
"""
Build hash table from right table
Expand Down Expand Up @@ -180,16 +214,20 @@ def _build_hash_table(self) -> dict[Any, list[dict[str, Any]]]:
return hash_table

def _merge_rows(
self, left_row: dict[str, Any] | None, right_row: dict[str, Any] | None
self,
left_row: dict[str, Any] | None,
right_row: dict[str, Any] | None,
left_columns: list[str] | None = None,
) -> dict[str, Any]:
"""
Merge left and right rows into a single output row

Handles column name conflicts by prefixing with table names if needed.

Args:
left_row: Row from left table (None for RIGHT JOIN with no match)
right_row: Row from right table (None for LEFT JOIN with no match)
left_row: Row from left table (None for RIGHT/FULL JOIN with no match)
right_row: Row from right table (None for LEFT/FULL JOIN with no match)
left_columns: Known left column names (for producing NULLs when left_row is None)

Returns:
Merged row dictionary
Expand All @@ -199,11 +237,10 @@ def _merge_rows(
# Add left columns
if left_row is not None:
result.update(left_row)
elif right_row is not None:
# For RIGHT JOIN with no match, add NULL for all left columns
# We don't know the left schema, so we just don't add anything
# The columns will be added on first matched row
pass
elif left_columns is not None:
# For RIGHT/FULL JOIN with no match, add NULL for known left columns
for col in left_columns:
result[col] = None

# Add right columns
if right_row is not None:
Expand Down
Loading
Loading