Full Source Code Repository

Complete production-grade code for the Bank Statement AI. All modules are fully synchronized with the 7-stage architecture.

Project Features:

  • Gemini 2.5 Flash: Semantic extraction with Batch-Continuity context.
  • Strict Schema Control: 100% valid JSON generation.
  • Smart Extractor: Auto-detects Text/Image and runs PaddleOCR fallback.
  • Premium Exporter: Multi-sheet Excel with merged headers and zebra styling.
  • Math Validator: Automatic debit/credit swap and balance correction.
  • LangChain LCEL: Advanced orchestration with traceable components and structured outputs.
main.py
"""
main.py
=======
Bank Statement to Tally-style Excel Orchestrator

This is the central entry point that coordinates the pipeline:
1. Detect document type (Text/Scanned)
2. Extract text (pdfplumber / PaddleOCR)
3. Parse transactions (LLM / Rule-based)
4. Clean & Normalize data
5. Validate balances & detect issues
6. Categorize transactions (Vch Type)
7. Export to Premium Excel
"""

import argparse
import logging
import sys
import time
from pathlib import Path
import json

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s  %(levelname)-8s  %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger("bank_pipeline")

def _load_env():
    """Load API keys from .env if available."""
    try:
        from dotenv import load_dotenv
        load_dotenv()
    except ImportError:
        pass

def run_pipeline(pdf_path: str, output_path: Path):
    """Orchestrate the parsing stages."""
    _load_env()
    
    # ── Import modules here (after env is loaded) ─────────────────────────
    from modules.extractor import extract
    from modules.llm_parser import parse_transactions_llm
    from modules.cleaner import clean_transactions
    from modules.validator import validate_transactions
    from modules.categorizer import categorize_transactions
    from modules.exporter import export_to_excel

    pdf_file = Path(pdf_path)
    if not pdf_file.exists():
        logger.error("❌ PDF not found: %s", pdf_path)
        return

    # ┌──────────────────────────────────────────────┐
    # │  Stage 1 & 2: Detection & Extraction         │
    # └──────────────────────────────────────────────┘
    t0 = time.perf_counter()
    raw_lines = extract(pdf_file)
    logger.info("⏱  Extraction done in %.2fs → %d lines extracted", time.perf_counter() - t0, len(raw_lines))

    # ┌──────────────────────────────────────────────┐
    # │  Stage 3: LLM-First Semantic Parsing         │
    # └──────────────────────────────────────────────┘
    t0 = time.perf_counter()
    llm_data = parse_transactions_llm(raw_lines)
    
    raw_records = llm_data.get("transactions", [])
    account_info = llm_data.get("account_info", {})
    summary_stats = llm_data.get("summary", {})
    
    if not raw_records:
        logger.error("❌ LLM Parsing failed. No transactions extracted.")
        return
        
    logger.info("⏱  Parsing done in %.2fs → %d records", time.perf_counter() - t0, len(raw_records))

    # ┌──────────────────────────────────────────────┐
    # │  Stage 4: Data Cleaning                      │
    # └──────────────────────────────────────────────┘
    t0 = time.perf_counter()
    cleaned_records = clean_transactions(raw_records)
    logger.info("⏱  Stage 4 done in %.2fs", time.perf_counter() - t0)

    # ┌──────────────────────────────────────────────┐
    # │  Stage 5: Balance Validation & Issues        │
    # └──────────────────────────────────────────────┘
    t0 = time.perf_counter()
    validated_records, validation_issues = validate_transactions(cleaned_records, summary_stats)
    logger.info("⏱  Stage 5 done in %.2fs", time.perf_counter() - t0)

    # ┌──────────────────────────────────────────────┐
    # │  Stage 6: Categorization (AI-Enriched)       │
    # └──────────────────────────────────────────────┘
    t0 = time.perf_counter()
    categorized_records = categorize_transactions(validated_records)
    logger.info("⏱  Stage 6 done in %.2fs", time.perf_counter() - t0)

    # ┌──────────────────────────────────────────────┐
    # │  Stage 7: Premium Excel Export               │
    # └──────────────────────────────────────────────┘
    t0 = time.perf_counter()
    output_xlsx = output_path.with_suffix(".xlsx")
    export_to_excel(
        transactions=categorized_records, 
        output_path=output_xlsx,
        account_info=account_info,
        summary_stats=summary_stats,
        validation_issues=validation_issues
    )
    logger.info("⏱  Stage 7 done in %.2fs", time.perf_counter() - t0)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Bank Statement PDF → Excel")
    parser.add_argument("pdf", help="Input PDF file path")
    parser.add_argument("--llm", action="store_true", help="Force LLM parsing (default is auto)")
    parser.add_argument("--debug", action="store_true", help="Save intermediate files")
    args = parser.parse_args()

    out = Path(args.pdf).stem + "_tally"
    run_pipeline(args.pdf, Path(out))
