"""
Paste this entire file into a GNU Radio Companion Embedded Python Block.

Connection:
    LoRa Rx lower byte-stream output -> this block -> Socket PDU message input

The block forwards only frames whose gr-lora_sdr frame_info tag reports a valid
LoRa payload CRC. Each valid tagged byte frame becomes one standard GNU Radio
PDU. Invalid and incomplete frames are dropped before they reach UDP.
"""

import numpy as np
from gnuradio import gr
import pmt


class blk(gr.sync_block):
    """CRC-gate gr-lora_sdr tagged payload bytes and emit standard PDUs."""

    def __init__(self, print_drops=False):
        gr.sync_block.__init__(
            self,
            name="LoRa CRC-valid tagged bytes to PDU",
            in_sig=[np.uint8],
            out_sig=[],
        )
        self.message_port_register_out(pmt.intern("pdus"))
        self.set_tag_propagation_policy(gr.TPP_DONT)

        self.print_drops = bool(print_drops)
        self._buffer = bytearray()
        self._expected_length = 0
        self._crc_valid = False
        self._metadata = pmt.make_dict()
        self._valid_frames = 0
        self._invalid_frames = 0
        self._incomplete_frames = 0

    @staticmethod
    def _dict_get(info, key_name, default):
        key = pmt.intern(key_name)
        if not pmt.is_dict(info) or not pmt.dict_has_key(info, key):
            return default
        return pmt.dict_ref(info, key, default)

    def _start_frame(self, info):
        if self._expected_length and len(self._buffer) != self._expected_length:
            self._incomplete_frames += 1
            if self.print_drops:
                print(
                    f"[DROP] incomplete LoRa frame: "
                    f"received={len(self._buffer)} expected={self._expected_length}",
                    flush=True,
                )

        pay_len_value = self._dict_get(info, "pay_len", pmt.from_long(0))
        crc_value = self._dict_get(info, "crc_valid", pmt.PMT_F)

        try:
            expected_length = int(pmt.to_long(pay_len_value))
        except Exception:
            expected_length = 0

        try:
            crc_valid = bool(pmt.to_bool(crc_value))
        except Exception:
            crc_valid = False

        self._buffer = bytearray()
        self._expected_length = expected_length if 0 < expected_length <= 255 else 0
        self._crc_valid = crc_valid
        self._metadata = pmt.make_dict()
        self._metadata = pmt.dict_add(
            self._metadata,
            pmt.intern("crc_valid"),
            pmt.from_bool(self._crc_valid),
        )
        self._metadata = pmt.dict_add(
            self._metadata,
            pmt.intern("payload_length"),
            pmt.from_long(self._expected_length),
        )

    def _finish_frame(self):
        if self._crc_valid:
            payload = pmt.init_u8vector(len(self._buffer), list(self._buffer))
            self.message_port_pub(
                pmt.intern("pdus"),
                pmt.cons(self._metadata, payload),
            )
            self._valid_frames += 1
        else:
            self._invalid_frames += 1
            if self.print_drops:
                print(
                    f"[DROP] invalid LoRa payload CRC, length={self._expected_length}",
                    flush=True,
                )

        self._buffer = bytearray()
        self._expected_length = 0
        self._crc_valid = False
        self._metadata = pmt.make_dict()

    def work(self, input_items, output_items):
        samples = input_items[0]
        absolute_start = self.nitems_read(0)
        tags = self.get_tags_in_window(
            0,
            0,
            len(samples),
            pmt.intern("frame_info"),
        )

        tags_by_offset = {}
        for tag in tags:
            relative_offset = int(tag.offset - absolute_start)
            tags_by_offset.setdefault(relative_offset, []).append(tag)

        for index, value in enumerate(samples):
            for tag in tags_by_offset.get(index, []):
                self._start_frame(tag.value)

            if self._expected_length <= 0:
                continue

            self._buffer.append(int(value))
            if len(self._buffer) == self._expected_length:
                self._finish_frame()
            elif len(self._buffer) > self._expected_length:
                self._incomplete_frames += 1
                self._buffer = bytearray()
                self._expected_length = 0
                self._crc_valid = False

        return len(samples)
