#!/usr/bin/env python3
import argparse
import csv
import math
import os
try:
    from PIL import ImageStat
except Exception:  # pragma: no cover
    ImageStat = None
try:
    import cv2
    import numpy as np
except Exception:  # pragma: no cover
    cv2 = None
    np = None
import os
import re
import sys

try:
    import pdfplumber
except Exception as exc:  # pragma: no cover
    print("Error: pdfplumber is required. Install with: pip install -r requirements.txt", file=sys.stderr)
    raise


DIGIT_RE = re.compile(r"\d")


def median(values):
    if not values:
        return 0.0
    sorted_vals = sorted(values)
    n = len(sorted_vals)
    mid = n // 2
    if n % 2 == 1:
        return float(sorted_vals[mid])
    return float(sorted_vals[mid - 1] + sorted_vals[mid]) / 2.0


def group_chars_into_tokens(chars, line_tolerance=3.5, gap_factor=1.8):
    """
    Group numeric pdfplumber characters into tokens by line and proximity.

    Returns list of tokens: { text, x0, x1, top, bottom }
    """
    # Keep only numeric characters
    digit_chars = [ch for ch in chars if DIGIT_RE.match(ch.get("text", ""))]
    if not digit_chars:
        return []

    # Sort primarily by vertical position (top), then by horizontal (x0)
    digit_chars.sort(key=lambda c: (c.get("top", 0.0), c.get("x0", 0.0)))

    # Build lines by clustering nearby tops within line_tolerance
    lines = []  # list[list[char]]
    for ch in digit_chars:
        if not lines:
            lines.append([ch])
            continue
        current_line = lines[-1]
        current_top = sum(c.get("top", 0.0) for c in current_line) / float(len(current_line))
        if abs(ch.get("top", 0.0) - current_top) <= line_tolerance:
            current_line.append(ch)
        else:
            lines.append([ch])

    tokens = []
    for line_chars in lines:
        # Sort left-to-right
        line_chars.sort(key=lambda c: c.get("x0", 0.0))
        widths = [(c.get("x1", 0.0) - c.get("x0", 0.0)) for c in line_chars]
        m_width = median([w for w in widths if w > 0]) or 5.0
        # Threshold that determines whether two adjacent digits belong to same token
        gap_threshold = max(m_width * gap_factor, 7.0)

        current = [line_chars[0]]
        for prev, nxt in zip(line_chars, line_chars[1:]):
            gap = nxt.get("x0", 0.0) - prev.get("x1", 0.0)
            if gap <= gap_threshold:
                current.append(nxt)
            else:
                tokens.append(_chars_to_token(current))
                current = [nxt]
        if current:
            tokens.append(_chars_to_token(current))

    return tokens


def _chars_to_token(chars):
    text = "".join(ch.get("text", "") for ch in chars)
    x0 = min(ch.get("x0", 0.0) for ch in chars)
    x1 = max(ch.get("x1", 0.0) for ch in chars)
    top = min(ch.get("top", 0.0) for ch in chars)
    bottom = max(ch.get("bottom", 0.0) for ch in chars)
    sizes = [ch.get("size", None) for ch in chars if ch.get("size", None) is not None]
    size = median(sizes) if sizes else None
    return {"text": text, "x0": x0, "x1": x1, "top": top, "bottom": bottom, "size": size}


def tokens_to_codes(tokens):
    """Normalize tokens, returning only exact 6-digit strings with geometry."""
    codes = []
    for t in tokens:
        # Remove any non-digits just in case (defensive)
        digits = re.sub(r"\D", "", t.get("text", ""))
        if len(digits) == 6:
            codes.append({
                "code": digits,
                "x0": t.get("x0", 0.0),
                "x1": t.get("x1", 0.0),
                "top": t.get("top", 0.0),
                "bottom": t.get("bottom", 0.0),
                "size": t.get("size", None),
            })
    return codes