modules/extractor.py
\"\"\"
"""
extractor.py
============
Data Extraction Module

Extracts raw text from PDF using:
  - Primary  : pdfplumber (table-aware extraction)
  - Fallback : PaddleOCR  (for scanned documents)

Each page returns a list of raw text lines preserving
spatial order so the downstream parser can reconstruct rows.
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Generator

import pdfplumber

logger = logging.getLogger(__name__)

# ────────────────────────── Layout-Preserving Extraction ──────────────────────
def extract_text_pdf(pdf_path: str | Path) -> list[str]:
    """
    Extract text from a PDF while preserving layout exactly.
    Uses layout parameter to retain exact visual columnar spacing.
    """
    pdf_path = Path(pdf_path)
    all_lines: list[str] = []

    logger.info("📑 Spatial text extraction: %s", pdf_path.name)

    with pdfplumber.open(pdf_path) as pdf:
        for page_num, page in enumerate(pdf.pages, start=1):
            logger.debug("  → Page %d/%d", page_num, len(pdf.pages))
            
            # Use layout=True if available in pdfplumber (highly recommended for table structures)
            # Fallback to tightly constrained X/Y tolerance if layout fails
            try:
                text = page.extract_text(layout=True) or ""
            except Exception:
                 text = page.extract_text(x_tolerance=2, y_tolerance=3) or ""
                 
            for line in text.splitlines():
                if line.strip():
                    all_lines.append(line)

    logger.info("✅ Extracted %d text lines.", len(all_lines))
    return all_lines


# ────────────────────────── OCR-based extraction ─────────────────────────────
def extract_ocr_pdf(pdf_path: str | Path) -> list[str]:
    """
    Extract text from a scanned PDF using PaddleOCR.

    Process:
      1. Convert each PDF page to an in-memory image (via pdf2image / fitz)
      2. Run PaddleOCR on each image
      3. Collect text lines in reading order

    Parameters
    ----------
    pdf_path : str | Path

    Returns
    -------
    list[str]
        OCR-extracted lines, in document order.

    Notes
    -----
    Requires: paddleocr, pdf2image OR PyMuPDF (fitz)
    """
    pdf_path = Path(pdf_path)
    all_lines: list[str] = []

    logger.info("🔎 Starting OCR extraction (PaddleOCR): %s", pdf_path.name)

    try:
        import fitz  # PyMuPDF
        from paddleocr import PaddleOCR
        import numpy as np
        from PIL import Image
        import io
    except ImportError as exc:
        raise ImportError(
            "OCR dependencies missing. Install: paddleocr, PyMuPDF, pillow"
        ) from exc

    ocr_engine = PaddleOCR(use_angle_cls=True, lang="en", show_log=False)
    doc = fitz.open(str(pdf_path))

    for page_num in range(len(doc)):
        logger.debug("  → OCR Page %d/%d", page_num + 1, len(doc))
        page = doc[page_num]

        # Render page to image at 200 DPI for good OCR quality
        mat = fitz.Matrix(200 / 72, 200 / 72)
        clip = page.get_pixmap(matrix=mat, alpha=False)
        img_bytes = clip.tobytes("png")
        img_array = np.array(Image.open(io.BytesIO(img_bytes)))

        results = ocr_engine.ocr(img_array, cls=True)
        if not results or not results[0]:
            continue

        # Sort by Y coordinate (top-to-bottom reading order)
        sorted_results = sorted(results[0], key=lambda r: r[0][0][1])
        for item in sorted_results:
            text = item[1][0].strip()
            if text:
                all_lines.append(text)

    doc.close()
    logger.info("✅ OCR extracted %d lines total.", len(all_lines))
    return all_lines


def _is_scanned_pdf(pdf_path: Path) -> bool:
    """Check if the PDF contains font data. If not, it's likely scanned."""
    with pdfplumber.open(pdf_path) as pdf:
        for page in pdf.pages:
            # If any page has characters, it's a text-based PDF
            if page.chars:
                return False
    return True


# ────────────────────────── Unified Entry Point ───────────────────────────────
def extract(pdf_path: str | Path) -> list[str]:
    """
    Unified extraction entry point.
    Automatically detects if PDF is text-based or scanned.
    """
    pdf_path = Path(pdf_path)
    if _is_scanned_pdf(pdf_path):
        logger.info("🔎 Scanned/Image PDF detected. Switching to OCR...")
        return extract_ocr_pdf(pdf_path)
    
    logger.info("📑 Text-based PDF detected. Using fast extraction...")
    return extract_text_pdf(pdf_path)
modules/llm_parser.py

"""
llm_parser.py
=============
Production-Grade Semantic LLM Parser Module

Uses Gemini (Pro) with Strict Structured Outputs (Pydantic schemas)
to reconstruct transaction records definitively. This eliminates hallucination
and JSON parsing errors, working seamlessly across ICICI, HDFC, SBI, Axis, etc.

Strategy:
  1. Define rigid Pydantic data schemas.
  2. Map spatial text via Gemini 2.5 Pro / 2.0 Flash into these schemas.
  3. Use overlapping context to merge multi-line narrations.
"""

from __future__ import annotations

import logging
import os
import time
from typing import Optional, List
from pydantic import BaseModel, Field

logger = logging.getLogger(__name__)

# ── Pydantic Schemas for Strict Output ────────────────────────────────────────

class AccountInfo(BaseModel):
    name: Optional[str] = Field(description="Name of the account holder", default=None)
    bank: Optional[str] = Field(description="Name of the bank (HDFC, ICICI, SBI, etc.)", default=None)
    account_no: Optional[str] = Field(description="The bank account number", default=None)
    period_from: Optional[str] = Field(description="Statement start date (DD/MM/YYYY)", default=None)
    period_to: Optional[str] = Field(description="Statement end date (DD/MM/YYYY)", default=None)

class Transaction(BaseModel):
    date: str = Field(description="Transaction date (DD/MM/YYYY or original format)")
    description: str = Field(description="Full complete transaction narration/description. Merge any wrapped lines.")
    ref_no: Optional[str] = Field(description="Cheque or Reference ID. Leave completely blank if none.", default="")
    debit: Optional[float] = Field(description="Amount debited/withdrawn. Provide only number.", default=None)
    credit: Optional[float] = Field(description="Amount credited/deposited. Provide only number.", default=None)
    balance: Optional[float] = Field(description="Remaining balance after transaction.", default=None)

class Summary(BaseModel):
    opening_balance: Optional[float] = Field(description="Opening balance of the statement", default=None)
    closing_balance: Optional[float] = Field(description="Closing balance of the statement", default=None)

class ExtractedData(BaseModel):
    account_info: AccountInfo
    transactions: List[Transaction]
    summary: Summary


# ── Parser Engine ─────────────────────────────────────────────────────────────

