import math
import ctypes
import threading
import dataclasses
import torch
from typing import NamedTuple

from comfy.quant_ops import QuantizedTensor


class TensorFileSlice(NamedTuple):
    file_ref: object
    thread_id: int
    offset: int
    size: int


def read_tensor_file_slice_into(tensor, destination):

    if isinstance(tensor, QuantizedTensor):
        if not isinstance(destination, QuantizedTensor):
            return False
        if tensor._layout_cls != destination._layout_cls:
            return False

        if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
            return False

        dst_orig_dtype = destination._params.orig_dtype
        destination._params.copy_from(tensor._params, non_blocking=False)
        destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
        return True

    info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
    if info is None:
        return False

    file_obj = info.file_ref
    if (destination.device.type != "cpu"
            or file_obj is None
            or threading.get_ident() != info.thread_id
            or destination.numel() * destination.element_size() < info.size
            or tensor.numel() * tensor.element_size() != info.size
            or tensor.storage_offset() != 0
            or not tensor.is_contiguous()):
        return False

    if info.size == 0:
        return True

    buf_type = ctypes.c_ubyte * info.size
    view = memoryview(buf_type.from_address(destination.data_ptr()))

    try:
        file_obj.seek(info.offset)
        done = 0
        while done < info.size:
            try:
                n = file_obj.readinto(view[done:])
            except OSError:
                return False
            if n <= 0:
                return False
            done += n
        return True
    finally:
        view.release()

class TensorGeometry(NamedTuple):
    shape: any
    dtype: torch.dtype

    def element_size(self):
        info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype)
        return info.bits // 8

    def numel(self):
        return math.prod(self.shape)

def tensors_to_geometries(tensors, dtype=None):
    geometries = []
    for t in tensors:
        if t is None or isinstance(t, QuantizedTensor):
            geometries.append(t)
            continue
        tdtype = t.dtype
        if hasattr(t, "_model_dtype"):
            tdtype = t._model_dtype
        if dtype is not None:
            tdtype = dtype
        geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype))
    return geometries

def vram_aligned_size(tensor):
    if isinstance(tensor, list):
        return sum([vram_aligned_size(t) for t in tensor])

    if isinstance(tensor, QuantizedTensor):
        inner_tensors, _ = tensor.__tensor_flatten__()
        return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])

    if tensor is None:
        return 0

    size = tensor.numel() * tensor.element_size()
    aligment_req = 1024
    return (size + aligment_req - 1) // aligment_req * aligment_req

def interpret_gathered_like(tensors, gathered):
    offset = 0
    dest_views = []

    if gathered.dim() != 1 or gathered.element_size() != 1:
        raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")

    for tensor in tensors:

        if tensor is None:
            dest_views.append(None)
            continue

        if isinstance(tensor, QuantizedTensor):
            inner_tensors, qt_ctx = tensor.__tensor_flatten__()
            templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
        else:
            templates = { "data": tensor }

        actuals = {}
        for attr, template in templates.items():
            size = template.numel() * template.element_size()
            if offset + size > gathered.numel():
                raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
            actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
            offset += vram_aligned_size(template)

        if isinstance(tensor, QuantizedTensor):
            dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
        else:
            dest_views.append(actuals["data"])

    return dest_views

aimdo_enabled = False

extra_ram_release_callback = None
RAM_CACHE_HEADROOM = 0

def set_ram_cache_release_state(callback, headroom):
    global extra_ram_release_callback
    global RAM_CACHE_HEADROOM
    extra_ram_release_callback = callback
    RAM_CACHE_HEADROOM = max(0, int(headroom))

def extra_ram_release(target):
    if extra_ram_release_callback is None:
        return 0
    return extra_ram_release_callback(target)
