import {
    PropsWithChildren,
    useContext,
    useEffect,
    useMemo,
    useState,
} from "react";
import { useReactFlow, type Node, type Edge } from "reactflow";
import DAGContext, { type DAGContextValue } from "./DAGContext";
import { LayoutNodesProps } from "./types";

/**
 * DAGProvider
 * This is a custom context provider that handles propagating the selected nodes ancestors and descendants to the entities on the DAG diagram.
 * NOTE: This context must appear _within_ the ReactFlow context provider.
 */
const DAGProvider = ({
    children,
    selectedNode,
    setLayoutNodes,
}: PropsWithChildren<{
    selectedNode: Node | undefined;
    setLayoutNodes: (props: LayoutNodesProps) => void;
}>) => {
    const { getEdges } = useReactFlow();
    const [selectedSourceHandle, setSelectedSourceHandle] = useState("");

    useEffect(() => {
        setSelectedSourceHandle("");
    }, [selectedNode?.id]);

    const contextValue: DAGContextValue = useMemo(() => {
        return {
            setLayoutNodes,
            selectedNode,
            setSelectedSourceHandle,
            selectedSourceHandle,
            ancestors: findAncestors(selectedNode?.id, getEdges()),
            descendants: findDescendants(
                selectedNode?.id,
                getEdges(),
                selectedSourceHandle
            ),
        };
    }, [selectedNode, selectedSourceHandle, setSelectedSourceHandle]);

    return (
        <DAGContext.Provider value={contextValue}>
            {children}
        </DAGContext.Provider>
    );
};

export const useDAGContext = () => useContext(DAGContext);

export const useSetNodes = () => {
    const { setLayoutNodes } = useDAGContext();
    return {
        setLayoutNodes,
    };
};

export const useSelectedNode = () => {
    const { selectedNode, setSelectedSourceHandle, selectedSourceHandle } =
        useDAGContext();
    return {
        selectedNode,
        selectedSourceHandle,
        setSelectedSourceHandle,
    };
};

export const useIsFocusedNode = (nodeId: string | undefined): boolean => {
    const ctx = useDAGContext();

    // if there is no selected node anywhere in the DAG, all nodes should have the "focused" state
    const hasSelectedNode = ctx.selectedNode;

    // If the nodeId is provided, and is either the selected node, or is present in the ancestors or descendants set, then we consider the node "focused"
    const isNodeSelected =
        !!nodeId &&
        (ctx.selectedNode?.id === nodeId ||
            ctx.ancestors.has(nodeId) ||
            ctx.descendants.has(nodeId));

    return !hasSelectedNode || isNodeSelected;
};

// Breadth-first search to retrieve the ancestors of a node
const findAncestors = (
    startNodeId: string | undefined,
    edges: Edge[]
): Set<string> => {
    const queue = [startNodeId];
    const visited: Set<string> = new Set();

    if (!startNodeId) {
        return visited;
    }

    while (queue.length) {
        const nodeId = queue.shift();

        // Look for nodes that point to this one (ancestors)
        const incomingEdges = edges.filter((edge) => edge.target === nodeId);
        incomingEdges.forEach((edge) => {
            if (!visited.has(edge.source)) {
                visited.add(edge.source);
                queue.push(edge.source);
            }
        });
    }

    return visited;
};

// Breadth-first search to retrieve the ancestors of a node
const findDescendants = (
    startNodeId: string | undefined,
    edges: Edge[],
    startSourceHandle?: string
): Set<string> => {
    const queue = [startNodeId];
    const visited: Set<string> = new Set();

    if (!startNodeId) {
        return visited;
    }

    if (queue.length) {
        const nodeId = queue.shift();

        // Look for nodes that this one points to (descendants)
        const outgoingEdges = edges.filter(
            (edge) =>
                edge.source === nodeId &&
                (startSourceHandle
                    ? edge.sourceHandle === startSourceHandle
                    : true)
        );
        outgoingEdges.forEach((edge) => {
            if (!visited.has(edge.target)) {
                visited.add(edge.target);
                queue.push(edge.target);
            }
        });
    }

    while (queue.length) {
        const nodeId = queue.shift();

        // Look for nodes that this one points to (descendants)
        const outgoingEdges = edges.filter((edge) => edge.source === nodeId);
        outgoingEdges.forEach((edge) => {
            if (!visited.has(edge.target)) {
                visited.add(edge.target);
                queue.push(edge.target);
            }
        });
    }

    return visited;
};

export default DAGProvider;
