#!/usr/bin/env python3
"""
Robust passive Meshtastic packet decoder for raw LoRa payloads received over UDP.

Designed for the GNU Radio chain:

    LoRa Rx tagged-byte output
      -> CRC-gating tagged-stream-to-PDU Embedded Python Block
      -> Socket PDU (UDP client, 127.0.0.1:7355)

Capabilities
------------
* Plaintext Meshtastic application decoding.
* The documented public/default one-byte PSK family (optional audit).
* Explicitly supplied authorised PSKs in hexadecimal or Base64.
* Strict nested-Protobuf and semantic validation to reduce false positives.
* Correct key assessment: plaintext is reported as "no-encryption".
* Duplicate and alternate-relay-path classification.
* Compact JSONL logging and optional SQLite persistence.
* Persistent node, position, telemetry, and neighbour-link tables.

This tool intentionally does not perform arbitrary password guessing or exhaustive
AES key search. Use only on traffic and keys covered by your authorisation.
"""

from __future__ import annotations

import argparse
import base64
import binascii
import hashlib
import importlib.metadata
import json
import math
import os
import re
import socket
import sqlite3
import struct
import sys
import time
from collections import Counter
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Iterable, Optional, Sequence

from Crypto.Cipher import AES
from google.protobuf.json_format import MessageToDict
from google.protobuf.message import DecodeError, Message
from meshtastic.protobuf import mesh_pb2, portnums_pb2, telemetry_pb2


DEFAULT_MESHTASTIC_KEY = bytes.fromhex("d4f1bb3a20290759f0bcffabcf4e6901")
BROADCAST_NODE = 0xFFFFFFFF
NODE_ID_RE = re.compile(r"^![0-9a-fA-F]{8}$")
MIN_REASONABLE_TIME = int(datetime(2015, 1, 1, tzinfo=timezone.utc).timestamp())
MAX_LORA_PAYLOAD = 255


class ValidationError(ValueError):
    """A candidate decode parsed syntactically but failed semantic validation."""


@dataclass(frozen=True)
class CandidateKey:
    label: str
    key: Optional[bytes]  # None means plaintext/no encryption.
    assessment: str


@dataclass(frozen=True)
class DecoderOptions:
    candidates: tuple[CandidateKey, ...]
    channel_names: tuple[str, ...]
    allow_outer_only: bool
    strict: bool
    max_future_seconds: int


@dataclass
class SeenPacket:
    seen_at: float
    raw_hash: str
    payload_hash: str
    route_signature: tuple[Any, ...]
    count: int = 1


class DuplicateTracker:
    def __init__(self, ttl_seconds: int = 3600) -> None:
        self.ttl_seconds = max(1, ttl_seconds)
        self._seen: dict[tuple[int, int], SeenPacket] = {}
        self._operations = 0

    def classify(
        self,
        *,
        sender: int,
        packet_id: int,
        raw: bytes,
        route_signature: tuple[Any, ...],
        now: float,
    ) -> tuple[bool, Optional[str], int]:
        key = (sender, packet_id)
        raw_hash = hashlib.sha256(raw).hexdigest()[:16]
        payload_hash = hashlib.sha256(raw[16:]).hexdigest()[:16]
        previous = self._seen.get(key)

        if previous is None or now - previous.seen_at > self.ttl_seconds:
            self._seen[key] = SeenPacket(
                seen_at=now,
                raw_hash=raw_hash,
                payload_hash=payload_hash,
                route_signature=route_signature,
            )
            duplicate = False
            kind = None
            count = 1
        else:
            previous.count += 1
            previous.seen_at = now
            duplicate = True
            count = previous.count
            if raw_hash == previous.raw_hash:
                kind = "exact"
            elif payload_hash == previous.payload_hash and route_signature != previous.route_signature:
                kind = "alternate-path"
            elif payload_hash == previous.payload_hash:
                kind = "header-variant"
            else:
                kind = "packet-id-collision-or-corruption"

        self._operations += 1
        if self._operations % 1000 == 0:
            cutoff = now - self.ttl_seconds
            self._seen = {
                key_: value
                for key_, value in self._seen.items()
                if value.seen_at >= cutoff
            }

        return duplicate, kind, count


class Statistics:
    def __init__(self) -> None:
        self.started = time.time()
        self.counters: Counter[str] = Counter()
        self.apps: Counter[str] = Counter()

    def observe(self, packet: dict[str, Any]) -> None:
        self.counters["total"] += 1
        mode = packet.get("decode_mode", "unknown")
        self.counters[mode] += 1
        if packet.get("duplicate"):
            self.counters["duplicates"] += 1
            kind = packet.get("duplicate_kind")
            if kind:
                self.counters[f"duplicate:{kind}"] += 1
        app = packet.get("decoded", {}).get("portnum")
        if app:
            self.apps[app] += 1

    def snapshot(self) -> dict[str, Any]:
        return {
            "uptime_seconds": round(time.time() - self.started, 1),
            "counts": dict(self.counters),
            "applications": dict(self.apps),
        }


