import math

from typing_extensions import override

from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.hitpaw import (
    ImageEnhanceTaskCreateRequest,
    InputVideoModel,
    TaskCreateDataResponse,
    TaskCreateResponse,
    TaskStatusPollRequest,
    TaskStatusResponse,
    VideoEnhanceTaskCreateRequest,
)
from comfy_api_nodes.util import (
    ApiEndpoint,
    download_url_to_image_tensor,
    download_url_to_video_output,
    downscale_image_tensor,
    get_image_dimensions,
    poll_op,
    sync_op,
    upload_image_to_comfyapi,
    upload_video_to_comfyapi,
    validate_video_duration,
)

VIDEO_MODELS_MODELS_MAP = {
    "Portrait Restore Model (1x)": "portrait_restore_1x",
    "Portrait Restore Model (2x)": "portrait_restore_2x",
    "General Restore Model (1x)": "general_restore_1x",
    "General Restore Model (2x)": "general_restore_2x",
    "General Restore Model (4x)": "general_restore_4x",
    "Ultra HD Model (2x)": "ultrahd_restore_2x",
    "Generative Model (1x)": "generative_1x",
}

# Resolution name to target dimension (shorter side) in pixels
RESOLUTION_TARGET_MAP = {
    "720p": 720,
    "1080p": 1080,
    "2K/QHD": 1440,
    "4K/UHD": 2160,
    "8K": 4320,
}

# Square (1:1) resolutions use standard square dimensions
RESOLUTION_SQUARE_MAP = {
    "720p": 720,
    "1080p": 1080,
    "2K/QHD": 1440,
    "4K/UHD": 2048,  # DCI 4K square
    "8K": 4096,  # DCI 8K square
}

# Models with limited resolution support (no 8K)
LIMITED_RESOLUTION_MODELS = {"Generative Model (1x)"}

# Resolution options for different model types
RESOLUTIONS_LIMITED = ["original", "720p", "1080p", "2K/QHD", "4K/UHD"]
RESOLUTIONS_FULL = ["original", "720p", "1080p", "2K/QHD", "4K/UHD", "8K"]

# Maximum output resolution in pixels
MAX_PIXELS_GENERATIVE = 32_000_000
MAX_MP_GENERATIVE = MAX_PIXELS_GENERATIVE // 1_000_000


