diff --git a/.gitignore b/.gitignore index 15eb7d1..be177a8 100644 --- a/.gitignore +++ b/.gitignore @@ -157,6 +157,7 @@ data/ scratch/ *_sample_*.csv test_*.py +!tests/test_*.py debug_*.py # Debug files @@ -172,3 +173,18 @@ results/ profiles/ docs/*.html docs/index_files/ + +# Local run artifacts (shard lists, input lists, temp dirs) +*.txt +.tmp/ +archive_quarantine/ +tmp_polars_run_*/ +subagent_packages/ + +# Operator-only env (do not commit) +.env.comed + +# Pricing pilot data +data/pilot_interval_parquet/ +CLAUDE.md +.cursor/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 23d8811..091b2d6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,6 +13,7 @@ repos: args: [--autofix, --no-sort-keys] - id: end-of-file-fixer - id: trailing-whitespace + - id: detect-private-key - repo: https://github.com/astral-sh/ruff-pre-commit rev: "v0.11.5" @@ -20,3 +21,11 @@ repos: - id: ruff args: [--exit-non-zero-on-fix] - id: ruff-format + + - repo: local + hooks: + - id: forbid-secrets + name: Block secrets and credential files + entry: "bash -c 'echo BLOCKED: secrets/credential file staged for commit >&2; exit 1'" + language: system + files: '(\.env$|\.env\.|\.secrets|\.secret|credentials\.json|\.pem$|\.key$|\.p12$|\.pfx$|\.jks$)' diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..b3a4c65 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,254 @@ +# Agent guide: smart-meter-analysis + +This file orients AI agents so they can work effectively in this repo — building data pipelines, running analysis, and managing regulatory-grade datasets — without reading the entire codebase. + +## What this repo is + +**smart-meter-analysis** is [Switchbox's](https://switch.box/) smart meter data pipeline and analysis repo. Switchbox is a nonprofit think tank that produces rigorous, accessible data on U.S. state climate policy for advocates, policymakers, and the public. + +This repo processes ComEd smart meter data for the [Citizens Utility Board](https://www.citizensutilityboard.org/) (CUB) of Illinois, supporting regulatory proceedings that examine utility rate equity and energy affordability. It combines a **data engineering pipeline** (CSV-to-Parquet compaction) with **statistical analysis** (rate simulations, clustering, regression) to produce regulatory-grade datasets and publication-ready figures. + +The main inputs are ComEd interval meter data (CSV and Parquet), Census demographic data, and geographic shapefiles. The main outputs are compacted Parquet datasets, statistical analyses, GeoJSON maps, and figures for regulatory testimony. + +The companion repo [reports2](https://github.com/switchbox-data/reports2) produces the final published reports from these analysis outputs. See its AGENTS.md for report-writing conventions. + +## Layout + +| Path | Purpose | +| --------------------------------- | ------------------------------------------------------------------------------------------ | +| `scripts/csv_to_parquet/` | CSV-to-Parquet migration pipeline: ingestion, compaction, validation. | +| `scripts/analysis/` | Analysis scripts: rate comparisons, clustering, regression. | +| `scripts/bench/` | Benchmarking scripts for pipeline performance. | +| `smart_meter_analysis/` | Installable Python package (shared utilities). | +| `analysis/` | Exploratory analysis notebooks and one-off investigations. | +| `tests/` | Pytest test suite. | +| `config/` | Configuration files for pipeline runs. | +| `data/` | Local data cache (gitignored — real data lives on S3). | +| `results/` | Analysis outputs: tables, summary statistics. | +| `figures/` | Generated plots and maps. | +| `docs/` | MkDocs documentation. | +| `infra/` | Terraform and EC2 infrastructure (see `infra/README.md` for VM setup). | +| `logs/` | Pipeline run logs. | +| `archive/` | Archived scripts and old approaches (reference only). | +| `.devcontainer/` | Dev container configuration (Dockerfile, devcontainer.json). | +| `Justfile` | Root task runner: `install`, `check`, `test`, `dev-setup`, `dev-login`, `dev-teardown`. | +| `pyproject.toml` | Python dependencies (managed by uv). | + +## Pipeline architecture + +The primary data engineering work is compacting raw ComEd CSV exports (~30,000 files per month) into sorted, validated Parquet files suitable for regulatory testimony. This pipeline is the core of the repo — understand it before touching any pipeline code. + +### Data flow + +```text +Raw CSVs (S3/local) → Ingestion → Monthly Parquet files → Compaction → Validation → Validated output +``` + +1. **Ingestion** (`scripts/csv_to_parquet/migrate_month_runner.py`): Reads raw CSVs, converts to Parquet with consistent schema. +2. **Compaction** (`scripts/csv_to_parquet/compact_month_output.py`): Merges monthly Parquet files into fewer, larger files with correct sort order. Uses a two-pass k-way merge-sort to handle ~60 input files and up to 500M rows per month. +3. **Validation**: Checks schema consistency, sort order, null thresholds, duplicate detection, row count expectations. + +### Critical constraints + +These constraints are non-negotiable. They exist because pipeline outputs support regulatory testimony subject to cross-examination. + +- **Memory**: Pipeline runs on EC2 (m7i.2xlarge: 8 vCPUs, 32 GB RAM). All operations must stay within ~28 GB working memory. Use PyArrow `iter_batches()` for streaming reads. Never `collect()` or `to_pandas()` on full monthly datasets. +- **Sort order**: All compacted output must be sorted by `(account_id, date)`. Downstream analysis and regulatory reproducibility depend on it. Verify sort order explicitly after every compaction operation — do not trust it implicitly. +- **Data quality**: Every transformation must be auditable and reproducible. No silent data loss. No silent duplicate creation. +- **Naming conventions**: Output files follow Spark conventions with `_SUCCESS.json` metadata markers. + +### Reading large data + +**Polars (preferred for analysis):** + +```python +import polars as pl + +# Lazy scan — stays out of memory until .collect() +lf = pl.scan_parquet("/data.sb/comed/interval_data/2023/*.parquet") +result = lf.filter(pl.col("account_id") == "12345").collect() +``` + +**PyArrow (preferred for pipeline I/O):** + +```python +import pyarrow.parquet as pq + +# Streaming read for large files — never loads full file into memory +pf = pq.ParquetFile("/data.sb/comed/interval_data/2023/07.parquet") +for batch in pf.iter_batches(batch_size=100_000): + process(batch) +``` + +Stay in lazy execution as long as possible. Only `collect()` / `compute()` when you need the data in memory and have filtered first. + +### What NOT to do in pipeline code + +- Do not load full months into memory with `pl.read_parquet()` or `pd.read_parquet()`. Use `pl.scan_parquet()` or PyArrow's `pq.ParquetFile` with `iter_batches()`. +- Do not use pandas in production pipeline code. Use PyArrow for I/O and Polars for transforms. +- Do not hardcode file paths. Use config files or CLI arguments. +- Do not skip validation after compaction. Every output file must pass schema and sort-order checks. +- Do not assume sort order is preserved through joins or concatenations. Verify explicitly. + +## Analysis conventions + +Analysis scripts examine how alternative electricity rate structures (DTOU, Rate BEST) affect different customer segments. The analysis feeds into reports published via the reports2 repo. + +### Methods + +- **Rate simulation**: Computing hypothetical bills under alternative tariff structures for ~328,000 Chicago households. +- **DTW clustering**: Identifying distinct electricity usage patterns from interval data (19.8M household-day observations). +- **Multinomial logistic regression**: Quantifying how demographics explain usage pattern membership. +- **Geographic analysis**: GeoJSON maps showing rate impact by census block group. +- **Income regression**: Scatterplots examining equity (regressivity/progressivity) across income levels. + +### Output standards + +Analysis outputs may be used in regulatory testimony. Every number must be traceable to source data through documented transformations. + +- Figures go to `figures/` directory. +- Summary tables go to `results/` directory. +- All statistics must be reproducible from source data. +- No hardcoded statistics — compute from data. +- Document assumptions in code comments explaining _why_ a threshold, filter, or parameter was chosen. + +## Working with data + +All data lives on S3, mounted at `/data.sb/` on the EC2 VM. Never store data files in git. + +### Storage paths + +| Location | Path | Size | Persistent? | Use | +| ----------- | ----------- | --------- | ----------- | -------------------------------- | +| S3 mount | `/data.sb/` | Unlimited | Yes | Source data, shared datasets | +| EBS volume | `/ebs/` | 500 GB | Yes | Home directories, persistent work | +| Local cache | `data/` | — | No (git) | Temporary local data (gitignored) | + +### S3 naming conventions + +```text +s3://data.sb/// +``` + +- Lowercase with underscores. Date suffix reflects when data was downloaded. +- Always use a dataset directory, even for single files. +- Prefer Parquet format. + +### Local caching + +`data/` is gitignored. Use it for caching downloads and intermediate results, but the analysis must be reproducible from S3 alone. Never reference local-only files in committed code without a clear download/generation step. + +## Code quality + +Before considering any change done: + +- **`just check`**: Runs pre-commit hooks (ruff-check, ruff-format, trailing whitespace, end-of-file newline, YAML/JSON/TOML validation, no large files, no merge conflict markers). +- **`just test`**: Runs pytest suite. Add or extend tests for new or changed behavior. + +Python formatting: Ruff for formatting and linting. Type checking: mypy or ty. + +## How to work in this repo + +### Tasks + +Use `just` as the main interface. The root `Justfile` handles dev tasks and VM management. + +### Dependencies + +- **Python**: `uv add ` (updates `pyproject.toml` + `uv.lock`). Never use `pip install`. + +### Computing contexts + +- Data scientists' laptops (Mac with Apple Silicon) +- EC2 VM via `just dev-login` (`m7i.2xlarge`: 8 vCPUs, 32 GB RAM, 500 GB EBS) +- Be aware of which context you're in (affects available memory, S3 latency, and data access patterns). + +### AWS + +Data is on S3 in `us-west-2`. The EC2 VM mounts S3 at `/data.sb/`. See `infra/README.md` for full VM setup, login, and teardown instructions. Always run `just dev-teardown` when done to avoid unnecessary AWS costs. + +## Commits, branches, and PRs + +### Commits + +- **Atomic**: One logical change per commit. +- **Message format**: Imperative verb, <50 char summary (e.g., "Fix compaction sort-key overlap"). +- **WIP commits**: Prefix with `WIP:` for work-in-progress snapshots. + +### Branches and PRs + +- **PR title** MUST start with `[area]` (e.g., `[pipeline] Fix memory overflow in compaction stage`) — this becomes the squash-merge commit message on `main`. +- **Create PRs early** (draft is fine). This gives the team visibility into in-flight work. +- PRs should **merge within the sprint**; break large work into smaller PRs if needed. +- **Delete branches** after merging. +- **Description**: Don't duplicate the issue. Write: high-level overview, reviewer focus, non-obvious implementation details. +- **Close the GitHub issue**: Include `Closes #` (not the Linear identifier). +- Do not add "Made with Cursor" or LLM attribution. + +## Issue conventions + +All work is tracked via Linear issues (which sync to GitHub Issues). When creating or updating tickets, use the Linear MCP tools. Every new issue MUST satisfy the following before it is created: + +### Issue fields + +- **Type**: One of **Code** (delivered via commits/PRs), **Research** (starts with a question, findings documented in issue comments), or **Other** (proposals, graphics, coordination — deliverables vary). +- **Title**: `[area] Brief description` starting with a verb (e.g., `[pipeline] Add sort-order validation to compaction stage`). +- **What**: High-level description. Anyone can understand scope at a glance. +- **Why**: Context, importance, value. +- **How** (skip only when the What is self-explanatory and implementation is trivial): + - For Code issues: numbered implementation steps, trade-offs, dependencies. + - For Research issues: background context, options to consider, evaluation criteria. +- **Deliverables**: Concrete, verifiable outputs that define "done": + - Code: "PR that adds ...", "Tests for ...", "Updated `data/` directory with ..." + - Research: "Comment in this issue documenting ... with rationale and sources" + - Other: "Google Doc at ...", "Slide deck for ...", link to external deliverable + - Never vague ("Finish the analysis") or unmeasurable ("Make it better"). +- **Project**: Must be set. +- **Status**: Default to Backlog. Options: Backlog, To Do, In Progress, Under Review, Done. +- **Milestone**: Set when applicable (strongly encouraged). +- **Assignee**: Set if known. +- **Priority**: Set when urgency/importance is clear. + +### Status transitions + +Keep status updated as work progresses — this is critical for team visibility: + +- **Backlog** -> **To Do**: Picked for the current sprint +- **To Do** -> **In Progress**: Work has started (branch created for code issues) +- **In Progress** -> **Under Review**: PR ready for review, or findings documented +- **Under Review** -> **Done**: PR merged (auto-closes), or reviewer approves and closes + +## Conventions agents should follow + +1. **Memory first.** Always consider RAM constraints. Use streaming/lazy patterns for any dataset over 1 GB. +2. **Sort order is sacred.** Compacted output must be sorted by `(account_id, date)`. Verify after every compaction operation. +3. **No pandas in pipeline code.** PyArrow for I/O, Polars for transforms. +4. **Validate everything.** After compaction, after joins, after any transformation that could silently drop or duplicate rows. +5. **Use Context7.** Always look up current library docs before writing PyArrow or Polars code. These APIs change frequently. Do not rely on training data for API signatures. +6. **Run `just check`** before considering a change done. +7. **Config over hardcoding.** File paths, thresholds, and parameters belong in config files or CLI arguments, not inline. +8. **Data never goes in git.** S3 and `/data.sb/` for real data, `data/` (gitignored) for local caches. +9. **Tests for pipeline changes.** Any change to ingestion, compaction, or validation must have corresponding tests. +10. **Document assumptions.** Pipeline code is regulatory evidence. Comments should explain _why_ a threshold was chosen, why a filter is applied, what the expected data shape is. + +## MCP Tools + +### Context7 + +When writing or modifying code that uses a library, use the Context7 MCP server to fetch up-to-date documentation. Do not rely on training data for API signatures or usage patterns. + +### Linear + +When a task involves creating, updating, or referencing issues, use the Linear MCP server to interact with the workspace directly. Follow the issue conventions above. + +## Quick reference + +| Command | Where | What it does | +| ----------------------- | ----- | ----------------------------------------- | +| `just install` | Root | Set up dev environment | +| `just check` | Root | Lint, format, pre-commit hooks | +| `just test` | Root | Run pytest suite | +| `just dev-setup` | Root | Spin up EC2 VM (one-time admin) | +| `just dev-login` | Root | Log in to EC2 VM | +| `just dev-teardown` | Root | Stop VM, preserve data volume | +| `just dev-teardown-all` | Root | Destroy VM and all data (permanent) | diff --git a/Justfile b/Justfile index e38229e..01dabc2 100644 --- a/Justfile +++ b/Justfile @@ -9,7 +9,6 @@ default: # ============================================================================= install: - echo "🚀 Creating virtual environment using uv" uv sync uv run pre-commit install @@ -20,8 +19,6 @@ update: # 🔍 AWS # ============================================================================= -# Authenticate with AWS via SSO (for manual AWS CLI usage like S3 access) -# Automatically configures SSO if not already configured aws: .devcontainer/devpod/aws.sh @@ -29,159 +26,248 @@ aws: # 🚀 DEVELOPMENT ENVIRONMENT # ============================================================================= -# Ensure Terraform is installed (internal dependency). Depends on aws so credentials -# are valid before any Terraform or infra script runs. _terraform: aws bash infra/install-terraform.sh -# Set up EC2 instance (run once by admin) -# Idempotent: safe to run multiple times dev-setup: _terraform bash infra/dev-setup.sh -# Destroy EC2 instance but preserve data volume (to recreate, run dev-setup again) dev-teardown: _terraform bash infra/dev-teardown.sh -# Destroy everything including data volume (WARNING: destroys all data!) dev-teardown-all: _terraform bash infra/dev-teardown-all.sh -# User login (run by any authorized user) dev-login: aws bash infra/dev-login.sh # ============================================================================= -# 🔄 DATA PIPELINE +# 🔄 DATA PIPELINE (ANALYTICS) # ============================================================================= -test-pipeline-local: - uv run python scripts/run_comed_pipeline.py --source local - pipeline YEAR_MONTH: uv run python scripts/run_comed_pipeline.py --year-month {{YEAR_MONTH}} --source s3 -test-pipeline YEAR_MONTH MAX_FILES="10": - uv run mprof run scripts/run_comed_pipeline.py --year-month {{YEAR_MONTH}} --max-files {{MAX_FILES}} --source s3 - pipeline-skip-download YEAR_MONTH: uv run python scripts/run_comed_pipeline.py --year-month {{YEAR_MONTH}} --skip-download --source s3 pipeline-debug YEAR_MONTH: uv run python scripts/run_comed_pipeline.py --year-month {{YEAR_MONTH}} --debug --source s3 -download-transform YEAR_MONTH MAX_FILES="": - uv run python -m smart_meter_analysis.aws_loader {{YEAR_MONTH}} {{MAX_FILES}} - # ============================================================================= -# 🧪 SAMPLE DATA (S3 + Synthetic) +# 🔄 CSV → PARQUET MIGRATION (PORTABLE + OPEN SOURCE SAFE) # ============================================================================= +# Operator configuration: +# .env.comed (gitignored) may define: +# COMED_S3_PREFIX +# COMED_MIGRATE_OUT_BASE +# COMED_MIGRATE_BATCH_SIZE +# COMED_MIGRATE_WORKERS +# CONTINUE_ON_ERROR +# COMED_ORCHESTRATOR_LOG_DIR + +S3_PREFIX := env_var_or_default("COMED_S3_PREFIX", "") +MIGRATE_OUT_BASE := env_var_or_default("COMED_MIGRATE_OUT_BASE", "") +MIGRATE_BATCH_SIZE := env_var_or_default("COMED_MIGRATE_BATCH_SIZE", "100") +MIGRATE_WORKERS := env_var_or_default("COMED_MIGRATE_WORKERS", "6") +CONTINUE_ON_ERROR := env_var_or_default("CONTINUE_ON_ERROR", "") +ORCHESTRATOR_LOG_DIR := env_var_or_default("COMED_ORCHESTRATOR_LOG_DIR", "") +OUT_ROOT_TEMPLATE := env_var_or_default("COMED_OUT_ROOT_TEMPLATE", "") + +# ----------------------------------------------------------------------------- +# List available YYYYMM months from S3 +# ----------------------------------------------------------------------------- + +months-from-s3 OUT_FILE PREFIX=S3_PREFIX: + #!/usr/bin/env bash + set -euo pipefail + if [ -f ".env.comed" ]; then source ".env.comed"; fi -download-samples YEAR_MONTH="202308" NUM_FILES="5": - uv run python scripts/testing/download_samples_from_s3.py --year-month {{YEAR_MONTH}} --num-files {{NUM_FILES}} + prefix="{{PREFIX}}" + if [ -z "$prefix" ]; then prefix="${COMED_S3_PREFIX:-}"; fi + if [ -z "$prefix" ]; then + echo "ERROR: S3 prefix not set. Use COMED_S3_PREFIX or PREFIX=..." >&2 + exit 1 + fi + prefix="${prefix%/}/" -download-samples-small YEAR_MONTH="202308": - uv run python scripts/testing/download_samples_from_s3.py --year-month {{YEAR_MONTH}} --num-files 3 + AWS_PAGER="" aws s3 ls "$prefix" \ + | awk '/PRE/ {gsub(/\//,"",$2); if ($2 ~ /^[0-9]{6}$/) print $2}' \ + | sort -u > "{{OUT_FILE}}" -download-samples-large YEAR_MONTH="202308": - uv run python scripts/testing/download_samples_from_s3.py --year-month {{YEAR_MONTH}} --num-files 10 + echo "Wrote $(wc -l < "{{OUT_FILE}}") months to {{OUT_FILE}}" -generate-samples: - uv run python scripts/testing/generate_sample_data.py +# ----------------------------------------------------------------------------- +# Single-month migration (EC2 only) +# ----------------------------------------------------------------------------- -generate-samples-custom ACCOUNTS DAYS START_DATE: - uv run python scripts/testing/generate_sample_data.py --num-accounts {{ACCOUNTS}} --num-days {{DAYS}} --start-date {{START_DATE}} +migrate-month YEAR_MONTH: + #!/usr/bin/env bash + set -euo pipefail + if [ -f ".env.comed" ]; then source ".env.comed"; fi -validate-local: - uv run python scripts/diagnostics/validate_pipeline.py --input data/processed/comed_samples.parquet + if [ ! -d /ebs ]; then + echo "ERROR: /ebs not found. Must run on EC2 with EBS mounted." >&2 + exit 1 + fi -inspect-dst-local: - uv run python scripts/diagnostics/inspect_dst_days.py --input data/processed/comed_samples.parquet --start 2023-11-01 --end 2023-11-10 + prefix="{{S3_PREFIX}}" + if [ -z "$prefix" ]; then prefix="${COMED_S3_PREFIX:-}"; fi + if [ -z "$prefix" ]; then + echo "ERROR: S3 prefix not set. Use COMED_S3_PREFIX or S3_PREFIX=..." >&2 + exit 1 + fi + prefix="${prefix%/}/" -view-sample: - @ls data/samples/*.csv 2>/dev/null | head -1 | xargs head -n 5 || echo "No samples found. Run: just download-samples" + bucket=$(echo "$prefix" | sed 's|^s3://||' | cut -d/ -f1) -clean-samples: - rm -rf data/samples/*.csv - @echo "Sample data cleaned" + out_base="{{MIGRATE_OUT_BASE}}" + if [ -z "$out_base" ]; then out_base="${COMED_MIGRATE_OUT_BASE:-}"; fi + if [ -z "$out_base" ]; then out_base="/ebs/home/$(whoami)/runs"; fi -# ============================================================================= -# 🗄️ DATA COLLECTION -# ============================================================================= + INPUT_LIST="$HOME/s3_paths_{{YEAR_MONTH}}_full.txt" + OUT_ROOT="${out_base}/out_{{YEAR_MONTH}}_production" + + AWS_PAGER="" aws s3 ls "${prefix}{{YEAR_MONTH}}/" --recursive \ + | awk -v b="s3://${bucket}/" -v m="{{YEAR_MONTH}}" 'match($4,/ANONYMOUS_DATA_([0-9]{6})_/,a) && a[1]==m {print b $4}' \ + | sort -u > "$INPUT_LIST" -download-ameren: - uv run python scripts/data_collection/ameren_scraper.py + if [ "$(wc -l < "$INPUT_LIST")" -eq 0 ]; then + echo "ERROR: No CSVs found for {{YEAR_MONTH}}" >&2 + exit 1 + fi -download-ameren-force: - uv run python scripts/data_collection/ameren_scraper.py --force + echo "Wrote $(wc -l < "$INPUT_LIST") CSVs to $INPUT_LIST" -download-ameren-debug: - uv run python scripts/data_collection/ameren_scraper.py --debug + uv run python scripts/csv_to_parquet/migrate_month_runner.py \ + --input-list "$INPUT_LIST" \ + --out-root "$OUT_ROOT" \ + --year-month "{{YEAR_MONTH}}" \ + --batch-size "{{MIGRATE_BATCH_SIZE}}" \ + --workers "{{MIGRATE_WORKERS}}" \ + --resume \ + --exec-mode lazy_sink -# ============================================================================= -# 🏙️ CHICAGO-WIDE SAMPLER -# ============================================================================= +# ----------------------------------------------------------------------------- +# Multi-month migration (sequential) +# ----------------------------------------------------------------------------- + +migrate-months MONTHS_FILE: + #!/usr/bin/env bash + set -euo pipefail + if [ -f ".env.comed" ]; then source ".env.comed"; fi -sample-city zips start end out bucket prefix target="200" cm90="": + if [ ! -d /ebs ]; then + echo "ERROR: /ebs not found. Must run on EC2." >&2 + exit 1 + fi + + log_dir="{{ORCHESTRATOR_LOG_DIR}}" + if [ -z "$log_dir" ]; then log_dir="/ebs/home/$(whoami)/runs/_orchestrator_logs"; fi + mkdir -p "$log_dir" + + ts=$(date -u +%Y%m%dT%H%M%SZ) + log_file="$log_dir/migrate_${ts}.log" + + succeeded=0; failed=0; skipped=0; failures="" + + log() { echo "[$(date -u +%Y-%m-%dT%H:%M:%SZ)] $*" | tee -a "$log_file"; } + + while IFS= read -r line || [ -n "$line" ]; do + month=$(echo "$line" | sed 's/#.*//' | tr -d '[:space:]') + [ -z "$month" ] && continue + if ! echo "$month" | grep -qE '^[0-9]{6}$'; then + log "SKIP invalid month: $month" + skipped=$((skipped + 1)) + continue + fi + + rc=0 + log "START $month" + just migrate-month "$month" 2>&1 | tee -a "$log_file" || rc=$? + log "END $month rc=$rc" + + if [ "$rc" -eq 0 ]; then + succeeded=$((succeeded + 1)) + else + failed=$((failed + 1)) + failures="$failures $month" + if [ "{{CONTINUE_ON_ERROR}}" != "1" ]; then + log "ABORT on first failure" + break + fi + fi + done < "{{MONTHS_FILE}}" + + log "DONE succeeded=$succeeded failed=$failed skipped=$skipped" + [ "$failed" -eq 0 ] + +# ----------------------------------------------------------------------------- +# Validation +# ----------------------------------------------------------------------------- + +validate-month YEAR_MONTH OUT_ROOT MAX_FILES="50" CHECK_MODE="sample" DST="1": #!/usr/bin/env bash set -euo pipefail - CM90="{{cm90}}" - if [ -n "$CM90" ]; then EXTRA="--cm90 $CM90"; else EXTRA=""; fi - python scripts/tasks/task_runner.py sample \ - --zips "{{zips}}" \ - --start "{{start}}" \ - --end "{{end}}" \ - --bucket "{{bucket}}" \ - --prefix-base "{{prefix}}" \ - --target-per-zip {{target}} \ - --out "{{out}}" \ - $EXTRA - -sample-city-file zips_file start end out bucket prefix target="100" cm90="": + run_base="{{OUT_ROOT}}/_runs/{{YEAR_MONTH}}" + run_dir=$(ls -1dt "$run_base"/*/ 2>/dev/null | head -1 || true) + run_dir="${run_dir%/}" + + if [ -z "$run_dir" ]; then + run_dir="$run_base/_unknown" + mkdir -p "$run_dir" + fi + + ts=$(date -u +%Y%m%dT%H%M%SZ) + report="$run_dir/validation_${ts}.json" + + dst_flag="" + if [ "{{DST}}" = "1" ]; then dst_flag="--dst-month-check"; fi + + python3 scripts/csv_to_parquet/validate_month_output.py \ + --out-root "{{OUT_ROOT}}" \ + --check-mode "{{CHECK_MODE}}" \ + --max-files "{{MAX_FILES}}" \ + $dst_flag \ + --run-dir "$run_dir" \ + --output-report "$report" + + echo "Report: $report" + +validate-months MONTHS_FILE OUT_BASE_DIR="/ebs/home/$(whoami)/runs": #!/usr/bin/env bash set -euo pipefail - CM90="{{cm90}}" - if [ -n "$CM90" ]; then EXTRA="--cm90 $CM90"; else EXTRA=""; fi - python scripts/tasks/task_runner.py sample \ - --zips-file "{{zips_file}}" \ - --start "{{start}}" \ - --end "{{end}}" \ - --bucket "{{bucket}}" \ - --prefix-base "{{prefix}}" \ - --target-per-zip {{target}} \ - --out "{{out}}" \ - $EXTRA - -viz inp out: - python scripts/tasks/task_runner.py viz --inp "{{inp}}" --out "{{out}}" -# ============================================================================= -# 📊 BENCHMARKS (eager vs lazy) -# ============================================================================= + log_dir="{{ORCHESTRATOR_LOG_DIR}}" + if [ -z "$log_dir" ]; then log_dir="$OUT_BASE_DIR/_orchestrator_logs"; fi + mkdir -p "$log_dir" -# Run a specific benchmark: N in {100, 1000, 10000} -bench-run N MODE="lazy": - uv run python scripts/bench/eager_vs_lazy_benchmarks.py run \ - --mode {{MODE}} \ - --n {{N}} - -# Build summary CSV from stored profiles -bench-summary: - uv run python scripts/bench/eager_vs_lazy_benchmarks.py summary - -# Plot memory curves (requires existing profiles) -bench-plot: - uv run python scripts/bench/eager_vs_lazy_benchmarks.py plot - -# Run all benchmarks for eager + lazy -bench-all: - just bench-run 100 eager - just bench-run 100 lazy - just bench-run 1000 eager - just bench-run 1000 lazy - just bench-run 10000 eager - # lazy 10k intentionally omitted (8+ hrs) - @echo "✔ Benchmark suite complete" + ts=$(date -u +%Y%m%dT%H%M%SZ) + log_file="$log_dir/validate_${ts}.log" + + while read -r month; do + [ -z "$month" ] && continue + out_root="$OUT_BASE_DIR/out_${month}_production" + just validate-month "$month" "$out_root" 2>&1 | tee -a "$log_file" + done < "{{MONTHS_FILE}}" + +# ----------------------------------------------------------------------------- +# Status dashboard +# ----------------------------------------------------------------------------- + +migration-status OUT_BASE_DIR="/ebs/home/$(whoami)/runs": + #!/usr/bin/env bash + for d in "$OUT_BASE_DIR"/out_*_production; do + [ -d "$d" ] || continue + m=$(basename "$d" | grep -oE '[0-9]{6}') + files=$(find "$d" -name "*.parquet" | wc -l) + run=$(ls -1dt "$d/_runs/$m/"* 2>/dev/null | head -1) + if [ -f "$run/run_summary.json" ]; then + python3 -c 'import json; s=json.load(open("$run/run_summary.json")); print(f"{m} files={files} success={s['total_success']} failure={s['total_failure']}")' + else + echo "$m files=$files (no run_summary.json)" + fi + done # ============================================================================= # 🔍 CODE QUALITY & TESTING @@ -215,120 +301,3 @@ typecheck: test-coverage: uv run pytest --cov=smart_meter_analysis --cov-report=html - -# ============================================================================= -# 📚 DOCUMENTATION -# ============================================================================= - -docs-test: - uv run mkdocs build -s - -docs: - uv run mkdocs serve - -docs-serve: - uv run pdoc smart_meter_analysis - -# ============================================================================= -# 📊 DATA EXPLORATION -# ============================================================================= - -notebook: - uv run jupyter notebook - -lab: - uv run jupyter lab - -inspect-data FILE N="10": - uv run python -c "import polars as pl; df = pl.scan_parquet('{{FILE}}').limit({{N}}).collect(); print(df)" - -inspect-schema FILE: - uv run python -c "import polars as pl; print(pl.scan_parquet('{{FILE}}').collect_schema())" - -count-rows FILE: - uv run python -c "import polars as pl; print(pl.scan_parquet('{{FILE}}').select(pl.len()).collect())" - -# ============================================================================= -# 🧹 UTILITIES -# ============================================================================= - -clean: - rm -rf .pytest_cache - rm -rf .mypy_cache - rm -rf .ruff_cache - rm -rf htmlcov - rm -rf dist - rm -rf *.egg-info - find . -type d -name __pycache__ -exec rm -rf {} + - find . -type f -name "*.pyc" -delete - -clean-data: - #!/usr/bin/env bash - echo "This will delete processed data files!" - echo "Raw data in S3 will not be affected." - read -p "Are you sure? (y/N) " -n 1 -r - if [[ $$REPLY =~ ^[Yy]$ ]]; then - rm -rf data/processed/* - echo "Data cleaned" - fi - -du: - @echo "Data directory sizes:" - @du -sh data/* 2>/dev/null || echo "No data directories found" - -# ============================================================================= -# 📦 BUILD & RELEASE -# ============================================================================= - -clean-build: - #!/usr/bin/env bash - echo "🚀 Removing build artifacts" - rm -rf dist - echo "Removed 'dist' (if it existed)." - -build: clean-build - echo "🚀 Creating wheel file" - uvx --from build pyproject-build --installer uv - -publish: - echo "🚀 Publishing." - uvx twine upload --repository-url https://upload.pypi.org/legacy/ dist/* - -build-and-publish: build publish - -# ============================================================================= -# 💡 EXAMPLES -# ============================================================================= - -example-quick: - @echo "Step 1: Download 5 sample files from S3..." - just download-samples-small 202308 - @echo "" - @echo "Step 2: Run pipeline on samples..." - just test-pipeline-local - @echo "" - @echo "Step 3: Inspect results..." - just inspect-data data/processed/comed_samples.parquet 10 - -example-quick-offline: - @echo "Step 1: Generate synthetic sample data..." - just generate-samples - @echo "" - @echo "Step 2: Run pipeline on samples..." - just test-pipeline-local - @echo "" - @echo "Step 3: Inspect results..." - just inspect-data data/processed/comed_samples.parquet 10 - -example-test: - @echo "Running test pipeline with 10 files from S3..." - just test-pipeline 202308 10 - -example-full: - @echo "Running full pipeline for August 2023..." - @echo "This will take approximately 5-8 hours." - just pipeline 202308 - -example-rerun: - @echo "Re-running analysis on existing August 2023 data..." - just pipeline-skip-download 202308 diff --git a/README.md b/README.md index ebdd931..022508a 100644 --- a/README.md +++ b/README.md @@ -239,7 +239,7 @@ just dev-teardown-all ``` ⚠️ WARNING: This will destroy EVERYTHING including the data volume! All data on the EBS volume will be permanently deleted. -Are you sure? Type 'yes' to confirm: +Are you sure? Type 'yes' to confirm: ``` Type `yes` to confirm, then the cleanup proceeds. diff --git a/analysis/pipelines/chicago_sampler.py b/archive/chicago_sampler.py similarity index 99% rename from analysis/pipelines/chicago_sampler.py rename to archive/chicago_sampler.py index 513f87e..bc26826 100644 --- a/analysis/pipelines/chicago_sampler.py +++ b/archive/chicago_sampler.py @@ -1,7 +1,8 @@ #!/usr/bin/env python """ -CLI wrapper for Chicago smart meter sampling. +CLI wrapper for ComEd smart meter sampling. Handles multiple ZIP codes and writes output for each. +Pass --zips or --zips-file to specify which ZIP codes to sample. """ import argparse diff --git a/scripts/analysis/create_chicago_visualizations.py b/archive/create_chicago_visualizations.py similarity index 64% rename from scripts/analysis/create_chicago_visualizations.py rename to archive/create_chicago_visualizations.py index f0d543a..816112c 100644 --- a/scripts/analysis/create_chicago_visualizations.py +++ b/archive/create_chicago_visualizations.py @@ -1,15 +1,25 @@ #!/usr/bin/env python """ -Final visualizations for Chicago smart meter data (CM90 dataset): +Visualizations for ComEd smart meter data (CM90 dataset): - Heatmap shows MEAN kWh per 30-min per customer. - Monthly bar chart annotates each bar with the month's mean kWh. - Hourly profile with peak/baseload annotations. - Weekend vs weekday comparison. -Usage: +Illinois is the default geographic scope. Pass --geography and --zip-codes +to produce a named subset (e.g. Chicago). + +Usage (Illinois, all ZIPs): + python scripts/analysis/create_chicago_visualizations.py \ + --input data/illinois_2024/CLIPPED_CM90.parquet \ + --output figures/illinois_2024 + +Usage (Chicago subset): python scripts/analysis/create_chicago_visualizations.py \ - --input analysis/chicago_2024/final/CLIPPED_CM90.parquet \ - --output analysis/chicago_2024/visualizations + --input data/illinois_2024/CLIPPED_CM90.parquet \ + --output figures/chicago_2024 \ + --geography Chicago \ + --zip-codes 60601 60602 60603 """ import argparse @@ -33,11 +43,18 @@ }) -def create_heatmap(data_path: str, output_path: Path): +def _apply_zip_filter(lf: pl.LazyFrame, zip_codes: list[str] | None) -> pl.LazyFrame: + """Restrict to a named ZIP code subset when zip_codes is provided.""" + if zip_codes is not None: + lf = lf.filter(pl.col("zip_code").is_in(zip_codes)) + return lf + + +def create_heatmap(data_path: str, output_path: Path, geography: str, zip_codes: list[str] | None): """Monthly-hourly heatmap (MEAN kWh per customer).""" print("\n📊 Creating heatmap (MEAN kWh per customer)...") - lf = pl.scan_parquet(data_path) + lf = _apply_zip_filter(pl.scan_parquet(data_path), zip_codes) stats = lf.select([ pl.col("account_identifier").n_unique().alias("n_customers"), @@ -54,6 +71,10 @@ def create_heatmap(data_path: str, output_path: Path): .collect(engine="streaming") ) + if monthly_hourly.is_empty(): + print(f"⚠️ No data for geography '{geography}' — skipping heatmap.") + return + matrix = monthly_hourly.pivot(index="hour", columns="sample_month", values="mean_kwh").fill_null(0) hour_labels = matrix.select("hour").to_series().to_list() @@ -83,7 +104,7 @@ def create_heatmap(data_path: str, output_path: Path): ax.set_ylabel("Hour of Day", fontsize=15, fontweight="bold", labelpad=12) ax.set_title( "Residential Electricity Load Patterns: Temporal Heat Map\n" - f"Chicago • {date_range} • {n_customers:,} Households", + f"{geography} • {date_range} • {n_customers:,} Households", fontsize=18, fontweight="bold", pad=25, @@ -92,17 +113,18 @@ def create_heatmap(data_path: str, output_path: Path): ax.invert_yaxis() plt.tight_layout(rect=[0, 0.03, 1, 1]) - output_file = output_path / "chicago_heatmap.png" + geo_slug = geography.lower().replace(" ", "_") + output_file = output_path / f"{geo_slug}_heatmap.png" plt.savefig(output_file, dpi=300, bbox_inches="tight", facecolor="white") print(f"✅ Saved: {output_file}") plt.close() -def create_hourly_profile(data_path: str, output_path: Path): +def create_hourly_profile(data_path: str, output_path: Path, geography: str, zip_codes: list[str] | None): """Average hourly profile across the year (mean and IQR).""" print("\n📊 Creating hourly profile...") - lf = pl.scan_parquet(data_path) + lf = _apply_zip_filter(pl.scan_parquet(data_path), zip_codes) n_customers = lf.select(pl.col("account_identifier").n_unique()).collect(engine="streaming")[0, 0] hourly = ( @@ -116,6 +138,10 @@ def create_hourly_profile(data_path: str, output_path: Path): .collect(engine="streaming") ) + if hourly.is_empty(): + print(f"⚠️ No data for geography '{geography}' — skipping hourly profile.") + return + _fig, ax = plt.subplots(figsize=(15, 8)) hours = hourly["hour"].to_list() mean = hourly["mean_kwh"].to_list() @@ -128,7 +154,7 @@ def create_hourly_profile(data_path: str, output_path: Path): ax.set_xlabel("Hour of Day", fontsize=15, fontweight="bold", labelpad=12) ax.set_ylabel("Energy Consumption (kWh per 30-min)", fontsize=15, fontweight="bold", labelpad=12) ax.set_title( - f"Average Hourly Electricity Usage Profile\n{n_customers:,} Chicago Households", + f"Average Hourly Electricity Usage Profile\n{n_customers:,} {geography} Households", fontsize=18, fontweight="bold", pad=25, @@ -167,17 +193,18 @@ def create_hourly_profile(data_path: str, output_path: Path): ax.legend(loc="upper left", framealpha=0.95, edgecolor="#000", fancybox=True, shadow=True, fontsize=12) plt.tight_layout() - output_file = output_path / "chicago_hourly_profile.png" + geo_slug = geography.lower().replace(" ", "_") + output_file = output_path / f"{geo_slug}_hourly_profile.png" plt.savefig(output_file, dpi=300, bbox_inches="tight", facecolor="white") print(f"✅ Saved: {output_file}") plt.close() -def create_monthly_profile(data_path: str, output_path: Path): +def create_monthly_profile(data_path: str, output_path: Path, geography: str, zip_codes: list[str] | None): """Monthly average bar chart with mean kWh annotations.""" print("\n📊 Creating monthly profile...") - lf = pl.scan_parquet(data_path) + lf = _apply_zip_filter(pl.scan_parquet(data_path), zip_codes) monthly = ( lf.group_by("sample_month") @@ -190,6 +217,10 @@ def create_monthly_profile(data_path: str, output_path: Path): .collect(engine="streaming") ) + if monthly.is_empty(): + print(f"⚠️ No data for geography '{geography}' — skipping monthly profile.") + return + _fig, ax = plt.subplots(figsize=(15, 8)) months = monthly["sample_month"].to_list() mean_kwh = monthly["mean_kwh"].to_list() @@ -214,7 +245,7 @@ def create_monthly_profile(data_path: str, output_path: Path): ax.set_xlabel("Month", fontsize=15, fontweight="bold", labelpad=12) ax.set_ylabel("Average Energy (kWh per 30-min)", fontsize=15, fontweight="bold", labelpad=12) - ax.set_title("Monthly Average Electricity Consumption\nChicago", fontsize=18, fontweight="bold", pad=25) + ax.set_title(f"Monthly Average Electricity Consumption\n{geography}", fontsize=18, fontweight="bold", pad=25) ax.set_xticks(range(len(months))) ax.set_xticklabels(month_labels, fontsize=12) @@ -222,17 +253,18 @@ def create_monthly_profile(data_path: str, output_path: Path): ax.grid(True, alpha=0.3, axis="y") plt.tight_layout(rect=[0, 0.03, 1, 1]) - output_file = output_path / "chicago_monthly_profile.png" + geo_slug = geography.lower().replace(" ", "_") + output_file = output_path / f"{geo_slug}_monthly_profile.png" plt.savefig(output_file, dpi=300, bbox_inches="tight", facecolor="white") print(f"✅ Saved: {output_file}") plt.close() -def create_weekend_comparison(data_path: str, output_path: Path): +def create_weekend_comparison(data_path: str, output_path: Path, geography: str, zip_codes: list[str] | None): """Weekday vs weekend mean kWh per 30-min.""" print("\n📊 Creating weekend comparison...") - lf = pl.scan_parquet(data_path) + lf = _apply_zip_filter(pl.scan_parquet(data_path), zip_codes) comparison = ( lf.group_by(["hour", "is_weekend"]) .agg(pl.col("kwh").mean().alias("mean_kwh")) @@ -240,9 +272,13 @@ def create_weekend_comparison(data_path: str, output_path: Path): .collect(engine="streaming") ) - weekday = comparison.filter(not pl.col("is_weekend")) + weekday = comparison.filter(~pl.col("is_weekend")) weekend = comparison.filter(pl.col("is_weekend")) + if weekday.is_empty() or weekend.is_empty(): + print(f"⚠️ Missing weekday or weekend data for geography '{geography}' — skipping weekend comparison.") + return + _fig, ax = plt.subplots(figsize=(15, 9)) sns.lineplot( @@ -268,7 +304,7 @@ def create_weekend_comparison(data_path: str, output_path: Path): ax.set_xlabel("Hour of Day", fontsize=15, fontweight="bold", labelpad=12) ax.set_ylabel("Average Energy (kWh per 30-min)", fontsize=15, fontweight="bold", labelpad=12) - ax.set_title("Weekday vs Weekend Load Profiles\nChicago", fontsize=18, fontweight="bold", pad=30) + ax.set_title(f"Weekday vs Weekend Load Profiles\n{geography}", fontsize=18, fontweight="bold", pad=30) ax.grid(True, alpha=0.4, linestyle="--", linewidth=0.8) ax.set_xticks(range(0, 24, 2)) ax.set_xlim(-0.5, 23.5) @@ -279,50 +315,86 @@ def create_weekend_comparison(data_path: str, output_path: Path): ax.legend(loc="upper left", framealpha=0.95, edgecolor="#000", fancybox=True, shadow=True, fontsize=13) plt.tight_layout() - output_file = output_path / "chicago_weekend_comparison.png" + geo_slug = geography.lower().replace(" ", "_") + output_file = output_path / f"{geo_slug}_weekend_comparison.png" plt.savefig(output_file, dpi=300, bbox_inches="tight", facecolor="white") print(f"✅ Saved: {output_file}") plt.close() def main(): - parser = argparse.ArgumentParser(description="Create visualizations from Chicago smart meter data") + parser = argparse.ArgumentParser( + description="Create visualizations from ComEd smart meter data. " + "Illinois is the default geographic scope; use --geography and --zip-codes " + "to produce a named subset (e.g. Chicago)." + ) parser.add_argument( "--input", required=True, - help="Path to input parquet file (e.g., analysis/chicago_2024/final/CLIPPED_CM90.parquet)", + help="Path to input parquet file (must contain a zip_code column).", ) parser.add_argument( "--output", required=True, - help="Output directory for visualizations (e.g., analysis/chicago_2024/visualizations)", + help="Output directory for visualizations.", + ) + parser.add_argument( + "--geography", + default="Illinois", + help="Geographic label used in figure titles and output filenames (default: Illinois).", + ) + parser.add_argument( + "--zip-codes", + nargs="+", + default=None, + metavar="ZIP", + help="Optional list of 5-digit ZIP codes to restrict analysis to a subset " + "(e.g. --zip-codes 60601 60602). Omit to include all ZIPs in the input file.", ) args = parser.parse_args() data_path = Path(args.input) output_dir = Path(args.output) + geography: str = args.geography + zip_codes: list[str] | None = args.zip_codes - # Validate input + # Validate input file exists if not data_path.exists(): print(f"❌ File not found: {data_path}") raise SystemExit(1) + # Preflight: check required columns before running any visualizations. + # sample_month and date are produced by chicago_sampler.py; hour/is_weekend + # come from add_time_columns(); kwh/account_identifier/zip_code are canonical. + required_cols = {"account_identifier", "zip_code", "date", "sample_month", "hour", "kwh", "is_weekend"} + actual_cols = set(pl.scan_parquet(data_path).collect_schema().names()) + missing_cols = required_cols - actual_cols + if missing_cols: + print(f"❌ Input file is missing required columns: {sorted(missing_cols)}") + print(f" Found: {sorted(actual_cols)}") + raise SystemExit(1) + # Create output directory output_dir.mkdir(parents=True, exist_ok=True) print("=" * 80) - print("CHICAGO SMART METER VISUALIZATIONS") + print(f"{geography.upper()} SMART METER VISUALIZATIONS") print("=" * 80) - print(f"Input: {data_path}") - print(f"Output: {output_dir}") + print(f"Input: {data_path}") + print(f"Output: {output_dir}") + print(f"Geography: {geography}") + if zip_codes: + print(f"ZIP filter: {', '.join(zip_codes)}") + else: + print("ZIP filter: none (all ZIPs in input file)") print("=" * 80) # Create all visualizations - create_heatmap(str(data_path), output_dir) - create_hourly_profile(str(data_path), output_dir) - create_monthly_profile(str(data_path), output_dir) - create_weekend_comparison(str(data_path), output_dir) + create_heatmap(str(data_path), output_dir, geography, zip_codes) + create_hourly_profile(str(data_path), output_dir, geography, zip_codes) + create_monthly_profile(str(data_path), output_dir, geography, zip_codes) + create_weekend_comparison(str(data_path), output_dir, geography, zip_codes) print("\n" + "=" * 80) print("✅ ALL VISUALIZATIONS COMPLETE!") diff --git a/infra/README.md b/infra/README.md index db2bac6..d14e7e8 100644 --- a/infra/README.md +++ b/infra/README.md @@ -197,7 +197,7 @@ just dev-teardown-all ``` ⚠️ WARNING: This will destroy EVERYTHING including the data volume! All data on the EBS volume will be permanently deleted. -Are you sure? Type 'yes' to confirm: +Are you sure? Type 'yes' to confirm: ``` Type `yes` to confirm, then the cleanup proceeds. diff --git a/pyproject.toml b/pyproject.toml index c7297bb..11d8040 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "boto3>=1.40.46", "botocore>=1.40.47", "cenpy>=1.0.1", + "fsspec>=2026.1.0", "ipykernel>=6.30.1", "matplotlib>=3.9.4", "memory-profiler>=0.61.0", @@ -29,6 +30,7 @@ dependencies = [ "pyarrow>=14.0.0", "pyyaml>=6.0.3", "requests>=2.32.5", + "s3fs>=0.4.2", "scikit-learn>=1.6.1", "seaborn>=0.13.2", "selenium>=4.37.0", @@ -43,10 +45,12 @@ Repository = "https://github.com/switchbox-data/smart-meter-analysis" Documentation = "https://switchbox-data.github.io/smart-meter-analysis/" [tool.deptry] -exclude = ["archive", ".venv", "tests"] +exclude = ["archive", ".venv", "tests", ".claude_quarantine_*"] [tool.deptry.per_rule_ignores] +DEP001 = ["compact_month_output"] DEP002 = [ + "fsspec", "ipykernel", "cenpy", "openpyxl", @@ -58,6 +62,8 @@ DEP002 = [ "memory-profiler", "snakeviz", "tslearn", + "s3fs", + "statsmodels", ] DEP003 = ["botocore", "analysis", "smart_meter_analysis", "pandas", "scipy"] DEP004 = ["botocore"] @@ -109,6 +115,40 @@ ignore_missing_imports = true module = ["pandas.*", "scipy.*"] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "polars.*" +ignore_missing_imports = true + +# Modules using polars (no stubs) or other untyped deps; allow Any from unfollowed imports +[[tool.mypy.overrides]] +module = ["smart_meter_analysis.wide_to_long", "smart_meter_analysis.transformation", "smart_meter_analysis.pipeline_validator", "smart_meter_analysis.census"] +disallow_any_unimported = false + +[[tool.mypy.overrides]] +module = "cenpy" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "smart_meter_analysis.census" +disable_error_code = ["import-untyped", "import-not-found", "unused-ignore"] + +[[tool.mypy.overrides]] +module = "smart_meter_analysis.config" +disable_error_code = ["import-untyped"] + +[[tool.mypy.overrides]] +module = "smart_meter_analysis.manifests" +disable_error_code = ["no-any-return"] + +[[tool.mypy.overrides]] +module = "botocore.exceptions" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "pyarrow.*" +ignore_missing_imports = true + + [tool.pytest.ini_options] testpaths = ["tests"] markers = [ @@ -148,6 +188,11 @@ extend-ignore = ["TRY003", "TRY300", "TRY400"] "analysis/clustering/stage2_multinomial.py" = ["C901"] "tests/validate_total_comed_pipeline.py" = ["C901", "S603", "RUF001"] "scripts/testing/generate_sample_data.py" = ["UP035", "UP006", "UP007", "S311"] +"scripts/csv_to_parquet/migrate_month_runner.py" = ["C901"] +"scripts/csv_to_parquet/compact_month_output.py" = ["C901", "S603", "S607"] +"scripts/csv_to_parquet/restructure_for_export.py" = ["C901", "S603", "S607"] +"scripts/validate_wide_to_long_batched.py" = [] +"smart_meter_analysis/wide_to_long.py" = ["C901"] "tests/*" = ["S101", "RUF001"] diff --git a/scripts/csv_to_parquet/PREFLIGHT_200.md b/scripts/csv_to_parquet/PREFLIGHT_200.md new file mode 100644 index 0000000..16ddf4e --- /dev/null +++ b/scripts/csv_to_parquet/PREFLIGHT_200.md @@ -0,0 +1,181 @@ +# Preflight Validation Checklist: 200-File Run (202307) + +Target: Validate 200-file shard run before scaling to full ~30k month. + +## Prerequisites + +- 25-file batch (shard 100) completed successfully +- Output at: `/ebs/home/griffin_switch_box/runs/out_test_output_ec2/` +- Run artifacts at: `/ebs/home/griffin_switch_box/runs/out_test_output_ec2/_runs/202307//` + +--- + +## Step 1: Run the 200-file migration + +```bash +# Prepare input list (200 files from the sorted CSV inventory) +head -200 /path/to/all_csvs_202307_sorted.txt > /tmp/shard_200.txt +wc -l /tmp/shard_200.txt # confirm 200 + +# Run migration (adjust paths as needed) +python scripts/csv_to_parquet/migrate_month_runner.py \ + --input-list /tmp/shard_200.txt \ + --out-root /ebs/home/griffin_switch_box/runs/out_200_preflight \ + --year-month 202307 \ + --shard-id 200 \ + --batch-size 50 \ + --workers 4 \ + --exec-mode lazy_sink \ + --fail-fast +``` + +Expect: 4 batch files (200 / 50 = 4 batches). + +--- + +## Step 2: Quick sanity (before full validation) + +```bash +# Confirm output structure +find /ebs/home/griffin_switch_box/runs/out_200_preflight/year=2023/month=07/ \ + -name '*.parquet' | sort + +# Expected: 4 files named shard_200_batch_0000.parquet through shard_200_batch_0003.parquet + +# Confirm run completed cleanly +cat /ebs/home/griffin_switch_box/runs/out_200_preflight/_runs/202307/*/run_summary.json \ + | python -m json.tool | grep -E '"total_(success|failure|skip)|batches_written|stop_requested"' + +# Expected: total_failure=0, total_success=200, batches_written=4, stop_requested=false +``` + +--- + +## Step 3: Identify the run-dir + +```bash +# List run directories to find the run_id +ls /ebs/home/griffin_switch_box/runs/out_200_preflight/_runs/202307/ + +# Set variable for convenience (replace with actual) +RUN_DIR="/ebs/home/griffin_switch_box/runs/out_200_preflight/_runs/202307/" +OUT_ROOT="/ebs/home/griffin_switch_box/runs/out_200_preflight" +``` + +--- + +## Step 4: Full validation (all checks) + +```bash +python scripts/csv_to_parquet/validate_month_output.py \ + --out-root "$OUT_ROOT" \ + --check-mode full \ + --dst-month-check \ + --run-dir "$RUN_DIR" \ + --output-report "$RUN_DIR/validation_report_200.json" +``` + +This single command validates all of the following: + +| Check | What it verifies | +|---|---| +| Schema contract | All 10 columns present, exact dtypes | +| Partition integrity | year=2023, month=7 in every file | +| No duplicates | No duplicate (zip_code, account_identifier, datetime) within any batch | +| Datetime invariants | No nulls, min=00:00, max=23:30, no spillover | +| DST Option B | Exactly 48 slots/day, no timestamps beyond 23:30 | +| Sortedness (full) | Lexicographic order by (zip_code, account_identifier, datetime) | +| Run artifact integrity | plan.json valid, run_summary.json clean, manifests 0 failures | +| Row counts | Per-file and total row counts reported | + +Expected output on success: +``` +OK: validated 4 parquet files across 1 partitions (discovered total parquet files=4, total rows validated=NNNNNN). +Validation report written to: .../validation_report_200.json +``` + +--- + +## Step 5: Review the validation report + +```bash +python -m json.tool "$RUN_DIR/validation_report_200.json" +``` + +Checklist for the report JSON: + +- [ ] `"status": "pass"` +- [ ] `"files_validated": 4` +- [ ] `"total_rows_validated"` is reasonable (expect ~200 files * ~N accounts * 48 slots * 31 days) +- [ ] `"checks_passed"` contains all 7 checks: + - `schema_contract` + - `partition_integrity` + - `no_duplicates` + - `datetime_invariants` + - `sortedness_full` + - `dst_option_b` + - `run_artifact_integrity` +- [ ] `"per_file_rows"` shows all 4 batch files with non-zero row counts +- [ ] `"run_artifacts"."summary_total_failure"` is 0 +- [ ] `"run_artifacts"."manifest_success_count"` is 200 + +--- + +## Step 6: Spot-check a parquet file interactively + +```python +import polars as pl + +f = "/ebs/home/griffin_switch_box/runs/out_200_preflight/year=2023/month=07/shard_200_batch_0000.parquet" +df = pl.read_parquet(f) + +print("Shape:", df.shape) +print("Schema:", df.schema) +print("Head:\n", df.head(5)) +print("Tail:\n", df.tail(5)) + +# Verify sort order visually +print("Sorted check:", df.select([ + pl.col("zip_code"), + pl.col("account_identifier"), + pl.col("datetime"), +]).head(20)) + +# Unique accounts +print("Unique accounts:", df["account_identifier"].n_unique()) +print("Date range:", df["datetime"].min(), "to", df["datetime"].max()) +``` + +--- + +## Step 7: Cross-check with 25-file run (optional determinism) + +If the 200-file input list's first 25 files overlap with the original 25-file shard: + +```bash +python scripts/csv_to_parquet/validate_month_output.py \ + --out-root "$OUT_ROOT" \ + --compare-root /ebs/home/griffin_switch_box/runs/out_test_output_ec2 \ + --check-mode sample +``` + +Note: This will only work if both roots share identical partition structure. +If shard IDs differ, compare individual batch files manually instead. + +--- + +## Go/No-Go Decision + +| Criterion | Required | +|---|---| +| Step 4 prints `OK` | YES | +| Validation report `status: pass` | YES | +| All 7 checks in `checks_passed` | YES | +| `total_rows_validated > 0` | YES | +| `run_artifacts.summary_total_failure == 0` | YES | +| `run_artifacts.manifest_success_count == 200` | YES | +| No unexpected files in output directory | YES | +| Spot-check schema + sort order looks correct | YES | + +If all criteria pass: proceed to full-month sharded run. +If any fail: investigate, fix, re-run the 200-file batch. diff --git a/scripts/csv_to_parquet/compact_month_output.py b/scripts/csv_to_parquet/compact_month_output.py new file mode 100644 index 0000000..8231a8f --- /dev/null +++ b/scripts/csv_to_parquet/compact_month_output.py @@ -0,0 +1,1167 @@ +#!/usr/bin/env python3 +"""Month-level Parquet compaction for the ComEd CSV→Parquet pipeline. + +Compacts all ``batch_*.parquet`` files produced by ``migrate_month_runner.py`` +into deterministic ``compacted_NNNN.parquet`` files targeting ~1 GiB each. +Invoked by the runner after all batches for a month complete with zero failures. + +Design decisions +---------------- +1. **Memory-safe file-by-file streaming** — reads one batch Parquet file at a + time, accumulates rows until a per-row-count budget is reached, then flushes. + Maximum in-memory footprint is approximately two batch files simultaneously + (current file + carry-over slice from the previous boundary). No global + collect of the entire month's data is performed. + +2. **No re-sort** — relies on the invariant (enforced by the runner) that batch + files are already globally sorted by ``(zip_code, account_identifier, + datetime)`` and that lexicographic filename order (batch_0000 < batch_0001 + < …) preserves global sort across file boundaries. The adjacent-key + validation pass verifies this contract on the *output* files. + +3. **Atomic directory swap** — staging output is written under + ``/compaction_staging/year=YYYY/month=MM/``. After validation the + original month directory is renamed to ``month=MM_precompact_`` and + the staging directory is renamed to the canonical month directory. Both + renames use ``os.replace()`` (single rename(2) syscall on Linux), which is + atomic within the same filesystem. A failed phase-2 rename triggers an + automatic rollback of phase 1. + +4. **Fail-loud** — any validation failure raises ``RuntimeError`` before the + atomic swap, leaving the original batch files completely untouched. + +5. **Audit trail** — five JSON artifacts are written under + ``/compaction/`` regardless of outcome (where possible). + +6. **Self-contained** — this module does NOT import from + ``migrate_month_runner`` to avoid circular imports (the runner imports this + module at its top level). Shared constants (``FINAL_LONG_COLS``, + ``SORT_KEYS``) are re-declared here with a cross-reference comment. +""" + +from __future__ import annotations + +import contextlib +import datetime as dt +import hashlib +import json +import os +import shutil +import subprocess +import time +import traceback +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import polars as pl +import pyarrow.parquet as pq + +# --------------------------------------------------------------------------- +# Canonical schema contract +# These constants MUST stay in sync with migrate_month_runner.py and +# validate_month_output.py. They are re-declared here (not imported) to +# prevent a circular import: the runner imports this module at top level, so +# this module cannot import the runner. +# --------------------------------------------------------------------------- + +FINAL_LONG_COLS: tuple[str, ...] = ( + "zip_code", + "delivery_service_class", + "delivery_service_name", + "account_identifier", + "datetime", + "energy_kwh", + "plc_value", + "nspl_value", + "year", + "month", +) + +SORT_KEYS: tuple[str, str, str] = ("zip_code", "account_identifier", "datetime") + +DEFAULT_COMPACT_TARGET_SIZE_BYTES: int = 1_073_741_824 # 1 GiB + +JsonDict = dict[str, Any] + +STREAMING_VALIDATOR_VERSION: str = "2.0.0-streaming-pyarrow" +DEFAULT_VALIDATION_BATCH_SIZE: int = 1_000_000 # rows per PyArrow iter_batches call +ROWS_PER_ROW_GROUP: int = 50_000_000 # rows per row group; bounds peak RSS to ~2 GiB per group + + +# --------------------------------------------------------------------------- +# Data model +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class CompactionConfig: + """Immutable configuration for one compaction run. + + Constructed by ``migrate_month_runner.main()`` from parsed CLI args and the + existing ``RunnerConfig``, then passed to ``run_compaction()``. + """ + + year_month: str # YYYYMM — must match the runner's target month + run_id: str # same run_id as the enclosing migration run + out_root: Path # dataset root (Hive partitions live here) + run_dir: Path # _runs/YYYYMM// — audit artifacts go here + target_size_bytes: int # target on-disk size per output Parquet file + max_files: int | None # optional cap on number of output compacted files + overwrite: bool # allow overwriting existing compacted_*.parquet + dry_run: bool # plan only: write plan + original inventory + summary; skip write/validate/swap + no_swap: bool # run compaction + validation + write artifacts; skip atomic swap + validation_batch_size: int = DEFAULT_VALIDATION_BATCH_SIZE # rows per batch for streaming validation + + +# --------------------------------------------------------------------------- +# Utilities (self-contained; no runner imports) +# --------------------------------------------------------------------------- + + +def _now_utc_iso() -> str: + return dt.datetime.now(dt.timezone.utc).isoformat(timespec="seconds") + + +def _elapsed_ms(t0: float, t1: float) -> int: + return round((t1 - t0) * 1000.0) + + +def _write_json(path: Path, obj: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(obj, indent=2, sort_keys=True), encoding="utf-8") + + +def _try_git_sha() -> str | None: + """Return the current git HEAD SHA, or None if git is unavailable.""" + try: + cp = subprocess.run( + ["git", "rev-parse", "HEAD"], + check=False, + capture_output=True, + text=True, + ) + return cp.stdout.strip() if cp.returncode == 0 else None + except Exception: + return None + + +def _year_month_dirs(year_month: str) -> tuple[str, str]: + y = int(year_month[:4]) + m = int(year_month[4:6]) + return f"{y:04d}", f"{m:02d}" + + +def _file_list_hash(paths: list[Path]) -> str: + """Stable SHA-256 fingerprint of a sorted list of file paths.""" + content = "\n".join(str(p) for p in paths) + return hashlib.sha256(content.encode("utf-8")).hexdigest()[:16] + + +# --------------------------------------------------------------------------- +# Parquet metadata helpers — row counts from file footer; no data loaded +# --------------------------------------------------------------------------- + + +def _parquet_row_count(path: Path) -> int: + """Read row count from Parquet file footer metadata without loading data.""" + return pq.read_metadata(str(path)).num_rows + + +def _parquet_schema_names(path: Path) -> list[str]: + """Read column names from Parquet file schema without loading data.""" + arrow_schema = pq.read_schema(str(path)) + return list(arrow_schema.names) + + +def _file_inventory_entry(path: Path) -> JsonDict: + """Build a metadata inventory entry for a single Parquet file.""" + meta = pq.read_metadata(str(path)) + return { + "path": str(path), + "size_bytes": int(path.stat().st_size), + "num_rows": int(meta.num_rows), + "num_row_groups": int(meta.num_row_groups), + } + + +def _write_success_marker( + month_dir: Path, + output_names: list[str], + year_month: str, + run_id: str, + pre_rows: int, + post_rows: int, + total_output_bytes: int, + git_sha: str | None, + sort_keys: tuple[str, ...] = SORT_KEYS, + schema: tuple[str, ...] = FINAL_LONG_COLS, +) -> None: + """Write a _SUCCESS.json marker with compaction metadata. + + Follows the Spark convention of a per-partition success marker. Contains + metadata for downstream validation without needing to read Parquet footers. + """ + files_manifest = [] + for name in output_names: + path = month_dir / name + meta = pq.read_metadata(str(path)) + files_manifest.append({ + "name": name, + "size_bytes": int(path.stat().st_size), + "num_rows": int(meta.num_rows), + "num_row_groups": int(meta.num_row_groups), + }) + + marker: JsonDict = { + "timestamp": _now_utc_iso(), + "git_sha": git_sha, + "year_month": year_month, + "compaction_run_id": run_id, + "n_files": len(output_names), + "total_rows": post_rows, + "total_bytes": total_output_bytes, + "sort_keys": list(sort_keys), + "schema": list(schema), + "files": files_manifest, + } + _write_json(month_dir / "_SUCCESS.json", marker) + + +# --------------------------------------------------------------------------- +# Core: memory-safe streaming write (multi-row-group ParquetWriter) +# --------------------------------------------------------------------------- + + +def _stream_write_chunks( + sorted_input_files: list[Path], + staging_month_dir: Path, + rows_per_row_group: int, + target_size_bytes: int, + max_files: int | None, + logger: Any, + log_ctx: JsonDict, +) -> list[Path]: + """Stream through sorted batch files and write compacted output files. + + Memory-safety guarantee + ----------------------- + Reads one batch Parquet file at a time. Maintains a ``carry`` DataFrame + of rows that didn't fill the previous row group. Maximum RSS footprint is + approximately two row groups in memory simultaneously (~2 x 1.5 GiB). + + No global sort is performed here. The function relies on the runner's + invariant that input batch files are already globally sorted by SORT_KEYS + and that lexicographic filename order preserves global sort order. The + post-write adjacent-key validation pass verifies this contract. + + Row-group and file-rollover semantics + ------------------------------------- + Rows are flushed to the current ``pq.ParquetWriter`` as row groups of up + to ``rows_per_row_group`` rows. After each flush, if the on-disk file + size reaches ``target_size_bytes`` the writer is closed and a new file is + opened for the next row group. The final partial row group is always + flushed regardless of size. + + Parameters + ---------- + sorted_input_files: + Input batch Parquet files in lexicographic order (deterministic). + staging_month_dir: + Directory to write ``compacted_NNNN.parquet`` files. + rows_per_row_group: + Maximum rows per row group; bounds peak RSS to ~2 GiB per group. + target_size_bytes: + Close the current file and open a new one when on-disk size reaches + this threshold after a row-group flush. + max_files: + Optional hard cap on the number of output files written. + logger: + JsonlLogger-compatible object (``logger.log(dict)``). + log_ctx: + Context fields included in every log event. + + Returns + ------- + list[Path] + Paths of written output Parquet files in write order. + """ + output_files: list[Path] = [] + file_idx = 0 + carry: pl.DataFrame | None = None + writer: pq.ParquetWriter | None = None + current_file_path: Path | None = None + arrow_schema: Any = None # pa.Schema, derived from first flush + + def flush_row_group(rg_df: pl.DataFrame) -> None: + nonlocal writer, file_idx, current_file_path, arrow_schema + + t0 = time.time() + arrow_table = rg_df.to_arrow() + if arrow_schema is None: + arrow_schema = arrow_table.schema + + if writer is None: + current_file_path = staging_month_dir / f"part-{file_idx:05d}.parquet" + writer = pq.ParquetWriter( + str(current_file_path), + arrow_schema, + compression="snappy", + write_statistics=True, + ) + output_files.append(current_file_path) + + writer.write_table(arrow_table) + if current_file_path is None: + raise RuntimeError("current_file_path is None after writer was opened") + size = int(current_file_path.stat().st_size) + + logger.log({ + **log_ctx, + "event": "compaction_write_row_group", + "status": "info", + "file_idx": file_idx, + "file_path": str(current_file_path), + "num_rows": rg_df.height, + "file_size_bytes": size, + "elapsed_ms": _elapsed_ms(t0, time.time()), + }) + + if size >= target_size_bytes: + writer.close() + writer = None + file_idx += 1 + current_file_path = None + + staging_month_dir.mkdir(parents=True, exist_ok=True) + logger.log({ + **log_ctx, + "event": "compaction_scan_start", + "status": "info", + "n_input_files": len(sorted_input_files), + "rows_per_row_group": rows_per_row_group, + "target_size_bytes": target_size_bytes, + }) + + try: + for input_path in sorted_input_files: + if max_files is not None and file_idx >= max_files: + break + + df = pl.read_parquet(str(input_path)) + + if carry is not None and carry.height > 0: + df = pl.concat([carry, df], how="vertical", rechunk=False) + carry = None + + while df.height >= rows_per_row_group: + if max_files is not None and file_idx >= max_files: + break + flush_row_group(df.slice(0, rows_per_row_group)) + df = df.slice(rows_per_row_group) + + carry = df if df.height > 0 else None + + if carry is not None and carry.height > 0 and (max_files is None or file_idx < max_files): + flush_row_group(carry) + + if writer is not None: + writer.close() + writer = None + except Exception: + if writer is not None: + with contextlib.suppress(Exception): + writer.close() + raise + + return output_files + + +# --------------------------------------------------------------------------- +# Validation: pre-flight check for existing compacted files +# --------------------------------------------------------------------------- + + +def _pre_validate_no_existing_compacted( + canonical_dir: Path, + overwrite: bool, +) -> None: + """Fail-loud if part-*.parquet already exist in the canonical dir. + + Called early in run_compaction() before any writes to prevent accidental + overwrites. Skipped when ``overwrite=True``. + """ + if overwrite: + return + existing = sorted(canonical_dir.glob("part-*.parquet")) + if existing: + names = [p.name for p in existing[:5]] + suffix = f" (and {len(existing) - 5} more)" if len(existing) > 5 else "" + raise RuntimeError( + f"Canonical directory already contains {len(existing)} part-*.parquet " + f"file(s): {names}{suffix}. Use --overwrite-compact to allow replacement, " + f"or remove them manually before re-running compaction." + ) + + +# --------------------------------------------------------------------------- +# Validation: adjacent-key sort order and uniqueness — streaming (no group_by) +# --------------------------------------------------------------------------- + + +def _validate_adjacent_keys_streaming( + sorted_files: list[Path], + label: str, + batch_size: int = DEFAULT_VALIDATION_BATCH_SIZE, + logger: Any | None = None, + log_ctx: JsonDict | None = None, +) -> JsonDict: + """Validate global sort order and key uniqueness via true streaming batches. + + Algorithm + --------- + Uses ``pyarrow.ParquetFile.iter_batches()`` to read only the three key + columns in fixed-size batches. Memory is bounded to one batch (~1M rows + x 3 columns ~30-50 MB) plus a single carry-forward key tuple. + + For each batch: + + 1. **Cross-boundary check** — first key of current batch must be strictly + greater than ``prev_key`` (last key of previous batch/file). + 2. **Within-batch check** — vectorized Polars ``shift(1)`` on the three + key columns, bounded to ``batch_size`` rows. No ``group_by``, no + Python-level row iteration. + + The compound sort key is ``(zip_code, account_identifier, datetime)`` + with lexicographic ordering. + + Parameters + ---------- + sorted_files : + Files to validate, in logical sort order (filename order). + label : + Human-readable name for the dataset being validated. + batch_size : + Rows per PyArrow batch. Default 1,000,000. + logger : + Optional JsonlLogger for structured event logging. + log_ctx : + Optional base dict merged into every log event. + + Returns + ------- + dict with keys: ``passed``, ``error``, ``error_location``, ``n_files``, + ``total_rows``, ``key_min``, ``key_max``, ``validator_version``, + ``validator_method``, ``batch_size``. + """ + KEY_COLS: list[str] = list(SORT_KEYS) + _log_ctx: JsonDict = log_ctx or {} + + prev_key: tuple[Any, ...] | None = None + total_rows: int = 0 + n_files: int = 0 + key_min: tuple[Any, ...] | None = None + key_max: tuple[Any, ...] | None = None + error: str | None = None + error_location: JsonDict | None = None + + for file_idx, path in enumerate(sorted_files): + pf = pq.ParquetFile(str(path)) + n_files += 1 + file_rows: int = 0 + batch_idx: int = 0 + + if logger is not None: + logger.log({ + **_log_ctx, + "event": "validation_file_start", + "status": "info", + "file": path.name, + "file_idx": file_idx, + "n_row_groups": pf.metadata.num_row_groups, + }) + + for batch in pf.iter_batches(batch_size=batch_size, columns=KEY_COLS): + n = batch.num_rows + if n == 0: + batch_idx += 1 + continue + + # Convert PyArrow RecordBatch → Polars DataFrame (zero-copy where possible). + df = pl.from_arrow(batch) + + # First/last key as Python tuples — constant cost per batch. + first_row = df.row(0) + last_row = df.row(df.height - 1) + batch_first_key: tuple[Any, ...] = (first_row[0], first_row[1], first_row[2]) + batch_last_key: tuple[Any, ...] = (last_row[0], last_row[1], last_row[2]) + + if key_min is None: + key_min = batch_first_key + key_max = batch_last_key + + # ── Cross-boundary check ────────────────────────────────────── + if prev_key is not None: + if batch_first_key < prev_key: + error = ( + f"{label}: sort violation at boundary: " + f"file={path.name} batch={batch_idx} " + f"first_key={batch_first_key!r} < prev_key={prev_key!r}" + ) + error_location = { + "file": path.name, + "file_idx": file_idx, + "batch_idx": batch_idx, + "row_offset_in_file": file_rows, + "row_offset_global": total_rows + file_rows, + } + break + if batch_first_key == prev_key: + error = ( + f"{label}: duplicate key at boundary: " + f"file={path.name} batch={batch_idx} " + f"key={batch_first_key!r}" + ) + error_location = { + "file": path.name, + "file_idx": file_idx, + "batch_idx": batch_idx, + "row_offset_in_file": file_rows, + "row_offset_global": total_rows + file_rows, + } + break + + # ── Within-batch adjacent-pair check (vectorized, bounded) ──── + if n > 1: + zip_prev = pl.col("zip_code").shift(1) + acct_prev = pl.col("account_identifier").shift(1) + dt_prev = pl.col("datetime").shift(1) + + zip_eq = zip_prev == pl.col("zip_code") + acct_eq = acct_prev == pl.col("account_identifier") + dt_eq = dt_prev == pl.col("datetime") + + zip_gt = zip_prev > pl.col("zip_code") + acct_gt = acct_prev > pl.col("account_identifier") + dt_gt = dt_prev > pl.col("datetime") + + sort_violation = zip_gt | (zip_eq & acct_gt) | (zip_eq & acct_eq & dt_gt) + dup_violation = zip_eq & acct_eq & dt_eq + + check = ( + df.with_row_index("_row_idx") + .with_columns([ + sort_violation.alias("_sort_viol"), + dup_violation.alias("_dup_viol"), + ]) + .slice(1) # row 0 has null shift values + ) + + sort_bad = check.filter(pl.col("_sort_viol")).head(1) + if sort_bad.height > 0: + bad = sort_bad.select([*KEY_COLS, "_row_idx"]).to_dicts()[0] + row_in_file = file_rows + int(bad["_row_idx"]) + error = ( + f"{label}: sort violation in file={path.name} " + f"batch={batch_idx} row_in_file={row_in_file} " + f"bad_row={bad}" + ) + error_location = { + "file": path.name, + "file_idx": file_idx, + "batch_idx": batch_idx, + "row_in_batch": int(bad["_row_idx"]), + "row_offset_in_file": row_in_file, + "row_offset_global": total_rows + row_in_file, + "bad_key": {k: str(bad[k]) for k in KEY_COLS}, + } + break + + dup_bad = check.filter(pl.col("_dup_viol")).head(1) + if dup_bad.height > 0: + bad = dup_bad.select([*KEY_COLS, "_row_idx"]).to_dicts()[0] + row_in_file = file_rows + int(bad["_row_idx"]) + error = ( + f"{label}: duplicate key in file={path.name} " + f"batch={batch_idx} row_in_file={row_in_file} " + f"bad_row={bad}" + ) + error_location = { + "file": path.name, + "file_idx": file_idx, + "batch_idx": batch_idx, + "row_in_batch": int(bad["_row_idx"]), + "row_offset_in_file": row_in_file, + "row_offset_global": total_rows + row_in_file, + "bad_key": {k: str(bad[k]) for k in KEY_COLS}, + } + break + + prev_key = batch_last_key + file_rows += n + batch_idx += 1 + + # Count rows scanned even on partial files. + total_rows += file_rows + + if error is not None: + break + + if logger is not None: + logger.log({ + **_log_ctx, + "event": "validation_file_end", + "status": "info", + "file": path.name, + "file_idx": file_idx, + "file_rows": file_rows, + "running_total_rows": total_rows, + }) + + return { + "passed": error is None, + "error": error, + "error_location": error_location, + "n_files": n_files, + "total_rows": total_rows, + "key_min": str(key_min) if key_min is not None else None, + "key_max": str(key_max) if key_max is not None else None, + "validator_version": STREAMING_VALIDATOR_VERSION, + "validator_method": "adjacent_key_streaming_pyarrow_iter_batches", + "batch_size": batch_size, + } + + +# --------------------------------------------------------------------------- +# Validation: schema conformance +# --------------------------------------------------------------------------- + + +def _validate_schema(files: list[Path], label: str) -> JsonDict: + """Verify every file has exactly FINAL_LONG_COLS in canonical column order. + + Uses Parquet file footer metadata — no row data is loaded. + """ + errors: list[str] = [] + for path in files: + actual = _parquet_schema_names(path) + # pyarrow may surface internal ``__null_dask_index__`` or similar names; + # filter them out before comparing. + actual_clean = [c for c in actual if not c.startswith("__")] + if tuple(actual_clean) != FINAL_LONG_COLS: + errors.append(f"{path.name}: expected {list(FINAL_LONG_COLS)} got {actual_clean}") + return {"passed": len(errors) == 0, "errors": errors, "label": label} + + +# --------------------------------------------------------------------------- +# Validation: partition uniformity +# --------------------------------------------------------------------------- + + +def _validate_partition_uniformity(files: list[Path], year_month: str) -> JsonDict: + """Verify all rows belong to the target (year, month) partition. + + Reads only the ``year`` and ``month`` columns. A violation here would mean + cross-month row leakage — a serious data corruption that must block swap. + """ + y = int(year_month[:4]) + m = int(year_month[4:6]) + errors: list[str] = [] + for path in files: + df = pl.read_parquet(str(path), columns=["year", "month"]) + bad = df.filter((pl.col("year") != y) | (pl.col("month") != m)).height + if bad > 0: + errors.append(f"{path.name}: {bad} rows with wrong (year,month); expected ({y},{m})") + return {"passed": len(errors) == 0, "errors": errors, "year": y, "month": m} + + +# --------------------------------------------------------------------------- +# Atomic directory swap +# --------------------------------------------------------------------------- + + +def _atomic_swap( + month_canonical_dir: Path, + staging_month_dir: Path, + precompact_dir: Path, + logger: Any, + log_ctx: JsonDict, +) -> None: + """Atomically swap staging compacted files into the canonical month directory. + + Two-phase rename sequence + ------------------------- + Phase 1 (atomic): + ``os.replace(month_canonical_dir → precompact_dir)`` + The canonical directory disappears; batch files are now in + ``precompact_dir`` and remain safe. + + Phase 2 (atomic): + ``os.replace(staging_month_dir → month_canonical_dir)`` + Staging becomes the new canonical directory. + + Rollback: + If phase 2 fails, phase 1 is reversed by renaming ``precompact_dir`` + back to ``month_canonical_dir``. If the rollback also fails, a + ``RuntimeError`` with manual recovery instructions is raised — the + original files are in ``precompact_dir`` and can be restored manually. + + Filesystem constraint: + ``os.replace()`` is atomic only within the same filesystem. We + assert both directories share the same device before proceeding. + Cross-device moves are refused with a clear error. + + Note: this function must NEVER touch ``/_runs/`` — the run + artifacts directory is completely separate from Hive partition directories. + """ + # Guard: same-filesystem requirement for atomic rename(2). + canonical_dev = os.stat(str(month_canonical_dir.parent)).st_dev + staging_dev = os.stat(str(staging_month_dir.parent)).st_dev + if canonical_dev != staging_dev: + raise RuntimeError( + f"Staging and canonical parent directories are on different " + f"filesystems (canonical_dev={canonical_dev}, " + f"staging_dev={staging_dev}). Atomic rename is not possible. " + f"Move staging under the same mount point as the output root." + ) + + logger.log({ + **log_ctx, + "event": "compaction_atomic_swap", + "status": "start", + "month_canonical_dir": str(month_canonical_dir), + "staging_month_dir": str(staging_month_dir), + "precompact_dir": str(precompact_dir), + }) + + # Phase 1: canonical → precompact (atomic rename; canonical dir vanishes). + os.replace(str(month_canonical_dir), str(precompact_dir)) + + try: + # Phase 2: staging → canonical (atomic rename; staging dir vanishes). + os.replace(str(staging_month_dir), str(month_canonical_dir)) + + except Exception as swap_err: + # Phase 2 failed. Attempt rollback of phase 1 to restore original state. + try: + os.replace(str(precompact_dir), str(month_canonical_dir)) + except Exception as rollback_err: + # Both phase-2 and rollback failed. The original batch files are + # in precompact_dir; manual intervention is required. + raise RuntimeError( + "CRITICAL: compaction phase-2 swap failed AND rollback failed. " + "Manual recovery required: " + f"rename {precompact_dir} → {month_canonical_dir}. " + f"swap_err={swap_err!r} rollback_err={rollback_err!r}" + ) from swap_err + + # Rollback succeeded; original batch files restored. + raise RuntimeError( + f"Compaction phase-2 swap failed (rollback succeeded; original " + f"batch files are intact): {swap_err!r}. " + f"Staging files remain at: {staging_month_dir.parent}" + ) from swap_err + + logger.log({ + **log_ctx, + "event": "compaction_atomic_swap", + "status": "success", + "month_canonical_dir": str(month_canonical_dir), + "precompact_dir": str(precompact_dir), + }) + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + + +def run_compaction(cfg: CompactionConfig, logger: Any) -> JsonDict: + """Run month-level Parquet compaction. + + Caller contract + --------------- + - All batches for ``cfg.year_month`` completed with ``total_failure == 0``. + - The cooperative stop flag is not set. + - ``logger`` is a ``JsonlLogger``-compatible object exposing ``.log(dict)``. + + Stages + ------ + 1. Locate and sort input ``batch_*.parquet`` files (lexicographic order). + 2. Idempotency guard (fail if ``compacted_*`` files exist without + ``--overwrite-compact``). + 3. Read pre-compaction row counts from Parquet footer metadata (no I/O). + 4. Validate input schema against ``FINAL_LONG_COLS``. + 5. Derive ``rows_per_row_group`` and estimate output file count. + 6. Write audit ``compaction_plan.json`` and ``original_file_inventory.json``. + 7. Stream-write compacted chunks to staging directory. + 8. Post-write validations: row count, schema, sort order, duplicates, + partition uniformity. + 9. Write audit ``compaction_validation.json`` and + ``compacted_file_inventory.json``. + 10. Atomic directory swap (skip if ``cfg.dry_run``). + 11. Write audit ``compaction_summary.json``. + + Returns + ------- + Summary dict (also written to ``/compaction/compaction_summary.json``). + + Raises + ------ + RuntimeError + On any unrecoverable failure. Staging files are cleaned up before + raising. Original batch files are never modified. + """ + t_start = time.time() + git_sha = _try_git_sha() + ydir, mdir = _year_month_dirs(cfg.year_month) + + month_canonical_dir = cfg.out_root / ydir / mdir + staging_base = cfg.run_dir / "compaction_staging" + staging_month_dir = staging_base / ydir / mdir + audit_dir = cfg.run_dir / "compaction" + # Precompact backup name: month=MM_precompact_ — lives adjacent to + # the canonical month dir so it is on the same filesystem. + precompact_dir = month_canonical_dir.parent / f"{mdir}_precompact_{cfg.run_id}" + + log_ctx: JsonDict = { + "ts_utc": _now_utc_iso(), + "year_month": cfg.year_month, + "run_id": cfg.run_id, + } + + logger.log({ + **log_ctx, + "event": "compaction_start", + "status": "start", + "month_canonical_dir": str(month_canonical_dir), + "staging_month_dir": str(staging_month_dir), + "audit_dir": str(audit_dir), + "target_size_bytes": cfg.target_size_bytes, + "dry_run": cfg.dry_run, + "overwrite": cfg.overwrite, + }) + + # ── 1. Locate and sort input files ────────────────────────────────────── + if not month_canonical_dir.exists(): + raise RuntimeError(f"Month directory does not exist: {month_canonical_dir}") + + all_parquet = sorted(month_canonical_dir.glob("*.parquet")) + if not all_parquet: + raise RuntimeError(f"No .parquet files found in {month_canonical_dir}") + + # ── 2. Idempotency guard ───────────────────────────────────────────────── + existing_compacted = [p for p in all_parquet if p.name.startswith("part-")] + if existing_compacted and not cfg.overwrite: + raise RuntimeError( + f"Compacted files already exist in {month_canonical_dir}: " + f"{[p.name for p in existing_compacted]}. " + f"Pass --overwrite-compact to re-compact." + ) + + # Input files: only batch_*.parquet (never pre-existing compacted_* files). + input_files: list[Path] = sorted(p for p in all_parquet if p.name.startswith("batch_")) + if not input_files: + raise RuntimeError( + f"No batch_*.parquet input files found in {month_canonical_dir}. " + f"All files present: {[p.name for p in all_parquet]}" + ) + + # ── 3. Pre-compaction row count from Parquet footer metadata ───────────── + pre_rows = sum(_parquet_row_count(p) for p in input_files) + total_input_bytes = sum(p.stat().st_size for p in input_files) + + if pre_rows == 0: + raise RuntimeError("Input batch files contain zero rows; nothing to compact.") + + # ── 4. Input schema validation ─────────────────────────────────────────── + input_schema_result = _validate_schema(input_files, label="input") + if not input_schema_result["passed"]: + raise RuntimeError(f"Input schema validation failed: {input_schema_result['errors']}") + + # ── Pre-flight: check for existing compacted files in canonical dir ── + _pre_validate_no_existing_compacted(month_canonical_dir, cfg.overwrite) + + # ── 5. Derive row-group size and estimate output file count ────────────── + # rows_per_row_group is a fixed constant bounding peak RSS to ~2 GiB. + # estimated_n_output_files is for audit/logging only; actual file count is + # determined at runtime by on-disk size rollover in _stream_write_chunks. + bytes_per_row: float = total_input_bytes / pre_rows + rows_per_row_group: int = ROWS_PER_ROW_GROUP + estimated_n_output_files: int = max(1, round(total_input_bytes / cfg.target_size_bytes)) + + # ── 6. Write audit plan ────────────────────────────────────────────────── + audit_dir.mkdir(parents=True, exist_ok=True) + original_inventory = [_file_inventory_entry(p) for p in input_files] + compaction_plan: JsonDict = { + "ts_utc": _now_utc_iso(), + "git_sha": git_sha, + "year_month": cfg.year_month, + "run_id": cfg.run_id, + "month_canonical_dir": str(month_canonical_dir), + "staging_month_dir": str(staging_month_dir), + "precompact_dir": str(precompact_dir), + "n_input_files": len(input_files), + "pre_rows": pre_rows, + "total_input_bytes": total_input_bytes, + "bytes_per_row_estimate": round(bytes_per_row, 4), + "rows_per_row_group": rows_per_row_group, + "estimated_n_output_files": estimated_n_output_files, + "target_size_bytes": cfg.target_size_bytes, + "max_files": cfg.max_files, + "dry_run": cfg.dry_run, + "overwrite": cfg.overwrite, + "sort_keys": list(SORT_KEYS), + "final_long_cols": list(FINAL_LONG_COLS), + "input_file_list_hash": _file_list_hash(input_files), + } + _write_json(audit_dir / "compaction_plan.json", compaction_plan) + _write_json(audit_dir / "original_file_inventory.json", original_inventory) + + # ── Dry-run: stop before writing any data ──────────────────────────────── + if cfg.dry_run: + logger.log({ + **log_ctx, + "event": "compaction_complete", + "status": "info", + "msg": "dry_run=True; validation plan written; swap skipped", + "pre_rows": pre_rows, + "rows_per_row_group": rows_per_row_group, + "n_input_files": len(input_files), + }) + dry_summary: JsonDict = { + "ts_utc": _now_utc_iso(), + "git_sha": git_sha, + "year_month": cfg.year_month, + "run_id": cfg.run_id, + "status": "dry_run", + "pre_rows": pre_rows, + "n_input_files": len(input_files), + "rows_per_row_group": rows_per_row_group, + "target_size_bytes": cfg.target_size_bytes, + "elapsed_ms": _elapsed_ms(t_start, time.time()), + } + _write_json(audit_dir / "compaction_summary.json", dry_summary) + return dry_summary + + # ── 7. Stream-write compacted chunks to staging ────────────────────────── + staging_month_dir.mkdir(parents=True, exist_ok=True) + try: + output_files = _stream_write_chunks( + sorted_input_files=input_files, + staging_month_dir=staging_month_dir, + rows_per_row_group=rows_per_row_group, + target_size_bytes=cfg.target_size_bytes, + max_files=cfg.max_files, + logger=logger, + log_ctx=log_ctx, + ) + except Exception as write_err: + logger.log({ + **log_ctx, + "event": "compaction_failure", + "status": "failure", + "phase": "write", + "exception_type": type(write_err).__name__, + "exception_msg": str(write_err), + "traceback": traceback.format_exc(), + }) + shutil.rmtree(str(staging_month_dir), ignore_errors=True) + raise RuntimeError(f"Compaction write phase failed: {write_err}") from write_err + + if not output_files: + shutil.rmtree(str(staging_month_dir), ignore_errors=True) + raise RuntimeError("Compaction produced zero output files; aborting before swap.") + + # ── 8. Post-write validations ──────────────────────────────────────────── + # Row count: sum footer metadata (no data load). + post_rows = sum(_parquet_row_count(p) for p in output_files) + row_count_ok = post_rows == pre_rows + + output_schema_result = _validate_schema(output_files, label="output") + sort_dup_result = _validate_adjacent_keys_streaming( + output_files, + label="compacted", + batch_size=cfg.validation_batch_size, + logger=logger, + log_ctx=log_ctx, + ) + partition_result = _validate_partition_uniformity(output_files, cfg.year_month) + + total_output_bytes_staging = sum(p.stat().st_size for p in output_files) + + # ── 9. Write validation audit artifacts ────────────────────────────────── + compacted_inventory = [_file_inventory_entry(p) for p in output_files] + all_passed = ( + row_count_ok and output_schema_result["passed"] and sort_dup_result["passed"] and partition_result["passed"] + ) + validation_result: JsonDict = { + "ts_utc": _now_utc_iso(), + "git_sha": git_sha, + "year_month": cfg.year_month, + "run_id": cfg.run_id, + "pre_rows": pre_rows, + "post_rows": post_rows, + "row_count_match": row_count_ok, + "schema_validation": output_schema_result, + "sort_dup_validation": sort_dup_result, + "partition_validation": partition_result, + "passed": all_passed, + } + _write_json(audit_dir / "compaction_validation.json", validation_result) + _write_json(audit_dir / "compacted_file_inventory.json", compacted_inventory) + + if not all_passed: + failure_reasons: list[str] = [] + if not row_count_ok: + failure_reasons.append(f"row_count_mismatch: pre={pre_rows} post={post_rows}") + if not output_schema_result["passed"]: + failure_reasons.append(f"schema: {output_schema_result['errors']}") + if not sort_dup_result["passed"]: + failure_reasons.append(f"sort_or_dup: {sort_dup_result['error']}") + if not partition_result["passed"]: + failure_reasons.append(f"partition: {partition_result['errors']}") + + logger.log({ + **log_ctx, + "event": "compaction_failure", + "status": "failure", + "phase": "validation", + "reasons": failure_reasons, + }) + shutil.rmtree(str(staging_month_dir), ignore_errors=True) + raise RuntimeError( + f"Compaction post-write validation failed: {failure_reasons}. " + f"Original batch files are untouched at {month_canonical_dir}." + ) + + logger.log({ + **log_ctx, + "event": "compaction_validation_pass", + "status": "success", + "pre_rows": pre_rows, + "post_rows": post_rows, + "n_output_files": len(output_files), + "total_output_bytes_staging": total_output_bytes_staging, + }) + + # ── No-swap mode: keep staged outputs, skip atomic swap ────────────────── + if cfg.no_swap: + t_end = time.time() + summary: JsonDict = { + "ts_utc": _now_utc_iso(), + "git_sha": git_sha, + "year_month": cfg.year_month, + "run_id": cfg.run_id, + "status": "no_swap", + "n_input_files": len(input_files), + "n_output_files": len(output_files), + "pre_rows": pre_rows, + "post_rows": post_rows, + "total_input_bytes": total_input_bytes, + "total_output_bytes_staging": total_output_bytes_staging, + "rows_per_row_group": rows_per_row_group, + "target_size_bytes": cfg.target_size_bytes, + "month_canonical_dir": str(month_canonical_dir), + "staging_month_dir": str(staging_month_dir), + "precompact_dir": str(precompact_dir), + "elapsed_ms": _elapsed_ms(t_start, t_end), + "sort_keys": list(SORT_KEYS), + "input_file_list_hash": _file_list_hash(input_files), + "staged_output_file_list_hash": _file_list_hash(output_files), + } + _write_json(audit_dir / "compaction_summary.json", summary) + + logger.log({ + **log_ctx, + "event": "compaction_complete", + "status": "info", + "msg": "no_swap=True; staging+validation complete; swap skipped", + "pre_rows": pre_rows, + "post_rows": post_rows, + "n_input_files": len(input_files), + "n_output_files": len(output_files), + "total_output_bytes_staging": total_output_bytes_staging, + "elapsed_ms": _elapsed_ms(t_start, t_end), + }) + return summary + + # ── 10. Atomic directory swap ───────────────────────────────────────────── + try: + _atomic_swap( + month_canonical_dir=month_canonical_dir, + staging_month_dir=staging_month_dir, + precompact_dir=precompact_dir, + logger=logger, + log_ctx=log_ctx, + ) + except Exception as swap_err: + logger.log({ + **log_ctx, + "event": "compaction_failure", + "status": "failure", + "phase": "atomic_swap", + "exception_type": type(swap_err).__name__, + "exception_msg": str(swap_err), + "traceback": traceback.format_exc(), + }) + # Staging files are still in staging_month_dir (swap failed before or + # during phase-2; _atomic_swap rolls back phase-1 automatically). + shutil.rmtree(str(staging_month_dir), ignore_errors=True) + raise + + # After successful swap the output files now live under month_canonical_dir. + # Compute output bytes from the canonical location (staging dir is gone). + output_names = [p.name for p in output_files] + total_output_bytes = sum((month_canonical_dir / name).stat().st_size for name in output_names) + + # Write _SUCCESS.json marker (Spark convention) with compaction metadata. + _write_success_marker( + month_dir=month_canonical_dir, + output_names=output_names, + year_month=cfg.year_month, + run_id=cfg.run_id, + pre_rows=pre_rows, + post_rows=post_rows, + total_output_bytes=total_output_bytes, + git_sha=git_sha, + ) + + # ── 11. Final summary ───────────────────────────────────────────────────── + t_end = time.time() + summary: JsonDict = { + "ts_utc": _now_utc_iso(), + "git_sha": git_sha, + "year_month": cfg.year_month, + "run_id": cfg.run_id, + "status": "success", + "n_input_files": len(input_files), + "n_output_files": len(output_files), + "pre_rows": pre_rows, + "post_rows": post_rows, + "total_input_bytes": total_input_bytes, + "total_output_bytes": total_output_bytes, + "rows_per_row_group": rows_per_row_group, + "target_size_bytes": cfg.target_size_bytes, + "month_canonical_dir": str(month_canonical_dir), + "precompact_dir": str(precompact_dir), + "elapsed_ms": _elapsed_ms(t_start, t_end), + "sort_keys": list(SORT_KEYS), + "input_file_list_hash": _file_list_hash(input_files), + "output_file_list_hash": _file_list_hash([month_canonical_dir / n for n in output_names]), + } + _write_json(audit_dir / "compaction_summary.json", summary) + + logger.log({ + **log_ctx, + "event": "compaction_complete", + "status": "success", + "pre_rows": pre_rows, + "post_rows": post_rows, + "n_input_files": len(input_files), + "n_output_files": len(output_files), + "total_output_bytes": total_output_bytes, + "elapsed_ms": _elapsed_ms(t_start, t_end), + }) + + return summary diff --git a/scripts/csv_to_parquet/migrate_month_runner.py b/scripts/csv_to_parquet/migrate_month_runner.py new file mode 100644 index 0000000..9f59f79 --- /dev/null +++ b/scripts/csv_to_parquet/migrate_month_runner.py @@ -0,0 +1,1465 @@ +#!/usr/bin/env python3 +"""Deterministic, resumable CSV-to-Parquet month migration runner. + +Architecture +------------ +Orchestrates conversion of ~30k wide-format ComEd smart-meter CSVs per month +into Hive-partitioned Parquet (year=YYYY/month=MM). + +Key design decisions: + +1. **Batch-level atomicity** — Files are grouped into fixed-size batches. Each + batch produces exactly one Parquet file, written to a staging directory and + atomically published via ``os.replace()``. Readers never see partial output. + +2. **Resume / checkpointing** — Per-file success is recorded in JSONL manifests. + Re-running with ``--resume`` skips already-succeeded inputs, enabling safe + restarts after crashes or OOMs without re-processing the entire month. + +3. **Deterministic output** — Inputs are sorted lexicographically before batching + so batch composition is reproducible. Within each batch, rows are globally + sorted by ``(zip_code, account_identifier, datetime)`` before writing. + +4. **Lazy-then-collect execution** — The ``lazy_sink`` mode builds LazyFrames + per file, concatenates them, then *materializes* via ``.collect()`` before + sorting and writing. This is required because Polars' streaming + ``sink_parquet`` does not honor ``.sort()`` — it processes data in unordered + chunks. Explicit collect → sort → write guarantees sorted output at the cost + of batch-level memory. + +5. **Thread-pool parallelism** — Batches execute concurrently via + ``ThreadPoolExecutor``. Threads (not processes) are chosen because the + per-batch workload is I/O-bound (CSV read → transform → Parquet write) and + Polars releases the GIL during its native Rust operations. + +6. **Full audit trail** — Every file- and batch-level event is logged to + structured JSONL. ``plan.json``, ``run_summary.json``, and per-batch manifest + files provide complete post-hoc reproducibility evidence. + +Usage (via Justfile):: + + just migrate-month 202307 +""" + +from __future__ import annotations + +import argparse +import concurrent.futures as cf +import dataclasses +import datetime as dt +import hashlib +import json +import os +import platform +import shutil +import signal +import subprocess +import sys +import tempfile +import threading +import time +import traceback +from collections.abc import Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal + +import polars as pl +from compact_month_output import ( + DEFAULT_COMPACT_TARGET_SIZE_BYTES, + DEFAULT_VALIDATION_BATCH_SIZE, + CompactionConfig, + run_compaction, +) + +from smart_meter_analysis.wide_to_long import transform_wide_to_long, transform_wide_to_long_lf + +JsonDict = dict[str, Any] +Status = Literal["start", "success", "failure", "skip", "warning", "info"] + +# Canonical sort order for all batch output. This three-column key is the +# contract shared with validate_month_output.py — both must agree exactly. +# Ordering rationale: zip_code groups geographically co-located accounts, +# account_identifier within zip provides stable per-meter ordering, and +# datetime gives the natural time series within each meter. +SORT_KEYS: tuple[str, str, str] = ("zip_code", "account_identifier", "datetime") + +# Minimum columns required in the upstream wide CSV to proceed with transform. +# If any are absent, the CSV is structurally invalid — fail-loud rather than +# attempting partial output that would silently corrupt downstream analysis. +REQUIRED_WIDE_COLS: tuple[str, ...] = ( + "ZIP_CODE", + "DELIVERY_SERVICE_CLASS", + "DELIVERY_SERVICE_NAME", + "ACCOUNT_IDENTIFIER", + "INTERVAL_READING_DATE", + "INTERVAL_LENGTH", + "TOTAL_REGISTERED_ENERGY", + "PLC_VALUE", + "NSPL_VALUE", +) + +# Canonical output schema (exact column order). All Parquet files produced by +# this runner must have exactly these columns in this order. The validator +# (validate_month_output.py) cross-checks against this contract. +FINAL_LONG_COLS: tuple[str, ...] = ( + "zip_code", + "delivery_service_class", + "delivery_service_name", + "account_identifier", + "datetime", + "energy_kwh", + "plc_value", + "nspl_value", + "year", + "month", +) + +DEFAULT_WORKERS = 4 +DEFAULT_BATCH_SIZE = 50 +DEFAULT_MAX_ERRORS = 1000 +DEFAULT_PRINT_FAILURES = 10 +DEFAULT_EXEC_MODE = "lazy_sink" # default streaming sink +DEFAULT_SKIP_EXISTING_BATCH_OUTPUTS = True + + +# ----------------------------- +# Data models +# ----------------------------- + + +@dataclass(frozen=True) +class RunnerConfig: + """Immutable run configuration resolved from CLI arguments at startup. + + Frozen to prevent accidental mutation during concurrent batch execution. + All filesystem paths are resolved to absolute at construction time so that + batch workers can operate independently of working-directory changes. + """ + + compact_no_swap: bool + + year_month: str # YYYYMM + input_list: Path + out_root: Path # dataset root (Hive partitions live here) + run_id: str # unique per invocation; used for artifact directory naming + + workers: int + batch_size: int + resume: bool # when True, skip inputs already logged as success in manifests + dry_run: bool + fail_fast: bool + max_errors: int # per-batch error budget before aborting + max_files: int | None # optional cap on total inputs (for testing) + + shard_id: int | None # enables filename-safe parallel sharding + + skip_existing_batch_outputs: bool # batch-level idempotence guard + overwrite: bool # opt-in to overwrite existing batch outputs + + run_dir: Path # _runs/// — all artifacts live here + log_jsonl: Path # structured event log (append-only) + manifest_dir: Path # per-batch JSONL manifests for resume + staging_dir: Path # temp write location for atomic publish + + print_failures: int + + exec_mode: Literal["eager", "lazy_sink"] + debug_mem: bool + debug_temp_scan: bool + polars_temp_dir: str | None + + # Compaction stage (optional; runs after all batches complete). + compact_month: bool + compact_target_size_bytes: int + compact_max_files: int | None + overwrite_compact: bool + compact_dry_run: bool + validation_batch_size: int + + +@dataclass(frozen=True) +class BatchPlan: + """A unit of work: a group of input CSVs that will produce one Parquet file. + + batch_id is zero-padded (batch_0000, batch_0001, ...) for deterministic + filesystem ordering and human readability in logs. + """ + + batch_id: str + inputs: list[str] + + +# ----------------------------- +# Logging +# ----------------------------- + + +class JsonlLogger: + """Thread-safe, append-only structured event logger. + + Uses a threading.Lock to serialize writes from concurrent batch workers. + JSONL (one JSON object per line) is chosen over CSV or multi-line JSON + because it is append-safe, grep-friendly, and trivially parseable for + post-hoc analysis (e.g., extracting failure events from a 100k-line log). + """ + + def __init__(self, path: Path) -> None: + self._path = path + self._lock = threading.Lock() + self._path.parent.mkdir(parents=True, exist_ok=True) + + def log(self, event: JsonDict) -> None: + line = json.dumps(event, ensure_ascii=False, sort_keys=True) + with self._lock, self._path.open("a", encoding="utf-8") as f: + f.write(line + "\n") + + +def now_utc_iso() -> str: + return dt.datetime.now(dt.timezone.utc).isoformat(timespec="seconds") + + +def elapsed_ms(t0: float, t1: float) -> int: + return round((t1 - t0) * 1000.0) + + +def stable_hash(s: str) -> str: + return hashlib.sha256(s.encode("utf-8")).hexdigest()[:16] + + +def try_git_info() -> JsonDict: + """Capture git SHA and dirty state for the audit trail. + + Best-effort: returns None fields if git is unavailable (e.g., in Docker + images without git). This is logged in plan.json to tie output artifacts + back to the exact code version that produced them. + """ + + def _run(args: list[str]) -> str | None: + try: + cp = subprocess.run(args, check=False, capture_output=True, text=True) # noqa: S603 + if cp.returncode != 0: + return None + return cp.stdout.strip() + except Exception: + return None + + sha = _run(["git", "rev-parse", "HEAD"]) + dirty = _run(["git", "status", "--porcelain"]) + return {"sha": sha, "is_dirty": bool(dirty) if dirty is not None else None} + + +def build_env_info() -> JsonDict: + return { + "python": sys.version.replace("\n", " "), + "platform": platform.platform(), + "polars": pl.__version__, + "cwd": str(Path.cwd()), + } + + +# ----------------------------- +# Debug helpers (RSS / disk / temp) +# ----------------------------- + + +def _read_rss_bytes() -> int | None: + """Read resident set size from /proc/self/status (Linux only). + + Used with --debug-mem to track per-batch memory growth and detect leaks + during long-running migrations. Returns None on non-Linux platforms. + """ + try: + with open("/proc/self/status", encoding="utf-8") as f: + for line in f: + if line.startswith("VmRSS:"): + parts = line.split() + if len(parts) >= 2 and parts[1].isdigit(): + kb = int(parts[1]) + return kb * 1024 + except Exception: + return None + return None + + +def _disk_usage_bytes(path: Path) -> JsonDict: + try: + du = shutil.disk_usage(str(path)) + return {"free": int(du.free), "total": int(du.total), "used": int(du.used)} + except Exception as e: + return {"error": type(e).__name__, "msg": str(e)} + + +def _snapshot_dir(path: Path, limit: int = 2000) -> dict[str, int]: + out: dict[str, int] = {} + try: + if not path.exists() or not path.is_dir(): + return out + for i, p in enumerate(path.iterdir()): + if i >= limit: + break + try: + if p.is_file(): + out[p.name] = int(p.stat().st_size) + except Exception: # noqa: S112 + continue + except Exception: + return out + return out + + +# ----------------------------- +# Planning / inputs +# ----------------------------- + + +def normalize_input_path(p: str) -> str: + """Canonicalize an input path for deterministic deduplication. + + S3 URIs are kept as-is (already canonical); local paths are resolved to + absolute so that the same file referenced via different relative paths + (e.g., ./foo.csv vs ../dir/foo.csv) produces the same manifest key. + """ + p = p.strip() + if not p: + return p + if p.startswith("s3://"): + return p + return str(Path(p).expanduser().resolve()) + + +def load_inputs(input_list: Path) -> list[str]: + """Load and sort the input file list. + + Sorting is critical for determinism: it ensures that the same set of inputs + always produces the same batch assignments, regardless of the order in which + ``aws s3 ls`` or ``find`` emits them. This makes runs reproducible across + retries and enables meaningful determinism comparisons between outputs. + """ + if not input_list.exists(): + raise SystemExit(f"--input-list not found: {input_list}") + raw = input_list.read_text(encoding="utf-8").splitlines() + inputs = [normalize_input_path(x) for x in raw if x.strip() and not x.strip().startswith("#")] + inputs_sorted = sorted(inputs) + if not inputs_sorted: + raise SystemExit("No inputs found in --input-list after filtering comments/blank lines.") + return inputs_sorted + + +def make_batches(inputs_sorted: list[str], batch_size: int) -> list[BatchPlan]: + """Partition the sorted input list into fixed-size, sequentially-numbered batches. + + Sequential numbering (batch_0000, batch_0001, ...) is required for: + - deterministic output filenames that sort naturally on disk + - resume correctness (batch_id is the checkpoint key) + - human readability in logs and manifest files + """ + if batch_size <= 0: + raise SystemExit("--batch-size must be > 0") + out: list[BatchPlan] = [] + n = len(inputs_sorted) + for i in range(0, n, batch_size): + j = i // batch_size + out.append(BatchPlan(batch_id=f"batch_{j:05d}", inputs=inputs_sorted[i : i + batch_size])) + return out + + +def write_json(path: Path, obj: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(obj, indent=2, sort_keys=True), encoding="utf-8") + + +def to_jsonable(x: Any) -> Any: + if isinstance(x, Path): + return str(x) + if dataclasses.is_dataclass(x) and not isinstance(x, type): + return to_jsonable(dataclasses.asdict(x)) + if isinstance(x, dict): + return {str(k): to_jsonable(v) for k, v in x.items()} + if isinstance(x, (list, tuple)): + return [to_jsonable(v) for v in x] + return x + + +# ----------------------------- +# Resume / checkpointing +# ----------------------------- + + +def iter_manifest_success_inputs(manifest_dir: Path) -> set[str]: + """Build the set of input paths that previously succeeded (for --resume). + + Scans all manifest JSONL files and collects input_path values with + status=success. This set is used as a skip-list so that resumed runs + don't re-transform files that already completed. The manifest is the + source of truth — not the presence of output files — because a crash + could leave partial staging files without a success record. + """ + if not manifest_dir.exists(): + return set() + + success: set[str] = set() + for p in sorted(manifest_dir.glob("manifest_*.jsonl")): + with p.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + rec = json.loads(line) + except json.JSONDecodeError: + continue + if rec.get("status") == "success" and isinstance(rec.get("input_path"), str): + success.add(rec["input_path"]) + return success + + +# ----------------------------- +# Schema / validation helpers +# ----------------------------- + + +def build_wide_schema() -> dict[str, pl.DataType]: + """Construct an explicit Polars schema for the upstream wide CSV. + + An explicit schema is used instead of inference because: + - Inference is nondeterministic across files (a column that happens to have + all-integer values in one file may be inferred as Int64, while another + file with the same column may be inferred as Float64). + - ZIP_CODE and ACCOUNT_IDENTIFIER must be read as Utf8 to preserve leading + zeros (e.g., ZIP 01234). + - INTERVAL_READING_DATE is read as Utf8 and parsed downstream with an + explicit date format to avoid DD/MM vs MM/DD ambiguity. + """ + schema: dict[str, pl.DataType] = { + "ZIP_CODE": pl.Utf8, + "DELIVERY_SERVICE_CLASS": pl.Utf8, + "DELIVERY_SERVICE_NAME": pl.Utf8, + "ACCOUNT_IDENTIFIER": pl.Utf8, + "INTERVAL_READING_DATE": pl.Utf8, + "INTERVAL_LENGTH": pl.Int32, + "TOTAL_REGISTERED_ENERGY": pl.Float64, + "PLC_VALUE": pl.Float64, + "NSPL_VALUE": pl.Float64, + } + + # Standard 0030..2400 (48 cols) + DST extras 2430/2500 (2 cols) + for minutes in [*list(range(30, 1441, 30)), 1470, 1500]: + hh, mm = divmod(minutes, 60) + schema[f"INTERVAL_HR{hh:02d}{mm:02d}_ENERGY_QTY"] = pl.Float64 + + return schema + + +def validate_wide_contract(df: pl.DataFrame) -> None: + """Fail-loud pre-transform contract check on an eager DataFrame. + + Checks required columns exist and INTERVAL_LENGTH is uniformly 1800s (30 min). + This is the authoritative guard against upstream schema drift — catching it + here prevents corrupt long output from being written to the batch Parquet file. + """ + missing = [c for c in REQUIRED_WIDE_COLS if c not in df.columns] + if missing: + raise ValueError(f"Missing required wide columns: {missing}") + + if df.schema.get("INTERVAL_LENGTH") not in { + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + }: + raise ValueError(f"INTERVAL_LENGTH dtype must be integer seconds. observed={df.schema.get('INTERVAL_LENGTH')}") + + bad = df.filter(pl.col("INTERVAL_LENGTH").is_null() | (pl.col("INTERVAL_LENGTH") != 1800)).height + if bad > 0: + sample = ( + df.filter(pl.col("INTERVAL_LENGTH").is_null() | (pl.col("INTERVAL_LENGTH") != 1800)) + .select(["ZIP_CODE", "ACCOUNT_IDENTIFIER", "INTERVAL_READING_DATE", "INTERVAL_LENGTH"]) + .head(10) + .to_dicts() + ) + raise ValueError( + f"INTERVAL_LENGTH contract violation: expected 1800 everywhere. bad_rows={bad} sample={sample}" + ) + + +def validate_wide_contract_lf(lf: pl.LazyFrame) -> None: + """Lazy-mode equivalent of validate_wide_contract. + + Both eager and lazy variants exist because the runner supports two execution + modes. The lazy variant avoids full materialization — it collects only the + bad-row count and a small diagnostic sample. + """ + cols = lf.collect_schema().names() + missing = [c for c in REQUIRED_WIDE_COLS if c not in cols] + if missing: + raise ValueError(f"Missing required wide columns: {missing}") + + bad = ( + lf.filter(pl.col("INTERVAL_LENGTH").is_null() | (pl.col("INTERVAL_LENGTH") != 1800)) + .select(pl.len().alias("bad_rows")) + .collect(engine="streaming") + .item() + ) + if int(bad) > 0: + sample = ( + lf.filter(pl.col("INTERVAL_LENGTH").is_null() | (pl.col("INTERVAL_LENGTH") != 1800)) + .select(["ZIP_CODE", "ACCOUNT_IDENTIFIER", "INTERVAL_READING_DATE", "INTERVAL_LENGTH"]) + .head(10) + .collect(engine="streaming") + .to_dicts() + ) + raise ValueError( + f"INTERVAL_LENGTH contract violation: expected 1800 everywhere. bad_rows={int(bad)} sample={sample}" + ) + + +def shape_long_after_transform(df: pl.DataFrame) -> pl.DataFrame: + """Enforce canonical column names, dtypes, and order on the transform output. + + This is a defensive layer between wide_to_long.py (which owns the transform + logic) and the Parquet writer (which requires an exact schema). It handles: + - Legacy column naming (interval_energy → energy_kwh) + - Dtype coercion to the canonical 10-column schema + - Adding year/month partition columns derived from datetime + - Projecting to FINAL_LONG_COLS in exact order + """ + out = df + if "energy_kwh" not in out.columns and "interval_energy" in out.columns: + out = out.rename({"interval_energy": "energy_kwh"}) + + required = [ + "zip_code", + "delivery_service_class", + "delivery_service_name", + "account_identifier", + "datetime", + "energy_kwh", + "plc_value", + "nspl_value", + ] + missing = [c for c in required if c not in out.columns] + if missing: + raise ValueError(f"Transform output missing required columns: {missing} present_cols={out.columns}") + + if out.schema.get("datetime") == pl.Utf8: + raise ValueError("datetime is Utf8. transform must output Datetime.") + + out = out.with_columns([ + pl.col("zip_code").cast(pl.Utf8), + pl.col("account_identifier").cast(pl.Utf8), + pl.col("delivery_service_class").cast(pl.Categorical), + pl.col("delivery_service_name").cast(pl.Categorical), + pl.col("energy_kwh").cast(pl.Float64, strict=False), + pl.col("plc_value").cast(pl.Float64, strict=False), + pl.col("nspl_value").cast(pl.Float64, strict=False), + pl.col("datetime").cast(pl.Datetime("us")), + ]).with_columns([ + pl.col("datetime").dt.year().cast(pl.Int32).alias("year"), + pl.col("datetime").dt.month().cast(pl.Int8).alias("month"), + ]) + + return out.select(list(FINAL_LONG_COLS)) + + +def shape_long_after_transform_lf(lf: pl.LazyFrame) -> pl.LazyFrame: + """Lazy-mode equivalent of shape_long_after_transform.""" + cols = lf.collect_schema().names() + if "energy_kwh" not in cols and "interval_energy" in cols: + lf = lf.rename({"interval_energy": "energy_kwh"}) + + cols = lf.collect_schema().names() + required = [ + "zip_code", + "delivery_service_class", + "delivery_service_name", + "account_identifier", + "datetime", + "energy_kwh", + "plc_value", + "nspl_value", + ] + missing = [c for c in required if c not in cols] + if missing: + raise ValueError(f"Transform output missing required columns: {missing} present_cols={cols}") + + lf = lf.with_columns([ + pl.col("zip_code").cast(pl.Utf8), + pl.col("account_identifier").cast(pl.Utf8), + pl.col("delivery_service_class").cast(pl.Categorical), + pl.col("delivery_service_name").cast(pl.Categorical), + pl.col("energy_kwh").cast(pl.Float64, strict=False), + pl.col("plc_value").cast(pl.Float64, strict=False), + pl.col("nspl_value").cast(pl.Float64, strict=False), + pl.col("datetime").cast(pl.Datetime("us")), + ]).with_columns([ + pl.col("datetime").dt.year().cast(pl.Int32).alias("year"), + pl.col("datetime").dt.month().cast(pl.Int8).alias("month"), + ]) + + return lf.select(list(FINAL_LONG_COLS)) + + +def validate_year_month(df: pl.DataFrame, year_month: str) -> None: + """Guard against partition spillover: every row must belong to the target month. + + A CSV file dated in July that contains even one row with a datetime in August + would corrupt the Hive partition. Catching this at transform time (rather + than post-hoc validation) prevents bad data from being written to disk. + """ + y = int(year_month[:4]) + m = int(year_month[4:6]) + bad = df.filter((pl.col("year") != y) | (pl.col("month") != m)).height + if bad > 0: + raise ValueError(f"--year-month {year_month} validation failed: bad_rows={bad}") + + +def validate_year_month_lf(lf: pl.LazyFrame, year_month: str) -> None: + """Lazy-mode equivalent of validate_year_month.""" + y = int(year_month[:4]) + m = int(year_month[4:6]) + bad = ( + lf.filter((pl.col("year") != y) | (pl.col("month") != m)) + .select(pl.len().alias("bad_rows")) + .collect(engine="streaming") + .item() + ) + if int(bad) > 0: + raise ValueError(f"--year-month {year_month} validation failed: bad_rows={int(bad)}") + + +# --------------------------------------------------------------------------- +# Paths / deterministic output naming +# +# Output path structure: /year=YYYY/month=MM/.parquet +# Staging path structure: //year=YYYY/month=MM/.parquet +# +# The staging directory mirrors the final path hierarchy so that atomic_publish +# can use a single os.replace() call. When sharding is enabled, filenames +# include the shard_id prefix to avoid collisions between parallel shards. +# --------------------------------------------------------------------------- + + +def year_month_dirs(year_month: str) -> tuple[str, str]: + y = int(year_month[:4]) + m = int(year_month[4:6]) + return f"{y:04d}", f"{m:02d}" + + +def batch_output_filename(batch_id: str, shard_id: int | None) -> str: + if shard_id is None: + return f"{batch_id}.parquet" + return f"shard_{shard_id:02d}_{batch_id}.parquet" + + +def canonical_batch_out_path(cfg: RunnerConfig, batch_id: str) -> Path: + ydir, mdir = year_month_dirs(cfg.year_month) + return cfg.out_root / ydir / mdir / batch_output_filename(batch_id, cfg.shard_id) + + +def staging_batch_out_path(cfg: RunnerConfig, batch_id: str) -> Path: + ydir, mdir = year_month_dirs(cfg.year_month) + return cfg.staging_dir / batch_id / ydir / mdir / batch_output_filename(batch_id, cfg.shard_id) + + +def atomic_publish(staging_path: Path, final_path: Path, overwrite: bool) -> None: + """Move a completed batch file from staging to its final location. + + Uses os.replace() for atomicity on POSIX filesystems: the destination either + has the old file or the new file, never a partially-written one. This is + essential because downstream readers (validators, queries) may access the + output directory concurrently during long-running migrations. + """ + final_path.parent.mkdir(parents=True, exist_ok=True) + if overwrite: + os.replace(str(staging_path), str(final_path)) + return + if final_path.exists(): + raise FileExistsError(f"Refusing to overwrite existing output: {final_path}") + os.replace(str(staging_path), str(final_path)) + + +# ----------------------------- +# Batch execution +# ----------------------------- + + +def batch_manifest_paths(cfg: RunnerConfig, batch_id: str) -> tuple[Path, Path]: + manifest = cfg.manifest_dir / f"manifest_{batch_id}.jsonl" + summary = cfg.manifest_dir / f"summary_{batch_id}.json" + return manifest, summary + + +def _raise_batch_multi_year_month(uniq: pl.DataFrame) -> None: + """Raise ValueError with batch (year,month) values for diagnostics.""" + raise ValueError(f"Batch contains multiple (year,month) values: {uniq.sort(['year', 'month']).to_dicts()}") + + +def run_batch( + *, + cfg: RunnerConfig, + batch: BatchPlan, + logger: JsonlLogger, + skip_set: set[str], + stop_flag: threading.Event, +) -> JsonDict: + """Execute one batch: read CSVs, transform, sort, write Parquet. + + This is the core unit of work. Each batch: + 1. Checks whether the final output already exists (batch-level idempotence). + 2. Iterates over input files, transforming each wide CSV to long format. + 3. Concatenates all long DataFrames within the batch. + 4. Sorts by SORT_KEYS and writes a single Parquet file to staging. + 5. Atomically publishes the staging file to the final output location. + 6. Records per-file status in the batch manifest for resume support. + + Returns a batch summary dict logged to both the JSONL log and a JSON file. + """ + t_batch0 = time.time() + manifest_path, summary_path = batch_manifest_paths(cfg, batch.batch_id) + manifest_path.parent.mkdir(parents=True, exist_ok=True) + + batch_ctx: JsonDict = { + "ts_utc": now_utc_iso(), + "year_month": cfg.year_month, + "run_id": cfg.run_id, + "batch_id": batch.batch_id, + "shard_id": cfg.shard_id, + } + + final_out = canonical_batch_out_path(cfg, batch.batch_id) + + # Batch-level checkpoint (default ON): + if cfg.skip_existing_batch_outputs and final_out.exists() and not cfg.overwrite: + summary: JsonDict = { + **batch_ctx, + "status": "skip", + "skip_reason": "existing_batch_output", + "n_inputs": len(batch.inputs), + "n_success": 0, + "n_failure": 0, + "n_skip": 0, + "elapsed_ms": elapsed_ms(t_batch0, time.time()), + "manifest_jsonl": str(manifest_path), + "final_out_path": str(final_out), + "wrote_file": False, + "exec_mode": cfg.exec_mode, + "sort_keys": list(SORT_KEYS), + } + write_json(summary_path, summary) + logger.log({ + **batch_ctx, + "event": "batch_skip_existing_output", + "status": "skip", + "final_out_path": str(final_out), + }) + return summary + + logger.log({ + **batch_ctx, + "event": "batch_start", + "status": "start", + "n_inputs": len(batch.inputs), + "final_out_path": str(final_out), + }) + + tmp_dir = Path(tempfile.gettempdir()) + polars_tmp = os.environ.get("POLARS_TEMP_DIR") + wide_schema = build_wide_schema() + + frames: list[pl.DataFrame] = [] + lfs: list[pl.LazyFrame] = [] + + n_success = 0 + n_failure = 0 + n_skip = 0 + errors: list[JsonDict] = [] + + with manifest_path.open("a", encoding="utf-8") as mf: + for input_path in batch.inputs: + if stop_flag.is_set(): + break + + if input_path in skip_set: + n_skip += 1 + mf.write( + json.dumps( + {**batch_ctx, "input_path": input_path, "status": "skip", "reason": "resume_success"}, + sort_keys=True, + ) + + "\n" + ) + logger.log({ + **batch_ctx, + "event": "file_skip", + "status": "skip", + "input_path": input_path, + "reason": "resume_success", + }) + continue + + t0 = time.time() + file_ctx: JsonDict = {**batch_ctx, "input_path": input_path} + logger.log({**file_ctx, "event": "file_start", "status": "start"}) + + try: + if cfg.exec_mode == "eager": + df_wide = pl.read_csv( + input_path, + schema=wide_schema, + has_header=True, + infer_schema_length=0, + ignore_errors=False, + try_parse_dates=False, + ) + validate_wide_contract(df_wide) + + df_long = transform_wide_to_long(df_wide, strict=True, sort_output=False) + df_long = shape_long_after_transform(df_long) + validate_year_month(df_long, cfg.year_month) + + frames.append(df_long) + rows_wide = int(df_wide.height) + rows_long = int(df_long.height) + else: + lf_wide = pl.scan_csv( + input_path, + schema=wide_schema, + has_header=True, + ignore_errors=False, + try_parse_dates=False, + ) + validate_wide_contract_lf(lf_wide) + + lf_long = transform_wide_to_long_lf(lf_wide, strict=True, sort_output=False) + lf_long = shape_long_after_transform_lf(lf_long) + validate_year_month_lf(lf_long, cfg.year_month) + + rows_wide = int(lf_wide.select(pl.len()).collect(engine="streaming").item()) + rows_long = int(lf_long.select(pl.len()).collect(engine="streaming").item()) + lfs.append(lf_long) + + n_success += 1 + t1 = time.time() + mf.write( + json.dumps( + { + **file_ctx, + "status": "success", + "elapsed_ms": elapsed_ms(t0, t1), + "rows_wide": rows_wide, + "rows_long": rows_long, + }, + sort_keys=True, + ) + + "\n" + ) + logger.log({ + **file_ctx, + "event": "file_success", + "status": "success", + "elapsed_ms": elapsed_ms(t0, t1), + "rows_wide": rows_wide, + "rows_long": rows_long, + }) + + except Exception as e: + n_failure += 1 + t1 = time.time() + mf.write( + json.dumps( + { + **file_ctx, + "status": "failure", + "elapsed_ms": elapsed_ms(t0, t1), + "exception_type": type(e).__name__, + "exception_msg": str(e), + }, + sort_keys=True, + ) + + "\n" + ) + logger.log({ + **file_ctx, + "event": "file_failure", + "status": "failure", + "elapsed_ms": elapsed_ms(t0, t1), + "exception_type": type(e).__name__, + "exception_msg": str(e), + "traceback": traceback.format_exc(), + }) + errors.append({"input_path": input_path, "exception_type": type(e).__name__, "exception_msg": str(e)}) + + if cfg.fail_fast or n_failure >= cfg.max_errors: + break + + wrote_file = False + write_bytes = 0 + staging_out = staging_batch_out_path(cfg, batch.batch_id) + staging_out.parent.mkdir(parents=True, exist_ok=True) + + if cfg.debug_mem: + logger.log({ + **batch_ctx, + "event": "debug_env", + "status": "info", + "exec_mode": cfg.exec_mode, + "tmp_dir": str(tmp_dir), + "polars_temp_dir_env": polars_tmp, + "rss_bytes": _read_rss_bytes(), + "disk_tmp": _disk_usage_bytes(tmp_dir), + "disk_out_root": _disk_usage_bytes(cfg.out_root), + "final_out_path": str(final_out), + "staging_out_path": str(staging_out), + }) + + try: + if cfg.exec_mode == "eager" and frames: + df_batch = pl.concat(frames, how="vertical", rechunk=False) + df_batch = df_batch.sort(list(SORT_KEYS), maintain_order=True) + + uniq = df_batch.select(["year", "month"]).unique() + if uniq.height != 1: + _raise_batch_multi_year_month(uniq) + + df_batch.write_parquet(str(staging_out), compression="snappy", statistics=True, use_pyarrow=False) + wrote_file = True + + if cfg.exec_mode == "lazy_sink" and lfs: + lf_batch = pl.concat(lfs, how="vertical") + uniq = lf_batch.select(["year", "month"]).unique().collect(engine="streaming") + if uniq.height != 1: + _raise_batch_multi_year_month(uniq) + + # Collect before sort+write: sink_parquet uses the streaming engine + # which does not honor .sort() — it processes data in unordered chunks, + # silently producing unsorted output. Materializing first guarantees + # write_parquet emits rows in sorted order. + df_batch = lf_batch.collect(engine="streaming") + df_batch = df_batch.sort(list(SORT_KEYS), maintain_order=True) + df_batch.write_parquet(str(staging_out), compression="snappy", statistics=True, use_pyarrow=False) + wrote_file = True + + if wrote_file: + write_bytes = staging_out.stat().st_size + atomic_publish(staging_out, final_out, overwrite=cfg.overwrite) + + try: + staging_batch_root = cfg.staging_dir / batch.batch_id + if staging_batch_root.exists(): + shutil.rmtree(staging_batch_root, ignore_errors=True) + except Exception: # noqa: S110 + pass + + except FileExistsError as e: + logger.log({ + **batch_ctx, + "event": "batch_publish_collision", + "status": "warning", + "exception_msg": str(e), + "final_out_path": str(final_out), + "staging_out_path": str(staging_out), + }) + n_failure += 1 + wrote_file = False + except Exception as e: + logger.log({ + **batch_ctx, + "event": "batch_write_failure", + "status": "failure", + "exception_type": type(e).__name__, + "exception_msg": str(e), + "traceback": traceback.format_exc(), + }) + n_failure += 1 + wrote_file = False + + t_batch1 = time.time() + batch_summary: JsonDict = { + "ts_utc": now_utc_iso(), + "year_month": cfg.year_month, + "run_id": cfg.run_id, + "batch_id": batch.batch_id, + "shard_id": cfg.shard_id, + "n_inputs": len(batch.inputs), + "n_success": n_success, + "n_failure": n_failure, + "n_skip": n_skip, + "elapsed_ms": elapsed_ms(t_batch0, t_batch1), + "errors_sample": errors[:10], + "manifest_jsonl": str(manifest_path), + "final_out_path": str(final_out), + "staging_out_path": str(staging_out), + "wrote_file": wrote_file, + "write_bytes": write_bytes, + "sort_keys": list(SORT_KEYS), + "exec_mode": cfg.exec_mode, + "tmp_dir": str(tmp_dir), + "polars_temp_dir_env": polars_tmp, + } + write_json(summary_path, batch_summary) + logger.log({**batch_ctx, "event": "batch_end", "status": "info", **batch_summary}) + return batch_summary + + +# ----------------------------- +# CLI / main +# ----------------------------- + + +def parse_args(argv: Sequence[str]) -> RunnerConfig: + ap = argparse.ArgumentParser( + prog="migrate_month_runner", + description="Deterministic, resumable CSV→Parquet month runner (single-file per batch; shard-safe filenames).", + ) + ap.add_argument("--input-list", required=True, type=Path, help="Newline-delimited input paths (local or s3://).") + ap.add_argument("--out-root", required=True, type=Path, help="Output dataset root (Hive partitions).") + ap.add_argument("--year-month", required=True, help="Target month in YYYYMM, e.g. 202307") + ap.add_argument("--run-id", default=None, help="Optional run id. Default: UTC timestamp + stable hash.") + ap.add_argument("--workers", type=int, default=DEFAULT_WORKERS) + ap.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE) + ap.add_argument("--resume", action="store_true") + ap.add_argument("--fail-fast", action="store_true") + ap.add_argument("--dry-run", action="store_true") + ap.add_argument("--max-errors", type=int, default=DEFAULT_MAX_ERRORS) + ap.add_argument("--max-files", type=int, default=None) + ap.add_argument("--shard-id", type=int, default=None, help="Shard identifier (used in output filenames).") + ap.add_argument( + "--skip-existing-batch-outputs", + action="store_true", + default=DEFAULT_SKIP_EXISTING_BATCH_OUTPUTS, + help="Skip a batch if its expected output file already exists (default: on).", + ) + ap.add_argument( + "--no-skip-existing-batch-outputs", + action="store_false", + dest="skip_existing_batch_outputs", + help="Disable skip-existing behavior (not recommended).", + ) + ap.add_argument( + "--overwrite", + action="store_true", + help="Overwrite existing batch output files (dangerous; opt-in).", + ) + ap.add_argument( + "--exec-mode", + choices=["eager", "lazy_sink"], + default=DEFAULT_EXEC_MODE, + help="Execution mode. Default is lazy_sink (sort+sink_parquet in streaming).", + ) + ap.add_argument("--debug-mem", action="store_true", help="Log RSS/disk/timing per batch stage.") + ap.add_argument("--debug-temp-scan", action="store_true", help="Snapshot temp dir before/after sink.") + ap.add_argument( + "--polars-temp-dir", + default=None, + help="If set, exports POLARS_TEMP_DIR for this process (helps prove spill location).", + ) + ap.add_argument("--print-failures", type=int, default=DEFAULT_PRINT_FAILURES) + + # Compaction flags (all optional; compaction is off by default). + ap.add_argument( + "--compact-month", + action="store_true", + help="Run month-level compaction after all batches complete successfully.", + ) + ap.add_argument( + "--compact-target-size-bytes", + type=int, + default=DEFAULT_COMPACT_TARGET_SIZE_BYTES, + help="Target on-disk size per compacted Parquet file (default 1 GiB).", + ) + ap.add_argument( + "--compact-max-files", + type=int, + default=None, + help="Optional cap on the number of compacted output files.", + ) + ap.add_argument( + "--overwrite-compact", + action="store_true", + help="Allow overwriting existing compacted_*.parquet files.", + ) + ap.add_argument( + "--compact-dry-run", + action="store_true", + help="Plan-only: write compaction_plan.json + original inventory + summary; do not write compacted outputs.", + ) + ap.add_argument( + "--compact-no-swap", + action="store_true", + help=( + "Run full month compaction into staging and perform all post-write validations, " + "but DO NOT atomically swap staged outputs into the canonical month directory." + ), + ) + ap.add_argument( + "--validation-batch-size", + type=int, + default=DEFAULT_VALIDATION_BATCH_SIZE, + help=( + "Rows per PyArrow batch during streaming adjacent-key validation. " + "Lower values reduce peak memory; higher values improve throughput. " + f"Default: {DEFAULT_VALIDATION_BATCH_SIZE:,}." + ), + ) + + ns = ap.parse_args(list(argv)) + + ym = ns.year_month.strip() + if len(ym) != 6 or (not ym.isdigit()): + raise SystemExit("--year-month must be YYYYMM (6 digits)") + + out_root = ns.out_root.expanduser().resolve() + out_root.mkdir(parents=True, exist_ok=True) + + if ns.polars_temp_dir is not None: + os.environ["POLARS_TEMP_DIR"] = str(Path(ns.polars_temp_dir).expanduser().resolve()) + + if ns.run_id is None: + ts = dt.datetime.now(dt.timezone.utc).strftime("%Y%m%dT%H%M%SZ") + rid = f"{ts}_{stable_hash(ym + '|' + str(out_root))}" + else: + rid = ns.run_id.strip() + + run_dir = out_root / "_runs" / ym / rid + log_jsonl = run_dir / "logs" / "run_log.jsonl" + manifest_dir = run_dir / "manifests" + staging_dir = run_dir / "staging" + + return RunnerConfig( + year_month=ym, + input_list=ns.input_list.expanduser().resolve(), + out_root=out_root, + run_id=rid, + workers=ns.workers, + batch_size=ns.batch_size, + resume=ns.resume, + dry_run=ns.dry_run, + fail_fast=ns.fail_fast, + max_errors=ns.max_errors, + max_files=ns.max_files, + shard_id=ns.shard_id, + skip_existing_batch_outputs=ns.skip_existing_batch_outputs, + overwrite=ns.overwrite, + run_dir=run_dir, + log_jsonl=log_jsonl, + manifest_dir=manifest_dir, + staging_dir=staging_dir, + print_failures=ns.print_failures, + exec_mode=ns.exec_mode, + debug_mem=ns.debug_mem, + debug_temp_scan=ns.debug_temp_scan, + polars_temp_dir=ns.polars_temp_dir, + compact_month=ns.compact_month, + compact_target_size_bytes=ns.compact_target_size_bytes, + compact_max_files=ns.compact_max_files, + overwrite_compact=ns.overwrite_compact, + compact_dry_run=ns.compact_dry_run, + compact_no_swap=ns.compact_no_swap, + validation_batch_size=ns.validation_batch_size, + ) + + +def sample_failures_from_log(log_path: Path, n: int) -> list[dict[str, Any]]: + """Extract a bounded sample of failure events from the run log for stderr output. + + Provides immediate diagnostic visibility at the end of a run without + requiring the operator to manually parse the full JSONL log. + """ + if n <= 0 or (not log_path.exists()): + return [] + out: list[dict[str, Any]] = [] + for line in log_path.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line: + continue + try: + rec = json.loads(line) + except (json.JSONDecodeError, ValueError): + continue + if ( + rec.get("event") in ("file_failure", "batch_write_failure", "batch_publish_collision") + or rec.get("status") == "failure" + ): + out.append({ + k: rec.get(k) + for k in ["batch_id", "shard_id", "input_path", "exception_type", "exception_msg", "final_out_path"] + }) + if len(out) >= n: + break + return out + + +def main(argv: Sequence[str]) -> int: + """Entry point: plan → (optionally resume) → execute batches → summarize. + + Signal handling: SIGINT/SIGTERM set a cooperative stop flag rather than + killing workers abruptly. In-flight batches complete their current file + and exit cleanly, ensuring manifests reflect actual work done. This is + critical for resume correctness — an unrecorded partial write would cause + duplicate processing on retry. + """ + cfg = parse_args(argv) + + # Cooperative shutdown: workers check stop_flag between files. + stop_flag = threading.Event() + + def _handle_signal(_signum: int, _frame: Any) -> None: + stop_flag.set() + + signal.signal(signal.SIGINT, _handle_signal) + signal.signal(signal.SIGTERM, _handle_signal) + + cfg.run_dir.mkdir(parents=True, exist_ok=True) + cfg.manifest_dir.mkdir(parents=True, exist_ok=True) + cfg.staging_dir.mkdir(parents=True, exist_ok=True) + + logger = JsonlLogger(cfg.log_jsonl) + + inputs_sorted = load_inputs(cfg.input_list) + if cfg.max_files is not None: + inputs_sorted = inputs_sorted[: cfg.max_files] + + batches = make_batches(inputs_sorted, cfg.batch_size) + + plan = { + "ts_utc": now_utc_iso(), + "year_month": cfg.year_month, + "run_id": cfg.run_id, + "inputs_sorted": inputs_sorted, + "batches": [{"batch_id": b.batch_id, "n_inputs": len(b.inputs)} for b in batches], + "config": to_jsonable(cfg) | {"sort_keys": list(SORT_KEYS)}, + "env": build_env_info(), + "git": try_git_info(), + "notes": { + "deterministic_sort_keys": "zip_code, account_identifier, datetime", + "single_file_per_batch_month": True, + "lazy_sink_note": "lazy_sink uses LazyFrame.sink_parquet (streaming mode).", + "skip_existing_batch_outputs_default": DEFAULT_SKIP_EXISTING_BATCH_OUTPUTS, + }, + } + plan_path = cfg.run_dir / "plan.json" + write_json(plan_path, plan) + + logger.log({ + "ts_utc": now_utc_iso(), + "event": "run_start", + "status": "start", + "year_month": cfg.year_month, + "run_id": cfg.run_id, + "shard_id": cfg.shard_id, + "n_inputs": len(inputs_sorted), + "n_batches": len(batches), + "workers": cfg.workers, + "batch_size": cfg.batch_size, + "resume": cfg.resume, + "dry_run": cfg.dry_run, + "out_root": str(cfg.out_root), + "plan_path": str(plan_path), + "log_jsonl": str(cfg.log_jsonl), + "manifest_dir": str(cfg.manifest_dir), + "staging_dir": str(cfg.staging_dir), + "sort_keys": list(SORT_KEYS), + "exec_mode": cfg.exec_mode, + "polars_temp_dir_env": os.environ.get("POLARS_TEMP_DIR"), + "skip_existing_batch_outputs": cfg.skip_existing_batch_outputs, + "overwrite": cfg.overwrite, + }) + + if cfg.dry_run: + print( + json.dumps( + { + "year_month": cfg.year_month, + "run_id": cfg.run_id, + "shard_id": cfg.shard_id, + "n_inputs": len(inputs_sorted), + "n_batches": len(batches), + "first_inputs": inputs_sorted[:5], + "first_batches": [{"batch_id": b.batch_id, "n_inputs": len(b.inputs)} for b in batches[:3]], + "plan_path": str(plan_path), + "log_jsonl": str(cfg.log_jsonl), + "manifest_dir": str(cfg.manifest_dir), + "out_root": str(cfg.out_root), + "sort_keys": list(SORT_KEYS), + "exec_mode": cfg.exec_mode, + "skip_existing_batch_outputs": cfg.skip_existing_batch_outputs, + "overwrite": cfg.overwrite, + }, + indent=2, + sort_keys=True, + ) + ) + logger.log({"ts_utc": now_utc_iso(), "event": "run_end", "status": "info", "msg": "dry_run complete"}) + return 0 + + skip_set: set[str] = set() + if cfg.resume: + skip_set = iter_manifest_success_inputs(cfg.manifest_dir) + logger.log({ + "ts_utc": now_utc_iso(), + "event": "resume_loaded", + "status": "info", + "run_id": cfg.run_id, + "year_month": cfg.year_month, + "shard_id": cfg.shard_id, + "n_success_already": len(skip_set), + }) + + t0 = time.time() + summaries: list[JsonDict] = [] + + # ThreadPoolExecutor is preferred over ProcessPoolExecutor because: + # - Polars releases the GIL during native Rust operations (CSV parse, sort, + # Parquet write), so threads achieve true parallelism for the heavy work. + # - Threads share memory, avoiding the serialization overhead of passing + # DataFrames between processes. + # - Simpler error propagation and signal handling. + with cf.ThreadPoolExecutor(max_workers=cfg.workers) as ex: + futs: dict[cf.Future[JsonDict], BatchPlan] = {} + for b in batches: + futs[ex.submit(run_batch, cfg=cfg, batch=b, logger=logger, skip_set=skip_set, stop_flag=stop_flag)] = b + + for fut in cf.as_completed(futs): + b = futs[fut] + try: + summary = fut.result() + summaries.append(summary) + + if cfg.fail_fast and int(summary.get("n_failure", 0)) > 0: + stop_flag.set() + + except Exception as e: + logger.log({ + "ts_utc": now_utc_iso(), + "event": "batch_future_failure", + "status": "failure", + "year_month": cfg.year_month, + "run_id": cfg.run_id, + "shard_id": cfg.shard_id, + "batch_id": b.batch_id, + "exception_type": type(e).__name__, + "exception_msg": str(e), + "traceback": traceback.format_exc(), + }) + if cfg.fail_fast: + stop_flag.set() + + try: + if cfg.staging_dir.exists() and not any(cfg.staging_dir.iterdir()): + cfg.staging_dir.rmdir() + except Exception: # noqa: S110 + pass + + t1 = time.time() + total_success = sum(int(x.get("n_success", 0)) for x in summaries) + total_failure = sum(int(x.get("n_failure", 0)) for x in summaries) + total_skip = sum(int(x.get("n_skip", 0)) for x in summaries) + + # ── Optional month-level compaction ────────────────────────────────────── + # Runs only when explicitly requested AND the month completed cleanly: + # - zero file-level failures across all batches + # - cooperative stop flag was never set (no mid-run abort) + # - every planned batch produced a summary (no futures dropped silently) + compaction_summary: JsonDict | None = None + if cfg.compact_month: + compaction_eligible = total_failure == 0 and not stop_flag.is_set() and len(summaries) == len(batches) + if not compaction_eligible: + logger.log({ + "ts_utc": now_utc_iso(), + "event": "compaction_skipped", + "status": "warning", + "year_month": cfg.year_month, + "run_id": cfg.run_id, + "reason": ("total_failure > 0 or stop_flag set or incomplete batches"), + "total_failure": total_failure, + "stop_requested": stop_flag.is_set(), + "n_summaries": len(summaries), + "n_batches_planned": len(batches), + }) + else: + compact_cfg = CompactionConfig( + year_month=cfg.year_month, + run_id=cfg.run_id, + out_root=cfg.out_root, + run_dir=cfg.run_dir, + target_size_bytes=cfg.compact_target_size_bytes, + max_files=cfg.compact_max_files, + overwrite=cfg.overwrite_compact, + dry_run=cfg.compact_dry_run, + no_swap=cfg.compact_no_swap, + validation_batch_size=cfg.validation_batch_size, + ) + try: + compaction_summary = run_compaction(compact_cfg, logger) + except Exception as compact_err: + logger.log({ + "ts_utc": now_utc_iso(), + "event": "compaction_failure", + "status": "failure", + "year_month": cfg.year_month, + "run_id": cfg.run_id, + "exception_type": type(compact_err).__name__, + "exception_msg": str(compact_err), + "traceback": traceback.format_exc(), + }) + # Compaction failure is surfaced in the run summary but does + # NOT retroactively fail the batch migration exit code — the + # batch Parquet files are intact and usable. + compaction_summary = { + "status": "failure", + "exception_type": type(compact_err).__name__, + "exception_msg": str(compact_err), + } + + t1 = time.time() + batches_written = sum(1 for x in summaries if x.get("wrote_file") is True) + batches_skipped_existing_output = sum(1 for x in summaries if x.get("skip_reason") == "existing_batch_output") + batches_with_failures = sum(1 for x in summaries if int(x.get("n_failure", 0)) > 0) + + run_summary = { + "ts_utc": now_utc_iso(), + "year_month": cfg.year_month, + "run_id": cfg.run_id, + "shard_id": cfg.shard_id, + "out_root": str(cfg.out_root), + "n_inputs": len(inputs_sorted), + "n_batches_planned": len(batches), + "n_batches_completed": len(summaries), + "batches_written": batches_written, + "batches_skipped_existing_output": batches_skipped_existing_output, + "batches_with_failures": batches_with_failures, + "total_success": total_success, + "total_failure": total_failure, + "total_skip": total_skip, + "elapsed_ms": elapsed_ms(t0, t1), + "plan_path": str(plan_path), + "log_jsonl": str(cfg.log_jsonl), + "manifest_dir": str(cfg.manifest_dir), + "stop_requested": stop_flag.is_set(), + "sort_keys": list(SORT_KEYS), + "exec_mode": cfg.exec_mode, + "polars_temp_dir_env": os.environ.get("POLARS_TEMP_DIR"), + "skip_existing_batch_outputs": cfg.skip_existing_batch_outputs, + "overwrite": cfg.overwrite, + "compaction": compaction_summary, + } + write_json(cfg.run_dir / "run_summary.json", run_summary) + logger.log({"ts_utc": now_utc_iso(), "event": "run_end", "status": "info", **run_summary}) + + print(json.dumps(run_summary, indent=2, sort_keys=True)) + fails = sample_failures_from_log(cfg.log_jsonl, cfg.print_failures) + if fails: + print("Sample failures:") + for r in fails: + print(json.dumps(r, ensure_ascii=False)) + + return 1 if total_failure > 0 else 0 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv[1:])) diff --git a/scripts/csv_to_parquet/restructure_for_export.py b/scripts/csv_to_parquet/restructure_for_export.py new file mode 100644 index 0000000..3d9880c --- /dev/null +++ b/scripts/csv_to_parquet/restructure_for_export.py @@ -0,0 +1,580 @@ +#!/usr/bin/env python3 +"""Restructure 49 months of compacted Parquet data for export. + +Copies files from the blessed-run layout into a clean, Spark-compatible layout: + + SOURCE: /out_YYYYMM_blessed/year=YYYY/month=MM/compacted_NNNN.parquet + EXPORT: /YYYY/MM/part-NNNNN.parquet + _SUCCESS.json + +Three transformations applied during copy: + 1. year=YYYY/month=MM/ → YYYY/MM/ (drop Hive prefixes) + 2. compacted_NNNN → part-NNNNN (Spark naming, 5-digit zero-pad) + 3. _SUCCESS.json generated per month from pyarrow footer metadata + +Source files are NEVER modified or deleted. The export directory is a fresh +copy. Use --force to overwrite an existing export. + +Usage +----- + uv run python scripts/csv_to_parquet/restructure_for_export.py \\ + --source-root /ebs/.../runs_bs0500_w2 \\ + --export-root /ebs/.../runs_bs0500_w2/export \\ + [--dry-run] [--force] +""" + +from __future__ import annotations + +import argparse +import datetime as dt +import json +import re +import shutil +import subprocess +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import pyarrow.parquet as pq + +# --------------------------------------------------------------------------- +# Constants — must stay in sync with the pipeline schema contract +# --------------------------------------------------------------------------- + +SORT_KEYS: tuple[str, ...] = ("zip_code", "account_identifier", "datetime") + +FINAL_LONG_COLS: tuple[str, ...] = ( + "zip_code", + "delivery_service_class", + "delivery_service_name", + "account_identifier", + "datetime", + "energy_kwh", + "plc_value", + "nspl_value", + "year", + "month", +) + +EXPECTED_MONTHS: tuple[str, ...] = ( + "202103", + "202104", + "202105", + "202106", + "202107", + "202108", + "202202", + "202203", + "202204", + "202205", + "202206", + "202207", + "202208", + "202209", + "202210", + "202211", + "202212", + "202301", + "202302", + "202303", + "202304", + "202305", + "202306", + "202307", + "202308", + "202309", + "202310", + "202311", + "202312", + "202401", + "202402", + "202403", + "202404", + "202405", + "202406", + "202407", + "202408", + "202409", + "202410", + "202411", + "202412", + "202501", + "202502", + "202503", + "202504", + "202505", + "202506", + "202507", + "202508", +) + +EXPECTED_MONTH_COUNT: int = 49 + +# Pattern: out_YYYYMM_blessed +_BLESSED_DIR_RE = re.compile(r"^out_(\d{6})_blessed$") + +# Pattern: compacted_NNNN.parquet → capture the 4-digit index +_COMPACTED_RE = re.compile(r"^compacted_(\d{4})\.parquet$") + +JsonDict = dict[str, Any] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _now_utc_iso() -> str: + return dt.datetime.now(dt.timezone.utc).isoformat(timespec="seconds") + + +def _git_sha() -> str | None: + try: + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + capture_output=True, + text=True, + check=True, + ) + return result.stdout.strip() or None + except Exception: + return None + + +def _write_json(path: Path, obj: JsonDict) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(obj, indent=2) + "\n", encoding="utf-8") + + +def _compacted_to_part_name(compacted_name: str) -> str: + """Convert compacted_NNNN.parquet → part-NNNNN.parquet (4-digit → 5-digit).""" + m = _COMPACTED_RE.match(compacted_name) + if not m: + raise ValueError(f"Unexpected filename (expected compacted_NNNN.parquet): {compacted_name!r}") + idx = int(m.group(1)) + return f"part-{idx:05d}.parquet" + + +# --------------------------------------------------------------------------- +# Data model +# --------------------------------------------------------------------------- + + +@dataclass +class MonthPlan: + year_month: str # YYYYMM + year: str # YYYY + month: str # MM + source_dir: Path # .../year=YYYY/month=MM/ + dest_dir: Path # export/YYYY/MM/ + # sorted list of (src_path, dest_name) pairs + files: list[tuple[Path, str]] = field(default_factory=list) + + @property + def n_files(self) -> int: + return len(self.files) + + +@dataclass +class MonthResult: + year_month: str + n_files: int + total_rows: int + total_bytes: int + status: str # "OK" or "FAIL: " + + +# --------------------------------------------------------------------------- +# Discovery: scan source root for blessed dirs +# --------------------------------------------------------------------------- + + +def discover_months(source_root: Path) -> list[MonthPlan]: + """Find all out_YYYYMM_blessed dirs and build MonthPlan list.""" + plans: dict[str, MonthPlan] = {} + + blessed_dirs = sorted(d for d in source_root.iterdir() if d.is_dir() and _BLESSED_DIR_RE.match(d.name)) + + if not blessed_dirs: + sys.exit(f"ERROR: No out_*_blessed directories found under {source_root}\n Check --source-root path.") + + for bdir in blessed_dirs: + m = _BLESSED_DIR_RE.match(bdir.name) + if m is None: + raise RuntimeError(f"BUG: {bdir.name!r} passed filter but didn't match regex") + year_month = m.group(1) + year = year_month[:4] + month = year_month[4:6] + + # Locate year=YYYY/month=MM/ sub-path + hive_year_dirs = sorted(bdir.glob("year=*")) + if not hive_year_dirs: + sys.exit(f"ERROR: {bdir} has no year=* subdirectory.\n Expected: {bdir}/year={year}/month={month}/") + + found_source: Path | None = None + for ydir in hive_year_dirs: + candidate = ydir / f"month={month}" + if candidate.is_dir(): + found_source = candidate + break + + if found_source is None: + sys.exit( + f"ERROR: {bdir} has no month={month} subdirectory under any year=* dir.\n" + f" Dirs present: {[d.name for d in hive_year_dirs]}" + ) + + compacted = sorted(found_source.glob("compacted_*.parquet")) + if not compacted: + sys.exit( + f"ERROR: {found_source} contains no compacted_*.parquet files.\n" + f" Files present: {[p.name for p in found_source.iterdir()]}" + ) + + # Validate all filenames match expected pattern + file_pairs: list[tuple[Path, str]] = [] + for src in compacted: + dest_name = _compacted_to_part_name(src.name) + file_pairs.append((src, dest_name)) + + if year_month in plans: + sys.exit(f"ERROR: Duplicate year_month {year_month} found in source root.") + + plans[year_month] = MonthPlan( + year_month=year_month, + year=year, + month=month, + source_dir=found_source, + dest_dir=Path("__placeholder__"), # set below after export_root known + files=file_pairs, + ) + + return list(plans.values()) + + +def attach_dest_dirs(plans: list[MonthPlan], export_root: Path) -> None: + for p in plans: + p.dest_dir = export_root / p.year / p.month + + +# --------------------------------------------------------------------------- +# Validation: expected months +# --------------------------------------------------------------------------- + + +def validate_month_set(plans: list[MonthPlan]) -> None: + found = {p.year_month for p in plans} + expected = set(EXPECTED_MONTHS) + + missing = sorted(expected - found) + extra = sorted(found - expected) + + errors: list[str] = [] + if missing: + errors.append(f"Missing months ({len(missing)}): {missing}") + if extra: + errors.append(f"Unexpected months ({len(extra)}): {extra}") + if len(found) != EXPECTED_MONTH_COUNT: + errors.append(f"Expected {EXPECTED_MONTH_COUNT} months, found {len(found)}") + + if errors: + sys.exit("ERROR: Month set mismatch:\n " + "\n ".join(errors)) + + +# --------------------------------------------------------------------------- +# Dry-run: print plan without copying +# --------------------------------------------------------------------------- + + +def print_dry_run(plans: list[MonthPlan], export_root: Path) -> None: + total_files = sum(p.n_files for p in plans) + total_bytes = sum(src.stat().st_size for p in plans for src, _ in p.files) + + print(f"\nDRY-RUN — {len(plans)} months, {total_files} files, {total_bytes / 1_073_741_824:.1f} GiB total\n") + print(f" export root: {export_root}\n") + print(f" {'MONTH':<8} {'FILES':>5} {'BYTES (GiB)':>11} DEST DIR") + print(f" {'-' * 8} {'-' * 5} {'-' * 11} {'-' * 50}") + + for p in sorted(plans, key=lambda x: x.year_month): + month_bytes = sum(src.stat().st_size for src, _ in p.files) + print(f" {p.year_month:<8} {p.n_files:>5} {month_bytes / 1_073_741_824:>10.3f} {p.dest_dir}") + for src, dest_name in p.files: + print(f" {src.name:>20} → {dest_name}") + + print(f"\n TOTAL: {len(plans)} months / {total_files} files / {total_bytes / 1_073_741_824:.2f} GiB\n") + + +# --------------------------------------------------------------------------- +# Copy: one month +# --------------------------------------------------------------------------- + + +def copy_month(plan: MonthPlan, force: bool, git_sha: str | None) -> MonthResult: + dest_dir = plan.dest_dir + dest_dir.mkdir(parents=True, exist_ok=True) + + # Check for existing part files + existing_parts = sorted(dest_dir.glob("part-*.parquet")) + if existing_parts and not force: + sys.exit( + f"ERROR: {dest_dir} already contains {len(existing_parts)} part-*.parquet file(s).\n" + f" Use --force to overwrite." + ) + if existing_parts and force: + for p in existing_parts: + p.unlink() + success_marker = dest_dir / "_SUCCESS.json" + if success_marker.exists(): + success_marker.unlink() + + # Copy files + for src, dest_name in plan.files: + dest_path = dest_dir / dest_name + shutil.copy2(src, dest_path) + + # Read metadata from destination files for _SUCCESS.json + files_manifest: list[JsonDict] = [] + total_rows = 0 + total_bytes = 0 + + for _src, dest_name in plan.files: + dest_path = dest_dir / dest_name + meta = pq.read_metadata(str(dest_path)) + size = int(dest_path.stat().st_size) + rows = int(meta.num_rows) + total_rows += rows + total_bytes += size + files_manifest.append({ + "name": dest_name, + "size_bytes": size, + "num_rows": rows, + "num_row_groups": int(meta.num_row_groups), + }) + + marker: JsonDict = { + "timestamp": _now_utc_iso(), + "git_sha": git_sha, + "year_month": plan.year_month, + "n_files": plan.n_files, + "total_rows": total_rows, + "total_bytes": total_bytes, + "sort_keys": list(SORT_KEYS), + "schema": list(FINAL_LONG_COLS), + "files": files_manifest, + } + _write_json(dest_dir / "_SUCCESS.json", marker) + + return MonthResult( + year_month=plan.year_month, + n_files=plan.n_files, + total_rows=total_rows, + total_bytes=total_bytes, + status="OK", + ) + + +# --------------------------------------------------------------------------- +# Post-copy verification +# --------------------------------------------------------------------------- + + +def verify_export(plans: list[MonthPlan], export_root: Path) -> list[MonthResult]: + """Verify all destination files: count, size, readability.""" + results: list[MonthResult] = [] + failures: list[str] = [] + + for plan in sorted(plans, key=lambda x: x.year_month): + dest_dir = plan.dest_dir + dest_parts = sorted(dest_dir.glob("part-*.parquet")) + + # 1. File count + if len(dest_parts) != plan.n_files: + msg = f"{plan.year_month}: file count mismatch — expected {plan.n_files}, found {len(dest_parts)}" + failures.append(msg) + results.append(MonthResult(plan.year_month, plan.n_files, 0, 0, f"FAIL: {msg}")) + continue + + month_rows = 0 + month_bytes = 0 + month_ok = True + + for (src, dest_name), dest_path in zip(plan.files, dest_parts): + # 2. Byte-for-byte size match + src_size = src.stat().st_size + dst_size = dest_path.stat().st_size + if src_size != dst_size: + msg = f"{plan.year_month}/{dest_name}: size mismatch — src {src_size}, dst {dst_size}" + failures.append(msg) + month_ok = False + continue + + # 3. Parquet readability + try: + meta = pq.read_metadata(str(dest_path)) + month_rows += int(meta.num_rows) + month_bytes += dst_size + except Exception as exc: + msg = f"{plan.year_month}/{dest_name}: pq.read_metadata failed — {exc}" + failures.append(msg) + month_ok = False + + status = "OK" if month_ok else "FAIL: see above" + results.append(MonthResult(plan.year_month, plan.n_files, month_rows, month_bytes, status)) + + # 4. Total month count + export_month_dirs = list(export_root.rglob("_SUCCESS.json")) + if len(export_month_dirs) != EXPECTED_MONTH_COUNT: + failures.append( + f"Total month count: expected {EXPECTED_MONTH_COUNT}, found {len(export_month_dirs)} _SUCCESS.json files" + ) + + if failures: + print("\nVERIFICATION FAILURES:") + for f in failures: + print(f" ✗ {f}") + return results + + return results + + +# --------------------------------------------------------------------------- +# Summary table +# --------------------------------------------------------------------------- + + +def print_summary(results: list[MonthResult]) -> bool: + """Print summary table. Returns True if all OK.""" + header = f"{'MONTH':<8} | {'FILES':>5} | {'ROWS':>16} | {'BYTES':>14} | STATUS" + sep = "-" * len(header) + print(f"\n{header}") + print(sep) + + total_files = 0 + total_rows = 0 + total_bytes = 0 + n_ok = 0 + + for r in sorted(results, key=lambda x: x.year_month): + row = f"{r.year_month:<8} | {r.n_files:>5} | {r.total_rows:>16,} | {r.total_bytes:>14,} | {r.status}" + print(row) + total_files += r.n_files + total_rows += r.total_rows + total_bytes += r.total_bytes + if r.status == "OK": + n_ok += 1 + + print(sep) + all_ok = n_ok == len(results) + overall = f"{n_ok}/{len(results)} OK" + print(f"{'TOTAL':<8} | {total_files:>5} | {total_rows:>16,} | {total_bytes:>14,} | {overall}") + print() + return all_ok + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Restructure 49 months of compacted Parquet for export.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--source-root", + required=True, + type=Path, + help="Root containing out_YYYYMM_blessed/ directories.", + ) + parser.add_argument( + "--export-root", + required=True, + type=Path, + help="Destination root for YYYY/MM/ export layout.", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print plan without copying any files.", + ) + parser.add_argument( + "--force", + action="store_true", + help="Overwrite existing part-*.parquet files in export dir.", + ) + args = parser.parse_args() + + source_root: Path = args.source_root.resolve() + export_root: Path = args.export_root.resolve() + + # Safety guard: export_root must be on /ebs + if not str(export_root).startswith("/ebs"): + sys.exit(f"ERROR: export-root must be under /ebs. Got: {export_root}") + if not str(source_root).startswith("/ebs"): + sys.exit(f"ERROR: source-root must be under /ebs. Got: {source_root}") + + # Export root must not be the same as source root + if export_root == source_root: + sys.exit("ERROR: --export-root must differ from --source-root.") + + # Export root must not be inside a blessed dir + if any(part.endswith("_blessed") for part in export_root.parts): + sys.exit("ERROR: --export-root must not be inside a blessed directory.") + + print(f"source-root : {source_root}") + print(f"export-root : {export_root}") + print(f"dry-run : {args.dry_run}") + print(f"force : {args.force}") + + # ── 1. Discover months ──────────────────────────────────────────────────── + print("\nDiscovering source months...") + plans = discover_months(source_root) + attach_dest_dirs(plans, export_root) + + # ── 2. Validate month set ───────────────────────────────────────────────── + validate_month_set(plans) + print(f"Found {len(plans)} months ({sum(p.n_files for p in plans)} files total). ✓") + + # ── 3. Dry-run shortcut ─────────────────────────────────────────────────── + if args.dry_run: + print_dry_run(plans, export_root) + return + + # ── 4. Copy months ──────────────────────────────────────────────────────── + git_sha = _git_sha() + print(f"\nCopying {len(plans)} months to {export_root} ...") + results: list[MonthResult] = [] + + for i, plan in enumerate(sorted(plans, key=lambda x: x.year_month), 1): + month_bytes = sum(src.stat().st_size for src, _ in plan.files) + print( + f" [{i:>2}/{len(plans)}] {plan.year_month} " + f"{plan.n_files} file(s) {month_bytes / 1_073_741_824:.3f} GiB ...", + end="", + flush=True, + ) + try: + result = copy_month(plan, force=args.force, git_sha=git_sha) + print(" OK") + results.append(result) + except Exception as exc: + print(f" FAILED: {exc}") + sys.exit(f"\nERROR: Copy failed for {plan.year_month}: {exc}") + + # ── 5. Verify ───────────────────────────────────────────────────────────── + print("\nVerifying export...") + verify_results = verify_export(plans, export_root) + + # ── 6. Summary ──────────────────────────────────────────────────────────── + all_ok = print_summary(verify_results) + + if not all_ok: + sys.exit("EXPORT FAILED — see verification failures above.") + + print(f"Export complete. {len(plans)} months written to {export_root}") + + +if __name__ == "__main__": + main() diff --git a/scripts/csv_to_parquet/validate_month_output.py b/scripts/csv_to_parquet/validate_month_output.py new file mode 100644 index 0000000..5088ded --- /dev/null +++ b/scripts/csv_to_parquet/validate_month_output.py @@ -0,0 +1,1138 @@ +# scripts/csv_to_parquet/validate_month_output.py +from __future__ import annotations + +import argparse +import datetime as dt_mod +import json +import random +import re +import sys +from collections.abc import Sequence +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, NoReturn + +import polars as pl +import pyarrow.parquet as pq + +JsonDict = dict[str, Any] + +""" +Month-output Validator (QA + determinism + contract enforcement) for ComEd CSV->Parquet migration. + +What this validates (fail-loud; raises ValueError with actionable diagnostics): +1) Discovery: + - Walks --out-root recursively and discovers Hive partitions year=YYYY/month=MM (no filename/count assumptions). + - Finds all parquet files under discovered partitions; fails if none found. + +2) Schema contract (metadata-first where possible): + - Required columns exist exactly (no silent passing on missing columns). + - Dtypes match contract: + zip_code Utf8 + delivery_service_class Categorical + delivery_service_name Categorical + account_identifier Utf8 + datetime Datetime + energy_kwh/plc_value/nspl_value Float64 + year Int32 + month Int8 + Note: year/month accept Int16/Int32 only if explicitly allowed via flags is NOT implemented; contract default is strict. + +3) Partition integrity (per-file): + - year/month columns exist, non-null, and min/max match partition directory year=... month=... + - Detects mismatches and reports offending files. + +4) Datetime invariants (per partition, collected per-file and merged): + - No null datetime. + - min(datetime) has (hour,minute)==(0,0) + - max(datetime) has (hour,minute)==(23,30) + - All datetime values fall within (partition year, partition month) (no spillover). + +5) DST Option B invariants (optional: --dst-month-check, collected per-file and merged): + - Exactly 48 distinct time slots per day (no 49/50 slot days). + - Ensures no timestamps beyond 23:30. + - Spot-checks that (23:00 and 23:30) exist on at least one day with non-null energy_kwh (coarse sanity). + +6) Sortedness + Uniqueness (non-tautological): + - Validates strict lexicographic ordering by (zip_code, account_identifier, datetime). + - Modes: + --check-mode full : PyArrow streaming pass across files; O(batch_size) memory; checks strictly + increasing composite key (sortedness + no duplicates in one pass). + --check-mode sample : checks first/last K rows and deterministic random windows per file; + also checks boundaries and strictly-increasing keys within windows. + +7) Determinism compare (optional: --compare-root): + - Compares directory trees (relative paths) and per-file sizes between two outputs. + - Optionally row-counts for a limited number of files (controlled by --max-files in compare pass). + +How to run: + python scripts/csv_to_parquet/validate_month_output.py --out-root /path/to/month_output_root --check-mode sample + python scripts/csv_to_parquet/validate_month_output.py --out-root ... --check-mode sample --dst-month-check + python scripts/csv_to_parquet/validate_month_output.py --out-root run1 --compare-root run2 --check-mode sample +""" + + +RE_YEAR_DIR = re.compile(r"^year=(?P\d{4})$") +RE_MONTH_DIR = re.compile(r"^month=(?P\d{1,2})$") + + +REQUIRED_SCHEMA: dict[str, pl.DataType] = { + "zip_code": pl.Utf8, + "delivery_service_class": pl.Categorical, + "delivery_service_name": pl.Categorical, + "account_identifier": pl.Utf8, + "datetime": pl.Datetime, + "energy_kwh": pl.Float64, + "plc_value": pl.Float64, + "nspl_value": pl.Float64, + "year": pl.Int32, + "month": pl.Int8, +} + +SORT_KEY_COLS: tuple[str, str, str] = ("zip_code", "account_identifier", "datetime") + + +# --------------------------------------------------------------------------- +# Data classes +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class Partition: + year: int + month: int + path: Path + + +@dataclass +class _DtStats: + """Aggregated datetime statistics for a single file or merged partition.""" + + dt_nulls: int = 0 + dt_min: dt_mod.datetime | None = None + dt_max: dt_mod.datetime | None = None + year_min: int | None = None + year_max: int | None = None + month_min: int | None = None + month_max: int | None = None + + +@dataclass +class _DstFileStats: + """Per-file DST statistics for merge across a partition.""" + + day_slots: dict[dt_mod.date, set[tuple[int, int]]] = field(default_factory=dict) + day_nonnull_late_slots: dict[dt_mod.date, set[tuple[int, int]]] = field(default_factory=dict) + has_beyond_2330: bool = False + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _fail(msg: str) -> NoReturn: + """Abort validation with a diagnostic message. + + Typed as NoReturn so that mypy narrows Optional types after guard clauses + that call _fail() — e.g., after ``if x is None: _fail(...)``, mypy knows + x is not None on the subsequent line. + """ + raise ValueError(msg) + + +def _is_parquet(p: Path) -> bool: + return p.is_file() and p.suffix.lower() == ".parquet" + + +def _read_parquet_schema(path: Path) -> dict[str, pl.DataType]: + """Extract column names and dtypes without reading row data. + + Uses a two-level fallback chain because Polars' ``read_parquet_schema`` + API has varied across versions; ``scan_parquet(...).schema`` is the + reliable alternative. Both are metadata-only operations (O(1) data I/O). + """ + try: + schema = pl.read_parquet_schema(str(path)) + return dict(schema) + except Exception: + try: + return dict(pl.scan_parquet(str(path)).schema) + except Exception as e: + _fail(f"Failed to read parquet schema for {path}: {e}") + return {} + + +def _dtype_eq(observed: pl.DataType, expected: pl.DataType) -> bool: + """Compare observed dtype against the schema contract. + + Special-cases Datetime because the contract requires "is a Datetime" without + constraining time_unit (us/ns/ms) or time_zone. Polars' Datetime is + parameterized, so a naive ``==`` check would reject valid ``Datetime('us')`` + when the contract specifies the unparameterized ``pl.Datetime``. + """ + if expected == pl.Datetime: + return isinstance(observed, pl.Datetime) or observed == pl.Datetime + return observed == expected + + +def _composite_key_expr() -> pl.Expr: + """Build a single-string composite key for sortedness and uniqueness checks. + + Uses U+001F (Unit Separator) as delimiter because it is a non-printable + control character that cannot appear in zip codes, account identifiers, or + datetime strings. This guarantees the composite key comparison is equivalent + to a lexicographic tuple comparison of the three sort-key columns, without + the overhead of maintaining and comparing three separate columns. + """ + return pl.concat_str([ + pl.col("zip_code").cast(pl.Utf8), + pl.lit("\u001f"), # unit separator + pl.col("account_identifier").cast(pl.Utf8), + pl.lit("\u001f"), + pl.col("datetime").cast(pl.Utf8), + ]).alias("_k") + + +def _get_row_count_metadata(path: Path) -> int: + """Get row count from parquet file metadata (O(1), no data scan).""" + pf = pq.ParquetFile(str(path)) + return pf.metadata.num_rows + + +# --------------------------------------------------------------------------- +# Phase 1: Discovery +# --------------------------------------------------------------------------- + + +def _discover_partitions(out_root: Path) -> list[Partition]: # noqa: C901 + if not out_root.exists(): + _fail(f"--out-root does not exist: {out_root}") + if not out_root.is_dir(): + _fail(f"--out-root is not a directory: {out_root}") + + parts: list[Partition] = [] + # Walk directories; find .../year=YYYY/month=MM + for year_dir in out_root.rglob("*"): + if not year_dir.is_dir(): + continue + # Skip _runs/ artifact directories + if "_runs" in year_dir.parts: + continue + m_y = RE_YEAR_DIR.match(year_dir.name) + if not m_y: + continue + year = int(m_y.group("year")) + for month_dir in year_dir.iterdir(): + if not month_dir.is_dir(): + continue + m_m = RE_MONTH_DIR.match(month_dir.name) + if not m_m: + continue + month = int(m_m.group("month")) + if not (1 <= month <= 12): + _fail(f"Invalid month directory detected: {month_dir} (month={month})") + parts.append(Partition(year=year, month=month, path=month_dir)) + + if not parts: + _fail(f"No Hive partitions found under out-root={out_root}. Expected directories like year=YYYY/month=MM.") + # Deterministic ordering + parts.sort(key=lambda p: (p.year, p.month, str(p.path))) + return parts + + +def _discover_parquet_files(partitions: Sequence[Partition]) -> dict[Partition, list[Path]]: + mapping: dict[Partition, list[Path]] = {} + total = 0 + for part in partitions: + files = [p for p in part.path.rglob("*.parquet") if _is_parquet(p)] + files.sort() + mapping[part] = files + total += len(files) + + if total == 0: + _fail( + "Discovery succeeded but found zero parquet files under discovered partitions. " + "Check out-root and conversion output." + ) + return mapping + + +# --------------------------------------------------------------------------- +# Phase 2: Metadata checks (schema + partition integrity) +# --------------------------------------------------------------------------- + + +def _validate_schema_on_file(path: Path) -> None: + """Validate that a single Parquet file conforms to the canonical schema. + + Runs metadata-only (no row data read). Checks both column presence and + dtype compatibility. Fails on the first file that violates the contract, + providing the exact filename and mismatch details for rapid diagnosis. + """ + schema = _read_parquet_schema(path) + + missing = [c for c in REQUIRED_SCHEMA if c not in schema] + if missing: + _fail( + f"Schema missing required columns in file {path}:\n missing={missing}\n observed_cols={sorted(schema.keys())}" + ) + + mismatches: list[str] = [] + for col, expected in REQUIRED_SCHEMA.items(): + observed = schema[col] + if not _dtype_eq(observed, expected): + mismatches.append(f"{col}: expected={expected}, observed={observed}") + + if mismatches: + _fail(f"Dtype mismatches in file {path}:\n " + "\n ".join(mismatches)) + + +def _validate_partition_integrity_file(path: Path, part: Partition) -> None: + """Verify that year/month column values match the Hive directory they reside in. + + A row with year=2023 in a ``year=2024`` directory would silently corrupt + queries that rely on Hive partition pruning. This check reads only six + scalar aggregates (min/max/null_count for year and month) — negligible I/O. + """ + # Read tiny aggregates only. + lf = pl.scan_parquet(str(path)).select([ + pl.col("year").null_count().alias("year_nulls"), + pl.col("month").null_count().alias("month_nulls"), + pl.col("year").min().alias("year_min"), + pl.col("year").max().alias("year_max"), + pl.col("month").min().alias("month_min"), + pl.col("month").max().alias("month_max"), + ]) + try: + row = lf.collect(engine="streaming").row(0) + except Exception as e: + _fail(f"Failed to collect partition integrity stats for {path}: {e}") + + year_nulls, month_nulls, year_min, year_max, month_min, month_max = row + if year_nulls != 0 or month_nulls != 0: + _fail(f"Null partition keys in file {path}: year_nulls={year_nulls}, month_nulls={month_nulls}") + if year_min != part.year or year_max != part.year or month_min != part.month or month_max != part.month: + _fail( + f"Partition key mismatch in file {path} (dir year={part.year}, month={part.month}) " + f"but columns have year_min={year_min}, year_max={year_max}, month_min={month_min}, month_max={month_max}" + ) + + +# --------------------------------------------------------------------------- +# Phase 3a: Streaming sort + duplicate check (full mode) +# --------------------------------------------------------------------------- + + +def _streaming_sort_and_dup_check( + files: Sequence[Path], + batch_size: int = 65_536, +) -> tuple[int, list[dict[str, object]]]: + """Combined streaming sortedness + uniqueness check across ordered files. + + Leverages the global sort order: data sorted by (zip_code, account_identifier, datetime) + means duplicates are always adjacent. Checks each composite key is strictly greater than + the previous (sort order AND uniqueness in a single pass). + + Uses PyArrow iter_batches for O(batch_size) memory per pass. + + Returns (total_rows, per_file_rows). + """ + prev_key: str | None = None + total_rows = 0 + per_file_rows: list[dict[str, object]] = [] + + for fpath in files: + pf = pq.ParquetFile(str(fpath)) + file_rows = 0 + + for batch in pf.iter_batches(batch_size=batch_size, columns=list(SORT_KEY_COLS)): + n = batch.num_rows + if n == 0: + continue + + # Convert PyArrow batch -> Polars DataFrame for composite key + df = pl.from_arrow(batch) + keys = df.select(_composite_key_expr())["_k"] + + # -- Cross-batch/file boundary check -- + first_key = str(keys[0]) + if prev_key is not None: + if first_key < prev_key: + _fail( + f"Sort violation at batch boundary (row ~{total_rows + file_rows}) " + f"in {fpath}: prev_key={prev_key!r} > first_key={first_key!r}" + ) + elif first_key == prev_key: + _fail( + f"Duplicate key at batch boundary (row ~{total_rows + file_rows}) in {fpath}: key={first_key!r}" + ) + + # -- Within-batch: strictly increasing check -- + if n > 1: + violations = ( + df.select([_composite_key_expr()]) + .with_row_index("_idx") + .with_columns(pl.col("_k").shift(1).alias("_kp")) + .filter(pl.col("_kp").is_not_null() & (pl.col("_k") <= pl.col("_kp"))) + .head(1) + ) + if violations.height > 0: + r = violations.row(0) + idx_in_batch, k, kp = r + abs_row = total_rows + file_rows + idx_in_batch + kind = "Duplicate key" if k == kp else "Sort violation" + _fail(f"{kind} at row ~{abs_row} in {fpath}: prev_key={kp!r}, key={k!r}") + + prev_key = str(keys[-1]) + file_rows += n + + if file_rows == 0: + _fail(f"Empty parquet file (0 rows): {fpath}") + + per_file_rows.append({"file": fpath.name, "rows": file_rows}) + total_rows += file_rows + + return total_rows, per_file_rows + + +# --------------------------------------------------------------------------- +# Phase 3b: Sample-mode sort + duplicate check +# --------------------------------------------------------------------------- + + +def _slice_keys(path: Path, offset: int, length: int) -> pl.DataFrame: + # Do NOT use engine="streaming" here: streaming may reorder rows for + # sliced reads, which defeats the purpose of sortedness validation. + # Slices are small (head_k / window_k rows of 3 key cols) so default + # engine is both correct and fast enough. + lf = pl.scan_parquet(str(path)).select([pl.col(c) for c in SORT_KEY_COLS]).slice(offset, length) + try: + return lf.collect() + except Exception as e: + _fail(f"Failed to slice keys for {path} offset={offset} length={length}: {e}") + + +def _keys_strictly_increasing_df(df: pl.DataFrame) -> bool: + """Check that composite keys in df are strictly increasing (sorted + unique). + + "Strictly increasing" (k[i] < k[i+1] for all i) validates both sortedness + AND uniqueness in a single pass: if any adjacent pair has k[i] == k[i+1], + the check fails. This is more efficient than separate sort + deduplicate + checks and is sound because the data is globally sorted by SORT_KEY_COLS. + """ + if df.height <= 1: + return True + violations = ( + df.select([_composite_key_expr()]) + .with_row_index("_idx") + .with_columns(pl.col("_k").shift(1).alias("_kp")) + .filter(pl.col("_kp").is_not_null() & (pl.col("_k") <= pl.col("_kp"))) + ) + return violations.height == 0 + + +def _first_last_key(df: pl.DataFrame) -> tuple[str, str]: + k = df.select( + pl.concat_str([ + pl.col("zip_code").cast(pl.Utf8), + pl.lit("\u001f"), + pl.col("account_identifier").cast(pl.Utf8), + pl.lit("\u001f"), + pl.col("datetime").cast(pl.Utf8), + ]).alias("_k") + )["_k"] + return str(k[0]), str(k[-1]) + + +def _check_sorted_sample(path: Path, seed: int, max_windows: int, window_k: int, head_k: int) -> None: + """Probabilistic sortedness check: validate sort order in sampled windows. + + Checks head, tail, and several deterministic random windows within a file. + Each window validates strictly-increasing composite keys internally, and + cross-window boundary checks confirm ordering between adjacent non-overlapping + windows. + + The overlap guard (``off >= prev_end``) is essential: random windows may + overlap with the head/tail slices or with each other. Comparing the last + key of slice A to the first key of slice B is only valid when B starts at + or after the end of A; otherwise the "boundary" is inside A and the + comparison is semantically meaningless (and produces false positives). + """ + # Get row count cheaply from parquet metadata + n = _get_row_count_metadata(path) + if n <= 1: + return + + rng = random.Random(seed) # noqa: S311 + + slices: list[tuple[int, int, str]] = [] + # Head and tail + slices.append((0, min(head_k, n), "head")) + tail_len = min(head_k, n) + slices.append((max(0, n - tail_len), tail_len, "tail")) + + # Deterministic random windows + if n > window_k and max_windows > 0: + for i in range(max_windows): + off = rng.randrange(0, n - window_k + 1) + slices.append((off, window_k, f"win{i}")) + + # Sort slices by offset to allow boundary checks + slices.sort(key=lambda t: t[0]) + + prev_last_key: str | None = None + prev_tag: str | None = None + prev_end: int = 0 # offset + length of previous slice + + for off, length, tag in slices: + df = _slice_keys(path, off, length) + + if not _keys_strictly_increasing_df(df): + # Locate the violation for diagnostics + viol = ( + df.select([_composite_key_expr()]) + .with_row_index("_idx") + .with_columns(pl.col("_k").shift(1).alias("_kp")) + .filter(pl.col("_kp").is_not_null() & (pl.col("_k") <= pl.col("_kp"))) + .head(1) + ) + if viol.height > 0: + r = viol.row(0) + idx, k, kp = r + kind = "Duplicate key" if k == kp else "Sort violation" + _fail( + f"{kind} in slice tag={tag} offset={off}+{idx} in file {path}: " + f"prev_key={kp!r}, key={k!r}. " + f"Re-run with --check-mode full for exact break index." + ) + _fail( + f"Strictly-increasing violation in slice tag={tag} offset={off} in file {path}. " + f"Re-run with --check-mode full for exact break index." + ) + + first_k, last_k = _first_last_key(df) + # Cross-slice boundary check: only valid when slices do NOT overlap. + # Random windows can overlap with head/tail or each other; comparing + # last-key-of-A to first-key-of-B is meaningless if B starts inside A. + if prev_last_key is not None and off >= prev_end: + if first_k < prev_last_key: + _fail( + f"Sort violation across slice boundary in file {path}: " + f"prev_slice={prev_tag} last_key={prev_last_key!r} > " + f"slice={tag} first_key={first_k!r}. " + f"Re-run with --check-mode full for exact break index." + ) + elif first_k == prev_last_key: + _fail( + f"Duplicate key across slice boundary in file {path}: " + f"prev_slice={prev_tag} key={first_k!r}. " + f"Re-run with --check-mode full for exact break index." + ) + + prev_last_key = last_k + prev_tag = tag + prev_end = off + length + + +# --------------------------------------------------------------------------- +# Phase 4: Datetime invariants (per-file collect + merge) +# --------------------------------------------------------------------------- + + +def _collect_datetime_stats_file(path: Path) -> _DtStats: + """Collect datetime aggregate stats from a single file (cheap aggregates).""" + lf = pl.scan_parquet(str(path)).select([ + pl.col("datetime").null_count().alias("dt_nulls"), + pl.col("datetime").min().alias("dt_min"), + pl.col("datetime").max().alias("dt_max"), + pl.col("datetime").dt.year().min().alias("dt_year_min"), + pl.col("datetime").dt.year().max().alias("dt_year_max"), + pl.col("datetime").dt.month().min().alias("dt_month_min"), + pl.col("datetime").dt.month().max().alias("dt_month_max"), + ]) + try: + row = lf.collect(engine="streaming").row(0) + except Exception as e: + _fail(f"Failed to collect datetime stats for {path}: {e}") + + return _DtStats( + dt_nulls=row[0], + dt_min=row[1], + dt_max=row[2], + year_min=row[3], + year_max=row[4], + month_min=row[5], + month_max=row[6], + ) + + +def _merge_dt_stats(stats_list: Sequence[_DtStats]) -> _DtStats: + """Merge per-file datetime stats into partition-level stats.""" + merged = _DtStats() + for s in stats_list: + merged.dt_nulls += s.dt_nulls + if s.dt_min is not None: + merged.dt_min = min(merged.dt_min, s.dt_min) if merged.dt_min is not None else s.dt_min + if s.dt_max is not None: + merged.dt_max = max(merged.dt_max, s.dt_max) if merged.dt_max is not None else s.dt_max + if s.year_min is not None: + merged.year_min = min(merged.year_min, s.year_min) if merged.year_min is not None else s.year_min + if s.year_max is not None: + merged.year_max = max(merged.year_max, s.year_max) if merged.year_max is not None else s.year_max + if s.month_min is not None: + merged.month_min = min(merged.month_min, s.month_min) if merged.month_min is not None else s.month_min + if s.month_max is not None: + merged.month_max = max(merged.month_max, s.month_max) if merged.month_max is not None else s.month_max + return merged + + +def _validate_datetime_stats_for_partition(merged: _DtStats, part: Partition) -> None: + """Validate merged datetime stats against partition expectations.""" + if merged.dt_nulls != 0: + _fail(f"Null datetime found in partition {part.path}: dt_nulls={merged.dt_nulls}") + + if merged.dt_min is None or merged.dt_max is None: + _fail(f"Datetime min/max unexpectedly None in partition {part.path}") + + # Ensure within partition month + if ( + merged.year_min != part.year + or merged.year_max != part.year + or merged.month_min != part.month + or merged.month_max != part.month + ): + _fail( + f"Datetime spillover in partition {part.path} (dir year={part.year}, month={part.month}) " + f"but datetime year range=({merged.year_min},{merged.year_max}) " + f"month range=({merged.month_min},{merged.month_max})" + ) + + # Time-of-day checks + if (merged.dt_min.hour, merged.dt_min.minute) != (0, 0): + _fail(f"Partition {part.path} has dt_min={merged.dt_min} but expected time-of-day 00:00") + if (merged.dt_max.hour, merged.dt_max.minute) != (23, 30): + _fail(f"Partition {part.path} has dt_max={merged.dt_max} but expected time-of-day 23:30") + + +# --------------------------------------------------------------------------- +# Phase 5: DST Option B (per-file collect + merge) +# --------------------------------------------------------------------------- + + +def _collect_dst_stats_file(path: Path) -> _DstFileStats: + """Collect DST-relevant stats from a single file. + + Returns per-day unique (h,m) slot sets and spot-check data. + Memory: O(days_in_month * 48) = ~1500 entries max. + """ + lf_base = ( + pl.scan_parquet(str(path)) + .select(["datetime", "energy_kwh"]) + .with_columns([ + pl.col("datetime").dt.date().alias("d"), + pl.col("datetime").dt.hour().alias("h"), + pl.col("datetime").dt.minute().alias("m"), + ]) + ) + + # Unique (d, h, m) — at most 31 * 48 = 1488 rows regardless of account count + slots_df = lf_base.select(["d", "h", "m"]).unique().collect(engine="streaming") + day_slots: dict[dt_mod.date, set[tuple[int, int]]] = {} + for row in slots_df.iter_rows(): + d, h, m = row + day_slots.setdefault(d, set()).add((h, m)) + + # Beyond 23:30 check + beyond_count = int( + lf_base.filter((pl.col("h") > 23) | ((pl.col("h") == 23) & (pl.col("m") > 30))) + .select(pl.len()) + .collect(engine="streaming") + .row(0)[0] + ) + + # Non-null energy at 23:00 and 23:30 (for spot-check merge) + late_df = ( + lf_base.filter((pl.col("h") == 23) & pl.col("m").is_in([0, 30]) & pl.col("energy_kwh").is_not_null()) + .select(["d", "h", "m"]) + .unique() + .collect(engine="streaming") + ) + day_nonnull: dict[dt_mod.date, set[tuple[int, int]]] = {} + for row in late_df.iter_rows(): + d, h, m = row + day_nonnull.setdefault(d, set()).add((h, m)) + + return _DstFileStats( + day_slots=day_slots, + day_nonnull_late_slots=day_nonnull, + has_beyond_2330=beyond_count > 0, + ) + + +def _validate_dst_for_partition(part: Partition, files: Sequence[Path]) -> None: + """Validate DST Option B by collecting per-file stats and merging.""" + merged_slots: dict[dt_mod.date, set[tuple[int, int]]] = {} + merged_nonnull: dict[dt_mod.date, set[tuple[int, int]]] = {} + any_beyond = False + + for f in files: + stats = _collect_dst_stats_file(f) + for d, s in stats.day_slots.items(): + merged_slots.setdefault(d, set()).update(s) + for d, s in stats.day_nonnull_late_slots.items(): + merged_nonnull.setdefault(d, set()).update(s) + if stats.has_beyond_2330: + any_beyond = True + + # Check 1: exactly 48 unique time slots per day + bad_days = [(d, len(s)) for d, s in merged_slots.items() if len(s) != 48] + if bad_days: + bad_days.sort() + sample = [{"date": str(d), "slots": n} for d, n in bad_days[:10]] + _fail(f"DST Option B violation: days with slots!=48 in partition {part.path}. Examples (up to 10): {sample}") + + # Check 2: no timestamps beyond 23:30 + if any_beyond: + _fail(f"DST Option B violation: found datetime beyond 23:30 in partition {part.path}.") + + # Check 3: at least one day has non-null energy_kwh at both 23:00 and 23:30 + days_with_both = sum(1 for s in merged_nonnull.values() if (23, 0) in s and (23, 30) in s) + if days_with_both == 0: + _fail( + f"DST Option B spot-check failed in partition {part.path}: " + f"did not find any day with non-null energy_kwh at both 23:00 and 23:30." + ) + + +# --------------------------------------------------------------------------- +# File selection +# --------------------------------------------------------------------------- + + +def _select_files_for_mode(files: Sequence[Path], mode: str, max_files: int | None, seed: int) -> list[Path]: + """Select a subset of files for validation when --max-files is set. + + Full mode uses deterministic first-N selection (reproducible, bias toward + early batches). Sample mode uses seeded random selection to provide + coverage across the full output without examining every file. The seed + ensures that repeated runs with the same arguments validate the same files. + """ + if max_files is None or max_files <= 0 or max_files >= len(files): + return list(files) + + if mode == "full": + return list(files)[:max_files] + + rng = random.Random(seed) # noqa: S311 + idxs = list(range(len(files))) + rng.shuffle(idxs) + chosen = sorted(idxs[:max_files]) + return [files[i] for i in chosen] + + +# --------------------------------------------------------------------------- +# Determinism compare +# --------------------------------------------------------------------------- + + +def _compare_roots(root_a: Path, root_b: Path, max_files: int | None, seed: int) -> None: # noqa: C901 + """Compare two output directories for determinism (same code + inputs → same output). + + Three-tier comparison strategy, each progressively more expensive: + 1. Directory tree structure (file paths must match exactly) + 2. File sizes (cheap; catches most non-determinism from different row counts + or compression differences) + 3. Row counts for a sample of Parquet files (controlled by --max-files) + + This does NOT do byte-for-byte comparison because Parquet writer versions + and compression settings may produce bitwise-different files with identical + logical content. Size + row count is sufficient for migration QA purposes. + """ + if not root_b.exists() or not root_b.is_dir(): + _fail(f"--compare-root is not a directory: {root_b}") + + def list_rel_files(root: Path) -> list[Path]: + rels = [] + for p in root.rglob("*"): + if p.is_file(): + rels.append(p.relative_to(root)) + rels.sort() + return rels + + a_files = list_rel_files(root_a) + b_files = list_rel_files(root_b) + + a_set = set(a_files) + b_set = set(b_files) + if a_set != b_set: + only_a = sorted(a_set - b_set)[:20] + only_b = sorted(b_set - a_set)[:20] + _fail( + "Determinism compare failed: directory trees differ.\n" + f" only_in_out_root (up to 20): {only_a}\n" + f" only_in_compare_root (up to 20): {only_b}" + ) + + # Compare sizes (cheap and stable) + mismatches: list[str] = [] + for rel in a_files: + pa = root_a / rel + pb = root_b / rel + sa = pa.stat().st_size + sb = pb.stat().st_size + if sa != sb: + mismatches.append(f"{rel}: size_out={sa}, size_compare={sb}") + if len(mismatches) >= 50: + break + + if mismatches: + _fail( + "Determinism compare failed: file sizes differ (note: writer versions may legitimately differ; " + "this check is intentionally strict on size).\n " + "\n ".join(mismatches) + ) + + # Optional: row counts for up to max_files parquet files (controlled) + parquet_rels = [rel for rel in a_files if rel.suffix.lower() == ".parquet"] + if not parquet_rels: + return + + chosen: list[Path] + if max_files is None or max_files <= 0 or max_files >= len(parquet_rels): + chosen = parquet_rels + else: + rng = random.Random(seed) # noqa: S311 + idxs = list(range(len(parquet_rels))) + rng.shuffle(idxs) + chosen = [parquet_rels[i] for i in sorted(idxs[:max_files])] + + row_mismatches: list[str] = [] + for rel in chosen: + pa = root_a / rel + pb = root_b / rel + try: + na = int(pl.scan_parquet(str(pa)).select(pl.len()).collect(engine="streaming").row(0)[0]) + nb = int(pl.scan_parquet(str(pb)).select(pl.len()).collect(engine="streaming").row(0)[0]) + except Exception as e: + _fail(f"Determinism compare failed reading row counts for {rel}: {e}") + if na != nb: + row_mismatches.append(f"{rel}: rows_out={na}, rows_compare={nb}") + if len(row_mismatches) >= 50: + break + + if row_mismatches: + _fail("Determinism compare failed: row counts differ.\n " + "\n ".join(row_mismatches)) + + +# --------------------------------------------------------------------------- +# Run artifact validation +# --------------------------------------------------------------------------- + + +def _validate_run_artifacts(run_dir: Path, expected_parquet_count: int | None = None) -> JsonDict: # noqa: C901 + """Validate runner artifacts under a _runs/// directory. + + Checks: + - plan.json exists and is valid JSON + - run_summary.json exists, is valid JSON, and reports total_failure=0 + - Manifest JSONL files exist; all file-level entries are success or skip + - If expected_parquet_count is provided, cross-checks batches_written + + Returns a dict of artifact-check results for inclusion in the validation report. + """ + if not run_dir.exists() or not run_dir.is_dir(): + _fail(f"--run-dir does not exist or is not a directory: {run_dir}") + + results: JsonDict = {"run_dir": str(run_dir)} + + # -- plan.json -- + plan_path = run_dir / "plan.json" + if not plan_path.exists(): + _fail(f"Missing plan.json in run artifacts: {plan_path}") + try: + plan = json.loads(plan_path.read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError) as e: + _fail(f"Invalid plan.json: {plan_path}: {e}") + results["plan_n_inputs"] = len(plan.get("inputs_sorted", [])) + results["plan_n_batches"] = len(plan.get("batches", [])) + + # -- run_summary.json -- + summary_path = run_dir / "run_summary.json" + if not summary_path.exists(): + _fail(f"Missing run_summary.json in run artifacts: {summary_path}") + try: + summary = json.loads(summary_path.read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError) as e: + _fail(f"Invalid run_summary.json: {summary_path}: {e}") + + total_failure = int(summary.get("total_failure", -1)) + total_success = int(summary.get("total_success", 0)) + total_skip = int(summary.get("total_skip", 0)) + batches_written = int(summary.get("batches_written", 0)) + stop_requested = summary.get("stop_requested", False) + + if total_failure != 0: + _fail( + f"run_summary.json reports total_failure={total_failure} (must be 0). " + f"total_success={total_success}, total_skip={total_skip}. " + f"Investigate logs at: {run_dir / 'logs' / 'run_log.jsonl'}" + ) + + if stop_requested: + _fail(f"run_summary.json reports stop_requested=True. Run was interrupted: {summary_path}") + + results["summary_total_success"] = total_success + results["summary_total_failure"] = total_failure + results["summary_total_skip"] = total_skip + results["summary_batches_written"] = batches_written + + if expected_parquet_count is not None and batches_written != expected_parquet_count: + _fail( + f"Batch count mismatch: run_summary.json reports batches_written={batches_written} " + f"but discovered {expected_parquet_count} parquet files on disk." + ) + + # -- manifest JSONL -- + manifest_dir = run_dir / "manifests" + if not manifest_dir.exists(): + _fail(f"Missing manifests directory: {manifest_dir}") + + manifest_files = sorted(manifest_dir.glob("manifest_*.jsonl")) + if not manifest_files: + _fail(f"No manifest_*.jsonl files found in {manifest_dir}") + + manifest_failures: list[str] = [] + manifest_success_count = 0 + manifest_skip_count = 0 + + for mf in manifest_files: + try: + for line in mf.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line: + continue + rec = json.loads(line) + status = rec.get("status", "") + if status == "success": + manifest_success_count += 1 + elif status == "skip": + manifest_skip_count += 1 + elif status == "failure": + inp = rec.get("input_path", "?") + exc = rec.get("exception_msg", "?") + manifest_failures.append(f"{inp}: {exc}") + except (json.JSONDecodeError, OSError) as e: + _fail(f"Error reading manifest file {mf}: {e}") + + if manifest_failures: + sample = manifest_failures[:10] + _fail(f"Manifest contains {len(manifest_failures)} failure entries (must be 0). Sample (up to 10): {sample}") + + results["manifest_files_checked"] = len(manifest_files) + results["manifest_success_count"] = manifest_success_count + results["manifest_skip_count"] = manifest_skip_count + + # -- batch summaries -- + summary_files = sorted(manifest_dir.glob("summary_*.json")) + batch_failures = [] + for sf in summary_files: + try: + bs = json.loads(sf.read_text(encoding="utf-8")) + if int(bs.get("n_failure", 0)) > 0: + batch_failures.append(f"{sf.name}: n_failure={bs.get('n_failure')}") + except (json.JSONDecodeError, OSError): + batch_failures.append(f"{sf.name}: unreadable") + + if batch_failures: + _fail(f"Batch summary files report failures: {batch_failures[:10]}") + + results["batch_summaries_checked"] = len(summary_files) + + return results + + +# --------------------------------------------------------------------------- +# Main (phase-based architecture) +# --------------------------------------------------------------------------- + + +def main(argv: Sequence[str] | None = None) -> int: # noqa: C901 + """Orchestrate validation in sequential phases. + + Phase architecture rationale: phases are ordered from cheapest to most + expensive. If a cheap check (schema, partition integrity) fails, expensive + checks (streaming sort, DST) are never reached. This fail-fast approach + minimizes wall-clock time when data is corrupt. + + Phases: + 1. Discovery — find partitions and Parquet files + 1b. Compare — structural determinism check (optional, fail-fast) + 2. Metadata — schema contract + partition integrity (metadata-only I/O) + 3. Sortedness + duplicates — streaming or sample-based (configurable) + 4. Datetime invariants — per-file collect + merge (all files) + 5. DST Option B — per-file collect + merge (optional) + 6. Run artifacts — plan.json, manifests, summaries (optional) + 7. Report — build and write validation summary + """ + p = argparse.ArgumentParser(description="Validate ComEd month-output parquet dataset contract.") + p.add_argument( + "--out-root", required=True, help="Converted dataset output root containing year=YYYY/month=MM partitions." + ) + p.add_argument( + "--check-mode", choices=["full", "sample"], default="sample", help="Validation intensity for sortedness checks." + ) + p.add_argument( + "--dst-month-check", action="store_true", help="Enable DST Option B shape checks (48 slots/day; no extras)." + ) + p.add_argument( + "--compare-root", default=None, help="Optional second output root to compare for determinism invariants." + ) + p.add_argument( + "--max-files", type=int, default=None, help="Max parquet files to validate (selection depends on mode)." + ) + p.add_argument("--seed", type=int, default=42, help="Deterministic seed for sampling selection/windows.") + p.add_argument("--output-report", default=None, help="Write validation report JSON to this path.") + p.add_argument( + "--run-dir", + default=None, + help="Runner artifact directory (_runs/YYYYMM//) to validate plan.json, run_summary.json, manifests.", + ) + args = p.parse_args(list(argv) if argv is not None else None) + + out_root = Path(args.out_root).resolve() + + # ── Phase 1: Discovery ────────────────────────────────────────────── + partitions = _discover_partitions(out_root) + mapping = _discover_parquet_files(partitions) + + # ── Phase 1b: Compare mode (structural, fail fast) ────────────────── + if args.compare_root is not None: + _compare_roots(out_root, Path(args.compare_root).resolve(), args.max_files, args.seed) + + # ── Phase 2: Metadata checks (schema + partition integrity) ───────── + total_files = sum(len(v) for v in mapping.values()) + checked_files = 0 + + for part in partitions: + files = mapping[part] + if not files: + _fail( + f"Discovered partition {part.path} (year={part.year}, month={part.month}) " + f"but found zero parquet files under it." + ) + + selected = _select_files_for_mode(files, args.check_mode, args.max_files, args.seed) + for f in selected: + _validate_schema_on_file(f) + _validate_partition_integrity_file(f, part) + checked_files += 1 + + if checked_files == 0: + _fail("No files validated (unexpected). Check --max-files and discovered outputs.") + + # ── Phase 3: Sortedness + duplicates + row counts ─────────────────── + total_rows = 0 + per_file_rows: list[dict[str, object]] = [] + + for part in partitions: + files = mapping[part] + selected = _select_files_for_mode(files, args.check_mode, args.max_files, args.seed) + + if args.check_mode == "full": + # Combined streaming sort+dup check — O(batch_size) memory + partition_rows, partition_per_file = _streaming_sort_and_dup_check(selected) + total_rows += partition_rows + per_file_rows.extend(partition_per_file) + else: + # Sample mode: enhanced strict-increasing check per file + for f in selected: + _check_sorted_sample( + f, + seed=args.seed, + max_windows=3, + window_k=5_000, + head_k=5_000, + ) + frows = _get_row_count_metadata(f) + total_rows += frows + per_file_rows.append({"file": f.name, "rows": frows}) + + # ── Phase 4: Datetime invariants (all files, per-file + merge) ────── + for part in partitions: + files = mapping[part] + dt_stats_list = [_collect_datetime_stats_file(f) for f in files] + merged = _merge_dt_stats(dt_stats_list) + _validate_datetime_stats_for_partition(merged, part) + + # ── Phase 5: DST Option B (all files, per-file + merge) ──────────── + if args.dst_month_check: + for part in partitions: + _validate_dst_for_partition(part, mapping[part]) + + # ── Phase 6: Run artifact integrity (optional) ───────────────────── + run_artifact_results: JsonDict | None = None + if args.run_dir is not None: + run_artifact_results = _validate_run_artifacts( + Path(args.run_dir).resolve(), + expected_parquet_count=total_files, + ) + + # ── Phase 7: Build validation report ──────────────────────────────── + checks_passed = [ + "schema_contract", + "partition_integrity", + "no_duplicates", + "datetime_invariants", + f"sortedness_{args.check_mode}", + ] + + if args.dst_month_check: + checks_passed.append("dst_option_b") + + if args.compare_root: + checks_passed.append("determinism_compare") + + if run_artifact_results is not None: + checks_passed.append("run_artifact_integrity") + + report: JsonDict = { + "status": "pass", + "timestamp": dt_mod.datetime.now(dt_mod.timezone.utc).isoformat(), + "out_root": str(out_root), + "partitions_validated": len(partitions), + "partition_details": [{"year": p.year, "month": p.month, "files": len(mapping[p])} for p in partitions], + "files_validated": checked_files, + "total_files_discovered": total_files, + "total_rows_validated": total_rows, + "per_file_rows": per_file_rows, + "check_mode": args.check_mode, + "dst_month_check": args.dst_month_check, + "checks_passed": checks_passed, + "sort_order": list(SORT_KEY_COLS), + } + + if args.compare_root: + report["compare_root"] = str(Path(args.compare_root).resolve()) + + if run_artifact_results is not None: + report["run_artifacts"] = run_artifact_results + + # Write report if requested + if args.output_report: + report_path = Path(args.output_report) + report_path.parent.mkdir(parents=True, exist_ok=True) + with open(report_path, "w") as outfile: + json.dump(report, outfile, indent=2) + print(f"Validation report written to: {report_path}") + + # Minimal success signal (no prints during failure). + print( + f"OK: validated {checked_files} parquet files across {len(partitions)} partitions " + f"(discovered total parquet files={total_files}, total rows validated={total_rows})." + ) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/edit_geojson.py b/scripts/edit_geojson.py new file mode 100644 index 0000000..8ebe5cd --- /dev/null +++ b/scripts/edit_geojson.py @@ -0,0 +1,140 @@ +import argparse +import glob +import json +import os +import shutil +from typing import Any, Optional + +ANCHOR_MIN_ID = "__DOMAIN_ANCHOR_MIN__" +ANCHOR_MAX_ID = "__DOMAIN_ANCHOR_MAX__" + + +def _load_bound_sym(folder: str, bound_override: Optional[float]) -> float: + if bound_override is not None: + return float(bound_override) + + range_path = os.path.join(folder, "range_global.json") + if os.path.exists(range_path): + with open(range_path) as f: + meta = json.load(f) + if "bound_sym" not in meta: + raise RuntimeError(f"{range_path} exists but has no 'bound_sym' key.") + return float(meta["bound_sym"]) + + # Fallback: last known bound from your run (range_global.json you pasted) + return 24.143564317219816 + + +def _has_anchor(features: list[dict[str, Any]]) -> bool: + for feat in features: + props = feat.get("properties") or {} + geoid = props.get("geoid_bg") + if geoid in (ANCHOR_MIN_ID, ANCHOR_MAX_ID): + return True + return False + + +def _make_anchor(geoid_bg: str, value: float) -> dict[str, Any]: + return { + "type": "Feature", + "geometry": None, + "properties": { + "geoid_bg": geoid_bg, + "n_households": 0, + "mean_delta": None, + "mean_delta_cap_global_sym": value, + "is_domain_anchor": True, + }, + } + + +def _warn_on_invariant(path: str, features: list[dict[str, Any]]) -> None: + # Fail-loud would be better upstream; here we just warn. + bad = 0 + for feat in features: + props = feat.get("properties") or {} + geoid = props.get("geoid_bg") + if geoid in (ANCHOR_MIN_ID, ANCHOR_MAX_ID): + continue + nh = props.get("n_households") + v = props.get("mean_delta_cap_global_sym") + try: + nh_i = int(nh) if nh is not None else 0 + except Exception: + nh_i = 0 + if nh_i > 0: + try: + float(v) + except Exception: + bad += 1 + if bad: + print( + f"[WARN] {os.path.basename(path)}: {bad} features have n_households>0 but invalid mean_delta_cap_global_sym" + ) + + +def process_file(path: str, bound_sym: float, make_backup: bool) -> bool: + with open(path) as f: + data = json.load(f) + + if data.get("type") != "FeatureCollection": + print(f"[SKIP] {path}: not a FeatureCollection") + return False + + features = data.get("features") + if not isinstance(features, list): + print(f"[SKIP] {path}: missing/invalid 'features' array") + return False + + if _has_anchor(features): + print(f"[OK] {os.path.basename(path)}: anchors already present (skipping)") + _warn_on_invariant(path, features) + return False + + features.append(_make_anchor(ANCHOR_MIN_ID, -bound_sym)) + features.append(_make_anchor(ANCHOR_MAX_ID, bound_sym)) + + _warn_on_invariant(path, features) + + if make_backup: + shutil.copy2(path, path + ".bak") + + # Write compact JSON (smaller files, fast upload). Use indent=2 if you prefer readable. + tmp_path = path + ".tmp" + with open(tmp_path, "w") as f: + json.dump(data, f, ensure_ascii=False) + os.replace(tmp_path, path) + + print(f"[WRITE] {os.path.basename(path)}: added anchors ±{bound_sym:.6f}") + return True + + +def main() -> None: + ap = argparse.ArgumentParser(description="Append ±bound_sym domain anchor features to all GeoJSONs in a folder.") + ap.add_argument("folder", help="Folder containing *.geojson files (and optionally range_global.json).") + ap.add_argument( + "--bound-sym", type=float, default=None, help="Override bound_sym (else read range_global.json if present)." + ) + ap.add_argument("--no-backup", action="store_true", help="Do not write .bak backups.") + args = ap.parse_args() + + folder = os.path.abspath(os.path.expanduser(args.folder)) + if not os.path.isdir(folder): + raise SystemExit(f"Not a directory: {folder}") + + bound_sym = _load_bound_sym(folder, args.bound_sym) + geojson_paths = sorted(glob.glob(os.path.join(folder, "*.geojson"))) + + if not geojson_paths: + raise SystemExit(f"No *.geojson files found in {folder}") + + changed = 0 + for p in geojson_paths: + if process_file(p, bound_sym, make_backup=(not args.no_backup)): + changed += 1 + + print(f"Done. Changed {changed} files out of {len(geojson_paths)}. bound_sym={bound_sym}") + + +if __name__ == "__main__": + main() diff --git a/scripts/validate_wide_to_long_batched.py b/scripts/validate_wide_to_long_batched.py new file mode 100644 index 0000000..38dad10 --- /dev/null +++ b/scripts/validate_wide_to_long_batched.py @@ -0,0 +1,479 @@ +from __future__ import annotations + +import argparse +import hashlib +import json +import time +from collections.abc import Sequence +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import polars as pl + +from smart_meter_analysis.wide_to_long import transform_wide_to_long_lf + +# -------------------------------------------------------------------------------------- +# Batched wide_to_long validator (resumable, JSONL checkpoints) +# +# What is this? +# - Month-scale Zip4 validation in the Docker devcontainer can wedge Docker if we do +# a global sort or force full materialization of a full month. Many checks are +# expensive if they require multiple passes over the dataset. +# - This script validates correctness (schema contracts, Daylight Saving Time +# behavior, datetime bounds, and 48-interval invariants) at full-month scale +# by processing input CSVs in bounded batches with checkpoints. +# If the container crashes or Docker becomes unstable, we can resume without +# redoing completed work. +# -------------------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class BatchResult: + """ + One record for one batch attempt. + + Stored as JSONL for two reasons: + 1) resume: a successful batch is a stable checkpoint (skip on rerun) + 2) audit: gives an append-only provenance trail of what was validated + + We record both configuration (strict/sort/engine) and outcomes (rows, datetime bounds), + because "it passed" is not meaningful unless we can tie it to the exact validation mode. + """ + + run_id: str + batch_id: str + batch_index: int + batch_size: int + n_files: int + first_path: str + last_path: str + started_at_utc: str + finished_at_utc: str + elapsed_sec: float + + strict: bool + sort_output: bool + engine: str + infer_schema_length: int + + # Validation outputs + long_rows: int | None + long_rows_mod_48: int | None + min_datetime: str | None + max_datetime: str | None + any_null_datetime: bool | None + schema_fingerprint: str | None + + ok: bool + error_type: str | None + error_message: str | None + + +def _utc_now_iso() -> str: + """ + Return a stable, timezone-aware UTC timestamp for logs. + + Why: datetime.utcnow() is deprecated in newer Python versions and produces naive + datetimes. We intentionally write Zulu timestamps into JSONL to keep log output + consistent across environments and to avoid local-time ambiguity. + """ + return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z") + + +def _read_paths(list_path: str) -> list[str]: + """ + Read an input manifest of CSV paths (S3 or local), one path per line. + + Why: + - We want the driver/orchestrator to own discovery. This tool should be deterministic + and replayable: given a manifest, it validates exactly that set of files. + - We support comment lines (# ...) to allow simple manifest curation. + """ + p = Path(list_path) + if not p.exists(): + raise FileNotFoundError(f"Input list file not found: {list_path}") + + paths: list[str] = [] + for line in p.read_text().splitlines(): + s = line.strip() + if not s or s.startswith("#"): + continue + paths.append(s) + + if not paths: + raise ValueError(f"No usable paths found in list file: {list_path}") + + return paths + + +def _chunk_paths(paths: Sequence[str], batch_size: int) -> list[list[str]]: + """ + Partition the manifest into batches of bounded size. + + Why: + - Memory and swap pressure is the primary failure mode in Docker Desktop devcontainers. + Batching is the simplest, most reliable control for peak memory use. + - We prefer deterministic partitioning (simple slicing) so resume behavior is stable: + batch_00000 always contains the same files for the same manifest and batch_size. + """ + if batch_size <= 0: + raise ValueError(f"batch_size must be positive; got {batch_size}") + return [list(paths[i : i + batch_size]) for i in range(0, len(paths), batch_size)] + + +def _schema_fingerprint(schema: pl.Schema) -> str: + """ + Compute a stable fingerprint of the output schema (name + dtype). + + Why: + - When validating at scale, we want a compact way to detect drift across batches. + If one file has a surprising type coercion behavior, schema fingerprints will diverge. + - This fingerprint is not meant to be cryptographic security; SHA256 is convenient, + ubiquitous, and stable. + """ + pairs = [(name, str(dtype)) for name, dtype in schema.items()] + payload = json.dumps(pairs, sort_keys=False, separators=(",", ":")).encode("utf-8") + return hashlib.sha256(payload).hexdigest() + + +def _load_completed_batches(checkpoint_jsonl: Path) -> set[str]: + """ + Read JSONL checkpoints and return the set of batch_ids that completed successfully. + + Why: + - We treat successful batches as durable checkpoints, allowing safe resume after + a wedge/crash without reprocessing. + - We ignore malformed lines rather than failing the resume path; the checkpoint + file is append-only and may be truncated in a crash scenario. + """ + completed: set[str] = set() + if not checkpoint_jsonl.exists(): + return completed + + for line in checkpoint_jsonl.read_text().splitlines(): + line = line.strip() + if not line: + continue + try: + rec = json.loads(line) + except json.JSONDecodeError: + continue + if rec.get("ok") is True and isinstance(rec.get("batch_id"), str): + completed.add(rec["batch_id"]) + + return completed + + +def _append_jsonl(checkpoint_jsonl: Path, rec: dict[str, Any]) -> None: + """ + Append one JSON record to the JSONL checkpoint file. + + Why: + - Append-only writing is resilient: if a run is interrupted, prior records remain valid. + - JSONL is easy to inspect, grep, parse, and archive for audit trails. + """ + checkpoint_jsonl.parent.mkdir(parents=True, exist_ok=True) + with checkpoint_jsonl.open("a", encoding="utf-8") as f: + f.write(json.dumps(rec, sort_keys=True) + "\n") + + +def _validate_batch( + *, + run_id: str, + batch_id: str, + batch_index: int, + batch_paths: Sequence[str], + batch_size: int, + strict: bool, + sort_output: bool, + engine: str, + infer_schema_length: int, +) -> BatchResult: + """ + Validate a single batch of CSV files. + + Validation strategy: + - Build one LazyFrame scan over the batch. + - Apply wide_to_long (transform-only). + - Compute all validation metrics in one select + one collect to avoid multiple + passes over potentially large datasets. + + Why one collect: + - In Polars, separate .collect() calls can trigger separate executions. At scale, + multiple passes are expensive and can increase peak memory usage and I/O. + - We constrain validation to the minimum set of invariants that provide strong + correctness guarantees without needing to materialize the full long table. + + Note on sorting: + - sort_output is intentionally configurable. Month-scale semantic validation uses + sort_output=False to avoid global sorts that can wedge Docker Desktop. + - Deterministic ordering can be validated separately on bounded samples with + sort_output=True. + """ + t0 = time.time() + started_iso = _utc_now_iso() + + # Batching against resource exhaustion. + wide_lf = pl.scan_csv(batch_paths, infer_schema_length=infer_schema_length) + out_lf = transform_wide_to_long_lf(lf=wide_lf, strict=strict, sort_output=sort_output) + + # Single-pass metrics collection (streaming engine recommended). + metrics_df = out_lf.select([ + pl.len().alias("long_rows"), + (pl.len() % 48).alias("long_rows_mod_48"), + pl.col("datetime").min().alias("min_datetime"), + pl.col("datetime").max().alias("max_datetime"), + pl.col("datetime").is_null().any().alias("any_null_datetime"), + ]).collect(engine=engine) + + long_rows = int(metrics_df["long_rows"][0]) + long_rows_mod_48 = int(metrics_df["long_rows_mod_48"][0]) + mn = metrics_df["min_datetime"][0] + mx = metrics_df["max_datetime"][0] + any_null = bool(metrics_df["any_null_datetime"][0]) + + mn_s = mn.isoformat() if mn is not None else None + mx_s = mx.isoformat() if mx is not None else None + + # Schema fingerprint is cheap (no data scan); it uses the logical schema post-transform. + schema_fp = _schema_fingerprint(out_lf.collect_schema()) + + # Invariants (fail-loud): + # + # These checks are intentionally chosen because they have high diagnostic value: + # - long_rows % 48 == 0 catches interval count issues (missing/extra intervals). + # - min/max datetime validate the core datetime semantics and DST folding behavior. + # - null datetime indicates parsing or datetime math failures (must never happen in strict mode). + if long_rows == 0: + raise ValueError("Batch produced 0 long rows (unexpected).") + + if long_rows_mod_48 != 0: + raise ValueError(f"Batch long_rows not divisible by 48: long_rows={long_rows} mod_48={long_rows_mod_48}") + + if any_null: + raise ValueError("Batch contains null datetime values.") + + if mn is None or mx is None: + raise ValueError("Batch min/max datetime is null (unexpected).") + + if (mn.hour, mn.minute) != (0, 0): + raise ValueError(f"Batch min datetime not at 00:00: {mn!r}") + + if (mx.hour, mx.minute) != (23, 30): + raise ValueError(f"Batch max datetime not at 23:30: {mx!r}") + + finished_iso = _utc_now_iso() + elapsed = time.time() - t0 + + return BatchResult( + run_id=run_id, + batch_id=batch_id, + batch_index=batch_index, + batch_size=batch_size, + n_files=len(batch_paths), + first_path=batch_paths[0], + last_path=batch_paths[-1], + started_at_utc=started_iso, + finished_at_utc=finished_iso, + elapsed_sec=elapsed, + strict=strict, + sort_output=sort_output, + engine=engine, + infer_schema_length=infer_schema_length, + long_rows=long_rows, + long_rows_mod_48=long_rows_mod_48, + min_datetime=mn_s, + max_datetime=mx_s, + any_null_datetime=any_null, + schema_fingerprint=schema_fp, + ok=True, + error_type=None, + error_message=None, + ) + + +def main(argv: Sequence[str] | None = None) -> int: + """ + CLI entrypoint. + + This tool is intentionally a validator rather than a writer: + - it proves correctness at scale without entangling file output concerns + - it is safe to run repeatedly (idempotent resume semantics) + - it produces a stable, append-only audit log (JSONL checkpoints) + """ + ap = argparse.ArgumentParser(description="Batched wide_to_long validator with JSONL checkpoints (resumable).") + + ap.add_argument("--input-list", required=True, help="Text file of CSV paths (one per line).") + ap.add_argument("--batch-size", type=int, default=25, help="Files per batch (e.g., 25, 10).") + ap.add_argument( + "--out-dir", + default="/workspaces/smart-meter-analysis/data/validation", + help="Directory for checkpoints/logs.", + ) + ap.add_argument("--run-id", default=None, help="Run identifier (default: timestamp-based).") + ap.add_argument( + "--resume", + action="store_true", + help="Skip batches already marked ok in checkpoints.jsonl.", + ) + + ap.add_argument("--strict", action="store_true", help="Enable strict wide_to_long validations.") + ap.add_argument("--no-strict", dest="strict", action="store_false", help="Disable strict mode.") + ap.set_defaults(strict=True) + + ap.add_argument( + "--sort-output", + action="store_true", + help="Enable global sort inside wide_to_long (use for determinism checks on bounded samples).", + ) + ap.add_argument( + "--no-sort-output", + dest="sort_output", + action="store_false", + help="Disable global sort inside wide_to_long (recommended for month-scale semantic validation).", + ) + ap.set_defaults(sort_output=False) + + ap.add_argument( + "--engine", + default="streaming", + choices=["streaming", "in_memory"], + help="Polars collect engine.", + ) + ap.add_argument( + "--infer-schema-length", + type=int, + default=0, + help="Polars scan_csv infer_schema_length (0 = full scan of types).", + ) + + ap.add_argument("--max-batches", type=int, default=None, help="Process at most this many batches.") + ap.add_argument("--start-batch", type=int, default=0, help="Start at this batch index (0-based).") + ap.add_argument( + "--continue-on-error", + action="store_true", + help="Log failure and continue to next batch.", + ) + + args = ap.parse_args(list(argv) if argv is not None else None) + + paths = _read_paths(args.input_list) + batches = _chunk_paths(paths, args.batch_size) + + run_id = args.run_id or datetime.now(timezone.utc).strftime("wide_to_long_validate_%Y%m%dT%H%M%SZ") + out_dir = Path(args.out_dir) / run_id + checkpoint_jsonl = out_dir / "checkpoints.jsonl" + + completed = _load_completed_batches(checkpoint_jsonl) if args.resume else set() + + total_files = len(paths) + print(f"run_id={run_id}") + print(f"input_list={args.input_list}") + print(f"total_files={total_files}") + print(f"batch_size={args.batch_size}") + print(f"n_batches={len(batches)}") + print(f"out_dir={out_dir}") + print(f"checkpoint_jsonl={checkpoint_jsonl}") + print( + f"strict={args.strict} sort_output={args.sort_output} engine={args.engine} infer_schema_length={args.infer_schema_length}" + ) + print(f"resume={args.resume} completed_batches={len(completed)}") + + out_dir.mkdir(parents=True, exist_ok=True) + + batches_ok = 0 + batches_failed = 0 + files_ok = 0 + files_failed = 0 + + for i in range(args.start_batch, len(batches)): + if args.max_batches is not None and (batches_ok + batches_failed) >= args.max_batches: + print(f"Reached --max-batches={args.max_batches}; stopping.") + break + + batch_paths = batches[i] + batch_id = f"batch_{i:05d}" + + # Resume behavior: skip batches already confirmed OK by prior runs. + if args.resume and batch_id in completed: + print(f"[SKIP] {batch_id} already ok (resume).") + continue + + print(f"[RUN ] {batch_id} n_files={len(batch_paths)} first={batch_paths[0]}") + + try: + res = _validate_batch( + run_id=run_id, + batch_id=batch_id, + batch_index=i, + batch_paths=batch_paths, + batch_size=args.batch_size, + strict=args.strict, + sort_output=args.sort_output, + engine=args.engine, + infer_schema_length=args.infer_schema_length, + ) + _append_jsonl(checkpoint_jsonl, asdict(res)) + batches_ok += 1 + files_ok += res.n_files + + print( + f"[OK ] {batch_id} files={res.n_files} files_ok={files_ok}/{total_files} " + f"long_rows={res.long_rows} min={res.min_datetime} max={res.max_datetime} " + f"elapsed_sec={res.elapsed_sec:.2f}" + ) + except Exception as e: + # On failure, we still checkpoint the error. This makes failures reproducible + # and supports later triage without rerunning the full month. + batches_failed += 1 + files_failed += len(batch_paths) + + rec: dict[str, Any] = { + "run_id": run_id, + "batch_id": batch_id, + "batch_index": i, + "batch_size": args.batch_size, + "n_files": len(batch_paths), + "first_path": batch_paths[0] if batch_paths else "", + "last_path": batch_paths[-1] if batch_paths else "", + "started_at_utc": None, + "finished_at_utc": _utc_now_iso(), + "elapsed_sec": None, + "strict": args.strict, + "sort_output": args.sort_output, + "engine": args.engine, + "infer_schema_length": args.infer_schema_length, + "ok": False, + "error_type": type(e).__name__, + "error_message": str(e), + } + _append_jsonl(checkpoint_jsonl, rec) + + print( + f"[FAIL] {batch_id} files={len(batch_paths)} files_failed={files_failed}/{total_files} " + f"error_type={type(e).__name__} error={e}" + ) + + if not args.continue_on_error: + print("Stopping on first failure (use --continue-on-error to proceed).") + print( + "Summary: " + f"batches_ok={batches_ok} batches_failed={batches_failed} " + f"files_ok={files_ok} files_failed={files_failed} total_files={total_files} " + f"checkpoint_jsonl={checkpoint_jsonl}" + ) + return 2 + + print( + "Done. " + f"batches_ok={batches_ok} batches_failed={batches_failed} " + f"files_ok={files_ok} files_failed={files_failed} total_files={total_files} " + f"checkpoint_jsonl={checkpoint_jsonl}" + ) + return 0 if batches_failed == 0 else 2 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/smart_meter_analysis/wide_to_long.py b/smart_meter_analysis/wide_to_long.py new file mode 100644 index 0000000..c979e55 --- /dev/null +++ b/smart_meter_analysis/wide_to_long.py @@ -0,0 +1,514 @@ +from __future__ import annotations + +import re +from collections.abc import Iterable, Sequence +from dataclasses import dataclass + +import polars as pl + +__all__ = ["IntervalColSpec", "transform_wide_to_long", "transform_wide_to_long_lf"] + +# ------------------------------------------------------------------------------------------------- +# Zip4 wide → long canonicalization (ComEd smart-meter interval data) +# +# Context: +# - Source data arrives as one row per account per day with 48 interval "end time" columns +# (0030..2400). DST anomalies may appear as extra end-time columns (2430, 2500). +# - Downstream clustering and regression require a canonical "long" representation: +# one row per account per interval with a true interval START timestamp. +# +# Primary design goals: +# 1) Fail-loud contract enforcement in strict mode (regulatory defensibility / auditability). +# 2) Stable, canonical output schema (order + dtypes) independent of input inference quirks. +# 3) DST policy is: fold extras into their base intervals with null-preserving semantics. +# 4) Determinism is required for partitioned Parquet writing, but global sorts are operationally +# expensive at month scale. We therefore gate sorting behind sort_output. +# +# Operational note: +# - For month-scale validation in constrained environments (e.g., Docker devcontainer), +# prefer sort_output=False and validate determinism separately on bounded samples or +# in a higher-memory runtime. +# ------------------------------------------------------------------------------------------------- + + +# Exact header match only (no IGNORECASE). +_INTERVAL_COL_RE = re.compile(r"^INTERVAL_HR(?P\d{4})_ENERGY_QTY$") + +# Standard 48 end-times: 0030...2400 (0000 absent) at 30-min cadence. +# Expressed as minutes since midnight for simple set arithmetic. +_STANDARD_END_MINUTES: set[int] = set(range(30, 1441, 30)) + +# DST extras appear as end-times 24:30 and 25:00 (minutes 1470, 1500). +_DST_EXTRA_END_MINUTES: set[int] = {1470, 1500} # 2430, 2500 + +# DST fold-in map +_DST_FOLD_MAP = { + "INTERVAL_HR2430_ENERGY_QTY": "INTERVAL_HR2330_ENERGY_QTY", + "INTERVAL_HR2500_ENERGY_QTY": "INTERVAL_HR2400_ENERGY_QTY", +} + +# Used historically for dtype enforcement; kept as an immutable set for defensive checks if needed. +_INTEGER_DTYPES = frozenset({ + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, +}) + + +@dataclass(frozen=True) +class IntervalColSpec: + """ + Parsed metadata for one interval column. + + Why keep this structure: + - We want a durable, inspectable representation of the interval headers rather than + relying on ad hoc string slicing scattered across the transform. + - Using end_minutes/start_minutes (not just HHMM) makes validation and datetime math simpler. + """ + + colname: str + hhmm: int + end_minutes: int + start_minutes: int + + +def _format_list_preview(x: list, max_items: int = 10) -> str: + """ + Format a bounded preview of a python list for error messages. + + Why: + - When strict validation fails, we want diagnostic information without dumping huge values. + - This keeps exceptions readable in CI logs and in interactive debugging sessions. + """ + if len(x) <= max_items: + return str(x) + return str(x[:max_items])[:-1] + f", ...] (n={len(x)})" + + +def _require_columns_from_names(schema_names: set[str], required: Iterable[str]) -> None: + """ + Fail-loud if required wide columns are missing. + + Why: + - Missing required columns almost always indicates upstream schema drift, not something + we should guess at or attempt to "repair" inside the transform. + """ + missing = [c for c in required if c not in schema_names] + if missing: + raise ValueError(f"Missing required columns: {missing}") + + +def _enforce_total_columns_59_from_names(schema_names_in_order: Sequence[str]) -> None: + """ + Enforce the "59 columns total" wide schema contract in strict mode. + + Why: + - The source format is treated as a contract. Silent acceptance of schema changes + is a common root cause of hard-to-debug downstream failures. + - This is deliberately an exact count check, not "at least" or "contains". + """ + if len(schema_names_in_order) != 59: + cols = list(schema_names_in_order) + raise ValueError( + "Contract violation: expected exactly 59 columns in the wide CSV schema.\n" + f"- observed_n_columns={len(cols)}\n" + f"- first_20_columns={cols[:20]}\n" + ) + + +def _parse_interval_specs_from_columns(columns: Sequence[str]) -> list[IntervalColSpec]: + """ + Parse all interval columns (including possible DST extras) from exact headers. + + Why: + - Header parsing is the authoritative place to enforce interval-label invariants. + - We do not attempt to normalize invalid HHMMs; invalid headers are treated as data contract + violations and should fail loudly. + """ + specs: list[IntervalColSpec] = [] + for c in columns: + m = _INTERVAL_COL_RE.match(c) + if not m: + continue + + hhmm = int(m.group("hhmm")) + hh = hhmm // 100 + mm = hhmm % 100 + + # Reject invalid minutes and invalid hour range. + if mm not in (0, 30): + raise ValueError(f"Contract violation: invalid interval column minutes (expected 00/30).\n- column={c}\n") + if hh < 0 or hh > 25: + raise ValueError(f"Contract violation: invalid interval column hour (expected 00..25).\n- column={c}\n") + + # Authoritative 0000 rejection (locked contract). + # Why: the dataset is defined in terms of end-times 0030..2400; 0000 must not appear. + if hhmm == 0: + raise ValueError( + f"Contract violation: found an interval ending at 0000 (HHMM=0000). Do not guess.\n- column={c}\n" + ) + + end_minutes = 60 * hh + mm + start_minutes = end_minutes - 30 + + # Defensive redundancy. This should be unreachable given the explicit 0000 rejection, + # but it provides additional protection against malformed headers. + if start_minutes < 0: + raise ValueError( + "Contract violation: interval implies start_minutes < 0 (likely 0000), which must not exist.\n" + f"- column={c}\n" + ) + + specs.append( + IntervalColSpec( + colname=c, + hhmm=hhmm, + end_minutes=end_minutes, + start_minutes=start_minutes, + ) + ) + + specs.sort(key=lambda s: (s.end_minutes, s.colname)) + return specs + + +def _validate_interval_set(interval_specs: Sequence[IntervalColSpec], *, strict: bool) -> None: + """ + Validate the observed end-time set against locked contract requirements. + + Why: + - Most downstream correctness depends on having exactly the expected interval grid. + Missing or unexpected interval headers cannot be fixed reliably downstream. + - We validate in terms of end-minutes, which is robust to header ordering and avoids + string-based comparisons. + """ + observed_end = {s.end_minutes for s in interval_specs} + + if strict: + missing_standard = sorted(_STANDARD_END_MINUTES - observed_end) + if missing_standard: + raise ValueError( + "Interval HHMM set missing standard end-minutes (expected 48 columns).\n" + f"- missing_end_minutes={missing_standard}\n" + ) + + allowed = set(_STANDARD_END_MINUTES) | set(_DST_EXTRA_END_MINUTES) + unexpected = sorted(observed_end - allowed) + if unexpected: + raise ValueError( + "Interval HHMM set has unexpected end-minutes (allowed extras only 1470/1500).\n" + f"- unexpected_end_minutes={unexpected}\n" + ) + + observed_standard = observed_end & _STANDARD_END_MINUTES + if len(observed_standard) != 48: + raise ValueError( + "Contract violation: standard interval end-times are not exactly 48 distinct values.\n" + f"- observed_standard_count={len(observed_standard)}\n" + ) + + +def _validate_interval_length_1800_lf(lf: pl.LazyFrame) -> None: + """ + Contract check: INTERVAL_LENGTH must represent constant 1800 seconds (fail-loud). + + Why keep this validation even though INTERVAL_LENGTH is dropped from output: + - It's assumed in the datetime semantics. If intervals aren't 30 minutes, + the entire long representation becomes invalid. + - In practice, S3 scans often infer INTERVAL_LENGTH as String. The contract is the *value*, + not the storage dtype, so we accept either as long as it parses to 1800 everywhere. + """ + col = "INTERVAL_LENGTH" + schema = lf.collect_schema() + if col not in schema.names(): + raise ValueError("Missing required column: INTERVAL_LENGTH") + + # Accept either integer-typed or string-typed inputs; enforce that the value parses to 1800. + # We use strict=False cast to avoid blowing up on benign string representations like "1800". + il_int = pl.col(col).cast(pl.Int32, strict=False) + invalid = il_int.is_null() | (il_int != pl.lit(1800, dtype=pl.Int32)) + + any_invalid = bool(lf.select(invalid.any()).collect().item()) + if not any_invalid: + return + + # Provide a small raw sample of offending values for debugging (bounded to 30). + bad_vals = lf.filter(invalid).select(pl.col(col).cast(pl.Utf8).unique().head(30)).collect().to_series().to_list() + raise ValueError( + "INTERVAL_LENGTH contract violation: values must be 1800 seconds (string or integer accepted).\n" + f"- raw_values_sample={_format_list_preview(bad_vals, max_items=30)}\n" + ) + + +def _validate_reading_date_parses_strict_lf(lf: pl.LazyFrame, *, colname: str) -> None: + """ + Locked contract: INTERVAL_READING_DATE parses with %m/%d/%Y only (fail-loud). + + Why: + - Date parsing ambiguity is a classic source of silent data corruption (e.g., DD/MM vs MM/DD). + - We do not accept "best effort" parsing here; strictness is intentional. + """ + parsed = pl.col(colname).cast(pl.Utf8).str.strptime(pl.Date, format="%m/%d/%Y", strict=False) + bad_mask = parsed.is_null() & pl.col(colname).is_not_null() + any_bad = bool(lf.select(bad_mask.any()).collect().item()) + if not any_bad: + return + + bad_vals = lf.filter(bad_mask).select(pl.col(colname).unique().head(30)).collect().to_series().to_list() + raise ValueError( + f"Failed to parse {colname} into Date for some rows using %m/%d/%Y.\n" + f"- raw_values_failed_parse_sample={_format_list_preview(bad_vals, max_items=30)}\n" + ) + + +def _fold_in_preserve_nulls(base: pl.Expr, extra: pl.Expr) -> pl.Expr: + """ + Policy: + - HR2330 := HR2330 + HR2430 + - HR2400 := HR2400 + HR2500 + - Drop extras after fold-in. + + Null semantics are important: + - If both base and extra are null, output must remain null (unknown). + - Otherwise treat null as 0.0 for summation. + + Why: + - This reflects how DST extras behave operationally: an extra interval is additive if present, + but we must not turn a fully-missing pair into a synthetic 0. + """ + base_f = base.cast(pl.Float64, strict=False) + extra_f = extra.cast(pl.Float64, strict=False) + return ( + pl.when(base_f.is_null() & extra_f.is_null()) + .then(pl.lit(None, dtype=pl.Float64)) + .otherwise(base_f.fill_null(0.0) + extra_f.fill_null(0.0)) + ) + + +def transform_wide_to_long_lf( + lf: pl.LazyFrame, + *, + strict: bool = True, + sort_output: bool = True, +) -> pl.LazyFrame: + """ + Wide CSV -> Long (canonical) LazyFrame transform (transform-only; no writing). + + This function is intentionally "pure transform": + - It does not read/write files directly. + - It does not manage batching. + - It does not choose execution resources. + Those concerns belong to the driver/orchestrator layer. + + Determinism: + - sort_output=True enforces deterministic global ordering on + (zip_code, account_identifier, datetime). + - Month-scale validation in constrained environments should typically use + sort_output=False and validate determinism separately on bounded samples. + + Final output schema (exact order + dtypes): + 1) zip_code: Utf8 + 2) delivery_service_class: Categorical + 3) delivery_service_name: Categorical + 4) account_identifier: Utf8 + 5) datetime: Datetime(us) + 6) energy_kwh: Float64 + 7) plc_value: Float64 + 8) nspl_value: Float64 + 9) year: Int32 + 10) month: Int8 + """ + required = [ + "ZIP_CODE", + "DELIVERY_SERVICE_CLASS", + "DELIVERY_SERVICE_NAME", + "ACCOUNT_IDENTIFIER", + "INTERVAL_READING_DATE", + "INTERVAL_LENGTH", + "PLC_VALUE", + "NSPL_VALUE", + ] + + # Collecting schema is metadata-only and does not scan data. We use it to make + # validation decisions without triggering a full execution. + schema = lf.collect_schema() + schema_cols_in_order = schema.names() + schema_names = set(schema_cols_in_order) + + _require_columns_from_names(schema_names, required) + + if strict: + _enforce_total_columns_59_from_names(schema_cols_in_order) + + interval_specs_all = _parse_interval_specs_from_columns(schema_cols_in_order) + if not interval_specs_all: + raise ValueError("Contract violation: no interval columns found matching ^INTERVAL_HR\\d{4}_ENERGY_QTY$.\n") + + _validate_interval_set(interval_specs_all, strict=strict) + + if strict: + _validate_interval_length_1800_lf(lf) + _validate_reading_date_parses_strict_lf(lf, colname="INTERVAL_READING_DATE") + + # Derive the canonical "standard" interval columns from the observed schema. We do not + # hardcode the headers to avoid dependence on input ordering; strict mode ensures the set. + standard_specs = [s for s in interval_specs_all if s.end_minutes in _STANDARD_END_MINUTES] + standard_cols = [s.colname for s in standard_specs] + + if strict and len(standard_cols) != 48: + raise ValueError( + "Contract violation: expected exactly 48 standard interval columns.\n" + f"- observed_n_standard_cols={len(standard_cols)}\n" + ) + + # Fail-loud if fold targets missing. We must always have base columns 2330 and 2400 + # since DST fold-in adds into them. + if "INTERVAL_HR2330_ENERGY_QTY" not in schema_names or "INTERVAL_HR2400_ENERGY_QTY" not in schema_names: + raise ValueError("Contract violation: missing required standard columns HR2330 or HR2400.\n") + + # Parse date as Date (not Datetime) first; this keeps semantics explicit and avoids + # timezone ambiguity. We later cast to Datetime(us) for interval math. + reading_date_expr = ( + pl.col("INTERVAL_READING_DATE") + .cast(pl.Utf8) + .str.strptime(pl.Date, format="%m/%d/%Y", strict=True) + .alias("interval_reading_date") + ) + + # Project early to minimize memory pressure: + # - keep only identifier columns + PLC/NSPL + reading_date + interval columns + # - drop filler columns and any other wide fields not needed for the canonical long output + dst_extra_cols = [extra for extra in _DST_FOLD_MAP if extra in schema_names] + wide = lf.select([ + pl.col("ZIP_CODE").cast(pl.Utf8).alias("zip_code"), + pl.col("DELIVERY_SERVICE_CLASS").cast(pl.Categorical).alias("delivery_service_class"), + pl.col("DELIVERY_SERVICE_NAME").cast(pl.Categorical).alias("delivery_service_name"), + pl.col("ACCOUNT_IDENTIFIER").cast(pl.Utf8).alias("account_identifier"), + pl.col("PLC_VALUE").cast(pl.Float64, strict=False).alias("plc_value"), + pl.col("NSPL_VALUE").cast(pl.Float64, strict=False).alias("nspl_value"), + reading_date_expr, + *[pl.col(c).cast(pl.Float64, strict=False) for c in standard_cols], + *[pl.col(c).cast(pl.Float64, strict=False) for c in dst_extra_cols], + ]) + + # Apply DST Option B fold-in via mapping, then drop the extra columns. + fold_exprs: list[pl.Expr] = [] + for extra_col, base_col in _DST_FOLD_MAP.items(): + if extra_col in schema_names: + fold_exprs.append(_fold_in_preserve_nulls(pl.col(base_col), pl.col(extra_col)).alias(base_col)) + if fold_exprs: + wide = wide.with_columns(fold_exprs) + + if dst_extra_cols: + wide = wide.drop(dst_extra_cols) + + # id_vars define the "identity" columns that are repeated for each unpivoted interval. + # interval_reading_date is kept only until we compute datetime; it is not part of final output. + id_vars = [ + "zip_code", + "delivery_service_class", + "delivery_service_name", + "account_identifier", + "plc_value", + "nspl_value", + "interval_reading_date", + ] + + # Unpivot produces one row per (id_vars, interval_col). We immediately cast energy_kwh + # to Float64 to enforce canonical dtype regardless of upstream inference. + long = wide.unpivot( + index=id_vars, + on=standard_cols, + variable_name="interval_col", + value_name="energy_kwh", + ).with_columns(pl.col("energy_kwh").cast(pl.Float64, strict=False)) + + # Extract end-time HHMM from the interval column label. This is intentionally strict: + # interval headers are part of the upstream contract; if they don't match, we should fail. + long = long.with_columns( + pl.col("interval_col") + .str.extract(r"^INTERVAL_HR(\d{4})_ENERGY_QTY$", 1) + .cast(pl.Int32, strict=True) + .alias("hhmm") + ).with_columns((((pl.col("hhmm") // 100) * 60) + (pl.col("hhmm") % 100)).alias("end_minutes")) + + if strict: + # After unpivot, ensure only standard end-times remain. + allowed = sorted(_STANDARD_END_MINUTES) + any_bad_end = bool(long.select((~pl.col("end_minutes").is_in(allowed)).any()).collect().item()) + if any_bad_end: + bad_cols = ( + long.filter(~pl.col("end_minutes").is_in(allowed)) + .select(pl.col("interval_col").unique().head(30)) + .collect() + .to_series() + .to_list() + ) + raise ValueError( + "Contract violation: unexpected interval columns appeared after unpivot.\n" + f"- unexpected_interval_cols_sample={_format_list_preview(bad_cols, max_items=30)}\n" + ) + + # datetime = interval START time: + # - Input labels are end-times (e.g., HR0030 ends at 00:30). + # - We subtract 30 minutes to get the interval start. + # - HR2400 therefore maps to 23:30 same day (no rollover), matching the locked semantics. + long = long.with_columns( + ( + pl.col("interval_reading_date").cast(pl.Datetime("us")) + + pl.duration(minutes=pl.col("end_minutes").cast(pl.Int64) - pl.lit(30)) + ).alias("datetime") + ) + + # Derived partition columns. These must come from datetime (not from INTERVAL_READING_DATE), + # because datetime semantics are the canonical time representation. + long = long.with_columns([ + pl.col("datetime").dt.year().cast(pl.Int32).alias("year"), + pl.col("datetime").dt.month().cast(pl.Int8).alias("month"), + ]) + + # Drop helper columns promptly to reduce downstream memory footprint. + long = long.drop(["interval_col", "hhmm", "end_minutes"]) + + if sort_output: + # Sorting is intentionally optional: + # - required for deterministic output in write paths + # - avoided in month-scale validation in constrained environments + long = long.sort(["zip_code", "account_identifier", "datetime"]) + + # Authoritative final projection: + # - enforces schema order and dtypes + # - ensures interval_reading_date is not in the final output + return long.select([ + pl.col("zip_code").cast(pl.Utf8), + pl.col("delivery_service_class").cast(pl.Categorical), + pl.col("delivery_service_name").cast(pl.Categorical), + pl.col("account_identifier").cast(pl.Utf8), + pl.col("datetime").cast(pl.Datetime("us")), + pl.col("energy_kwh").cast(pl.Float64, strict=False), + pl.col("plc_value").cast(pl.Float64, strict=False), + pl.col("nspl_value").cast(pl.Float64, strict=False), + pl.col("year").cast(pl.Int32), + pl.col("month").cast(pl.Int8), + ]) + + +def transform_wide_to_long( + df: pl.DataFrame, + *, + strict: bool = True, + sort_output: bool = True, +) -> pl.DataFrame: + """ + Backward-compatible DataFrame API wrapper. + + Why keep this: + - Some call sites prefer an eager DataFrame API (e.g., unit tests, small local files). + - We keep the LazyFrame transform as the source of truth and collect at the boundary. + """ + return transform_wide_to_long_lf(df.lazy(), strict=strict, sort_output=sort_output).collect() diff --git a/tests/test_compact_month_output.py b/tests/test_compact_month_output.py new file mode 100644 index 0000000..892b976 --- /dev/null +++ b/tests/test_compact_month_output.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +"""Tests for _stream_write_chunks and related constants in compact_month_output.py. + +Covers: +- ROWS_PER_ROW_GROUP constant value +- compaction_plan.json contains rows_per_row_group and estimated_n_output_files keys +- _stream_write_chunks: single small batch -> one file +- _stream_write_chunks: multi-row-group packing into a single file +- _stream_write_chunks: file rollover when target_size_bytes is reached +- _stream_write_chunks: total row count preserved across all output files +- _stream_write_chunks: sort order within a single output file +- _stream_write_chunks: sort order preserved across rolled-over files +""" + +from __future__ import annotations + +import datetime as dt +import json +import sys +from pathlib import Path +from typing import Any + +import polars as pl +import pyarrow.parquet as pq +import pytest + +# The scripts/ directory is not an installed package; add project root to path +# so pytest can resolve the namespace package on both local and CI runs. +sys.path.insert(0, str(Path(__file__).parents[1])) + +from scripts.csv_to_parquet.compact_month_output import ( + DEFAULT_COMPACT_TARGET_SIZE_BYTES, + ROWS_PER_ROW_GROUP, + SORT_KEYS, + CompactionConfig, + _stream_write_chunks, + run_compaction, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +YEAR_MONTH = "202307" +RUN_ID = "test_run_001" + + +def _canonical_df(n_rows: int = 20) -> pl.DataFrame: + """Return a minimal DataFrame matching the canonical 10-column schema.""" + base_ts = dt.datetime(2023, 7, 1, 0, 0, 0) + return pl.DataFrame({ + "zip_code": ["60601"] * n_rows, + "delivery_service_class": ["DS1"] * n_rows, + "delivery_service_name": ["Residential"] * n_rows, + "account_identifier": [f"ACCT{i:04d}" for i in range(n_rows)], + "datetime": [base_ts + dt.timedelta(hours=i) for i in range(n_rows)], + "energy_kwh": [float(i) * 0.5 for i in range(n_rows)], + "plc_value": [0.0] * n_rows, + "nspl_value": [0.0] * n_rows, + "year": [2023] * n_rows, + "month": [7] * n_rows, + }).with_columns( + pl.col("zip_code").cast(pl.Utf8), + pl.col("delivery_service_class").cast(pl.Categorical), + pl.col("delivery_service_name").cast(pl.Categorical), + pl.col("account_identifier").cast(pl.Utf8), + pl.col("datetime").cast(pl.Datetime("us")), + pl.col("energy_kwh").cast(pl.Float64), + pl.col("plc_value").cast(pl.Float64), + pl.col("nspl_value").cast(pl.Float64), + pl.col("year").cast(pl.Int32), + pl.col("month").cast(pl.Int8), + ) + + +def _write_batch(df: pl.DataFrame, month_dir: Path, name: str = "batch_0000.parquet") -> Path: + month_dir.mkdir(parents=True, exist_ok=True) + path = month_dir / name + df.write_parquet(path) + return path + + +def _make_cfg(tmp_path: Path, target_size_bytes: int, dry_run: bool = True) -> CompactionConfig: + return CompactionConfig( + year_month=YEAR_MONTH, + run_id=RUN_ID, + out_root=tmp_path / "out", + run_dir=tmp_path / "_runs" / YEAR_MONTH / RUN_ID, + target_size_bytes=target_size_bytes, + max_files=None, + overwrite=False, + dry_run=dry_run, + no_swap=False, + ) + + +class _SilentLogger: + """No-op logger compatible with run_compaction's logger.log(dict) protocol.""" + + def log(self, event: dict[str, Any]) -> None: + pass + + +def _setup_month_dir(tmp_path: Path) -> Path: + """Write one batch_0000.parquet to the canonical month directory.""" + month_dir = tmp_path / "out" / "2023" / "07" + _write_batch(_canonical_df(), month_dir) + return month_dir + + +# --------------------------------------------------------------------------- +# Unit tests — constants and plan keys +# --------------------------------------------------------------------------- + + +def test_rows_per_row_group_constant() -> None: + """ROWS_PER_ROW_GROUP must be exactly 50_000_000.""" + assert ROWS_PER_ROW_GROUP == 50_000_000 + + +def test_plan_contains_all_new_keys(tmp_path: pytest.TempPathFactory) -> None: + """compaction_plan.json must contain rows_per_row_group and estimated_n_output_files.""" + _setup_month_dir(tmp_path) + cfg = _make_cfg(tmp_path, target_size_bytes=DEFAULT_COMPACT_TARGET_SIZE_BYTES) + run_compaction(cfg, _SilentLogger()) + + plan_path = tmp_path / "_runs" / YEAR_MONTH / RUN_ID / "compaction" / "compaction_plan.json" + plan = json.loads(plan_path.read_text()) + + assert "rows_per_row_group" in plan, "Missing key 'rows_per_row_group' in compaction_plan.json" + assert "estimated_n_output_files" in plan, "Missing key 'estimated_n_output_files' in compaction_plan.json" + assert plan["rows_per_row_group"] == ROWS_PER_ROW_GROUP + assert plan["estimated_n_output_files"] >= 1 + + +# --------------------------------------------------------------------------- +# Unit tests — _stream_write_chunks directly +# --------------------------------------------------------------------------- + + +def test_single_small_batch_one_file(tmp_path: Path) -> None: + """5 rows, rows_per_row_group=10, huge target -> 1 output file with 1 row group.""" + staging = tmp_path / "staging" + batch_path = _write_batch(_canonical_df(5), tmp_path / "input", "batch_0000.parquet") + + output_files = _stream_write_chunks( + sorted_input_files=[batch_path], + staging_month_dir=staging, + rows_per_row_group=10, + target_size_bytes=10 * 1024**3, # 10 GiB — never triggers rollover + max_files=None, + logger=_SilentLogger(), + log_ctx={}, + ) + + assert len(output_files) == 1 + meta = pq.read_metadata(str(output_files[0])) + assert meta.num_row_groups == 1 + assert meta.num_rows == 5 + + +def test_multi_row_group_single_file(tmp_path: Path) -> None: + """25 rows, rows_per_row_group=10, huge target -> 1 file with 3 row groups (10+10+5).""" + staging = tmp_path / "staging" + batch_path = _write_batch(_canonical_df(25), tmp_path / "input", "batch_0000.parquet") + + output_files = _stream_write_chunks( + sorted_input_files=[batch_path], + staging_month_dir=staging, + rows_per_row_group=10, + target_size_bytes=10 * 1024**3, # 10 GiB — never triggers rollover + max_files=None, + logger=_SilentLogger(), + log_ctx={}, + ) + + assert len(output_files) == 1 + meta = pq.read_metadata(str(output_files[0])) + assert meta.num_row_groups == 3 + assert meta.num_rows == 25 + + +def test_file_rollover_at_target_size(tmp_path: Path) -> None: + """30 rows, rows_per_row_group=10, target_size_bytes=1 -> rollover after each rg -> 3 files.""" + staging = tmp_path / "staging" + batch_path = _write_batch(_canonical_df(30), tmp_path / "input", "batch_0000.parquet") + + output_files = _stream_write_chunks( + sorted_input_files=[batch_path], + staging_month_dir=staging, + rows_per_row_group=10, + target_size_bytes=1, # always triggers rollover + max_files=None, + logger=_SilentLogger(), + log_ctx={}, + ) + + assert len(output_files) == 3 + for f in output_files: + meta = pq.read_metadata(str(f)) + assert meta.num_row_groups == 1 + assert meta.num_rows == 10 + + +def test_row_count_preserved(tmp_path: Path) -> None: + """2 batch files of 15 rows each -> total 30 rows across all output files.""" + staging = tmp_path / "staging" + input_dir = tmp_path / "input" + b0 = _write_batch(_canonical_df(15), input_dir, "batch_0000.parquet") + # Second batch: distinct account_identifiers to avoid duplicate key issues + df2 = _canonical_df(15).with_columns(pl.Series("account_identifier", [f"ACCT{i:04d}" for i in range(15, 30)])) + b1 = _write_batch(df2, input_dir, "batch_0001.parquet") + + output_files = _stream_write_chunks( + sorted_input_files=[b0, b1], + staging_month_dir=staging, + rows_per_row_group=10, + target_size_bytes=10 * 1024**3, + max_files=None, + logger=_SilentLogger(), + log_ctx={}, + ) + + total = sum(pq.read_metadata(str(f)).num_rows for f in output_files) + assert total == 30 + + +def test_sort_order_within_file(tmp_path: Path) -> None: + """Rows read back from a single output file are sorted by SORT_KEYS.""" + staging = tmp_path / "staging" + input_dir = tmp_path / "input" + b0 = _write_batch(_canonical_df(15), input_dir, "batch_0000.parquet") + df2 = _canonical_df(15).with_columns(pl.Series("account_identifier", [f"ACCT{i:04d}" for i in range(15, 30)])) + b1 = _write_batch(df2, input_dir, "batch_0001.parquet") + + output_files = _stream_write_chunks( + sorted_input_files=[b0, b1], + staging_month_dir=staging, + rows_per_row_group=10, + target_size_bytes=10 * 1024**3, + max_files=None, + logger=_SilentLogger(), + log_ctx={}, + ) + + for f in output_files: + df = pl.read_parquet(str(f)) + sort_keys = list(SORT_KEYS) + sorted_df = df.sort(sort_keys) + assert df.select(sort_keys).equals(sorted_df.select(sort_keys)), ( + f"Rows in {f.name} are not sorted by {sort_keys}" + ) + + +def test_sort_order_across_files(tmp_path: Path) -> None: + """Last sort key of file N < first sort key of file N+1 when rollover occurs.""" + staging = tmp_path / "staging" + input_dir = tmp_path / "input" + b0 = _write_batch(_canonical_df(15), input_dir, "batch_0000.parquet") + df2 = _canonical_df(15).with_columns(pl.Series("account_identifier", [f"ACCT{i:04d}" for i in range(15, 30)])) + b1 = _write_batch(df2, input_dir, "batch_0001.parquet") + + output_files = _stream_write_chunks( + sorted_input_files=[b0, b1], + staging_month_dir=staging, + rows_per_row_group=10, + target_size_bytes=1, # rollover after each row group -> multiple files + max_files=None, + logger=_SilentLogger(), + log_ctx={}, + ) + + assert len(output_files) > 1, "Expected multiple output files with target_size_bytes=1" + sort_keys = list(SORT_KEYS) + prev_last: tuple[Any, ...] | None = None + for f in output_files: + df = pl.read_parquet(str(f)) + first = tuple(df.select(sort_keys).row(0)) + last = tuple(df.select(sort_keys).row(-1)) + if prev_last is not None: + assert prev_last <= first, ( + f"Sort order broken across files: last of prev={prev_last} >= first of next={first}" + ) + prev_last = last diff --git a/uv.lock b/uv.lock index f299c53..9dd291c 100644 --- a/uv.lock +++ b/uv.lock @@ -716,6 +716,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/93/0dd45cd283c32dea1545151d8c3637b4b8c53cdb3a625aeb2885b184d74d/fonttools-4.60.1-py3-none-any.whl", hash = "sha256:906306ac7afe2156fcf0042173d6ebbb05416af70f6b370967b47f8f00103bbb", size = 1143175, upload-time = "2025-09-29T21:13:24.134Z" }, ] +[[package]] +name = "fsspec" +version = "2026.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/51/7c/f60c259dcbf4f0c47cc4ddb8f7720d2dcdc8888c8e5ad84c73ea4531cc5b/fsspec-2026.2.0.tar.gz", hash = "sha256:6544e34b16869f5aacd5b90bdf1a71acb37792ea3ddf6125ee69a22a53fb8bff", size = 313441, upload-time = "2026-02-05T21:50:53.743Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e6/ab/fb21f4c939bb440104cc2b396d3be1d9b7a9fd3c6c2a53d98c45b3d7c954/fsspec-2026.2.0-py3-none-any.whl", hash = "sha256:98de475b5cb3bd66bedd5c4679e87b4fdfe1a3bf4d707b151b3c07e58c9a2437", size = 202505, upload-time = "2026-02-05T21:50:51.819Z" }, +] + [[package]] name = "fuzzywuzzy" version = "0.18.0" @@ -2589,6 +2598,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/bd/4168a751ddbbf43e86544b4de8b5c3b7be8d7167a2a5cb977d274e04f0a1/ruff-0.14.4-py3-none-win_arm64.whl", hash = "sha256:dd09c292479596b0e6fec8cd95c65c3a6dc68e9ad17b8f2382130f87ff6a75bb", size = 12663065, upload-time = "2025-11-06T22:07:42.603Z" }, ] +[[package]] +name = "s3fs" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, + { name = "fsspec" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d9/9a/504cb277632c4d325beabbd03bb43778f0decb9be22d9e0e6c62f44540c7/s3fs-0.4.2.tar.gz", hash = "sha256:2ca5de8dc18ad7ad350c0bd01aef0406aa5d0fff78a561f0f710f9d9858abdd0", size = 57527, upload-time = "2020-03-31T15:24:26.388Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/e4/b8fc59248399d2482b39340ec9be4bb2493846ac23641b43115a7e5cd675/s3fs-0.4.2-py3-none-any.whl", hash = "sha256:91c1dfb45e5217bd441a7a560946fe865ced6225ff7eb0fb459fe6e601a95ed3", size = 19791, upload-time = "2020-03-31T15:24:24.952Z" }, +] + [[package]] name = "s3transfer" version = "0.14.0" @@ -2898,6 +2920,7 @@ dependencies = [ { name = "boto3" }, { name = "botocore" }, { name = "cenpy" }, + { name = "fsspec" }, { name = "ipykernel" }, { name = "matplotlib" }, { name = "memory-profiler" }, @@ -2908,6 +2931,7 @@ dependencies = [ { name = "pyarrow" }, { name = "pyyaml" }, { name = "requests" }, + { name = "s3fs" }, { name = "scikit-learn" }, { name = "seaborn" }, { name = "selenium" }, @@ -2938,6 +2962,7 @@ requires-dist = [ { name = "boto3", specifier = ">=1.40.46" }, { name = "botocore", specifier = ">=1.40.47" }, { name = "cenpy", specifier = ">=1.0.1" }, + { name = "fsspec", specifier = ">=2026.1.0" }, { name = "ipykernel", specifier = ">=6.30.1" }, { name = "matplotlib", specifier = ">=3.9.4" }, { name = "memory-profiler", specifier = ">=0.61.0" }, @@ -2947,6 +2972,7 @@ requires-dist = [ { name = "pyarrow", specifier = ">=14.0.0" }, { name = "pyyaml", specifier = ">=6.0.3" }, { name = "requests", specifier = ">=2.32.5" }, + { name = "s3fs", specifier = ">=0.4.2" }, { name = "scikit-learn", specifier = ">=1.6.1" }, { name = "seaborn", specifier = ">=0.13.2" }, { name = "selenium", specifier = ">=4.37.0" },