def utc_now_iso() -> str:
    return datetime.now(timezone.utc).isoformat(timespec="milliseconds")


def protobuf_version() -> str:
    for package in ("meshtastic", "meshtastic-python"):
        try:
            return importlib.metadata.version(package)
        except importlib.metadata.PackageNotFoundError:
            continue
    return "unknown"


def xor_hash(value: bytes) -> int:
    result = 0
    for byte in value:
        result ^= byte
    return result


def meshtastic_channel_hash(channel_name: str, key: Optional[bytes]) -> int:
    return xor_hash(channel_name.encode("utf-8")) ^ xor_hash(key or b"")


def expand_short_psk(index: int) -> Optional[bytes]:
    if not 0 <= index <= 255:
        raise ValueError("short PSK index must be in the range 0..255")
    if index == 0:
        return None
    key = bytearray(DEFAULT_MESHTASTIC_KEY)
    key[-1] = (key[-1] + index - 1) & 0xFF
    return bytes(key)


def normalize_psk(raw: bytes) -> Optional[bytes]:
    length = len(raw)
    if length == 0:
        return None
    if length == 1:
        return expand_short_psk(raw[0])
    if length < 16:
        return raw.ljust(16, b"\x00")
    if length == 16:
        return raw
    if length < 32:
        return raw.ljust(32, b"\x00")
    if length == 32:
        return raw
    raise ValueError("Meshtastic PSK must be at most 32 bytes")


def aes_ctr_decrypt(ciphertext: bytes, sender: int, packet_id: int, key: bytes) -> bytes:
    iv = struct.pack("<Q", packet_id) + struct.pack("<I", sender) + b"\x00" * 4
    cipher = AES.new(
        key,
        AES.MODE_CTR,
        nonce=iv[:12],
        initial_value=int.from_bytes(iv[12:], "big"),
    )
    return cipher.decrypt(ciphertext)


def message_to_dict(message: Message) -> dict[str, Any]:
    return MessageToDict(message, preserving_proto_field_name=True)


def parse_message(message: Message, payload: bytes) -> Message:
    message.ParseFromString(payload)
    return message


def parse_outer_data(payload: bytes) -> Optional[mesh_pb2.Data]:
    data = mesh_pb2.Data()
    try:
        data.ParseFromString(payload)
        portnums_pb2.PortNum.Name(data.portnum)
    except (DecodeError, ValueError):
        return None

    if data.portnum == portnums_pb2.PortNum.UNKNOWN_APP:
        return None
    if not data.payload and not data.want_response:
        return None
    return data


def is_sane_text(value: str) -> bool:
    if not value or len(value) > 512:
        return False
    return all(character.isprintable() or character in "\r\n\t" for character in value)


def validate_position(
    message: mesh_pb2.Position,
    *,
    strict: bool,
    max_future_seconds: int,
) -> list[str]:
    warnings: list[str] = []
    latitude_i = int(getattr(message, "latitude_i", 0))
    longitude_i = int(getattr(message, "longitude_i", 0))

    if bool(latitude_i) != bool(longitude_i):
        raise ValidationError("position contains only one coordinate")
    if latitude_i and not -900_000_000 <= latitude_i <= 900_000_000:
        raise ValidationError(f"latitude_i out of range: {latitude_i}")
    if longitude_i and not -1_800_000_000 <= longitude_i <= 1_800_000_000:
        raise ValidationError(f"longitude_i out of range: {longitude_i}")

    precision_bits = int(getattr(message, "precision_bits", 0))
    if precision_bits and not 1 <= precision_bits <= 32:
        raise ValidationError(f"precision_bits out of range: {precision_bits}")

    altitude = int(getattr(message, "altitude", 0))
    if altitude and not -1_000 <= altitude <= 50_000:
        raise ValidationError(f"altitude out of range: {altitude}")

    position_time = int(getattr(message, "time", 0))
    if position_time:
        max_time = int(time.time()) + max_future_seconds
        if not MIN_REASONABLE_TIME <= position_time <= max_time:
            text = f"position time outside expected range: {position_time}"
            if strict:
                raise ValidationError(text)
            warnings.append(text)

    return warnings


