#!/usr/bin/env python3
"""
Receive raw Meshtastic LoRa payloads from GNU Radio over UDP, decode plaintext
traffic, decrypt traffic with authorised PSKs, and optionally audit Meshtastic's
documented one-byte/public PSK family.

Expected GNU Radio path:
    LoRa Rx tagged-byte output
      -> tagged-stream-to-PDU converter
      -> Socket PDU (UDP Client, 127.0.0.1:7355)

Examples:
    # Plaintext + public default key only
    python meshtastic_udp_decoder_audit.py

    # Also test all documented one-byte/public PSK variants (1..255)
    python meshtastic_udp_decoder_audit.py --audit-public-keys

    # Test an authorised AES key in hex
    python meshtastic_udp_decoder_audit.py \
        --psk-hex 00112233445566778899aabbccddeeff

    # Test an authorised PSK copied from a Meshtastic channel export
    python meshtastic_udp_decoder_audit.py --psk-b64 'BASE64_PSK'

    # Optional channel-name prefilter. A candidate is tried only when its
    # Meshtastic one-byte channel hash matches the packet.
    python meshtastic_udp_decoder_audit.py \
        --audit-public-keys --channel-name LongFast

Environment variables retained for compatibility:
    MESHTASTIC_PSK_HEX
    MESHTASTIC_PSK_B64
    MESHTASTIC_AUDIT_PUBLIC_KEYS=1
    MESHTASTIC_CHANNEL_NAME

Use only on channels and keys covered by your authorisation.
"""

from __future__ import annotations

import argparse
import base64
import binascii
import json
import os
import socket
import struct
import sys
from dataclasses import dataclass
from typing import Iterable, Optional, Sequence

from Crypto.Cipher import AES
from google.protobuf.json_format import MessageToDict
from google.protobuf.message import DecodeError, Message
from meshtastic.protobuf import mesh_pb2, portnums_pb2, telemetry_pb2


# Firmware expands one-byte PSK index 1 to this public AES-128 key.
DEFAULT_MESHTASTIC_KEY = bytes.fromhex(
    "d4f1bb3a20290759f0bcffabcf4e6901"
)

BROADCAST_NODE = 0xFFFFFFFF


@dataclass(frozen=True)
class CandidateKey:
    label: str
    key: Optional[bytes]  # None means plaintext/no encryption.
    public_or_weak: bool = False


@dataclass(frozen=True)
class DecoderOptions:
    candidates: tuple[CandidateKey, ...]
    channel_names: tuple[str, ...]
    allow_outer_only: bool


def xor_hash(value: bytes) -> int:
    result = 0
    for byte in value:
        result ^= byte
    return result


def meshtastic_channel_hash(channel_name: str, key: Optional[bytes]) -> int:
    """Reproduce the firmware's one-byte channel hash."""
    return xor_hash(channel_name.encode("utf-8")) ^ xor_hash(key or b"")


def expand_short_psk(index: int) -> Optional[bytes]:
    """
    Reproduce Meshtastic firmware expansion for a one-byte PSK.

    Index 0 disables encryption. Index 1 is the public default key. Higher
    indices add index-1 to the final byte of the public default key modulo 256.
    """
    if not 0 <= index <= 255:
        raise ValueError("short PSK index must be in the range 0..255")
    if index == 0:
        return None

    key = bytearray(DEFAULT_MESHTASTIC_KEY)
    key[-1] = (key[-1] + index - 1) & 0xFF
    return bytes(key)


def normalize_psk(raw: bytes) -> Optional[bytes]:
    """
    Normalize a Meshtastic channel PSK like the firmware does.

    * 0 bytes: plaintext
    * 1 byte: documented short-PSK expansion
    * 2..15 bytes: zero-pad to AES-128
    * 16 bytes: AES-128
    * 17..31 bytes: zero-pad to AES-256
    * 32 bytes: AES-256
    """
    length = len(raw)
    if length == 0:
        return None
    if length == 1:
        return expand_short_psk(raw[0])
    if length < 16:
        return raw.ljust(16, b"\x00")
    if length == 16:
        return raw
    if length < 32:
        return raw.ljust(32, b"\x00")
    if length == 32:
        return raw
    raise ValueError("Meshtastic PSK must be at most 32 bytes")


def aes_ctr_decrypt(
    ciphertext: bytes,
    sender: int,
    packet_id: int,
    key: bytes,
) -> bytes:
    """Decrypt a Meshtastic channel payload using its sender/packet nonce."""
    iv = struct.pack("<Q", packet_id) + struct.pack("<I", sender) + b"\x00" * 4
    cipher = AES.new(
        key,
        AES.MODE_CTR,
        nonce=iv[:12],
        initial_value=int.from_bytes(iv[12:], "big"),
    )
    return cipher.decrypt(ciphertext)


