import {
  DataConnectionType,
  DataSourceDatabaseId,
  DataSourceSchemaId,
  DataSourceTableId,
} from "@hex/common";
import { keyBy } from "lodash";
import { useCallback, useMemo } from "react";

import {
  importedOrgDataConnectionIdsSelector,
  useDataConnectionHexVersionLinksSelector,
} from "../../../../hex-version-multiplayer/state-hooks/dataConnectionHexVersionLinkStateHooks";
import { ORG_ID } from "../../../../orgs";
import { useProjectContext } from "../../../../util/projectContext";
import { useSessionContext } from "../../../../util/sessionContext";
import { DataConnectionSelectItem } from "../../../data/selector/DataConnectionSelectItemRenderer";

import {
  SqlCellSchemaFragment,
  useGetOrgDataConnectionsForSqlCellQuery,
  useGetSchemasForSqlCellQuery,
  useGetSharedProjectDataConnectionsForSqlCellQuery,
} from "./SqlLogicCell.generated";

export function useSchemaMap(): Record<string, SqlCellSchemaFragment> {
  const { data: schemaData } = useGetSchemasForSqlCellQuery();
  return useMemo(
    () => keyBy(schemaData?.dataConnectionSchemas, "connectionType"),
    [schemaData],
  );
}

export function useConnectionItems(): {
  loading: boolean;
  connectionItems: DataConnectionSelectItem[];
} {
  const { hexVersionId } = useProjectContext();
  const { appSessionId } = useSessionContext();
  const { data: connectionData, loading: sharedConnectionsLoading } =
    useGetSharedProjectDataConnectionsForSqlCellQuery({
      variables: { hexVersionId },
    });

  const importedOrgDataConnectionIds = useDataConnectionHexVersionLinksSelector(
    {
      selector: importedOrgDataConnectionIdsSelector,
    },
  );

  const { data: orgConnectionData, loading: orgDataLoading } =
    useGetOrgDataConnectionsForSqlCellQuery({
      variables: {
        orgId: ORG_ID,
        appSessionId,
      },
    });

  return useMemo(
    () => ({
      loading: sharedConnectionsLoading || orgDataLoading,
      connectionItems: [
        ...(connectionData?.sharedProjectDataConnections ?? []),
        ...(orgConnectionData?.orgDataConnections ?? []),
      ].map((conn) => ({
        ...conn,
        imported: importedOrgDataConnectionIds.includes(conn.id),
      })),
    }),
    [
      connectionData?.sharedProjectDataConnections,
      importedOrgDataConnectionIds,
      orgConnectionData?.orgDataConnections,
      orgDataLoading,
      sharedConnectionsLoading,
    ],
  );
}

interface TableNodeIdPathGetterArgs {
  connectionType: DataConnectionType;
  tableId: DataSourceTableId;
  schemaId: DataSourceSchemaId;
  databaseId: DataSourceDatabaseId;
}

export function useTableNodeIdPathGetter(): (
  args: TableNodeIdPathGetterArgs,
) => string[] | null {
  const schemaMap = useSchemaMap();

  return useCallback(
    (args: TableNodeIdPathGetterArgs) => {
      const { connectionType, databaseId, schemaId, tableId } = args;

      const multiDatabase =
        schemaMap[connectionType]?.connectionMetadata.multiDatabase;
      if (multiDatabase == null) return null;

      return multiDatabase
        ? [databaseId, schemaId, tableId]
        : [schemaId, tableId];
    },
    [schemaMap],
  );
}