def validate_nodeinfo(message: mesh_pb2.User, sender: int, *, strict: bool) -> list[str]:
    warnings: list[str] = []
    if not (
        message.id
        or message.long_name
        or message.short_name
        or message.hw_model
        or getattr(message, "public_key", b"")
    ):
        raise ValidationError("User message lacks identity fields")

    if message.id:
        if not NODE_ID_RE.fullmatch(message.id):
            raise ValidationError(f"invalid node ID format: {message.id!r}")
        inner_sender = int(message.id[1:], 16)
        if inner_sender != sender:
            text = f"NodeInfo ID {message.id} does not match outer sender !{sender:08x}"
            if strict:
                raise ValidationError(text)
            warnings.append(text)

    if len(message.long_name) > 80:
        raise ValidationError("long_name is implausibly long")
    if len(message.short_name) > 16:
        raise ValidationError("short_name is implausibly long")
    if message.macaddr and len(message.macaddr) != 6:
        raise ValidationError(f"MAC address has {len(message.macaddr)} bytes, expected 6")
    return warnings


def _finite_number(value: Any) -> Optional[float]:
    if isinstance(value, bool):
        return None
    if isinstance(value, (int, float)):
        number = float(value)
        return number if math.isfinite(number) else None
    return None


def validate_telemetry_dict(data: dict[str, Any]) -> None:
    ranges: dict[str, tuple[float, float]] = {
        "battery_level": (0, 255),
        "voltage": (-0.05, 100),
        "channel_utilization": (-0.05, 100.05),
        "air_util_tx": (-0.05, 100.05),
        "temperature": (-100, 150),
        "relative_humidity": (0, 100.05),
        "barometric_pressure": (100, 1_200),
    }

    def walk(value: Any) -> None:
        if isinstance(value, dict):
            for key, child in value.items():
                if key in ranges:
                    number = _finite_number(child)
                    minimum, maximum = ranges[key]
                    if number is None or not minimum <= number <= maximum:
                        raise ValidationError(f"telemetry {key} out of range: {child!r}")
                walk(child)
        elif isinstance(value, list):
            for child in value:
                walk(child)

    walk(data)


def decode_application(
    data: mesh_pb2.Data,
    *,
    sender: int,
    strict: bool,
    allow_outer_only: bool,
    max_future_seconds: int,
) -> tuple[dict[str, Any], bool, list[str]]:
    try:
        port_name = portnums_pb2.PortNum.Name(data.portnum)
    except ValueError:
        return ({"portnum": f"PORT_{data.portnum}"}, False, [])

    result: dict[str, Any] = {
        "portnum": port_name,
        "want_response": bool(data.want_response),
    }
    warnings: list[str] = []
    validated = False

    if port_name in ("TEXT_MESSAGE_APP", "ALERT_APP"):
        text = data.payload.decode("utf-8", errors="strict")
        if not is_sane_text(text):
            raise ValidationError("text payload is empty, too long, or contains controls")
        result["text"] = text
        validated = True

    elif port_name == "POSITION_APP":
        message = parse_message(mesh_pb2.Position(), data.payload)
        if not message.ListFields():
            raise ValidationError("empty Position message")
        warnings.extend(
            validate_position(
                message,
                strict=strict,
                max_future_seconds=max_future_seconds,
            )
        )
        position = message_to_dict(message)
        if getattr(message, "latitude_i", 0):
            position["latitude"] = message.latitude_i * 1e-7
        if getattr(message, "longitude_i", 0):
            position["longitude"] = message.longitude_i * 1e-7
        result["position"] = position
        validated = True

    elif port_name == "NODEINFO_APP":
        message = parse_message(mesh_pb2.User(), data.payload)
        warnings.extend(validate_nodeinfo(message, sender, strict=strict))
        result["nodeinfo"] = message_to_dict(message)
        validated = True

    elif port_name == "TELEMETRY_APP":
        message = parse_message(telemetry_pb2.Telemetry(), data.payload)
        if not message.ListFields():
            raise ValidationError("empty Telemetry message")
        telemetry = message_to_dict(message)
        validate_telemetry_dict(telemetry)
        result["telemetry"] = telemetry
        validated = True

    elif port_name == "ROUTING_APP":
        message = parse_message(mesh_pb2.Routing(), data.payload)
        if not message.ListFields():
            raise ValidationError("empty Routing message")
        result["routing"] = message_to_dict(message)
        validated = True

    elif port_name == "NEIGHBORINFO_APP" and hasattr(mesh_pb2, "NeighborInfo"):
        message = parse_message(mesh_pb2.NeighborInfo(), data.payload)
        if not message.ListFields():
            raise ValidationError("empty NeighborInfo message")
        reporter = int(getattr(message, "node_id", 0))
        if reporter and reporter != sender:
            text = f"NeighborInfo reporter !{reporter:08x} differs from sender !{sender:08x}"
            if strict:
                raise ValidationError(text)
            warnings.append(text)
        if len(getattr(message, "neighbors", [])) > 128:
            raise ValidationError("implausibly large NeighborInfo list")

        neighbor_info = message_to_dict(message)
        if reporter:
            neighbor_info["node"] = f"!{reporter:08x}"
        neighbors: list[dict[str, Any]] = []
        for neighbor in getattr(message, "neighbors", []):
            node_id = int(getattr(neighbor, "node_id", 0))
            snr = float(getattr(neighbor, "snr", 0.0))
            if node_id == 0:
                raise ValidationError("NeighborInfo contains zero node ID")
            if not -40.0 <= snr <= 30.0:
                raise ValidationError(f"NeighborInfo SNR out of range: {snr}")
            item = message_to_dict(neighbor)
            item["node"] = f"!{node_id:08x}"
            neighbors.append(item)
        neighbor_info["neighbors"] = neighbors
        result["neighbor_info"] = neighbor_info
        validated = True

    elif port_name == "TRACEROUTE_APP" and hasattr(mesh_pb2, "RouteDiscovery"):
        message = parse_message(mesh_pb2.RouteDiscovery(), data.payload)
        if not message.ListFields():
            raise ValidationError("empty RouteDiscovery message")
        if len(getattr(message, "route", [])) > 32 or len(getattr(message, "route_back", [])) > 32:
            raise ValidationError("implausibly long traceroute")
        route = message_to_dict(message)
        if getattr(message, "route", None):
            route["route_nodes"] = [f"!{node:08x}" for node in message.route]
        if getattr(message, "route_back", None):
            route["route_back_nodes"] = [f"!{node:08x}" for node in message.route_back]
        result["traceroute"] = route
        validated = True

    else:
        result["payload_hex"] = data.payload.hex()
        result["validation"] = "outer-data-only"
        validated = allow_outer_only

    if data.request_id:
        result["request_id"] = f"0x{data.request_id:08x}"
    if data.reply_id:
        result["reply_id"] = f"0x{data.reply_id:08x}"
    if data.emoji:
        result["emoji"] = data.emoji

    return result, validated, warnings


