from __future__ import annotations

from typing import TypedDict
import os
import folder_paths
import glob
from aiohttp import web
import hashlib


class Source:
    custom_node = "custom_node"
    templates = "templates"

class SubgraphEntry(TypedDict):
    source: str
    """
    Source of subgraph - custom_nodes vs templates.
    """
    path: str
    """
    Relative path of the subgraph file.
    For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
    """
    name: str
    """
    Name of subgraph file.
    """
    info: CustomNodeSubgraphEntryInfo
    """
    Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
    """
    data: str

class CustomNodeSubgraphEntryInfo(TypedDict):
    node_pack: str
    """Node pack name."""

class SubgraphManager:
    def __init__(self):
        self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
        self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None

    def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]:
        """Create a subgraph entry from a file path. Expects normalized path (forward slashes)."""
        entry_id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
        entry: SubgraphEntry = {
            "source": source,
            "name": os.path.splitext(os.path.basename(file))[0],
            "path": file,
            "info": {"node_pack": node_pack},
        }
        return entry_id, entry

    async def load_entry_data(self, entry: SubgraphEntry):
        with open(entry['path'], 'r', encoding='utf-8') as f:
            entry['data'] = f.read()
        return entry

    async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
        if entry is None:
            return None
        entry = entry.copy()
        entry.pop('path', None)
        if remove_data:
            entry.pop('data', None)
        return entry

    async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
        entries = entries.copy()
        for key in list(entries.keys()):
            entries[key] = await self.sanitize_entry(entries[key], remove_data)
        return entries

    async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
        """Load subgraphs from custom nodes."""
        if not force_reload and self.cached_custom_node_subgraphs is not None:
            return self.cached_custom_node_subgraphs

        subgraphs_dict: dict[SubgraphEntry] = {}
        for folder in folder_paths.get_folder_paths("custom_nodes"):
            pattern = os.path.join(folder, "*/subgraphs/*.json")
            for file in glob.glob(pattern):
                file = file.replace('\\', '/')
                node_pack = "custom_nodes." + file.split('/')[-3]
                entry_id, entry = self._create_entry(file, Source.custom_node, node_pack)
                subgraphs_dict[entry_id] = entry

        self.cached_custom_node_subgraphs = subgraphs_dict
        return subgraphs_dict

    async def get_blueprint_subgraphs(self, force_reload=False):
        """Load subgraphs from the blueprints directory."""
        if not force_reload and self.cached_blueprint_subgraphs is not None:
            return self.cached_blueprint_subgraphs

        subgraphs_dict: dict[SubgraphEntry] = {}
        blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints')

        if os.path.exists(blueprints_dir):
            for file in glob.glob(os.path.join(blueprints_dir, "*.json")):
                file = file.replace('\\', '/')
                entry_id, entry = self._create_entry(file, Source.templates, "comfyui")
                subgraphs_dict[entry_id] = entry

        self.cached_blueprint_subgraphs = subgraphs_dict
        return subgraphs_dict

    async def get_all_subgraphs(self, loadedModules, force_reload=False):
        """Get all subgraphs from all sources (custom nodes and blueprints)."""
        custom_node_subgraphs = await self.get_custom_node_subgraphs(loadedModules, force_reload)
        blueprint_subgraphs = await self.get_blueprint_subgraphs(force_reload)
        return {**custom_node_subgraphs, **blueprint_subgraphs}

    async def get_subgraph(self, id: str, loadedModules):
        """Get a specific subgraph by ID from any source."""
        entry = (await self.get_all_subgraphs(loadedModules)).get(id)
        if entry is not None and entry.get('data') is None:
            await self.load_entry_data(entry)
        return entry

    def add_routes(self, routes, loadedModules):
        @routes.get("/global_subgraphs")
        async def get_global_subgraphs(request):
            subgraphs_dict = await self.get_all_subgraphs(loadedModules)
            return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))

        @routes.get("/global_subgraphs/{id}")
        async def get_global_subgraph(request):
            id = request.match_info.get("id", None)
            subgraph = await self.get_subgraph(id, loadedModules)
            return web.json_response(await self.sanitize_entry(subgraph))
