import { useMemo } from 'react';
import { Edge } from '@xyflow/react';
import {
  HierarchyNode,
  HierarchyPointNode,
  stratify,
  tree,
} from 'd3-hierarchy';

import { ExpandCollapseNode } from './types';

export type UseExpandCollapseOptions = {
  layoutNodes?: boolean;
  treeWidth?: number;
  treeHeight?: number;
};
const colors = {
  root: '#001f3f',
  workflowItem: '#0074D9', 
  phase: '#7FDBFF',  
  step: '#39CCCC', 
};

function isHierarchyPointNode(
  pointNode:
    | HierarchyNode<ExpandCollapseNode>
    | HierarchyPointNode<ExpandCollapseNode>
): pointNode is HierarchyPointNode<ExpandCollapseNode> {
  return (
    typeof (pointNode as HierarchyPointNode<ExpandCollapseNode>).x ===
      'number' &&
    typeof (pointNode as HierarchyPointNode<ExpandCollapseNode>).y === 'number'
  );
}

function useExpandCollapse(
  nodes: ExpandCollapseNode[],
  edges: Edge[],
  {
    layoutNodes = true,
    treeWidth = 220,
    treeHeight = 100,
  }: UseExpandCollapseOptions = {}
): { nodes: ExpandCollapseNode[]; edges: Edge[] } {
  return useMemo(() => {
    const hierarchy = stratify<ExpandCollapseNode>()
      .id((d) => d.id)
      .parentId(
        (d: ExpandCollapseNode) =>
          edges.find((e: Edge) => e.target === d.id)?.source
      )(nodes);

    hierarchy.descendants().forEach((d) => {
      d.data.data.expandable = !!d.children?.length;
      d.children = d.data.data.expanded ? d.children : undefined;
        if (!d.parent) {
          d.data.data.color = colors.root;
        } else if (d.parent.id?.startsWith('process-')) {
          d.data.data.color = colors.workflowItem;
        } else if (d.parent.id?.includes('-phase-')) {
          d.data.data.color = colors.phase;
        } else {
          d.data.data.color = colors.step;
        }
      });
      
    const layout = tree<ExpandCollapseNode>()
  .nodeSize([treeWidth, treeHeight])
  .separation((a, b) => {
    const aIsPhase = a.parent?.id?.includes('-phase-');
    const bIsPhase = b.parent?.id?.includes('-phase-');

    if (aIsPhase || bIsPhase) {
      return 0;
    }
    return 1;
  });

    const root = layoutNodes ? layout(hierarchy) : hierarchy;

    const customPhaseTreeHeight = 100;

    // allows vertical positioning for steps 
    root.descendants().forEach((node) => {
      if (node.parent?.id?.includes('-phase-')) {
        const parentY = node.parent?.y ?? 0;
        const phaseSiblings = node.parent.children || [];
        const siblingIndex = phaseSiblings.indexOf(node);
        node.y = parentY + (siblingIndex + 1) * customPhaseTreeHeight;
      }
    });
    
    
    return {
      nodes: root.descendants().map((d) => ({
        ...d.data,
        data: { ...d.data.data },
        type: 'custom',
        position: isHierarchyPointNode(d)
          ? { x: d.x, y: d.y }
          : d.data.position,
      })),
      edges: edges.filter(
        (edge) =>
          root.find((h) => h.id === edge.source) &&
          root.find((h) => h.id === edge.target)
      ),
    };
  }, [nodes, edges, layoutNodes, treeWidth, treeHeight]);
}

export default useExpandCollapse;