def split_env_values(name: str) -> list[str]:
    raw = os.environ.get(name, "").strip()
    return [item.strip() for item in raw.split(",") if item.strip()] if raw else []


def parse_hex_psk(value: str) -> bytes:
    cleaned = value.strip().removeprefix("0x").replace(":", "").replace(" ", "")
    try:
        return bytes.fromhex(cleaned)
    except ValueError as exc:
        raise argparse.ArgumentTypeError(f"invalid hexadecimal PSK: {value!r}") from exc


def parse_b64_psk(value: str) -> bytes:
    try:
        return base64.b64decode(value.strip(), validate=True)
    except (binascii.Error, ValueError) as exc:
        raise argparse.ArgumentTypeError(f"invalid Base64 PSK: {value!r}") from exc


def assessment_for_authorised_psk(raw: bytes) -> str:
    if len(raw) == 0:
        return "no-encryption"
    if len(raw) == 1:
        return "public-or-weak"
    if len(raw) in (16, 32):
        return "authorised-custom"
    return "authorised-padded-key"


def deduplicate_candidates(candidates: Iterable[CandidateKey]) -> tuple[CandidateKey, ...]:
    result: list[CandidateKey] = []
    seen: set[Optional[bytes]] = set()
    for candidate in candidates:
        if candidate.key in seen:
            continue
        seen.add(candidate.key)
        result.append(candidate)
    return tuple(result)


def build_candidates(args: argparse.Namespace) -> tuple[CandidateKey, ...]:
    candidates: list[CandidateKey] = [
        CandidateKey("plaintext", None, "no-encryption"),
        CandidateKey(
            "public-short-psk-1-default",
            DEFAULT_MESHTASTIC_KEY,
            "public-or-weak",
        ),
    ]

    hex_values = list(args.psk_hex or []) + split_env_values("MESHTASTIC_PSK_HEX")
    b64_values = list(args.psk_b64 or []) + split_env_values("MESHTASTIC_PSK_B64")

    for index, value in enumerate(hex_values, start=1):
        raw = parse_hex_psk(value)
        candidates.append(
            CandidateKey(
                f"authorised-hex-psk-{index}",
                normalize_psk(raw),
                assessment_for_authorised_psk(raw),
            )
        )

    for index, value in enumerate(b64_values, start=1):
        raw = parse_b64_psk(value)
        candidates.append(
            CandidateKey(
                f"authorised-b64-psk-{index}",
                normalize_psk(raw),
                assessment_for_authorised_psk(raw),
            )
        )

    audit_public = args.audit_public_keys or os.environ.get("MESHTASTIC_AUDIT_PUBLIC_KEYS") == "1"
    if audit_public:
        for index in range(2, 256):
            candidates.append(
                CandidateKey(
                    f"public-short-psk-{index}",
                    expand_short_psk(index),
                    "public-or-weak",
                )
            )

    return deduplicate_candidates(candidates)