def parse_outer_data(payload: bytes) -> Optional[mesh_pb2.Data]:
    """Parse and minimally validate the outer Meshtastic Data protobuf."""
    data = mesh_pb2.Data()
    try:
        data.ParseFromString(payload)
        portnums_pb2.PortNum.Name(data.portnum)
    except (DecodeError, ValueError):
        return None

    if data.portnum == portnums_pb2.PortNum.UNKNOWN_APP:
        return None
    if not data.payload and not data.want_response:
        return None
    return data


def message_to_dict(message: Message) -> dict:
    return MessageToDict(message, preserving_proto_field_name=True)


def parse_message(message: Message, payload: bytes) -> Message:
    message.ParseFromString(payload)
    return message


def decode_application(
    data: mesh_pb2.Data,
    *,
    allow_outer_only: bool,
) -> tuple[dict, bool]:
    """
    Decode a nested application payload.

    Returns (decoded_object, strongly_validated). For key auditing, a candidate
    is accepted only when the nested message can be validated, unless
    --allow-outer-only is explicitly enabled.
    """
    try:
        port_name = portnums_pb2.PortNum.Name(data.portnum)
    except ValueError:
        return ({"portnum": f"PORT_{data.portnum}"}, False)

    result: dict = {
        "portnum": port_name,
        "want_response": bool(data.want_response),
    }
    validated = False

    if port_name in ("TEXT_MESSAGE_APP", "ALERT_APP"):
        # Strict UTF-8 reduces false positives while auditing candidate keys.
        text = data.payload.decode("utf-8", errors="strict")
        if not text:
            raise DecodeError("empty text payload")
        result["text"] = text
        validated = True

    elif port_name == "POSITION_APP":
        message = parse_message(mesh_pb2.Position(), data.payload)
        if not message.ListFields():
            raise DecodeError("empty Position message")
        position = message_to_dict(message)
        if message.latitude_i:
            position["latitude"] = message.latitude_i * 1e-7
        if message.longitude_i:
            position["longitude"] = message.longitude_i * 1e-7
        result["position"] = position
        validated = True

    elif port_name == "NODEINFO_APP":
        message = parse_message(mesh_pb2.User(), data.payload)
        # A real User announcement normally has at least an ID/name/model.
        if not (
            message.id
            or message.long_name
            or message.short_name
            or message.hw_model
            or message.public_key
        ):
            raise DecodeError("User message lacks identity fields")
        result["nodeinfo"] = message_to_dict(message)
        validated = True

    elif port_name == "TELEMETRY_APP":
        message = parse_message(telemetry_pb2.Telemetry(), data.payload)
        if not message.ListFields():
            raise DecodeError("empty Telemetry message")
        result["telemetry"] = message_to_dict(message)
        validated = True

    elif port_name == "ROUTING_APP":
        message = parse_message(mesh_pb2.Routing(), data.payload)
        if not message.ListFields():
            raise DecodeError("empty Routing message")
        result["routing"] = message_to_dict(message)
        validated = True

    elif port_name == "NEIGHBORINFO_APP" and hasattr(mesh_pb2, "NeighborInfo"):
        message = parse_message(mesh_pb2.NeighborInfo(), data.payload)
        if not message.ListFields():
            raise DecodeError("empty NeighborInfo message")
        neighbor_info = message_to_dict(message)
        if getattr(message, "node_id", 0):
            neighbor_info["node"] = f"!{message.node_id:08x}"
        neighbors = []
        for neighbor in getattr(message, "neighbors", []):
            item = message_to_dict(neighbor)
            if getattr(neighbor, "node_id", 0):
                item["node"] = f"!{neighbor.node_id:08x}"
            neighbors.append(item)
        if neighbors:
            neighbor_info["neighbors"] = neighbors
        result["neighbor_info"] = neighbor_info
        validated = True

    elif port_name == "TRACEROUTE_APP" and hasattr(mesh_pb2, "RouteDiscovery"):
        message = parse_message(mesh_pb2.RouteDiscovery(), data.payload)
        if not message.ListFields():
            raise DecodeError("empty RouteDiscovery message")
        route = message_to_dict(message)
        if getattr(message, "route", None):
            route["route_nodes"] = [f"!{node:08x}" for node in message.route]
        if getattr(message, "route_back", None):
            route["route_back_nodes"] = [
                f"!{node:08x}" for node in message.route_back
            ]
        result["traceroute"] = route
        validated = True

    else:
        result["payload_hex"] = data.payload.hex()
        result["validation"] = "outer-data-only"
        validated = allow_outer_only

    if data.request_id:
        result["request_id"] = f"0x{data.request_id:08x}"
    if data.reply_id:
        result["reply_id"] = f"0x{data.reply_id:08x}"
    if data.emoji:
        result["emoji"] = data.emoji

    return result, validated


