import { capitalize } from "lodash";
import {
  Array,
  Boolean,
  Literal,
  Null,
  Number,
  Optional,
  Record,
  Static,
  String,
  Union,
} from "runtypes";

import {
  ParamReference,
  SourceLocation,
  SourceRange,
} from "./cellReferencesV1";

// begin -- keep in sync with jinjasql/sqlsafe.py

export const CaseTransform = Record({
  type: Union(Literal("lower"), Literal("upper"), Literal("capitalize")),
});
export type CaseTransform = Static<typeof CaseTransform>;

export const ReplaceTransform = Record({
  type: Literal("replace"),
  old: String,
  new: String,
});
export type ReplaceTransform = Static<typeof ReplaceTransform>;

export const SimpleTransform = Union(CaseTransform, ReplaceTransform);
export type SimpleTransform = Static<typeof SimpleTransform>;

export const CompositeTransform = Record({
  type: Literal("composite"),
  transforms: Array(SimpleTransform),
});
export type CompositeTransform = Static<typeof CompositeTransform>;

export const KnownTransform = Union(SimpleTransform, CompositeTransform);
export type KnownTransform = Static<typeof KnownTransform>;

export function applyTransform(
  value: string,
  transform?: KnownTransform,
): string {
  if (transform == null) {
    return value;
  }
  switch (transform.type) {
    case "lower":
      return value.toLowerCase();
    case "upper":
      return value.toUpperCase();
    case "capitalize":
      return capitalize(value);
    case "replace":
      return value.replaceAll(transform.old, transform.new);
    case "composite":
      return transform.transforms.reduce(applyTransform, value);
  }
}

export const UnknownTransform = Record({
  type: Literal("unknown"),
});
export type UnknownTransform = Static<typeof UnknownTransform>;

export const SqlSafeTransform = Union(KnownTransform, UnknownTransform);
export type SqlSafeTransform = Static<typeof SqlSafeTransform>;

export const SqlSafeOutput = Record({
  name: String,
  range: SourceRange,
  transform: Optional(SqlSafeTransform),
});
export type SqlSafeOutput = Static<typeof SqlSafeOutput>;

// end -- keep in sync with jinjasql/sqlsafe.py

// begin -- keep in sync with jinjasql/parse.py

export const SqlSafeExpression = Record({
  range: SourceRange,
  outputs: Array(SqlSafeOutput),
});
export type SqlSafeExpression = Static<typeof SqlSafeExpression>;

export const SqlSafeTableReference = Record({
  type: Literal("sqlsafe"),
  index: Number,
});
export type SqlSafeTableReference = Static<typeof SqlSafeTableReference>;

export const StaticTableReference = Record({
  type: Literal("static"),
  name: String,
  // Each element of locations maps to a full reference to `name`, but that
  // reference may be split up across multiple source locations.
  locations: Array(Array(SourceLocation)),
});
export type StaticTableReference = Static<typeof StaticTableReference>;

export const DynamicTableReference = Record({
  type: Literal("dynamic"),
  segments: Array(Union(StaticTableReference, SqlSafeTableReference)),
});
export type DynamicTableReference = Static<typeof DynamicTableReference>;

export const TableReference = Union(
  StaticTableReference,
  DynamicTableReference,
);
export type TableReference = Static<typeof TableReference>;

export const JinjaSqlReferences = Record({
  jinjaReferencedParams: Array(ParamReference),
  sqlsafe: Array(SqlSafeExpression),
  tables: Array(TableReference),
  singleStatement: Boolean,
  onlySelects: Boolean,
  parseError: Union(Null, String),
});

// end -- keep in sync with jinjasql/parse.py