class HitPawGeneralImageEnhance(IO.ComfyNode):
    @classmethod
    def define_schema(cls):
        return IO.Schema(
            node_id="HitPawGeneralImageEnhance",
            display_name="HitPaw General Image Enhance",
            category="api node/image/HitPaw",
            description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. "
            f"Maximum output: {MAX_MP_GENERATIVE} megapixels.",
            inputs=[
                IO.Combo.Input("model", options=["generative_portrait", "generative"]),
                IO.Image.Input("image"),
                IO.Combo.Input("upscale_factor", options=[1, 2, 4]),
                IO.Boolean.Input(
                    "auto_downscale",
                    default=False,
                    tooltip="Automatically downscale input image if output would exceed the limit.",
                ),
            ],
            outputs=[
                IO.Image.Output(),
            ],
            hidden=[
                IO.Hidden.auth_token_comfy_org,
                IO.Hidden.api_key_comfy_org,
                IO.Hidden.unique_id,
            ],
            is_api_node=True,
            price_badge=IO.PriceBadge(
                depends_on=IO.PriceBadgeDepends(widgets=["model"]),
                expr="""
                (
                  $prices := {
                    "generative_portrait": {"min": 0.02, "max": 0.06},
                    "generative": {"min": 0.05, "max": 0.15}
                  };
                  $price := $lookup($prices, widgets.model);
                  {
                    "type": "range_usd",
                    "min_usd": $price.min,
                    "max_usd": $price.max
                  }
                )
                """,
            ),
        )

    @classmethod
    async def execute(
        cls,
        model: str,
        image: Input.Image,
        upscale_factor: int,
        auto_downscale: bool,
    ) -> IO.NodeOutput:
        height, width = get_image_dimensions(image)
        requested_scale = upscale_factor
        output_pixels = height * width * requested_scale * requested_scale
        if output_pixels > MAX_PIXELS_GENERATIVE:
            if auto_downscale:
                input_pixels = width * height
                scale = 1
                max_input_pixels = MAX_PIXELS_GENERATIVE

                for candidate in [4, 2, 1]:
                    if candidate > requested_scale:
                        continue
                    scale_output_pixels = input_pixels * candidate * candidate
                    if scale_output_pixels <= MAX_PIXELS_GENERATIVE:
                        scale = candidate
                        max_input_pixels = None
                        break
                    # Check if we can downscale input by at most 2x to fit
                    downscale_ratio = math.sqrt(scale_output_pixels / MAX_PIXELS_GENERATIVE)
                    if downscale_ratio <= 2.0:
                        scale = candidate
                        max_input_pixels = MAX_PIXELS_GENERATIVE // (candidate * candidate)
                        break

                if max_input_pixels is not None:
                    image = downscale_image_tensor(image, total_pixels=max_input_pixels)
                upscale_factor = scale
            else:
                output_width = width * requested_scale
                output_height = height * requested_scale
                raise ValueError(
                    f"Output size ({output_width}x{output_height} = {output_pixels:,} pixels) "
                    f"exceeds maximum allowed size of {MAX_PIXELS_GENERATIVE:,} pixels ({MAX_MP_GENERATIVE}MP). "
                    f"Enable auto_downscale or use a smaller input image or a lower upscale factor."
                )

        initial_res = await sync_op(
            cls,
            ApiEndpoint(path="/proxy/hitpaw/api/photo-enhancer", method="POST"),
            response_model=TaskCreateResponse,
            data=ImageEnhanceTaskCreateRequest(
                model_name=f"{model}_{upscale_factor}x",
                img_url=await upload_image_to_comfyapi(cls, image, total_pixels=None),
            ),
            wait_label="Creating task",
            final_label_on_success="Task created",
        )
        if initial_res.code != 200:
            raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}")
        request_price = initial_res.data.consume_coins / 1000
        final_response = await poll_op(
            cls,
            ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"),
            data=TaskCreateDataResponse(job_id=initial_res.data.job_id),
            response_model=TaskStatusResponse,
            status_extractor=lambda x: x.data.status,
            price_extractor=lambda x: request_price,
            poll_interval=10.0,
            max_poll_attempts=480,
        )
        return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url))