def parse_transactions_llm(lines: list[str]) -> dict:
    """
    Main entry point for highly-accurate LLM-based parsing. 
    Processes document with maximum precision using structured outputs.
    """
    try:
        from google import genai
    except ImportError:
        logger.error("google-genai not installed. Install it to use the parser.")
        return {}

    api_key = os.getenv("GOOGLE_API_KEY")
    if not api_key:
         logger.warning("GOOGLE_API_KEY not found in environment!")
    client = genai.Client(api_key=api_key)
    
    # We use a powerful model for parsing complex table setups
    MODEL_ID = "gemini-2.5-pro"
    
    all_transactions = []
    final_metadata = {"account_info": {}, "summary": {}}
    
    # Context buffer to ensure no transaction gets cut in half across batches
    context_tail = ""
    # Process mostly page-by-page. Assuming ~80 lines per page.
    batch_size = 90  
    
    total_lines = len(lines)
    logger.info("🤖 Starting Production-Grade Semantic Parser (%d lines)", total_lines)

    for i in range(0, total_lines, batch_size):
        batch_num = (i // batch_size) + 1
        current_chunk = lines[i : i + batch_size]
        text_to_parse = "\n".join(current_chunk)
        
        prompt_lines = []
        prompt_lines.append("You are a financial parsing engine handling arbitrary bank statements (HDFC, SBI, ICICI, etc.).")
        prompt_lines.append("Extract transactions strictly into the structured format.")
        prompt_lines.append("If a transaction narration spans multiple lines, merge them into the single `description` field for that date.")
        prompt_lines.append("If a row does not contain a date, it is likely a continuation of the previous row's narration. Merge it.")
        prompt_lines.append("Ensure no leading/trailing spaces.")
        
        if context_tail:
            prompt_lines.append(f"\n--- PREVIOUS PAGE OVERLAP CONTEXT ---\n{context_tail}\n(Use to continue narrations if split. Do not duplicate rows.)")
            
        prompt_lines.append(f"\n--- CURRENT PAGE TEXT ---\n{text_to_parse}")
        
        full_prompt = "\n".join(prompt_lines)
        
        # Retry with exponential backoff for API limits/transient errors
        max_retries = 3
        for attempt in range(max_retries):
            try:
                logger.info("  → Processing Batch #%d (Attempt %d)...", batch_num, attempt + 1)
                
                response = client.models.generate_content(
                    model=MODEL_ID,
                    contents=full_prompt,
                    config={
                        'response_mime_type': 'application/json',
                        'response_schema': ExtractedData,
                        'temperature': 0.1 # Low temp for deterministic extraction
                    }
                )
                
                # Check if it hit safety/other block
                if not response.text:
                    raise ValueError("Empty response. Possibly blocked.")
                    
                import json
                data = json.loads(response.text.strip())
                
                # Merge Metadata (take from any chunk if populated)
                if data.get("account_info"):
                    ai = data["account_info"]
                    for k, v in ai.items():
                        if v and not final_metadata["account_info"].get(k):
                            final_metadata["account_info"][k] = v
                            
                if data.get("summary"):
                    summ = data["summary"]
                    for k, v in summ.items():
                        if v is not None and final_metadata["summary"].get(k) is None:
                            final_metadata["summary"][k] = v
                            
                # Accumulate transactions
                if data.get("transactions"):
                    all_transactions.extend(data["transactions"])
                    
                # Set context for next batch (last 4 lines for continuity)
                context_tail = "\n".join(current_chunk[-4:])
                break  # Success, exit retry loop
                    
            except Exception as exc:
                logger.warning("  ⚠ Batch %d Attempt %d Failed: %s", batch_num, attempt + 1, exc)
                if attempt == max_retries - 1:
                    logger.error("  ❌ Batch %d completely failed. Skipping.", batch_num)
                time.sleep(2 ** attempt)

    logger.info("✅ Full Parse Complete. Extracted %d transactions.", len(all_transactions))
    return {
        "account_info": final_metadata["account_info"],
        "transactions": all_transactions,
        "summary": final_metadata["summary"]
    }
modules/validator.py
"""
validator.py
============
Advanced Validation & Correction Engine (Production Grade)

Rules:
  1. Row Balance Check: Prev Balance - Debit + Credit = Current Balance
  2. Auto-Correction: Swap Debit/Credit if it fix the math.
  3. Duplicate Detection: Same Ref No. on different rows.
  4. Summary Cross-Check: Stated Debits vs Sum of Rows.
  5. Flagging: Mark rows for manual review in Excel.
"""

from __future__ import annotations

import logging
from collections import defaultdict
from typing import Optional

logger = logging.getLogger(__name__)

TOLERANCE: float = 0.05  # acceptable rounding/floating point error in INR

def validate_transactions(records: list[dict], summary_stats: dict = None) -> tuple[list[dict], list[dict]]:
    """
    Validate all transaction records for mathematical and structural integrity.
    
    Returns:
        tuple (validated_records, list_of_issues)
    """
    validated: list[dict] = []
    issues: list[dict] = []
    
    prev_balance: Optional[float] = None
    seen_refs = defaultdict(list)
    
    total_debit_calc = 0.0
    total_credit_calc = 0.0
    
    logger.info("🔬 Advanced validation starting for %d records...", len(records))

    for idx, rec in enumerate(records):
        rec = dict(rec)  # working copy
        loc = f"Row #{idx+1} ({rec.get('date')})"
        
        debit = rec.get("debit", 0.0) or 0.0
        credit = rec.get("credit", 0.0) or 0.0
        balance = rec.get("balance")
        ref = str(rec.get("ref_no", "")).strip()

        # 1. ── DUPLICATE CHECK ────────────────────────────────────────────────
        if ref and ref not in ("0", "0000000000000000", ""):
            seen_refs[ref].append(idx + 1)

        # 2. ── BALANCE CONTINUITY ─────────────────────────────────────────────
        if prev_balance is None:
            # Anchor balance (first record)
            prev_balance = balance
            rec["validation_message"] = ""
        else:
            expected = round(prev_balance - debit + credit, 2)
            
            if balance is not None and abs(balance - expected) > TOLERANCE:
                # Mismatch detected → Attempt Fix
                fix_expected = round(prev_balance - credit + debit, 2)
                if abs(balance - fix_expected) <= TOLERANCE:
                     # FIX: Swap Debit/Credit
                     rec["debit"], rec["credit"] = credit, debit
                     debit, credit = credit, debit
                     rec["validation_message"] = "FIXED: Swapped Debit/Credit"
                     logger.debug("  🔧 %s: Swapped Debit/Credit for math check", loc)
                else:
                     # UNFIXABLE
                     rec["validation_message"] = "FLAGGED: Balance Mismatch"
                     issues.append({
                         "level": "WARNING",
                         "location": loc,
                         "message": f"Balance mismatch. Expected {expected}, got {balance}."
                     })
            else:
                rec["validation_message"] = ""

            prev_balance = balance if balance is not None else expected

        total_debit_calc += debit
        total_credit_calc += credit
        validated.append(rec)

    # 3. ── DUPLICATE REPORTING ────────────────────────────────────────────────
    for ref, rows in seen_refs.items():
        if len(rows) > 1:
            issues.append({
                "level": "INFO", 
                "location": f"Ref {ref}", 
                "message": f"Duplicate Ref No. found in rows: {rows}"
            })

    # 4. ── SUMMARY CROSS-CHECK ────────────────────────────────────────────────
    if summary_stats:
        stated_dr = summary_stats.get("total_debits")
        stated_cr = summary_stats.get("total_credits")
        
        if stated_dr is not None and abs(stated_dr - total_debit_calc) > 1.0:
            issues.append({
                "level": "WARNING",
                "location": "Summary",
                "message": f"Stated Debits ({stated_dr}) != Sum of Rows ({round(total_debit_calc,2)})"
            })
        if stated_cr is not None and abs(stated_cr - total_credit_calc) > 1.0:
            issues.append({
                "level": "WARNING",
                "location": "Summary",
                "message": f"Stated Credits ({stated_cr}) != Sum of Rows ({round(total_credit_calc,2)})"
            })

    logger.info("✅ Validation complete. %d issues flagged.", len(issues))
    return validated, issues
modules/exporter.py
"""
exporter.py
===========
Premium Excel Export Module (Validated & Structured)

Responsibilities:
  1. Generate Tally-ready Excel (.xlsx)
  2. Multi-sheet output:
     - Transactions (Main with Header metadata)
     - Monthly Summary (Analysis)
     - Account Info (Metadata)
     - Validation Engine Report (Flagged issues)
  3. Professional styling: Merged headers, zebra stripes, conditional formatting.
"""

from __future__ import annotations

import logging
import re
from collections import defaultdict
from pathlib import Path
from typing import Any

from openpyxl import Workbook
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
from openpyxl.utils import get_column_letter

logger = logging.getLogger(__name__)

# ── Styling Constants ────────────────────────────────────────────────────────
C_HEADER  = "1F4E79"  # Dark Blue
C_TITLE   = "2E75B6"  # Medium Blue
C_ALT     = "D6E4F0"  # Very Light Blue (for zebra)
C_DEBIT   = "FFE0E0"  # Light Red
C_CREDIT  = "E2EFDA"  # Light Green
C_SUMMARY = "FFF2CC"  # Pale Yellow
C_INFO    = "DEEAF1"  # Light Sky
C_ERROR   = "FCE4D6"  # Light Orange
C_WARN    = "FFEB9C"  # Gold
C_OK      = "C6EFCE"  # Light Green

THIN_SIDE = Side(style="thin", color="AAAAAA")
BORDER    = Border(left=THIN_SIDE, right=THIN_SIDE, top=THIN_SIDE, bottom=THIN_SIDE)
NUM_FMT   = "#,##0.00"

# ── Helper Styles ────────────────────────────────────────────────────────────
def _apply_header(cell, text: str):
    cell.value = text
    cell.font = Font(name="Arial", bold=True, size=10, color="FFFFFF")
    cell.fill = PatternFill("solid", start_color=C_HEADER)
    cell.alignment = Alignment(horizontal="center", vertical="center", wrap_text=True)
    cell.border = BORDER

def _apply_data(cell, value: Any, fill: PatternFill, fmt: str = None, align: str = "center"):
    cell.value = value
    cell.font = Font(name="Arial", size=9)
    cell.fill = fill
    cell.alignment = Alignment(horizontal=align, vertical="center")
    cell.border = BORDER
    if fmt:
        cell.number_format = fmt

def _create_title_row(ws, text: str, span: int, color: str = C_TITLE):
    ws.merge_cells(f"A1:{get_column_letter(span)}1")
    ws["A1"].value = text
    ws["A1"].font  = Font(name="Arial", bold=True, size=12, color="FFFFFF")
    ws["A1"].fill  = PatternFill("solid", start_color=color)
    ws["A1"].alignment = Alignment(horizontal="center", vertical="center")
    ws.row_dimensions[1].height = 26


# ── Main Exporter ────────────────────────────────────────────────────────────
def export_to_excel(transactions: list[dict], 
                    output_path: str | Path, 
                    account_info: dict = None,
                    summary_stats: dict = None,
                    validation_issues: list[dict] = None):
    """
    Export bank data to a professional, multi-sheet Excel workbook.
    """
    output_path = Path(output_path)
    wb = Workbook()
    
    # 1. TRANSACTIONS SHEET
    _write_transactions_sheet(wb.active, transactions, account_info, summary_stats)
    
    # 2. MONTHLY SUMMARY
    _write_monthly_summary(wb.create_sheet("Monthly Summary"), transactions)
    
    # 3. ACCOUNT INFO
    _write_account_info(wb.create_sheet("Account Info"), account_info, summary_stats, len(transactions))
    
    # 4. VALIDATION REPORT
    if validation_issues:
        _write_validation_report(wb.create_sheet("Validation Report"), validation_issues)

    try:
        wb.save(output_path)
        logger.info("✅ Premium Excel saved: %s", output_path)
    except Exception as e:
        logger.error("❌ Failed to save Excel: %s", e)
        raise


def _write_transactions_sheet(ws, txns, info, summ):
    ws.title = "Transactions"
    info = info or {}
    summ = summ or {}
    
    # Title Header (Name, Bank, Period)
    title_parts = []
    if info.get("bank"): title_parts.append(info["bank"])
    if info.get("name"): title_parts.append(info["name"])
    if info.get("account_no"): title_parts.append(f"A/C: {info['account_no']}")
    
    title_str = "  |  ".join(title_parts) or "Bank Statement"
    _create_title_row(ws, title_str, 8)
    
    # Summary Bar
    ws.merge_cells("A2:H2")
    ws["A2"].value = (
        f"Opening: {summ.get('opening_balance','-')}  |  "
        f"Total Debits: {summ.get('total_debits','-')}  |  "
        f"Total Credits: {summ.get('total_credits','-')}  |  "
        f"Closing: {summ.get('closing_balance', txns[-1].get('balance','-') if txns else '-')}"
    )
    ws["A2"].font  = Font(name="Arial", bold=True, size=10, color="1F4E79")
    ws["A2"].fill  = PatternFill("solid", start_color=C_INFO)
    ws["A2"].alignment = Alignment(horizontal="center", vertical="center")
    ws.row_dimensions[2].height = 20

    # Table Headers
    cols = ["Date", "Particulars/Narration", "Vch Type", "Ref/Chq No.", "Debit (Withdrawal)", "Credit (Deposit)", "Balance", "Flag"]
    for i, name in enumerate(cols, 1):
        _apply_header(ws.cell(3, i), name)
    ws.row_dimensions[3].height = 24

    # Data Rows
    for r_idx, txn in enumerate(txns, 4):
        debit = txn.get("debit", 0.0)
        credit = txn.get("credit", 0.0)
        vch_type = "Payment" if debit > 0 else ("Receipt" if credit > 0 else "Journal")
        msg = txn.get("validation_message", "")

        bg_color = C_ALT if r_idx % 2 == 0 else "FFFFFF"
        row_fill = PatternFill("solid", start_color=bg_color)
        
        # Column fills
        debit_fill  = PatternFill("solid", start_color=C_DEBIT if debit > 0 else bg_color)
        credit_fill = PatternFill("solid", start_color=C_CREDIT if credit > 0 else bg_color)
        flag_fill   = PatternFill("solid", start_color=C_WARN if msg else bg_color)

        _apply_data(ws.cell(r_idx, 1), txn.get("date"), row_fill)
        _apply_data(ws.cell(r_idx, 2), txn.get("description"), row_fill, align="left")
        _apply_data(ws.cell(r_idx, 3), vch_type, row_fill)
        _apply_data(ws.cell(r_idx, 4), txn.get("ref_no", ""), row_fill)
        _apply_data(ws.cell(r_idx, 5), debit if debit > 0 else None, debit_fill, fmt=NUM_FMT)
        _apply_data(ws.cell(r_idx, 6), credit if credit > 0 else None, credit_fill, fmt=NUM_FMT)
        _apply_data(ws.cell(r_idx, 7), txn.get("balance"), row_fill, fmt=NUM_FMT)
        _apply_data(ws.cell(r_idx, 8), msg, flag_fill)
        
        ws.row_dimensions[r_idx].height = 15

    # Totals Row
    last_row = len(txns) + 4
    ws.cell(last_row, 1).value = "TOTAL"
    ws.cell(last_row, 1).font = Font(bold=True)
    ws.cell(last_row, 5).value = f"=SUM(E4:E{last_row-1})"
    ws.cell(last_row, 6).value = f"=SUM(F4:F{last_row-1})"
    for c in range(1, 9):
        ws.cell(last_row, c).fill = PatternFill("solid", start_color=C_SUMMARY)
        ws.cell(last_row, c).border = BORDER
        if c in (5,6): ws.cell(last_row, c).number_format = NUM_FMT

    widths = [12, 60, 12, 18, 18, 18, 18, 10]
    for i, w in enumerate(widths, 1):
        ws.column_dimensions[get_column_letter(i)].width = w
    
    ws.freeze_panes = "A4"


def _write_monthly_summary(ws, txns):
    ws.title = "Monthly Summary"
    _create_title_row(ws, "Monthly Analysis & Cash Flow", 6)
    headers = ["Month", "Total Debits", "Total Credits", "Net Flow", "Dr Count", "Cr Count"]
    for i, h in enumerate(headers, 1): _apply_header(ws.cell(2, i), h)
    
    monthly = defaultdict(lambda: {"dr":0.0, "cr":0.0, "dr_c":0, "cr_c":0})
    for t in txns:
        date = str(t.get("date", ""))
        match = re.search(r"(\d{2})[-/](\d{2})[-/](\d{2,4})", date)
        month_key = f"{match.group(2)}/{match.group(3)}" if match else "Unknown"
        
        dr, cr = t.get("debit", 0.0), t.get("credit", 0.0)
        monthly[month_key]["dr"] += dr
        monthly[month_key]["cr"] += cr
        if dr > 0: monthly[month_key]["dr_c"] += 1
        if cr > 0: monthly[month_key]["cr_c"] += 1

    for r, (m, v) in enumerate(sorted(monthly.items()), 3):
        net = v["cr"] - v["dr"]
        fill = PatternFill("solid", start_color=C_ALT if r % 2 == 0 else "FFFFFF")
        _apply_data(ws.cell(r,1), m, fill)
        _apply_data(ws.cell(r,2), v["dr"], fill, fmt=NUM_FMT)
        _apply_data(ws.cell(r,3), v["cr"], fill, fmt=NUM_FMT)
        _apply_data(ws.cell(r,4), net, fill, fmt=NUM_FMT)
        _apply_data(ws.cell(r,5), v["dr_c"], fill)
        _apply_data(ws.cell(r,6), v["cr_c"], fill)
        
    for i, w in enumerate([15, 20, 20, 20, 12, 12], 1):
        ws.column_dimensions[get_column_letter(i)].width = w


def _write_account_info(ws, info, summ, txn_count):
    ws.title = "Account Info"
    _create_title_row(ws, "Account Metadata", 2)
    info = info or {}
    summ = summ or {}
    rows = [
        ("Account Holder", info.get("name", "N/A")),
        ("Bank Name", info.get("bank", "N/A")),
        ("Account Number", info.get("account_no", "N/A")),
        ("Period", f"{info.get('period_from','')} to {info.get('period_to','')}"),
        ("Total Transactions", txn_count),
        ("Opening Balance", summ.get("opening_balance", "N/A")),
        ("Closing Balance", summ.get("closing_balance", "N/A")),
    ]
    for r, (k, v) in enumerate(rows, 2):
        fill = PatternFill("solid", start_color=C_ALT if r % 2 == 0 else "FFFFFF")
        _apply_data(ws.cell(r,1), k, fill, align="left")
        ws.cell(r,1).font = Font(bold=True)
        _apply_data(ws.cell(r,2), v, fill, align="left")
    ws.column_dimensions["A"].width = 25
    ws.column_dimensions["B"].width = 50


def _write_validation_report(ws, issues):
    ws.title = "Validation Report"
    _create_title_row(ws, "Validation Engine Results", 4, color="C62828")
    headers = ["Level", "Location", "Message", "Status"]
    for i, h in enumerate(headers, 1): _apply_header(ws.cell(2, i), h)
    
    level_color = {"ERROR": C_ERROR, "WARNING": C_WARN, "INFO": C_OK}
    for r, issue in enumerate(issues, 3):
        lvl = issue.get("level", "INFO")
        fill = PatternFill("solid", start_color=level_color.get(lvl, "FFFFFF"))
        _apply_data(ws.cell(r,1), lvl, fill)
        _apply_data(ws.cell(r,2), issue.get("location", ""), fill)
        _apply_data(ws.cell(r,3), issue.get("message", ""), fill, align="left")
        _apply_data(ws.cell(r,4), "Fix Needed" if lvl == "ERROR" else "Review", fill)
    
    for i, w in enumerate([12, 22, 65, 16], 1):
        ws.column_dimensions[get_column_letter(i)].width = w
modules/cleaner.py
"""
"""
cleaner.py
==========
Data Cleaning & Normalization Module

Responsibilities:
  1. Standardize date to dd-mm-yyyy
  2. Strip commas and convert amounts to float
  3. Handle Cr/Dr suffixes common in Indian bank statements
  4. Clean description text (collapse whitespace, remove noise)
  5. Detect and swap debit/credit if amounts are misplaced
  6. Fill None → 0.0 for missing numeric fields
"""

from __future__ import annotations

import logging
import re
from datetime import datetime
from typing import Optional

logger = logging.getLogger(__name__)

# ─────────────────────────── Date format attempts ────────────────────────────
DATE_FORMATS = [
    "%d-%m-%Y",   # 05-04-2024
    "%d/%m/%Y",   # 05/04/2024
    "%d %b %Y",   # 05 Apr 2024
    "%d-%b-%Y",   # 05-Apr-2024
    "%d/%b/%Y",   # 05/Apr/2024
    "%d %B %Y",   # 05 April 2024
    "%d-%B-%Y",   # 05-April-2024
    "%Y-%m-%d",   # 2024-04-05  (ISO)
    "%Y/%m/%d",   # 2024/04/05
    "%d-%m-%y",   # 05-04-24
    "%d/%m/%y",   # 05/04/24
    "%d-%b-%y",   # 05-Apr-24
    "%b %d, %Y",  # Apr 05, 2024
]

# Clean amount: strips commas, trailing Cr/Dr, currency symbol
AMOUNT_CLEAN_RE = re.compile(r"[₹$,\s]")
CR_DR_RE = re.compile(r"\s*(Cr|Dr|CR|DR)\s*$", re.IGNORECASE)


# ─────────────────────────── Helpers ─────────────────────────────────────────
def _parse_date(raw: str) -> str:
    """
    Convert any recognized date string to dd-mm-yyyy.

    Returns the original string if parsing fails.
    """
    if not raw or str(raw).lower() == "none":
        return ""
    raw = str(raw).strip()
    
    # Try multiple formats
    for fmt in DATE_FORMATS:
        try:
            dt = datetime.strptime(raw, fmt)
            return dt.strftime("%d-%m-%Y")
        except ValueError:
            continue
            
    # Try dateutil as fallback if installed
    try:
        from dateutil import parser
        dt = parser.parse(raw)
        return dt.strftime("%d-%m-%Y")
    except Exception:
        pass
        
    logger.warning("⚠️  Could not parse date: '%s'", raw)
    return raw  # preserve as-is


def _parse_amount(raw) -> Optional[float]:
    """
    Convert a raw amount to float. Handles floats directly, and strings.
    Strips commas, suffixes, and currency symbols.
    Returns absolute value (sign is handled by column placement).
    """
    if raw is None or raw == "" or str(raw).lower() == "none":
        return None

    if isinstance(raw, (int, float)):
        return abs(float(raw))

    cleaned = str(raw)
    cleaned = AMOUNT_CLEAN_RE.sub("", cleaned)
    cleaned = CR_DR_RE.sub("", cleaned).strip()

    if not cleaned:
        return None

    try:
        # Replace comma and convert to absolute float
        value = float(cleaned.replace(",", ""))
        return abs(value)
    except ValueError:
        return None


def _clean_description(raw: str) -> str:
    """Collapse multiple spaces, tabs, remove non-printable chars."""
    cleaned = re.sub(r"[\t\r\n]+", " ", raw)
    cleaned = re.sub(r" {2,}", " ", cleaned)
    return cleaned.strip()


# ─────────────────────────── Main Cleaner ────────────────────────────────────
def clean_transactions(raw_records: list[dict]) -> list[dict]:
    """
    Clean and normalize a list of raw transaction records.

    Parameters
    ----------
    raw_records : list[dict]
        Output from parser.parse_transactions()

    Returns
    -------
    list[dict]
        Cleaned records with keys:
            date (str dd-mm-yyyy)
            description (str)
            debit (float)
            credit (float)
            balance (float)
    """
    cleaned: list[dict] = []

    logger.info("🧹 Cleaning %d raw records...", len(raw_records))

    for idx, rec in enumerate(raw_records):
        try:
            date = _parse_date(str(rec.get("date", "")))
            description = _clean_description(str(rec.get("description", "")))
            debit_raw = str(rec.get("debit", "")).strip()
            credit_raw = str(rec.get("credit", "")).strip()
            balance_raw = str(rec.get("balance", "")).strip()

            debit = _parse_amount(debit_raw) or 0.0
            credit = _parse_amount(credit_raw) or 0.0
            balance = _parse_amount(balance_raw)  # None if truly missing

            # ── Swap detection ────────────────────────────────────────────
            # Some statements put Credit where Debit col should be
            # We leave this for the validator to cross-check with balance.

            cleaned_rec = {
                "date": date,
                "description": description,
                "debit": round(debit, 2),
                "credit": round(credit, 2),
                "balance": round(balance, 2) if balance is not None else None,
            }
            cleaned.append(cleaned_rec)

        except Exception as exc:
            logger.error("❌ Failed to clean record #%d: %s | Record: %s", idx, exc, rec)
            continue

    logger.info("✅ Cleaning complete. %d/%d records cleaned.", len(cleaned), len(raw_records))
    return cleaned

modules/categorizer.py
"""
categorizer.py
==============
Transaction Categorization Module

Assigns a Tally-compatible Vch Type to each transaction based on:
  1. Rule-based keyword matching (fast, deterministic)
  2. Optional LLM enrichment via Gemini Flash (when USE_LLM=True)

Rule-based categories (priority order):
  ┌─────────────────┬──────────────────────┐
  │ Pattern         │ Vch Type             │
  ├─────────────────┼──────────────────────┤
  │ ATM             │ Cash Withdrawal      │
  │ NEFT/IMPS/RTGS  │ Bank Transfer        │
  │ ACH             │ Auto Debit           │
  │ BIL/BPAY/BBPS   │ Bill Payment         │
  │ CMS             │ Investment           │
  │ INT             │ Interest             │
  │ SAL/SALARY      │ Salary               │
  │ UPI             │ UPI Transfer         │
  │ NACH            │ Auto Debit           │
  │ POS/SWIPE       │ POS Purchase         │
  │ CHQ/CHEQUE      │ Cheque               │
  │ TDS             │ Tax Deduction        │
  │ EMI             │ Loan EMI             │
  │ GST             │ Tax Payment          │
  │ FD / RD         │ Investment           │
  │ (fallback)      │ General              │
  └─────────────────┴──────────────────────┘
"""

from __future__ import annotations

import logging
import os
import re
from typing import Optional

logger = logging.getLogger(__name__)

# ─────────────────────────── Constants ───────────────────────────────────────
USE_LLM: bool = os.getenv("USE_LLM_CATEGORIZER", "false").lower() == "true"
BATCH_SIZE: int = 30  # transactions per LLM batch call

# ─────────────────────────── Rule Table ──────────────────────────────────────
CATEGORY_RULES: list[tuple[re.Pattern, str]] = [
    (re.compile(r"\bATM\b", re.I),                             "Cash Withdrawal"),
    (re.compile(r"\b(NEFT|IMPS|RTGS)\b", re.I),               "Bank Transfer"),
    (re.compile(r"\bNACH\b", re.I),                            "Auto Debit"),
    (re.compile(r"\bACH\b", re.I),                             "Auto Debit"),
    (re.compile(r"\bEMI\b", re.I),                             "Loan EMI"),
    (re.compile(r"\b(BIL|BPAY|BBPS|BILL)\b", re.I),           "Bill Payment"),
    (re.compile(r"\bUPI\b", re.I),                             "UPI Transfer"),
    (re.compile(r"\b(SAL|SALARY|PAYROLL)\b", re.I),           "Salary"),
    (re.compile(r"\bCMS\b", re.I),                             "Investment"),
    (re.compile(r"\b(FD|RD|MUTUAL FUND|MF)\b", re.I),         "Investment"),
    (re.compile(r"\bINT(EREST)?\b", re.I),                     "Interest"),
    (re.compile(r"\b(POS|SWIPE|CARD)\b", re.I),               "POS Purchase"),
    (re.compile(r"\b(CHQ|CHEQUE|CQ)\b", re.I),                "Cheque"),
    (re.compile(r"\bTDS\b", re.I),                             "Tax Deduction"),
    (re.compile(r"\b(GST|IGST|CGST|SGST)\b", re.I),           "Tax Payment"),
    (re.compile(r"\b(REFUND|REVERSAL|REV)\b", re.I),          "Refund"),
    (re.compile(r"\b(INS|INSURANCE|LIC|PREMIUM)\b", re.I),    "Insurance"),
    (re.compile(r"\b(CASH|CSH)\b", re.I),                     "Cash Deposit"),
]

FALLBACK_CATEGORY = "General"


# ─────────────────────────── Rule-Based Core ─────────────────────────────────
def _rule_based_category(description: str) -> str:
    """Apply ordered rule table to a description string."""
    for pattern, category in CATEGORY_RULES:
        if pattern.search(description):
            return category
    return FALLBACK_CATEGORY


# ─────────────────────────── LLM Enrichment (optional) ───────────────────────
def _llm_categorize_batch(descriptions: list[str]) -> list[str]:
    """
    Send a batch of descriptions to Gemini Flash for categorization.

    Returns a list of category strings aligned to input.
    Falls back to rule-based on any error.
    """
    try:
        from google import genai

        api_key = os.getenv("GOOGLE_API_KEY")
        if not api_key:
            raise EnvironmentError("GOOGLE_API_KEY not set")
        client = genai.Client(api_key=api_key)
        
        numbered = "\n".join(
            f"{i+1}. {desc}" for i, desc in enumerate(descriptions)
        )

        prompt = f"""You are a Tally accounting assistant.
Categorize each bank transaction description into ONE of these Tally Vch Types:
Cash Withdrawal, Bank Transfer, Auto Debit, Bill Payment, Investment,
Interest, Salary, UPI Transfer, POS Purchase, Cheque, Tax Deduction,
Tax Payment, Loan EMI, Refund, Insurance, Cash Deposit, General.

Return ONLY a numbered list. Same number as input. No extra text.
Format:
1. 
2. 

Transactions:
{numbered}"""

        response = client.models.generate_content(
            model="gemini-2.0-flash",
            contents=prompt
        )
        text_response = response.text.strip()
        lines = text_response.splitlines()

        categories = []
        for line in lines:
            match = re.match(r"^\d+\.\s*(.+)$", line.strip())
            if match:
                categories.append(match.group(1).strip())

        if len(categories) == len(descriptions):
            return categories

        logger.warning("LLM returned %d categories for %d inputs — using rule fallback.",
                       len(categories), len(descriptions))
        return [_rule_based_category(d) for d in descriptions]

    except Exception as exc:
        logger.error("LLM categorization failed: %s — falling back to rules.", exc)
        return [_rule_based_category(d) for d in descriptions]


# ─────────────────────────── Main Entry ──────────────────────────────────────
def categorize_transactions(records: list[dict]) -> list[dict]:
    """
    Assign Vch Type to each transaction record.

    Modifies records in-place (adds "vch_type" field).

    Parameters
    ----------
    records : list[dict]
        Validated transaction records.

    Returns
    -------
    list[dict]
        Same records with "vch_type" populated.
    """
    logger.info(
        "🏷️  Categorizing %d transactions (LLM=%s)...",
        len(records),
        USE_LLM,
    )

    if not USE_LLM:
        # Pure rule-based — fast path
        for rec in records:
            rec["vch_type"] = _rule_based_category(rec.get("description", ""))
    else:
        # LLM enrichment in batches
        descs = [rec.get("description", "") for rec in records]
        categories: list[str] = []
        for i in range(0, len(descs), BATCH_SIZE):
            batch = descs[i : i + BATCH_SIZE]
            categories.extend(_llm_categorize_batch(batch))

        for rec, cat in zip(records, categories):
            rec["vch_type"] = cat

    logger.info("✅ Categorization complete.")
    return records

langchain_app.py
"""
langchain_app.py
================
LangChain-Powered Bank Statement Orchestrator

This module reimplements the bank statement parsing pipeline using LangChain LCEL.
It provides a robust, modular, and traceable chain for processing financial documents.
"""

import os
import logging
import time
from pathlib import Path
from typing import List, Dict, Any

from dotenv import load_dotenv
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda
from pydantic import BaseModel, Field

# Import existing specialized modules
from modules.extractor import extract
from modules.cleaner import clean_transactions
from modules.validator import validate_transactions
from modules.categorizer import categorize_transactions
from modules.exporter import export_to_excel

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("LangChainBankApp")

# ── Data Schemas (Reusing existing structures) ────────────────────────────────

class AccountInfo(BaseModel):
    name: str = Field(description="Name of the account holder", default="Unknown")
    bank: str = Field(description="Name of the bank", default="Unknown")
    account_no: str = Field(description="Bank account number", default="Unknown")
    period_from: str = Field(description="Start date", default="Unknown")
    period_to: str = Field(description="End date", default="Unknown")

class Transaction(BaseModel):
    date: str = Field(description="Transaction date")
    description: str = Field(description="Full narration")
    ref_no: str = Field(description="Reference/Cheque number", default="")
    debit: float = Field(description="Debit amount", default=0.0)
    credit: float = Field(description="Credit amount", default=0.0)
    balance: float = Field(description="Running balance", default=0.0)

class Summary(BaseModel):
    opening_balance: float = Field(description="Statement opening balance", default=0.0)
    closing_balance: float = Field(description="Statement closing balance", default=0.0)

class ExtractedData(BaseModel):
    account_info: AccountInfo
    transactions: List[Transaction]
    summary: Summary

# ── LangChain Components ──────────────────────────────────────────────────────

def get_parser_chain():
    """Builds the LangChain parsing chain."""
    llm = ChatGoogleGenerativeAI(
        model="gemini-1.5-pro",
        temperature=0,
        max_retries=2,
    )
    
    # Use structured output for guaranteed schema compliance
    structured_llm = llm.with_structured_output(ExtractedData)
    
    prompt = ChatPromptTemplate.from_messages([
        ("system", (
            "You are a specialized financial parser. Extract bank statement data into the provided schema. "
            "Merge multi-line descriptions. If a row lacks a date, it belongs to the previous transaction. "
            "Handle overlap context to avoid duplicates."
        )),
        ("human", "--- PREVIOUS CONTEXT ---\n{context}\n\n--- TARGET TEXT ---\n{text}")
    ])
    
    return prompt | structured_llm

def parse_with_langchain(lines: List[str]) -> Dict[str, Any]:
    """Processes extracted text lines in batches using LangChain."""
    chain = get_parser_chain()
    
    all_transactions = []
    final_metadata = {"account_info": {}, "summary": {}}
    context_tail = ""
    batch_size = 80
    
    for i in range(0, len(lines), batch_size):
        chunk = lines[i : i + batch_size]
        text = "\n".join(chunk)
        
        logger.info(f"Processing batch {i//batch_size + 1}...")
        
        try:
            result = chain.invoke({"context": context_tail, "text": text})
            
            # Aggregate data
            all_transactions.extend([t.dict() for t in result.transactions])
            
            # Update metadata if found
            if result.account_info.name != "Unknown":
                final_metadata["account_info"] = result.account_info.dict()
            if result.summary.opening_balance != 0.0:
                final_metadata["summary"] = result.summary.dict()
                
            # Set context for next batch
            context_tail = "\n".join(chunk[-5:])
            
        except Exception as e:
            logger.error(f"Error in batch {i//batch_size + 1}: {e}")
            
    return {
        "transactions": all_transactions,
        "account_info": final_metadata.get("account_info", {}),
        "summary": final_metadata.get("summary", {})
    }

# ── Main Pipeline ─────────────────────────────────────────────────────────────

def run_app(pdf_path: str):
    """Main application flow."""
    load_dotenv()
    
    t_start = time.perf_counter()
    pdf_file = Path(pdf_path)
    
    # 1. Extraction (Native)
    logger.info("Stage 1: Extracting text...")
    raw_lines = extract(pdf_file)
    
    # 2. Parsing (LangChain)
    logger.info("Stage 2: Semantic Parsing via LangChain...")
    parsed_data = parse_with_langchain(raw_lines)
    
    # 3. Business Logic (Native)
    logger.info("Stage 3-6: Cleaning, Validating, Categorizing...")
    cleaned = clean_transactions(parsed_data["transactions"])
    validated, issues = validate_transactions(cleaned, parsed_data["summary"])
    categorized = categorize_transactions(validated)
    
    # 4. Export (Native)
    logger.info("Stage 7: Exporting to Excel...")
    output_path = pdf_file.stem + "_langchain_output.xlsx"
    export_to_excel(
        transactions=categorized,
        output_path=Path(output_path),
        account_info=parsed_data["account_info"],
        summary_stats=parsed_data["summary"],
        validation_issues=issues
    )
    
    duration = time.perf_counter() - t_start
    logger.info(f"✅ Success! Integrated LangChain Pipeline completed in {duration:.2f}s")
    logger.info(f"📊 Result saved to: {output_path}")

if __name__ == "__main__":
    import sys
    if len(sys.argv) < 2:
        print("Usage: python langchain_app.py ")
    else:
        run_app(sys.argv[1])