import { ExtractBranchFromUnion } from "../discriminatedUnionUtils.js";
import { assertNever } from "../errors.js";
import { HqlNode } from "../hql/types.js";

import {
  CalcParseNode,
  CalcParseNodeOrError,
  CalcParseNodeWithoutErrors,
} from "./CalcParser.js";
import { CalcSchema } from "./calcTypes.js";

export function traverseAST<T extends CalcParseNodeOrError | HqlNode>(
  root: T,
  callback: (node: T) => void,
): void {
  callback(root);

  if (
    root.type !== "error" &&
    "additionalErrors" in root &&
    root.additionalErrors != null
  ) {
    for (const additionalError of root.additionalErrors) {
      callback(additionalError as T);
    }
  }

  switch (root.type) {
    case "binaryOp":
      traverseAST(root.left as T, callback);
      traverseAST(root.right as T, callback);
      break;
    case "unaryOp":
      traverseAST(root.left as T, callback);
      break;
    case "function":
      root.args.forEach((arg) => traverseAST(arg as T, callback));
      break;
    case "boolean":
    case "float":
    case "integer":
    case "null":
    case "str":
    case "column":
    case "parameterReference":
      // Do nothing, these are scalars and have no children
      break;
    case "error":
      break;
    default:
      assertNever(root, root);
  }
}

export function isErrorlessAst(
  root: CalcParseNodeOrError,
): root is CalcParseNodeWithoutErrors {
  let errorFound = false;
  traverseAST(root, (node) => {
    if (node.type === "error") {
      errorFound = true;
    }
  });
  return !errorFound;
}

/**
 *  Transforms the output from the parser (which contains metadata like location info)
 *  to a pure AST which only represents the language operations.
 *  Note that this mutates the AST _in-place_, so use with caution.
 */
export function cleanAstInPlace(root: CalcParseNodeWithoutErrors): CalcSchema {
  // Using some type logic, get all the properties that aren't in our base `CalcSchema`.
  // Because we're deailing with unions, we just pick one arbitrary member to use for this.
  type StartingType = ExtractBranchFromUnion<CalcParseNode, "type", "unaryOp">;
  type FinalType = ExtractBranchFromUnion<CalcSchema, "type", "unaryOp">;
  type ExtraProperties = Exclude<keyof StartingType, keyof FinalType>;

  const keysToRemove = [
    "additionalErrors",
    "start",
    "selfStart",
    "selfEnd",
    "end",
    "parent",
  ] as const;

  // Make sure our `keysToRemove` tuple above has all the extra keys we need to get rid of
  type MissingKeys = Exclude<ExtraProperties, (typeof keysToRemove)[number]>;
  type IsNever<T extends never> = T;
  type _AllKeysCovered = IsNever<MissingKeys>;

  traverseAST(root, (node) => {
    for (const key of keysToRemove) {
      delete (node as Partial<CalcParseNode>)[key];
    }
  });
  return root as unknown as CalcSchema;
}

export const nodeContainsOffset = (
  node: CalcParseNodeOrError,
  offset: number,
): boolean => node.start <= offset && offset < node.end;

export const nodeSelfContainsOffset = (
  node: CalcParseNodeOrError,
  offset: number,
): boolean => node.selfStart <= offset && offset < node.selfEnd;

export function getNodeAtOffset(
  root: CalcParseNodeOrError,
  offset: number,
): CalcParseNode | undefined {
  let matchingNode = undefined;
  traverseAST(root, (node) => {
    if (nodeContainsOffset(node, offset) && node.type !== "error") {
      matchingNode = node;
    }
  });
  return matchingNode;
}

/**
 * Finds the closest ancestor node to `startNode` that fufillis `predicate`.
 * If `predicate` is a typeguard, the return type will be narrowed.
 */
export function getClosestParentNode<S extends CalcParseNode>(
  startNode: CalcParseNode,
  predicate: (node: CalcParseNode) => node is S,
): S | undefined;
export function getClosestParentNode(
  startNode: CalcParseNode,
  predicate: (node: CalcParseNode) => boolean,
): CalcParseNode | undefined;
export function getClosestParentNode(
  startNode: CalcParseNode,
  predicate: (node: CalcParseNode) => boolean,
): CalcParseNode | undefined {
  let node: CalcParseNode | undefined = startNode;
  while (node != null) {
    if (predicate(node)) {
      return node;
    }
    node = node.parent;
  }
  return undefined;
}

export function getAllNodesOfType<
  T extends CalcParseNodeOrError | HqlNode,
  U extends T["type"],
>(root: T, type: U): ExtractBranchFromUnion<T, "type", U>[] {
  type Result = ExtractBranchFromUnion<T, "type", U>;

  const nodes: Result[] = [];
  traverseAST(root, (node) => {
    if (node.type === type) {
      nodes.push(node as Result);
    }
  });
  return nodes;
}

type CalcFunctionUsage = { [functionName: string]: number };
export function countFunctionUsage(
  root: CalcParseNodeOrError,
): CalcFunctionUsage {
  const acc: CalcFunctionUsage = {};
  traverseAST(root, (node) => {
    let name: string | undefined;
    if (node.type === "function") {
      name = `FUNC: ${node.name}`;
    } else if (node.type === "binaryOp") {
      name = `BINARYOP: ${node.op}`;
    } else if (node.type === "unaryOp") {
      name = `UNARYOP: ${node.op}`;
    }

    if (name != null) {
      acc[name] = (acc[name] ?? 0) + 1;
    }
  });
  return acc;
}