def split_env_values(name: str) -> list[str]:
    raw = os.environ.get(name, "").strip()
    if not raw:
        return []
    return [item.strip() for item in raw.split(",") if item.strip()]


def parse_hex_psk(value: str) -> bytes:
    cleaned = value.strip().removeprefix("0x").replace(":", "").replace(" ", "")
    try:
        return bytes.fromhex(cleaned)
    except ValueError as exc:
        raise argparse.ArgumentTypeError(f"invalid hexadecimal PSK: {value!r}") from exc


def parse_b64_psk(value: str) -> bytes:
    try:
        return base64.b64decode(value.strip(), validate=True)
    except (binascii.Error, ValueError) as exc:
        raise argparse.ArgumentTypeError(f"invalid Base64 PSK: {value!r}") from exc


def deduplicate_candidates(candidates: Iterable[CandidateKey]) -> tuple[CandidateKey, ...]:
    result: list[CandidateKey] = []
    seen: set[Optional[bytes]] = set()
    for candidate in candidates:
        if candidate.key in seen:
            continue
        seen.add(candidate.key)
        result.append(candidate)
    return tuple(result)


def build_candidates(args: argparse.Namespace) -> tuple[CandidateKey, ...]:
    candidates: list[CandidateKey] = [
        CandidateKey("plaintext", None, public_or_weak=True),
        CandidateKey(
            "public-short-psk-1-default",
            DEFAULT_MESHTASTIC_KEY,
            public_or_weak=True,
        ),
    ]

    hex_values: list[str] = list(args.psk_hex or []) + split_env_values(
        "MESHTASTIC_PSK_HEX"
    )
    b64_values: list[str] = list(args.psk_b64 or []) + split_env_values(
        "MESHTASTIC_PSK_B64"
    )

    for index, value in enumerate(hex_values, start=1):
        raw = parse_hex_psk(value)
        key = normalize_psk(raw)
        label = f"authorised-hex-psk-{index}"
        candidates.append(CandidateKey(label, key, public_or_weak=len(raw) <= 1))

    for index, value in enumerate(b64_values, start=1):
        raw = parse_b64_psk(value)
        key = normalize_psk(raw)
        label = f"authorised-b64-psk-{index}"
        candidates.append(CandidateKey(label, key, public_or_weak=len(raw) <= 1))

    audit_public = args.audit_public_keys or os.environ.get(
        "MESHTASTIC_AUDIT_PUBLIC_KEYS"
    ) == "1"
    if audit_public:
        for index in range(2, 256):
            candidates.append(
                CandidateKey(
                    f"public-short-psk-{index}",
                    expand_short_psk(index),
                    public_or_weak=True,
                )
            )

    return deduplicate_candidates(candidates)


def candidate_matches_channel_hash(
    candidate: CandidateKey,
    channel_names: Sequence[str],
    observed_hash: int,
) -> bool:
    if not channel_names:
        return True
    return any(
        meshtastic_channel_hash(name, candidate.key) == observed_hash
        for name in channel_names
    )


def decode_packet(raw: bytes, options: DecoderOptions) -> dict:
    if len(raw) < 17:
        return {
            "error": "packet too short",
            "length": len(raw),
            "raw_hex": raw.hex(),
        }

    destination, sender, packet_id = struct.unpack_from("<III", raw, 0)
    flags, channel_hash, next_hop, relay_node = struct.unpack_from("<BBBB", raw, 12)
    encrypted_or_plain = raw[16:]

    packet: dict = {
        "length": len(raw),
        "to": f"!{destination:08x}",
        "from": f"!{sender:08x}",
        "packet_id": f"0x{packet_id:08x}",
        "hop_limit": flags & 0x07,
        "want_ack": bool(flags & 0x08),
        "via_mqtt": bool(flags & 0x10),
        "hop_start": (flags >> 5) & 0x07,
        "channel_hash": f"0x{channel_hash:02x}",
        "next_hop": f"0x{next_hop:02x}",
        "relay_node": f"0x{relay_node:02x}",
    }

    attempted = 0
    hash_filtered = 0

    for candidate_key in options.candidates:
        if not candidate_matches_channel_hash(
            candidate_key,
            options.channel_names,
            channel_hash,
        ):
            hash_filtered += 1
            continue

        attempted += 1
        candidate_payload = (
            encrypted_or_plain
            if candidate_key.key is None
            else aes_ctr_decrypt(
                encrypted_or_plain,
                sender,
                packet_id,
                candidate_key.key,
            )
        )

        data = parse_outer_data(candidate_payload)
        if data is None:
            continue

        try:
            application, validated = decode_application(
                data,
                allow_outer_only=options.allow_outer_only,
            )
        except (DecodeError, UnicodeDecodeError, ValueError, TypeError):
            continue

        if not validated:
            continue

        packet["decode_mode"] = candidate_key.label
        packet["decoded"] = application
        packet["key_assessment"] = (
            "public-or-weak" if candidate_key.public_or_weak else "authorised-custom"
        )
        packet["candidate_keys_attempted"] = attempted
        if options.channel_names:
            packet["channel_name_candidates"] = list(options.channel_names)
            packet["candidate_keys_hash_filtered"] = hash_filtered
        return packet

    packet["decode_mode"] = "unresolved"
    packet["payload_hex"] = encrypted_or_plain.hex()
    packet["candidate_keys_attempted"] = attempted
    if options.channel_names:
        packet["channel_name_candidates"] = list(options.channel_names)
        packet["candidate_keys_hash_filtered"] = hash_filtered
    return packet