def _filter_codes_by_layout(codes, page_height, top_margin=0.05, bottom_margin=0.05, size_filter="auto"):
    """
    Remove header/footer and outlier font-size tokens.
    - top_margin, bottom_margin: fractions of page height to ignore.
    - size_filter: 'auto' uses median size; 'off' disables size filtering.
    """
    if not codes:
        return codes
    h = float(page_height) if page_height else 0.0
    top_cut = h * float(top_margin)
    bot_cut = h * (1.0 - float(bottom_margin))

    filtered = []
    for c in codes:
        top = float(c.get("top", 0.0))
        bottom = float(c.get("bottom", 0.0))
        if h > 0.0 and (top < top_cut or bottom > bot_cut):
            continue
        filtered.append(c)

    if size_filter != "auto":
        return filtered

    sizes = [c.get("size") for c in filtered if c.get("size")]
    if len(sizes) < 3:
        return filtered
    m = median(sizes)
    low = m * 0.7
    high = m * 1.3
    return [c for c in filtered if (c.get("size") is None or (low <= c.get("size") <= high))]


def _kmeans_1d(points, k, max_iter=50):
    """Simple 1D k-means. Returns (centers, labels)."""
    if k <= 0:
        return [], []
    pts = list(points)
    if not pts:
        return [0.0] * k, []
    pts.sort()
    # Initialize centers at k quantiles
    centers = []
    for i in range(k):
        idx = int(round((i + 0.5) * (len(pts) / float(k)) - 1))
        idx = max(0, min(len(pts) - 1, idx))
        centers.append(float(pts[idx]))
    labels = [0] * len(points)
    for _ in range(max_iter):
        # Assign
        changed = False
        for i, p in enumerate(points):
            nearest = min(range(k), key=lambda j: abs(p - centers[j]))
            if labels[i] != nearest:
                labels[i] = nearest
                changed = True
        # Update
        new_centers = centers[:]
        for j in range(k):
            cluster_pts = [p for p, lbl in zip(points, labels) if lbl == j]
            if cluster_pts:
                new_centers[j] = sum(cluster_pts) / float(len(cluster_pts))
        if not changed or new_centers == centers:
            centers = new_centers
            break
        centers = new_centers
    return centers, labels


def _silhouette_1d(points, centers, labels):
    # Compute average silhouette over points; labels aligned with points
    # Handle degenerate cases
    if not points or not centers:
        return 0.0
    k = len(centers)
    # Build clusters indices
    clusters = {j: [] for j in range(k)}
    for idx, lbl in enumerate(labels):
        clusters.setdefault(lbl, []).append(idx)
    # Precompute distances
    def avg_distance(idx_list_a, idx_list_b):
        if not idx_list_a or not idx_list_b:
            return 0.0
        if idx_list_a is idx_list_b and len(idx_list_a) <= 1:
            return 0.0
        total = 0.0
        count = 0
        for i in idx_list_a:
            for j in idx_list_b:
                if idx_list_a is idx_list_b and i == j:
                    continue
                total += abs(points[i] - points[j])
                count += 1
        return total / float(count) if count else 0.0
    silhouettes = []
    for i, lbl in enumerate(labels):
        same = clusters.get(lbl, [])
        # a: avg intra-cluster distance
        a = avg_distance([i], same)
        # b: min avg distance to other clusters
        b_candidates = []
        for other in range(k):
            if other == lbl:
                continue
            other_idxs = clusters.get(other, [])
            if other_idxs:
                b_candidates.append(avg_distance([i], other_idxs))
        b = min(b_candidates) if b_candidates else 0.0
        m = max(a, b)
        s = ((b - a) / m) if m > 0 else 0.0
        silhouettes.append(s)
    return sum(silhouettes) / float(len(silhouettes)) if silhouettes else 0.0


def _infer_num_columns(points, k_min=2, k_max=10):
    # points: x centers
    n = len(points)
    if n <= 1:
        return 1
    k_max = max(k_min, min(k_max, n))
    best_k = None
    best_score = -1.0
    for k in range(k_min, k_max + 1):
        centers, labels = _kmeans_1d(points, k)
        score = _silhouette_1d(points, centers, labels)
        # Prefer simpler model if scores equal within tiny epsilon
        if score > best_score + 1e-6 or (abs(score - best_score) <= 1e-6 and (best_k is None or k < best_k)):
            best_score = score
            best_k = k
    return best_k or 1