class HitPawVideoEnhance(IO.ComfyNode):
    @classmethod
    def define_schema(cls):
        model_options = []
        for model_name in VIDEO_MODELS_MODELS_MAP:
            if model_name in LIMITED_RESOLUTION_MODELS:
                resolutions = RESOLUTIONS_LIMITED
            else:
                resolutions = RESOLUTIONS_FULL
            model_options.append(
                IO.DynamicCombo.Option(
                    model_name,
                    [IO.Combo.Input("resolution", options=resolutions)],
                )
            )

        return IO.Schema(
            node_id="HitPawVideoEnhance",
            display_name="HitPaw Video Enhance",
            category="api node/video/HitPaw",
            description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. "
            "Prices shown are per second of video.",
            inputs=[
                IO.DynamicCombo.Input("model", options=model_options),
                IO.Video.Input("video"),
            ],
            outputs=[
                IO.Video.Output(),
            ],
            hidden=[
                IO.Hidden.auth_token_comfy_org,
                IO.Hidden.api_key_comfy_org,
                IO.Hidden.unique_id,
            ],
            is_api_node=True,
            price_badge=IO.PriceBadge(
                depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
                expr="""
                (
                  $m := $lookup(widgets, "model");
                  $res := $lookup(widgets, "model.resolution");
                  $standard_model_prices := {
                    "original": {"min": 0.01, "max": 0.198},
                    "720p": {"min": 0.01, "max": 0.06},
                    "1080p": {"min": 0.015, "max": 0.09},
                    "2k/qhd": {"min": 0.02, "max": 0.117},
                    "4k/uhd": {"min": 0.025, "max": 0.152},
                    "8k": {"min": 0.033, "max": 0.198}
                  };
                  $ultra_hd_model_prices := {
                    "original": {"min": 0.015, "max": 0.264},
                    "720p": {"min": 0.015, "max": 0.092},
                    "1080p": {"min": 0.02, "max": 0.12},
                    "2k/qhd": {"min": 0.026, "max": 0.156},
                    "4k/uhd": {"min": 0.034, "max": 0.203},
                    "8k": {"min": 0.044, "max": 0.264}
                  };
                  $generative_model_prices := {
                    "original": {"min": 0.015, "max": 0.338},
                    "720p": {"min": 0.008, "max": 0.090},
                    "1080p": {"min": 0.05, "max": 0.15},
                    "2k/qhd": {"min": 0.038, "max": 0.225},
                    "4k/uhd": {"min": 0.056, "max": 0.338}
                  };
                  $prices := $contains($m, "ultra hd") ? $ultra_hd_model_prices :
                             $contains($m, "generative") ? $generative_model_prices :
                             $standard_model_prices;
                  $price := $lookup($prices, $res);
                  {
                    "type": "range_usd",
                    "min_usd": $price.min,
                    "max_usd": $price.max,
                    "format": {"approximate": true, "suffix": "/second"}
                  }
                )
                """,
            ),
        )

    @classmethod
    async def execute(
        cls,
        model: InputVideoModel,
        video: Input.Video,
    ) -> IO.NodeOutput:
        validate_video_duration(video, min_duration=0.5, max_duration=60 * 60)
        resolution = model["resolution"]
        src_width, src_height = video.get_dimensions()

        if resolution == "original":
            output_width = src_width
            output_height = src_height
        else:
            if src_width == src_height:
                target_size = RESOLUTION_SQUARE_MAP[resolution]
                if target_size < src_width:
                    raise ValueError(
                        f"Selected resolution {resolution} ({target_size}x{target_size}) is smaller than "
                        f"the input video ({src_width}x{src_height}). Please select a higher resolution or 'original'."
                    )
                output_width = target_size
                output_height = target_size
            else:
                min_dimension = min(src_width, src_height)
                target_size = RESOLUTION_TARGET_MAP[resolution]
                if target_size < min_dimension:
                    raise ValueError(
                        f"Selected resolution {resolution} ({target_size}p) is smaller than "
                        f"the input video's shorter dimension ({min_dimension}p). "
                        f"Please select a higher resolution or 'original'."
                    )
                if src_width > src_height:
                    output_height = target_size
                    output_width = int(target_size * (src_width / src_height))
                else:
                    output_width = target_size
                    output_height = int(target_size * (src_height / src_width))
        initial_res = await sync_op(
            cls,
            ApiEndpoint(path="/proxy/hitpaw/api/video-enhancer", method="POST"),
            response_model=TaskCreateResponse,
            data=VideoEnhanceTaskCreateRequest(
                video_url=await upload_video_to_comfyapi(cls, video),
                resolution=[output_width, output_height],
                original_resolution=[src_width, src_height],
                model_name=VIDEO_MODELS_MODELS_MAP[model["model"]],
            ),
            wait_label="Creating task",
            final_label_on_success="Task created",
        )
        request_price = initial_res.data.consume_coins / 1000
        if initial_res.code != 200:
            raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}")
        final_response = await poll_op(
            cls,
            ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"),
            data=TaskStatusPollRequest(job_id=initial_res.data.job_id),
            response_model=TaskStatusResponse,
            status_extractor=lambda x: x.data.status,
            price_extractor=lambda x: request_price,
            poll_interval=10.0,
            max_poll_attempts=320,
        )
        return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url))


class HitPawExtension(ComfyExtension):
    @override
    async def get_node_list(self) -> list[type[IO.ComfyNode]]:
        return [
            HitPawGeneralImageEnhance,
            HitPawVideoEnhance,
        ]


async def comfy_entrypoint() -> HitPawExtension:
    return HitPawExtension()