def make_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description=(
            "Decode Meshtastic LoRa packets received as raw UDP datagrams from "
            "GNU Radio. Public-key auditing is limited to Meshtastic's "
            "documented one-byte PSK family."
        )
    )
    parser.add_argument("--bind", default="127.0.0.1")
    parser.add_argument("--port", type=int, default=7355)
    parser.add_argument(
        "--psk-hex",
        action="append",
        default=[],
        metavar="HEX",
        help="authorised channel PSK in hexadecimal; repeatable",
    )
    parser.add_argument(
        "--psk-b64",
        action="append",
        default=[],
        metavar="BASE64",
        help="authorised Meshtastic channel PSK in Base64; repeatable",
    )
    parser.add_argument(
        "--audit-public-keys",
        action="store_true",
        help="try all documented one-byte/public PSK variants 1..255",
    )
    parser.add_argument(
        "--channel-name",
        action="append",
        default=[],
        metavar="NAME",
        help=(
            "optional known/authorised channel name; repeatable. Candidate keys "
            "whose one-byte channel hash does not match are skipped"
        ),
    )
    parser.add_argument(
        "--allow-outer-only",
        action="store_true",
        help=(
            "accept a valid outer Data protobuf even for application types this "
            "script cannot strongly validate; increases false-positive risk"
        ),
    )
    parser.add_argument(
        "--show-candidates",
        action="store_true",
        help="print candidate labels and exit without opening the UDP socket",
    )
    return parser


def main() -> None:
    parser = make_parser()
    args = parser.parse_args()

    try:
        candidates = build_candidates(args)
    except argparse.ArgumentTypeError as exc:
        parser.error(str(exc))
    except ValueError as exc:
        parser.error(str(exc))

    env_channel_names = split_env_values("MESHTASTIC_CHANNEL_NAME")
    channel_names = tuple(dict.fromkeys([*args.channel_name, *env_channel_names]))

    options = DecoderOptions(
        candidates=candidates,
        channel_names=channel_names,
        allow_outer_only=args.allow_outer_only,
    )

    if args.show_candidates:
        for candidate in candidates:
            print(
                json.dumps(
                    {
                        "label": candidate.label,
                        "key_length": 0 if candidate.key is None else len(candidate.key),
                        "assessment": (
                            "public-or-weak"
                            if candidate.public_or_weak
                            else "authorised-custom"
                        ),
                    },
                    sort_keys=True,
                )
            )
        return

    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        sock.bind((args.bind, args.port))
    except OSError as exc:
        raise SystemExit(f"Could not bind udp://{args.bind}:{args.port}: {exc}") from exc

    audit_enabled = any(
        candidate.label.startswith("public-short-psk-")
        and candidate.label != "public-short-psk-1-default"
        for candidate in candidates
    )
    print(
        json.dumps(
            {
                "status": "listening",
                "udp": f"{args.bind}:{args.port}",
                "candidate_key_count": len(candidates),
                "public_short_psk_audit": audit_enabled,
                "channel_name_prefilter": list(channel_names),
                "allow_outer_only": args.allow_outer_only,
            },
            sort_keys=True,
        ),
        flush=True,
    )

    try:
        while True:
            raw, address = sock.recvfrom(4096)
            try:
                decoded = decode_packet(raw, options)
            except Exception as exc:  # Keep the live receiver running.
                decoded = {
                    "error": f"{type(exc).__name__}: {exc}",
                    "length": len(raw),
                    "raw_hex": raw.hex(),
                }
            decoded["udp_source"] = f"{address[0]}:{address[1]}"
            print(json.dumps(decoded, indent=2, sort_keys=True), flush=True)
    except KeyboardInterrupt:
        print("\nStopped.", file=sys.stderr)
    finally:
        sock.close()


if __name__ == "__main__":
    main()