def candidate_matches_channel_hash(
    candidate: CandidateKey,
    channel_names: Sequence[str],
    observed_hash: int,
) -> bool:
    if not channel_names:
        return True
    return any(
        meshtastic_channel_hash(name, candidate.key) == observed_hash
        for name in channel_names
    )


def decode_packet(raw: bytes, options: DecoderOptions) -> dict[str, Any]:
    if len(raw) < 17:
        return {
            "decode_mode": "malformed",
            "error": "packet too short",
            "length": len(raw),
            "raw_hex": raw.hex(),
        }
    if len(raw) > MAX_LORA_PAYLOAD:
        return {
            "decode_mode": "malformed",
            "error": f"packet exceeds {MAX_LORA_PAYLOAD} bytes",
            "length": len(raw),
            "raw_hex": raw.hex(),
        }

    destination, sender, packet_id = struct.unpack_from("<III", raw, 0)
    flags, channel_hash, next_hop, relay_node = struct.unpack_from("<BBBB", raw, 12)
    encrypted_or_plain = raw[16:]
    hop_limit = flags & 0x07
    hop_start = (flags >> 5) & 0x07

    packet: dict[str, Any] = {
        "length": len(raw),
        "to": f"!{destination:08x}",
        "from": f"!{sender:08x}",
        "packet_id": f"0x{packet_id:08x}",
        "hop_limit": hop_limit,
        "want_ack": bool(flags & 0x08),
        "via_mqtt": bool(flags & 0x10),
        "hop_start": hop_start,
        "channel_hash": f"0x{channel_hash:02x}",
        "next_hop": f"0x{next_hop:02x}",
        "relay_node": f"0x{relay_node:02x}",
    }

    if hop_limit > hop_start:
        packet.update(
            {
                "decode_mode": "malformed",
                "error": "hop_limit exceeds hop_start",
                "payload_hex": encrypted_or_plain.hex(),
            }
        )
        return packet
    packet["hops_used"] = hop_start - hop_limit

    if sender == 0:
        packet.update(
            {
                "decode_mode": "malformed",
                "error": "zero sender node ID",
                "payload_hex": encrypted_or_plain.hex(),
            }
        )
        return packet

    attempted = 0
    hash_filtered = 0
    candidate_rejections: list[str] = []

    for candidate in options.candidates:
        if not candidate_matches_channel_hash(candidate, options.channel_names, channel_hash):
            hash_filtered += 1
            continue

        attempted += 1
        candidate_payload = (
            encrypted_or_plain
            if candidate.key is None
            else aes_ctr_decrypt(encrypted_or_plain, sender, packet_id, candidate.key)
        )

        data = parse_outer_data(candidate_payload)
        if data is None:
            continue

        try:
            application, validated, warnings = decode_application(
                data,
                sender=sender,
                strict=options.strict,
                allow_outer_only=options.allow_outer_only,
                max_future_seconds=options.max_future_seconds,
            )
        except (DecodeError, UnicodeDecodeError, ValidationError, ValueError, TypeError) as exc:
            if len(candidate_rejections) < 3:
                candidate_rejections.append(f"{candidate.label}: {exc}")
            continue

        if not validated:
            continue

        packet["decode_mode"] = candidate.label
        packet["decoded"] = application
        packet["key_assessment"] = candidate.assessment
        packet["candidate_keys_attempted"] = attempted
        if warnings:
            packet["validation_warnings"] = warnings
        if options.channel_names:
            packet["channel_name_candidates"] = list(options.channel_names)
            packet["candidate_keys_hash_filtered"] = hash_filtered
        return packet

    packet["decode_mode"] = "unresolved"
    packet["payload_hex"] = encrypted_or_plain.hex()
    packet["candidate_keys_attempted"] = attempted
    if candidate_rejections:
        packet["candidate_rejections"] = candidate_rejections
    if options.channel_names:
        packet["channel_name_candidates"] = list(options.channel_names)
        packet["candidate_keys_hash_filtered"] = hash_filtered
    return packet


class JsonlWriter:
    def __init__(self, path: Optional[str]) -> None:
        self.path = Path(path).expanduser() if path else None
        self.handle = None
        if self.path:
            self.path.parent.mkdir(parents=True, exist_ok=True)
            self.handle = self.path.open("a", encoding="utf-8", buffering=1)

    def write(self, record: dict[str, Any]) -> None:
        if self.handle:
            self.handle.write(json.dumps(record, separators=(",", ":"), sort_keys=True) + "\n")

    def close(self) -> None:
        if self.handle:
            self.handle.close()