def split_into_columns(codes, page_width, num_cols=5):
    """Assign codes to columns using 1D k-means on x centers; supports num_cols='auto'.
    Returns (columns_dict, col_centers_sorted)
    """
    x_centers = [((c.get("x0", 0.0) + c.get("x1", 0.0)) / 2.0) for c in codes]
    if num_cols == "auto":
        if not x_centers:
            inferred = 1
        else:
            inferred = _infer_num_columns(x_centers, k_min=2, k_max=10)
        num = inferred
    else:
        num = int(num_cols)
    columns = {i: [] for i in range(num)}
    if x_centers and len(x_centers) >= max(2, num):
        centers, labels = _kmeans_1d(x_centers, num)
        order = sorted(range(num), key=lambda j: centers[j])
        cluster_to_col = {cluster_id: idx for idx, cluster_id in enumerate(order)}
        for c, lbl in zip(codes, labels):
            col_idx = cluster_to_col.get(lbl, 0)
            columns[col_idx].append(c)
        col_centers_sorted = [centers[j] for j in order]
    else:
        # Fallback: equal bands
        col_width = float(page_width) / float(max(1, num))
        for c in codes:
            x_center = (c.get("x0", 0.0) + c.get("x1", 0.0)) / 2.0
            col_idx = int(x_center // col_width)
            if col_idx < 0:
                col_idx = 0
            if col_idx >= num:
                col_idx = num - 1
            columns[col_idx].append(c)
        col_centers_sorted = [((i + 0.5) * col_width) for i in range(num)]
    for idx in columns:
        columns[idx].sort(key=lambda c: c.get("top", 0.0))
    return columns, col_centers_sorted


def parse_rows_per_page(value):
    if isinstance(value, int):
        return value
    s = str(value).strip().lower()
    if s == "auto":
        return "auto"
    try:
        n = int(s)
        if n <= 0:
            raise ValueError
        return n
    except Exception:
        raise argparse.ArgumentTypeError("--rows-per-page must be a positive integer or 'auto'")


def _infer_row_tolerance(y_positions):
    """Infer a reasonable row tolerance from sorted Y positions using gaps.
    Uses median of top-k smallest non-zero gaps, scaled a bit.
    """
    if not y_positions or len(y_positions) <= 1:
        return 10.0
    ys = sorted(y_positions)
    gaps = [abs(b - a) for a, b in zip(ys, ys[1:]) if abs(b - a) > 0]
    if not gaps:
        return 10.0
    gaps.sort()
    k = max(1, min(10, len(gaps)))
    core = gaps[:k]
    m = median(core) or 10.0
    # Allow some wiggle; scale up modestly
    return max(6.0, min(24.0, m * 1.6))


def _cluster_rows_by_y(columns, row_tolerance):
    """
    Build common row bands across the page using all tokens' Y positions.
    Returns a tuple (row_centers, assignments) where:
      - row_centers is a sorted list of Y centers (small -> top)
      - assignments is a dict: (row_idx, col_idx) -> list[codes]
    """
    # Collect all y positions (use top)
    y_positions = []
    for col_idx, items in columns.items():
        for c in items:
            y_positions.append(float(c.get("top", 0.0)))
    if not y_positions:
        return [], {}
    if row_tolerance == "auto":
        row_tolerance = _infer_row_tolerance(y_positions)
    y_positions.sort()
    # Greedy 1D clustering by distance threshold
    clusters = [[y_positions[0]]]
    for y in y_positions[1:]:
        if abs(y - clusters[-1][-1]) <= row_tolerance:
            clusters[-1].append(y)
        else:
            clusters.append([y])
    row_centers = [sum(g) / float(len(g)) for g in clusters]

    # Assign tokens to nearest row center if within tolerance
    assignments = {}
    for col_idx, items in columns.items():
        for c in items:
            yc = float(c.get("top", 0.0))
            # find nearest center
            nearest = min(range(len(row_centers)), key=lambda i: abs(yc - row_centers[i]))
            if abs(yc - row_centers[nearest]) <= row_tolerance:
                assignments.setdefault((nearest, col_idx), []).append(c)
            # else drop as outlier (e.g., header)
    return row_centers, assignments


def process_pdf(pdf_path, rows_per_page=3, num_cols=5, row_tolerance="auto", emit_missing=False,
                top_margin=0.05, bottom_margin=0.05, size_filter="auto", min_codes_per_row=2,
                export_crops_dir=None, dpi=144):
    all_codes_in_grid_order = []
    csv_rows = []
    base_row_offset = 0

    with pdfplumber.open(pdf_path) as pdf:
        for page_num, page in enumerate(pdf.pages, start=1):
            tokens = group_chars_into_tokens(page.chars)
            page_codes = tokens_to_codes(tokens)
            # Layout-based filtering to drop headers/footers and odd font sizes
            page_codes = _filter_codes_by_layout(
                page_codes,
                page_height=page.height,
                top_margin=top_margin,
                bottom_margin=bottom_margin,
                size_filter=size_filter,
            )

            columns, col_centers = split_into_columns(page_codes, page.width, num_cols=num_cols)
            # Per-page column letters (A..)
            page_num_cols = max(columns.keys()) + 1 if columns else (num_cols if isinstance(num_cols, int) else 1)
            col_letters = [chr(ord('A') + i) for i in range(page_num_cols)]

            # Build global row bands across the page, then assign tokens per row and column
            row_centers, assignments = _cluster_rows_by_y(columns, row_tolerance=row_tolerance)

            # If rows_per_page is a number, clamp to that many rows from the top
            row_indices = list(range(len(row_centers)))
            if rows_per_page != "auto":
                row_indices = row_indices[: int(rows_per_page)]

            # Emit in row-major order without padding unless requested
            out_row_counter = 0
            # Prepare optional crop rendering
            render_img = None
            scale = float(dpi) / 72.0
            if export_crops_dir:
                try:
                    os.makedirs(export_crops_dir, exist_ok=True)
                except Exception:
                    pass
                # Lazy-render page image only if needed
                render_img = page.to_image(resolution=dpi).original
            # Cache page vector images if available
            page_images = []
            try:
                page_images = list(getattr(page, 'images', []) or [])
            except Exception:
                page_images = []

            # Determine content bounds for edges
            content_x_min = min((c.get("x0", 0.0) for c in page_codes), default=0.0)
            content_x_max = max((c.get("x1", 0.0) for c in page_codes), default=page.width)
            content_top = min((c.get("top", 0.0) for c in page_codes), default=0.0)
            content_bottom = max((c.get("bottom", 0.0) for c in page_codes), default=page.height)

            # Compute column edges from centers
            def mids(vals):
                return [ (a + b) / 2.0 for a, b in zip(vals[:-1], vals[1:]) ]
            col_edges = [content_x_min] + mids(col_centers) + [content_x_max]

            # Compute row edges from centers for selected indices
            selected_centers = [row_centers[i] for i in row_indices] if row_indices else []
            row_edges = []
            if selected_centers:
                mids_y = mids(selected_centers)
                row_edges = [content_top] + mids_y + [content_bottom]

            for r in row_indices:
                total_in_row = sum(len(assignments.get((r, cidx), [])) for cidx in range(page_num_cols))
                # Skip sparse rows entirely unless emit_missing is set
                if not emit_missing and total_in_row < min_codes_per_row:
                    continue

                for col_idx in range(page_num_cols):
                    cell_items = assignments.get((r, col_idx), [])
                    if not cell_items:
                        if not emit_missing:
                            continue
                        code_val = "MISSING"
                    else:
                        # If more than one code in a cell, take the leftmost by x0
                        chosen = sorted(cell_items, key=lambda c: c.get("x0", 0.0))[0]
                        code_val = chosen["code"]

                    all_codes_in_grid_order.append(code_val)
                    csv_rows.append({
                        "Column": col_letters[col_idx],
                        "Row": base_row_offset + out_row_counter + 1,
                        "Code": code_val,
                        "Page": page_num,
                    })

                    # Optional crop export per cell
                    if export_crops_dir and render_img is not None and code_val and code_val != "MISSING" and row_edges:
                        # Determine cell bounds
                        row_pos = row_indices.index(r)
                        cell_y0 = row_edges[row_pos]
                        cell_y1 = row_edges[row_pos + 1]
                        cell_x0 = col_edges[col_idx]
                        cell_x1 = col_edges[col_idx + 1]

                        # Padding and text exclusion margin (points)
                        pad = 2.0
                        text_margin = 6.0
                        cx0 = max(content_x_min, cell_x0 + pad)
                        cx1 = min(content_x_max, cell_x1 - pad)
                        cy0 = max(content_top, cell_y0 + pad)
                        cy1 = min(content_bottom, cell_y1 - pad)

                        # Aim for a square crop adjacent to the detected code token
                        square_h = max(1.0, (cy1 - cy0) * 0.95)
                        max_cell_w = max(1.0, (cx1 - cx0) * 0.95)
                        square_w = min(square_h, max_cell_w)

                        code_x0 = chosen.get("x0", cx0)
                        code_x1 = chosen.get("x1", cx1)
                        code_center = (code_x0 + code_x1) / 2.0
                        cell_center = (cx0 + cx1) / 2.0

                        # Prefer swatch on the side away from the code text
                        # First, try to use vector image objects if present (more reliable than raster heuristics)
                        chosen_bbox = None
                        if page_images:
                            left_region_x1 = max(cx0, min(cx1, code_x0 - text_margin))
                            if left_region_x1 > cx0 + 1.0:
                                best_area = 0.0
                                for img in page_images:
                                    try:
                                        ix0 = float(img.get('x0', 0.0))
                                        ix1 = float(img.get('x1', 0.0))
                                        itop = float(img.get('top', 0.0))
                                        ibot = float(img.get('bottom', 0.0))
                                    except Exception:
                                        continue
                                    # Intersect with cell-left region
                                    rx0 = max(cx0, ix0)
                                    rx1 = min(left_region_x1, ix1)
                                    ry0 = max(cy0, itop)
                                    ry1 = min(cy1, ibot)
                                    if rx1 <= rx0 or ry1 <= ry0:
                                        continue
                                    inter_area = (rx1 - rx0) * (ry1 - ry0)
                                    if inter_area > best_area:
                                        best_area = inter_area
                                        chosen_bbox = (rx0, ry0, rx1, ry1)

                        if chosen_bbox is not None:
                            sx0, sy0, sx1, sy1 = chosen_bbox
                            # Expand to square while staying inside cell-left region
                            w = sx1 - sx0
                            h = sy1 - sy0
                            side = max(min(w, h), 1.0)
                            # Center square within the chosen bbox
                            cx = (sx0 + sx1) / 2.0
                            cy = (sy0 + sy1) / 2.0
                            sx0 = max(cx0, cx - side / 2.0)
                            sy0 = max(cy0, cy - side / 2.0)
                            sx1 = min(code_x0 - text_margin, sx0 + side)
                            sy1 = min(cy1, sy0 + side)
                            left = int(max(0, math.floor(sx0 * scale)))
                            right = int(min(render_img.width, math.ceil(sx1 * scale)))
                            upper = int(max(0, math.floor(sy0 * scale)))
                            lower = int(min(render_img.height, math.ceil(sy1 * scale)))
                        else:
                            # Always pick swatch LEFT of the code: scan left region and choose the darkest (least-white) square.
                        scan_x1 = max(cx0 + 1.0, min(cx1, code_x0 - text_margin))
                        if scan_x1 <= cx0 + 2.0:
                            # Nothing left of code; fallback to centered square in cell
                            sx0 = cx0 + max(0.0, (max_cell_w - square_w) / 2.0)
                            sx1 = sx0 + square_w
                            left = int(max(0, math.floor(sx0 * scale)))
                            right = int(min(render_img.width, math.ceil(sx1 * scale)))
                            side_px = square_w * scale
                            vy = int(max(0, math.floor((cy0 * scale) + max(0.0, ((cy1 - cy0) * scale - side_px) / 2.0))))
                            upper = vy
                            lower = int(min(render_img.height, math.ceil(vy + side_px)))
                        else:
                            # Prefer contour-based square detection in the left region
                            if cv2 is not None and np is not None:
                                # Render left region to grayscale
                                lx0 = int(max(0, math.floor(cx0 * scale)))
                                lx1 = int(min(render_img.width, math.ceil((code_x0 - text_margin) * scale)))
                                ly0 = int(max(0, math.floor(cy0 * scale)))
                                ly1 = int(min(render_img.height, math.ceil(cy1 * scale)))
                                if lx1 > lx0 + 5 and ly1 > ly0 + 5:
                                    region = render_img.crop((lx0, ly0, lx1, ly1)).convert('L')
                                    arr = np.array(region)
                                    # Normalize and edge-detect
                                    arr_blur = cv2.GaussianBlur(arr, (5, 5), 0)
                                    edges = cv2.Canny(arr_blur, 30, 100)
                                    # Find contours
                                    cnts, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                                    best_rect = None
                                    best_score = -1.0
                                    for c in cnts:
                                        x, y, w, h = cv2.boundingRect(c)
                                        if w < 10 or h < 10:
                                            continue
                                        # Prefer near-square patches with decent area
                                        aspect = min(w, h) / max(w, h)
                                        area = w * h
                                        score = aspect * area
                                        if score > best_score:
                                            best_score = score
                                            best_rect = (x, y, w, h)
                                    if best_rect is not None:
                                        x, y, w, h = best_rect
                                        side = min(w, h)
                                        cx = x + w // 2
                                        cy = y + h // 2
                                        # Square around center
                                        sx0 = max(0, cx - side // 2)
                                        sy0 = max(0, cy - side // 2)
                                        sx1 = min(arr.shape[1], sx0 + side)
                                        sy1 = min(arr.shape[0], sy0 + side)
                                        # Map back to page image coordinates
                                        left = lx0 + sx0
                                        right = lx0 + sx1
                                        upper = ly0 + sy0
                                        lower = ly0 + sy1
                                    else:
                                        left = right = upper = lower = 0
                                else:
                                    left = right = upper = lower = 0
                            else:
                                left = right = upper = lower = 0

                            if not (right > left and lower > upper):
                                # Fallback heuristic: 2D grid search scored by darkness + texture
                                left_region_w = scan_x1 - cx0
                                step_x = max(2.0, square_w * 0.25)
                                step_y = max(2.0, square_h * 0.25)
                                xs = []
                                x = cx0
                                while x + square_w <= scan_x1 + 1e-6:
                                    xs.append(x)
                                    x += step_x
                                xs.append(max(cx0, scan_x1 - square_w))

                                ys = []
                                y = cy0
                                while y + square_h <= cy1 + 1e-6:
                                    ys.append(y)
                                    y += step_y
                                if not ys:
                                    ys = [cy0]

                                best = None
                                best_score = -1.0
                                for sx0 in xs:
                                    sx0 = max(cx0, min(scan_x1 - square_w, sx0))
                                    sx1 = sx0 + square_w
                                    for sy0 in ys:
                                        sy0 = max(cy0, min(cy1 - square_h, sy0))
                                        sy1 = sy0 + square_h
                                        left_px = int(max(0, math.floor(sx0 * scale)))
                                        right_px = int(min(render_img.width, math.ceil(sx1 * scale)))
                                        upper_px = int(max(0, math.floor(sy0 * scale)))
                                        lower_px = int(min(render_img.height, math.ceil(sy1 * scale)))
                                        if right_px <= left_px or lower_px <= upper_px:
                                            continue
                                        sub = render_img.crop((left_px, upper_px, right_px, lower_px))
                                        mean = 255.0
                                        var = 0.0
                                        if ImageStat is not None:
                                            try:
                                                st = ImageStat.Stat(sub.convert('L'))
                                                mean = float(st.mean[0])
                                                var = float(st.var[0])
                                            except Exception:
                                                pass
                                        score = (255.0 - mean) + 0.5 * math.sqrt(max(var, 0.0))
                                        if score > best_score:
                                            best_score = score
                                            best = (left_px, upper_px, right_px, lower_px)

                                if best is not None:
                                    left, upper, right, lower = best
                                else:
                                    # Centered fallback
                                    sx0 = cx0 + max(0.0, (max_cell_w - square_w) / 2.0)
                                    sx1 = sx0 + square_w
                                    left = int(max(0, math.floor(sx0 * scale)))
                                    right = int(min(render_img.width, math.ceil(sx1 * scale)))
                                    side_px = square_w * scale
                                    vy = int(max(0, math.floor((cy0 * scale) + max(0.0, ((cy1 - cy0) * scale - side_px) / 2.0))))
                                    upper = vy
                                    lower = int(min(render_img.height, math.ceil(vy + side_px)))
                            left_region_w = scan_x1 - cx0
                            step_x = max(2.0, square_w * 0.25)
                            step_y = max(2.0, square_h * 0.25)
                            xs = []
                            x = cx0
                            while x + square_w <= scan_x1 + 1e-6:
                                xs.append(x)
                                x += step_x
                            xs.append(max(cx0, scan_x1 - square_w))

                            ys = []
                            y = cy0
                            while y + square_h <= cy1 + 1e-6:
                                ys.append(y)
                                y += step_y
                            if not ys:
                                ys = [cy0]

                            best = None
                            best_score = -1.0
                            for sx0 in xs:
                                sx0 = max(cx0, min(scan_x1 - square_w, sx0))
                                sx1 = sx0 + square_w
                                for sy0 in ys:
                                    sy0 = max(cy0, min(cy1 - square_h, sy0))
                                    sy1 = sy0 + square_h
                                    left_px = int(max(0, math.floor(sx0 * scale)))
                                    right_px = int(min(render_img.width, math.ceil(sx1 * scale)))
                                    upper_px = int(max(0, math.floor(sy0 * scale)))
                                    lower_px = int(min(render_img.height, math.ceil(sy1 * scale)))
                                    if right_px <= left_px or lower_px <= upper_px:
                                        continue
                                    sub = render_img.crop((left_px, upper_px, right_px, lower_px))
                                    mean = 255.0
                                    var = 0.0
                                    if ImageStat is not None:
                                        try:
                                            st = ImageStat.Stat(sub.convert('L'))
                                            mean = float(st.mean[0])
                                            var = float(st.var[0])
                                        except Exception:
                                            pass
                                    # Score: prefer darker and more textured
                                    score = (255.0 - mean) + 0.5 * math.sqrt(max(var, 0.0))
                                    if score > best_score:
                                        best_score = score
                                        best = (left_px, upper_px, right_px, lower_px)

                            if best is not None:
                                left, upper, right, lower = best
                            else:
                                # Fallback to centered
                                sx0 = cx0 + max(0.0, (max_cell_w - square_w) / 2.0)
                                sx1 = sx0 + square_w
                                left = int(max(0, math.floor(sx0 * scale)))
                                right = int(min(render_img.width, math.ceil(sx1 * scale)))
                                side_px = square_w * scale
                                vy = int(max(0, math.floor((cy0 * scale) + max(0.0, ((cy1 - cy0) * scale - side_px) / 2.0))))
                                upper = vy
                                lower = int(min(render_img.height, math.ceil(vy + side_px)))

                        if right > left and lower > upper:
                            crop = render_img.crop((left, upper, right, lower))
                            row_abs = base_row_offset + out_row_counter + 1
                            fname = f"p{page_num:02d}_r{row_abs:02d}_c{col_letters[col_idx]}_{code_val}.png"
                            out_path = os.path.join(export_crops_dir, fname)
                            try:
                                crop.save(out_path)
                            except Exception:
                                pass

                out_row_counter += 1

            base_row_offset += out_row_counter

    return all_codes_in_grid_order, csv_rows


def write_outputs(out_prefix, codes, csv_rows):
    # TXT: single line of codes separated by ", "
    txt_path = f"{out_prefix}.codes.txt"
    with open(txt_path, "w", encoding="utf-8") as f:
        f.write(", ".join(codes) + "\n")

    # CSV: Column, Row, Code
    csv_path = f"{out_prefix}.codes.csv"
    fieldnames = ["Column", "Row", "Code", "Page"]
    with open(csv_path, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for row in csv_rows:
            writer.writerow(row)

    return txt_path, csv_path


def main():
    parser = argparse.ArgumentParser(description="Extract 6-digit colour codes from a PDF colour card in grid order.")
    parser.add_argument("pdf", help="Path to input PDF file")
    parser.add_argument("--rows-per-page", dest="rows_per_page", default="auto", type=parse_rows_per_page,
                        help="Rows per page (int or 'auto'; default: auto).")
    parser.add_argument("--row-tol", dest="row_tolerance", default="auto",
                        help="Y clustering tolerance in points for row grouping (float or 'auto'; default: auto).")
    parser.add_argument("--emit-missing", dest="emit_missing", action="store_true",
                        help="Emit MISSING for empty cells; otherwise skip them.")
    parser.add_argument("--top-margin", dest="top_margin", default=0.05, type=float,
                        help="Top margin (0-0.5) to ignore header area (default: 0.05).")
    parser.add_argument("--bottom-margin", dest="bottom_margin", default=0.05, type=float,
                        help="Bottom margin (0-0.5) to ignore footer area (default: 0.05).")
    parser.add_argument("--size-filter", dest="size_filter", default="auto",
                        help="Font-size filter: 'auto' or 'off' (default: auto).")
    parser.add_argument("--min-codes-per-row", dest="min_codes_per_row", default=2, type=int,
                        help="Drop very sparse rows with <N codes (default: 2).")
    parser.add_argument("--export-crops-dir", dest="export_crops_dir", default=None,
                        help="Directory to save per-cell PNG crops. If omitted, no crops are saved.")
    parser.add_argument("--dpi", dest="dpi", default=144, type=int,
                        help="Render DPI for crop images (default: 144).")
    parser.add_argument("--cols", dest="num_cols", default="auto",
                        help="Columns per page (int or 'auto'; default: auto).")
    parser.add_argument("--out-prefix", dest="out_prefix", default=None,
                        help="Output prefix (default: input filename without extension)")

    args = parser.parse_args()

    pdf_path = args.pdf
    if not os.path.isfile(pdf_path):
        print(f"Error: file not found: {pdf_path}", file=sys.stderr)
        sys.exit(1)

    out_prefix = args.out_prefix
    if not out_prefix:
        base = os.path.basename(pdf_path)
        name, _ = os.path.splitext(base)
        out_prefix = name

    # Normalize cols argument to int or 'auto'
    num_cols = args.num_cols
    try:
        num_cols_val = int(num_cols)
    except Exception:
        num_cols_val = "auto"

    codes, csv_rows = process_pdf(
        pdf_path,
        rows_per_page=args.rows_per_page,
        num_cols=num_cols_val,
        row_tolerance=args.row_tolerance,
        emit_missing=args.emit_missing,
        top_margin=args.top_margin,
        bottom_margin=args.bottom_margin,
        size_filter=args.size_filter,
        min_codes_per_row=args.min_codes_per_row,
        export_crops_dir=args.export_crops_dir,
        dpi=args.dpi,
    )
    txt_path, csv_path = write_outputs(out_prefix, codes, csv_rows)

    # Also print the comma-separated line to stdout
    print(", ".join(codes))
    print(f"Wrote: {txt_path}")
    print(f"Wrote: {csv_path}")


if __name__ == "__main__":
    main()


