import { Literal, Static, Union } from "runtypes";

import { DataConnectionTypeLiteral } from "../dataConnectionTypes";
import { assertNever } from "../errors";
import { HqlAggregationFunction } from "../hql/types.js";
import { getNormalEnum } from "../runtypeEnums";

export const SqlDialectLiteral = Union(
  // eslint-disable-next-line tree-shaking/no-side-effects-in-initialization
  ...DataConnectionTypeLiteral.alternatives,
  Literal("duckdb"),
);
export type SqlDialect = Static<typeof SqlDialectLiteral>;
export const SqlDialect = getNormalEnum(SqlDialectLiteral);
export type PushdownSqlDialect = Exclude<SqlDialect, "transform">;

const SUPPORTS_SQLGEN = {
  alloydb: true,
  athena: true,
  bigquery: true,
  clickhouse: true,
  cloudsql__mysql: true,
  cloudsql__postgres: true,
  databricks: true,
  duckdb: true,
  mariadb: true,
  motherduck: true,
  mysql: true,
  postgres: true,
  redshift: true,
  snowflake: true,

  cloudsql: false,
  cloudsql__sqlserver: false,
  db2: false,
  dremio: false,
  materialize: false,
  prestodb: false,
  spark: false,
  sqlserver: false,
  starburst: false,
  transform: false,
  trino: false,
} satisfies Record<SqlDialect, boolean>;

type GetTrueKeys<T> = keyof {
  [K in keyof T as T[K] extends true ? K : never]: true;
};
export type SqlGenDialect = GetTrueKeys<typeof SUPPORTS_SQLGEN>;

export const isSqlGenDialect = (
  dialect: SqlDialect,
): dialect is SqlGenDialect => SUPPORTS_SQLGEN[dialect];

/**
 * This error type should be thrown from pushdown SQL generation to indicate
 * known/allowed gaps in functionality due to SQL dialect differences. For
 * example, this error is thrown for the median aggregation with MySQL/MariaDB
 * since they do not support computing medians.
 */
export class DialectUnsupportedError extends Error {
  constructor(message: string) {
    super(message);
  }
}
// eslint-disable-next-line no-useless-escape -- need to escape the brackets
export const QUOTING_CHARACTERS_REGEX = /["\[\]`]/g;

// Be sure to keep this in sync with `hex_shared/sql/dialects.py` in python-shared!
export function getQuotingCharacters(dialect: SqlDialect): [string, string] {
  switch (dialect) {
    case "postgres":
    case "cloudsql":
    case "cloudsql__postgres":
    case "prestodb":
    case "snowflake":
    case "athena":
    case "clickhouse":
    case "db2":
    case "dremio":
    case "redshift":
    case "transform":
    case "trino":
    case "duckdb":
    case "alloydb":
    case "starburst":
    case "materialize":
    case "motherduck":
      return [`"`, `"`];
    case "bigquery":
    case "databricks":
    case "mariadb":
    case "mysql":
    case "cloudsql__mysql":
    case "spark":
      return ["`", "`"];
    case "sqlserver":
    case "cloudsql__sqlserver":
      return ["[", "]"];
    default:
      assertNever(dialect, dialect);
  }
}

export function quoteSqlIdentifier(
  identifier: string,
  dialect: SqlDialect,
): string {
  const [openQuote, closeQuote] = getQuotingCharacters(dialect);
  return `${openQuote}${identifier}${closeQuote}`;
}

export function sqlGenSupportsAggForDialect(
  agg: HqlAggregationFunction,
  dialect: SqlDialect,
): boolean {
  switch (dialect) {
    // Dialects that support all aggregate functions
    case "clickhouse":
    case "databricks":
    case "duckdb":
    case "motherduck":
    case "postgres":
    case "cloudsql__postgres":
    case "alloydb":
    case "snowflake":
      return true;

    // Dialects that support all except for median
    case "athena":
    case "bigquery":
    case "mysql":
    case "cloudsql__mysql":
    case "mariadb":
    case "redshift":
      return agg !== "Median";

    // Dialects not support by sql-gen
    default:
      return false;
  }
}