class SqliteStore:
    def __init__(self, path: Optional[str]) -> None:
        self.connection: Optional[sqlite3.Connection] = None
        if not path:
            return
        database = Path(path).expanduser()
        database.parent.mkdir(parents=True, exist_ok=True)
        self.connection = sqlite3.connect(database)
        self.connection.execute("PRAGMA journal_mode=WAL")
        self.connection.execute("PRAGMA synchronous=NORMAL")
        self._create_schema()

    def _create_schema(self) -> None:
        assert self.connection is not None
        self.connection.executescript(
            """
            CREATE TABLE IF NOT EXISTS packets (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                received_at TEXT NOT NULL,
                sender TEXT,
                destination TEXT,
                packet_id TEXT,
                channel_hash TEXT,
                app TEXT,
                decode_mode TEXT NOT NULL,
                key_assessment TEXT,
                via_mqtt INTEGER,
                hop_start INTEGER,
                hop_limit INTEGER,
                hops_used INTEGER,
                next_hop TEXT,
                relay_node TEXT,
                duplicate INTEGER NOT NULL DEFAULT 0,
                duplicate_kind TEXT,
                packet_json TEXT NOT NULL
            );
            CREATE INDEX IF NOT EXISTS idx_packets_sender_time
                ON packets(sender, received_at);
            CREATE INDEX IF NOT EXISTS idx_packets_app_time
                ON packets(app, received_at);
            CREATE INDEX IF NOT EXISTS idx_packets_packet_id
                ON packets(sender, packet_id);

            CREATE TABLE IF NOT EXISTS nodes (
                node_id TEXT PRIMARY KEY,
                long_name TEXT,
                short_name TEXT,
                hw_model TEXT,
                role TEXT,
                is_licensed INTEGER,
                is_unmessagable INTEGER,
                first_seen TEXT NOT NULL,
                last_seen TEXT NOT NULL,
                node_json TEXT NOT NULL
            );

            CREATE TABLE IF NOT EXISTS positions (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                packet_rowid INTEGER NOT NULL,
                received_at TEXT NOT NULL,
                node_id TEXT NOT NULL,
                latitude REAL,
                longitude REAL,
                altitude REAL,
                precision_bits INTEGER,
                location_source TEXT,
                position_time INTEGER,
                position_json TEXT NOT NULL,
                FOREIGN KEY(packet_rowid) REFERENCES packets(id)
            );
            CREATE INDEX IF NOT EXISTS idx_positions_node_time
                ON positions(node_id, received_at);

            CREATE TABLE IF NOT EXISTS telemetry (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                packet_rowid INTEGER NOT NULL,
                received_at TEXT NOT NULL,
                node_id TEXT NOT NULL,
                telemetry_json TEXT NOT NULL,
                FOREIGN KEY(packet_rowid) REFERENCES packets(id)
            );

            CREATE TABLE IF NOT EXISTS neighbor_links (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                packet_rowid INTEGER NOT NULL,
                received_at TEXT NOT NULL,
                reporter TEXT NOT NULL,
                neighbor TEXT NOT NULL,
                snr REAL,
                FOREIGN KEY(packet_rowid) REFERENCES packets(id)
            );
            CREATE INDEX IF NOT EXISTS idx_neighbor_links_reporter_time
                ON neighbor_links(reporter, received_at);
            """
        )
        self.connection.commit()

    @staticmethod
    def _as_int(value: Any) -> Optional[int]:
        if value in (None, ""):
            return None
        try:
            return int(value)
        except (TypeError, ValueError):
            return None

    @staticmethod
    def _as_float(value: Any) -> Optional[float]:
        if value in (None, ""):
            return None
        try:
            number = float(value)
            return number if math.isfinite(number) else None
        except (TypeError, ValueError):
            return None

    def write(self, packet: dict[str, Any]) -> None:
        if self.connection is None:
            return
        decoded = packet.get("decoded") or {}
        app = decoded.get("portnum")
        cursor = self.connection.execute(
            """
            INSERT INTO packets (
                received_at, sender, destination, packet_id, channel_hash, app,
                decode_mode, key_assessment, via_mqtt, hop_start, hop_limit,
                hops_used, next_hop, relay_node, duplicate, duplicate_kind,
                packet_json
            ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """,
            (
                packet.get("received_at"),
                packet.get("from"),
                packet.get("to"),
                packet.get("packet_id"),
                packet.get("channel_hash"),
                app,
                packet.get("decode_mode"),
                packet.get("key_assessment"),
                int(bool(packet.get("via_mqtt"))),
                packet.get("hop_start"),
                packet.get("hop_limit"),
                packet.get("hops_used"),
                packet.get("next_hop"),
                packet.get("relay_node"),
                int(bool(packet.get("duplicate"))),
                packet.get("duplicate_kind"),
                json.dumps(packet, separators=(",", ":"), sort_keys=True),
            ),
        )
        rowid = int(cursor.lastrowid)
        received_at = str(packet.get("received_at"))
        sender = str(packet.get("from", ""))

        node = decoded.get("nodeinfo")
        if isinstance(node, dict):
            node_id = str(node.get("id") or sender)
            self.connection.execute(
                """
                INSERT INTO nodes (
                    node_id, long_name, short_name, hw_model, role,
                    is_licensed, is_unmessagable, first_seen, last_seen, node_json
                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                ON CONFLICT(node_id) DO UPDATE SET
                    long_name=excluded.long_name,
                    short_name=excluded.short_name,
                    hw_model=excluded.hw_model,
                    role=excluded.role,
                    is_licensed=excluded.is_licensed,
                    is_unmessagable=excluded.is_unmessagable,
                    last_seen=excluded.last_seen,
                    node_json=excluded.node_json
                """,
                (
                    node_id,
                    node.get("long_name"),
                    node.get("short_name"),
                    node.get("hw_model"),
                    node.get("role"),
                    int(bool(node.get("is_licensed"))),
                    int(bool(node.get("is_unmessagable"))),
                    received_at,
                    received_at,
                    json.dumps(node, separators=(",", ":"), sort_keys=True),
                ),
            )

        position = decoded.get("position")
        if isinstance(position, dict):
            self.connection.execute(
                """
                INSERT INTO positions (
                    packet_rowid, received_at, node_id, latitude, longitude,
                    altitude, precision_bits, location_source, position_time,
                    position_json
                ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                """,
                (
                    rowid,
                    received_at,
                    sender,
                    self._as_float(position.get("latitude")),
                    self._as_float(position.get("longitude")),
                    self._as_float(position.get("altitude")),
                    self._as_int(position.get("precision_bits")),
                    position.get("location_source"),
                    self._as_int(position.get("time")),
                    json.dumps(position, separators=(",", ":"), sort_keys=True),
                ),
            )

        telemetry = decoded.get("telemetry")
        if isinstance(telemetry, dict):
            self.connection.execute(
                """
                INSERT INTO telemetry (
                    packet_rowid, received_at, node_id, telemetry_json
                ) VALUES (?, ?, ?, ?)
                """,
                (
                    rowid,
                    received_at,
                    sender,
                    json.dumps(telemetry, separators=(",", ":"), sort_keys=True),
                ),
            )

        neighbor_info = decoded.get("neighbor_info")
        if isinstance(neighbor_info, dict):
            reporter = str(neighbor_info.get("node") or sender)
            for neighbor in neighbor_info.get("neighbors", []):
                if not isinstance(neighbor, dict):
                    continue
                self.connection.execute(
                    """
                    INSERT INTO neighbor_links (
                        packet_rowid, received_at, reporter, neighbor, snr
                    ) VALUES (?, ?, ?, ?, ?)
                    """,
                    (
                        rowid,
                        received_at,
                        reporter,
                        neighbor.get("node"),
                        self._as_float(neighbor.get("snr")),
                    ),
                )

        self.connection.commit()

    def close(self) -> None:
        if self.connection is not None:
            self.connection.commit()
            self.connection.close()


def make_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description=(
            "Decode raw Meshtastic LoRa payloads received as UDP datagrams. "
            "Public-key auditing is limited to Meshtastic's documented "
            "one-byte/default PSK family."
        )
    )
    parser.add_argument("--bind", default="127.0.0.1")
    parser.add_argument("--port", type=int, default=7355)
    parser.add_argument("--psk-hex", action="append", default=[], metavar="HEX")
    parser.add_argument("--psk-b64", action="append", default=[], metavar="BASE64")
    parser.add_argument("--audit-public-keys", action="store_true")
    parser.add_argument("--channel-name", action="append", default=[], metavar="NAME")
    parser.add_argument("--allow-outer-only", action="store_true")
    parser.add_argument(
        "--strict",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="reject semantically implausible nested payloads (default: true)",
    )
    parser.add_argument(
        "--max-future-days",
        type=float,
        default=30.0,
        help="maximum accepted future timestamp skew in strict mode",
    )
    parser.add_argument("--jsonl", metavar="PATH", help="append compact packet JSON to PATH")
    parser.add_argument("--sqlite", metavar="PATH", help="persist packets and derived tables")
    parser.add_argument("--compact", action="store_true", help="one JSON object per stdout line")
    parser.add_argument("--store-raw", action="store_true", help="include raw packet hex in every record")
    parser.add_argument(
        "--dedupe-ttl",
        type=int,
        default=3600,
        help="seconds to remember sender/packet IDs for duplicate classification",
    )
    parser.add_argument(
        "--suppress-exact-duplicates",
        action="store_true",
        help="do not print/store byte-identical duplicates; alternate paths are retained",
    )
    parser.add_argument(
        "--stats-interval",
        type=float,
        default=60.0,
        help="emit aggregate statistics to stderr every N seconds; 0 disables",
    )
    parser.add_argument("--show-candidates", action="store_true")
    return parser


def emit_stdout(record: dict[str, Any], *, compact: bool) -> None:
    if compact:
        print(json.dumps(record, separators=(",", ":"), sort_keys=True), flush=True)
    else:
        print(json.dumps(record, indent=2, sort_keys=True), flush=True)


def main() -> None:
    parser = make_parser()
    args = parser.parse_args()

    try:
        candidates = build_candidates(args)
    except (argparse.ArgumentTypeError, ValueError) as exc:
        parser.error(str(exc))

    channel_names = tuple(
        dict.fromkeys([*args.channel_name, *split_env_values("MESHTASTIC_CHANNEL_NAME")])
    )
    options = DecoderOptions(
        candidates=candidates,
        channel_names=channel_names,
        allow_outer_only=args.allow_outer_only,
        strict=args.strict,
        max_future_seconds=max(0, int(args.max_future_days * 86400)),
    )

    if args.show_candidates:
        for candidate in candidates:
            emit_stdout(
                {
                    "label": candidate.label,
                    "key_length": 0 if candidate.key is None else len(candidate.key),
                    "assessment": candidate.assessment,
                },
                compact=True,
            )
        return

    writer = JsonlWriter(args.jsonl)
    database = SqliteStore(args.sqlite)
    duplicate_tracker = DuplicateTracker(args.dedupe_ttl)
    statistics = Statistics()

    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        sock.bind((args.bind, args.port))
    except OSError as exc:
        writer.close()
        database.close()
        raise SystemExit(f"Could not bind udp://{args.bind}:{args.port}: {exc}") from exc

    startup = {
        "event": "startup",
        "received_at": utc_now_iso(),
        "udp": f"{args.bind}:{args.port}",
        "candidate_key_count": len(candidates),
        "public_short_psk_audit": args.audit_public_keys
        or os.environ.get("MESHTASTIC_AUDIT_PUBLIC_KEYS") == "1",
        "channel_name_prefilter": list(channel_names),
        "strict": args.strict,
        "protobuf_package_version": protobuf_version(),
        "jsonl": str(Path(args.jsonl).expanduser()) if args.jsonl else None,
        "sqlite": str(Path(args.sqlite).expanduser()) if args.sqlite else None,
    }
    emit_stdout(startup, compact=args.compact)
    writer.write(startup)

    last_stats = time.monotonic()
    try:
        while True:
            raw, address = sock.recvfrom(4096)
            received_wall = time.time()
            received_at = datetime.fromtimestamp(received_wall, timezone.utc).isoformat(
                timespec="milliseconds"
            )
            try:
                packet = decode_packet(raw, options)
            except Exception as exc:  # Keep the live receiver alive.
                packet = {
                    "decode_mode": "decoder-error",
                    "error": f"{type(exc).__name__}: {exc}",
                    "length": len(raw),
                }

            packet["event"] = "packet"
            packet["received_at"] = received_at
            packet["udp_source"] = f"{address[0]}:{address[1]}"

            if len(raw) >= 16:
                destination, sender, packet_id = struct.unpack_from("<III", raw, 0)
                flags, _channel_hash, next_hop, relay_node = struct.unpack_from("<BBBB", raw, 12)
                route_signature = (
                    destination,
                    flags & 0x07,
                    (flags >> 5) & 0x07,
                    bool(flags & 0x10),
                    next_hop,
                    relay_node,
                )
                duplicate, duplicate_kind, duplicate_count = duplicate_tracker.classify(
                    sender=sender,
                    packet_id=packet_id,
                    raw=raw,
                    route_signature=route_signature,
                    now=received_wall,
                )
                packet["duplicate"] = duplicate
                if duplicate_kind:
                    packet["duplicate_kind"] = duplicate_kind
                    packet["duplicate_count"] = duplicate_count

            if args.store_raw:
                packet["raw_hex"] = raw.hex()

            statistics.observe(packet)

            if not (
                args.suppress_exact_duplicates
                and packet.get("duplicate_kind") == "exact"
            ):
                emit_stdout(packet, compact=args.compact)
                writer.write(packet)
                database.write(packet)

            if args.stats_interval > 0 and time.monotonic() - last_stats >= args.stats_interval:
                print(
                    json.dumps(
                        {
                            "event": "statistics",
                            "received_at": utc_now_iso(),
                            **statistics.snapshot(),
                        },
                        sort_keys=True,
                    ),
                    file=sys.stderr,
                    flush=True,
                )
                last_stats = time.monotonic()

    except KeyboardInterrupt:
        print("\nStopped.", file=sys.stderr)
        print(json.dumps({"event": "final-statistics", **statistics.snapshot()}, sort_keys=True), file=sys.stderr)
    finally:
        sock.close()
        writer.close()
        database.close()


if __name__ == "__main__":
    main()
