import { GridPointLabel } from 'pages/projects/inspection-sheet/interfaces'
import { Matrix4, Vector2, Vector2Like, Vector3 } from 'three'

import { PointArray } from 'interfaces/attribute'
import { MinimumAreaBoundary } from 'interfaces/diagram'
import {
  GridOrderDirection,
  GridType,
  InspectionItemGrid,
  IntervalsAxisConfigValue,
  IntervalsConfig,
} from 'interfaces/inspectionItemGrid'
import { Polygon } from 'interfaces/shape'

import { linspace, millimeterToMeter, transposeMatrix } from 'services/Util'

import { GRID_POINT_MINIMUM_INTERVAL } from './constants'

export interface EdgeVector {
  id: string
  bottomId: string
  center: Vector3
  up: {
    direction: Vector3
    point: Vector3
    distance: number
  }
  right: {
    direction: Vector3
    point: Vector3
    distance: number
  }
  forward: {
    direction: Vector3
    point: Vector3
    distance: number
  }
  transform: Matrix4
}

/**
 * Generate grid points for the center-based grid.
 *
 * @param interval Intervals config containing the max and interval values.
 * @returns Grid points.
 */
const generateCenterBasedGridPoints = (
  interval: Omit<IntervalsConfig, 'topPlaneId' | 'bottomPlaneId'>,
  for3D: boolean,
): PointArray[] => {
  const yAxisInversion = for3D ? -1 : 1

  if (
    !interval.longAxis.interval.value ||
    !interval.shortAxis.interval.value ||
    !interval.longAxis.total ||
    !interval.shortAxis.total
  ) {
    return []
  }

  const aLength = interval.longAxis.interval.max * 2
  const aAxisCount = Math.ceil((aLength - interval.longAxis.interval.value) / interval.longAxis.interval.value)
  const aStart =
    (aLength - interval.longAxis.interval.value) / 2.0 -
    (aAxisCount / 2 - 1) * interval.longAxis.interval.value -
    interval.longAxis.offset
  const aEnd =
    (aLength + interval.longAxis.interval.value) / 2.0 +
    (aAxisCount / 2 - 1) * interval.longAxis.interval.value -
    interval.longAxis.offset

  const bLength = interval.shortAxis.interval.max * 2
  const bAxisCount = Math.ceil((bLength - interval.shortAxis.interval.value) / interval.shortAxis.interval.value)
  const bStart =
    (bLength - interval.shortAxis.interval.value) / 2.0 -
    (bAxisCount / 2 - 1) * interval.shortAxis.interval.value -
    interval.shortAxis.offset
  const bEnd =
    (bLength + interval.shortAxis.interval.value) / 2.0 +
    (bAxisCount / 2 - 1) * interval.shortAxis.interval.value -
    interval.shortAxis.offset

  const aPoints = linspace(aStart, aEnd, aAxisCount)
  const bPoints = linspace(bStart, bEnd, bAxisCount)

  const points = bPoints.reduce<PointArray[]>(
    (acc, b) => acc.concat(aPoints.map((a) => [a, b * yAxisInversion, 0])),
    [],
  )

  // Depending on orderDirection, re-order the points.
  // The default is HORIZONTAL so only need to handle the vertical case.
  if (interval.orderDirection === GridOrderDirection.Vertical) {
    return transposeMatrix(points, bAxisCount, aAxisCount)
  }

  return points
}

/**
 * Generate grid points for the edge-based grid.
 *
 * @param interval Intervals config containing the edge distance and point count.
 * @returns Grid points.
 */
const generateEdgeBasedGridPoints = (
  interval: Omit<IntervalsConfig, 'topPlaneId' | 'bottomPlaneId'>,
  for3D: boolean,
): PointArray[] => {
  const yAxisInversion = for3D ? -1 : 1

  // Generate points from edges with specified count and edge distance
  // If there's only 1  point, generate it at the center.
  const longPoints: number[] = []
  if (interval.longAxis.pointCount.value === 1) {
    longPoints.push(interval.longAxis.total / 2)
  } else {
    longPoints.push(
      ...linspace(
        interval.longAxis.edgeDistance.value,
        interval.longAxis.total - interval.longAxis.edgeDistance.value,
        interval.longAxis.pointCount.value,
      ),
    )
  }

  const shortPoints: number[] = []
  if (interval.shortAxis.pointCount.value === 1) {
    shortPoints.push(interval.shortAxis.total / 2)
  } else {
    shortPoints.push(
      ...linspace(
        interval.shortAxis.edgeDistance.value,
        interval.shortAxis.total - interval.shortAxis.edgeDistance.value,
        interval.shortAxis.pointCount.value,
      ),
    )
  }

  const points = shortPoints.reduce<PointArray[]>(
    (acc, b) =>
      acc.concat(
        longPoints.map((a) => [a - interval.longAxis.offset, (b - interval.shortAxis.offset) * yAxisInversion, 0]),
      ),
    [],
  )

  // Depending on orderDirection, re-order the points.
  // The default is HORIZONTAL so only need to handle the vertical case.
  if (interval.orderDirection === GridOrderDirection.Vertical) {
    return transposeMatrix(points, interval.shortAxis.pointCount.value, interval.longAxis.pointCount.value)
  }

  return points
}

/**
 * Generate distance labels for working grids.
 * The point coordinates will be localized to the plane.
 *
 * @param interval Intervals config containing the max and interval values.
 */
export const generateGridPoints = (
  interval: Omit<IntervalsConfig, 'topPlaneId' | 'bottomPlaneId'>,
  for3D = false,
): PointArray[] => {
  if (interval.type === GridType.CenterBased) {
    return generateCenterBasedGridPoints(interval, for3D)
  }

  if (interval.type === GridType.EdgeBased) {
    return generateEdgeBasedGridPoints(interval, for3D)
  }

  return []
}

/**
 * Calculate the maximum number of points that can be generated based on the total length, edge distance, and interval.
 *
 * @param total Total length of the axis.
 * @param edgeDistance Distance from the edge to the first point.
 * @returns Maximum number of points.
 */
export const calculateMaxPointCount = (total: number, edgeDistance: number): number =>
  Math.floor((total - edgeDistance * 2 + GRID_POINT_MINIMUM_INTERVAL) / GRID_POINT_MINIMUM_INTERVAL)

/**
 * Calculate the actual interval between points based on the total length, edge distance, and point count.
 *
 * @param total Total length of the axis.
 * @param edgeDistance Distance from the edge to the first point.
 * @param pointCount Number of points to generate.
 * @returns Interval between points.
 */
export const calculateInterval = (total: number, edgeDistance: number, pointCount: number): number =>
  pointCount === 1 ? 0 : Math.round((total - edgeDistance * 2) / (pointCount - 1))

/**
 * Validate the axis configuration value to be within the min and max range.
 *
 * @param value Value to validate.
 * @param min Minimum value.
 * @param max Maximum value.
 * @returns True if the value is within the range.
 */
export const validateAxisConfigValue = ({ min, max, value }: IntervalsAxisConfigValue): boolean =>
  value >= min && value <= max

/**
 * Validate the interval configuration values to be within the min and max range.
 *
 * @param interval Interval configuration to validate.
 * @returns True if the interval configuration is valid.
 */
export const validateIntervalConfig = (interval: IntervalsConfig): boolean =>
  (interval.type === GridType.CenterBased &&
    validateAxisConfigValue(interval.longAxis.interval) &&
    validateAxisConfigValue(interval.shortAxis.interval)) ||
  (interval.type === GridType.EdgeBased &&
    validateAxisConfigValue(interval.longAxis.pointCount) &&
    validateAxisConfigValue(interval.shortAxis.pointCount) &&
    validateAxisConfigValue(interval.longAxis.edgeDistance) &&
    validateAxisConfigValue(interval.shortAxis.edgeDistance))

export const isWhichPlacement = (
  p1: Vector3,
  p2: Vector3,
  center: Vector3,
  up: Vector3,
  right: Vector3,
): 'top' | 'bottom' | 'right' | 'left' | null => {
  const cp1 = new Vector3().subVectors(p1, center)
  const cp2 = new Vector3().subVectors(p2, center)
  const cp1DotUp = cp1.dot(up)
  const cp2DotUp = cp2.dot(up)
  const cp1DotRight = cp1.dot(right)
  const cp2DotRight = cp2.dot(right)

  if (cp1DotUp > 0 && cp2DotUp > 0) return 'top'
  if (cp1DotUp < 0 && cp2DotUp < 0) return 'bottom'
  if (cp1DotRight > 0 && cp2DotRight > 0) return 'right'
  if (cp1DotRight < 0 && cp2DotRight < 0) return 'left'

  return null
}

export const generateEdgePlanes = (
  interval: IntervalsConfig,
  minBoundary: MinimumAreaBoundary,
  topShape: Polygon,
): {
  shortAxis: [number, number][][]
  longAxis: [number, number][][]
  innerBox: [number, number][]
  shapeId: string
} | null => {
  // Between the 4 vertices, find the edge for short and long axis
  const vertVecs = minBoundary.vertices.map((v) => new Vector2(...v))
  const distances = vertVecs
    .map((vertex, index) => {
      const now = vertex
      const next = vertVecs[(index + 1) % vertVecs.length]
      return {
        points: [now, next],
        distance: now.distanceTo(next),
        extrudeDirection: new Vector2(),
      }
    })
    .sort((a, b) => b.distance - a.distance)

  // First two are the long axis, last two are the short axis
  const longAxisEdges = distances.slice(0, 2)
  const shortAxisEdges = distances.slice(2, 4)

  if (longAxisEdges.length !== 2 || shortAxisEdges.length !== 2) return null

  // Find which edge connects to the first edge of the short axis
  const shortConnected = shortAxisEdges[0]
  const longShortConnector = longAxisEdges.reduce<{
    edgeIndex: number
    posIndex: number
    common: Vector2
    extrudeDirection: Vector2
  } | null>((prev, edge, index) => {
    // If already found, return
    if (prev !== null) return prev

    if (edge.points[0].equals(shortConnected.points[0])) {
      return {
        edgeIndex: index,
        posIndex: 0,
        common: edge.points[0],
        extrudeDirection: new Vector2().subVectors(edge.points[1], edge.points[0]).normalize(),
      }
    }

    if (edge.points[1].equals(shortConnected.points[0])) {
      return {
        edgeIndex: index,
        posIndex: 1,
        common: edge.points[1],
        extrudeDirection: new Vector2().subVectors(edge.points[0], edge.points[1]).normalize(),
      }
    }

    return prev
  }, null)

  if (!longShortConnector) return null

  // Set the extrude direction for the edges
  shortConnected.extrudeDirection = longShortConnector.extrudeDirection
  shortAxisEdges[1].extrudeDirection = longShortConnector.extrudeDirection.clone().negate()

  const longConnected = longAxisEdges[longShortConnector.edgeIndex]
  longConnected.extrudeDirection = shortConnected.points[longShortConnector.posIndex]
    .clone()
    .sub(shortConnected.points[(longShortConnector.posIndex + 1) % 2])
    .normalize()
  longAxisEdges[(longShortConnector.edgeIndex + 1) % 2].extrudeDirection = longConnected.extrudeDirection
    .clone()
    .negate()

  // ## Width of the box on the edges
  // An important thing to remember, the edge planes are labelled by their placement, either on the long or short axis.
  // However, their size depends on the other axis.
  // eg. The width of the short axis edge plane depends on setting of the long axis edge distance.
  const shortAxisWidth = millimeterToMeter(interval.longAxis.edgeDistance.value)
  const longAxisWidth = millimeterToMeter(interval.shortAxis.edgeDistance.value)

  // Generate edge planes
  const edgePlanes = {
    shortAxis: [
      [
        shortAxisEdges[0].points[0],
        shortAxisEdges[0].points[1],
        shortAxisEdges[0].points[1]
          .clone()
          .add(shortAxisEdges[0].extrudeDirection.clone().multiplyScalar(shortAxisWidth)),
        shortAxisEdges[0].points[0]
          .clone()
          .add(shortAxisEdges[0].extrudeDirection.clone().multiplyScalar(shortAxisWidth)),
      ],
      [
        shortAxisEdges[1].points[0],
        shortAxisEdges[1].points[1],
        shortAxisEdges[1].points[1]
          .clone()
          .add(shortAxisEdges[1].extrudeDirection.clone().multiplyScalar(shortAxisWidth)),
        shortAxisEdges[1].points[0]
          .clone()
          .add(shortAxisEdges[1].extrudeDirection.clone().multiplyScalar(shortAxisWidth)),
      ],
    ],
    longAxis: [
      [
        longAxisEdges[0].points[0],
        longAxisEdges[0].points[1],
        longAxisEdges[0].points[1].clone().add(longAxisEdges[0].extrudeDirection.clone().multiplyScalar(longAxisWidth)),
        longAxisEdges[0].points[0].clone().add(longAxisEdges[0].extrudeDirection.clone().multiplyScalar(longAxisWidth)),
      ],
      [
        longAxisEdges[1].points[0],
        longAxisEdges[1].points[1],
        longAxisEdges[1].points[1].clone().add(longAxisEdges[1].extrudeDirection.clone().multiplyScalar(longAxisWidth)),
        longAxisEdges[1].points[0].clone().add(longAxisEdges[1].extrudeDirection.clone().multiplyScalar(longAxisWidth)),
      ],
    ],
  }
  return {
    shortAxis: edgePlanes.shortAxis.map((points) => points.map((p) => p.toArray())),
    longAxis: edgePlanes.longAxis.map((points) => points.map((p) => p.toArray())),
    innerBox: [],
    shapeId: topShape.shape_id,
  }
}

export const mergeIntervalsWithSavedConfig = (
  oldConfig: IntervalsConfig,
  savedConfig: InspectionItemGrid,
): IntervalsConfig => {
  const merged: IntervalsConfig = { ...oldConfig }

  merged.type = savedConfig.grid_type || merged.type
  merged.orderDirection = savedConfig.grid_order_direction || merged.orderDirection
  merged.longAxis.interval.value = savedConfig.intervals?.long_axis || merged.longAxis.interval.value
  merged.shortAxis.interval.value = savedConfig.intervals?.short_axis || merged.shortAxis.interval.value
  merged.longAxis.edgeDistance.value = savedConfig.distances_from_edge?.long_axis || merged.longAxis.edgeDistance.value
  merged.shortAxis.edgeDistance.value =
    savedConfig.distances_from_edge?.short_axis || merged.shortAxis.edgeDistance.value
  merged.longAxis.pointCount.max = calculateMaxPointCount(merged.longAxis.total, merged.longAxis.edgeDistance.value)
  merged.longAxis.pointCount.value = savedConfig.number_of_points?.long_axis || merged.longAxis.pointCount.value
  merged.shortAxis.pointCount.value = savedConfig.number_of_points?.short_axis || merged.shortAxis.pointCount.value
  merged.shortAxis.pointCount.max = calculateMaxPointCount(merged.shortAxis.total, merged.shortAxis.edgeDistance.value)

  return merged
}

/**
 * Find the most common value in an array.
 *
 * @param arr Array of numbers.
 * @returns Most common value and its index within the array.
 */
export const mostCommonValue = (arr: number[]): number | null => {
  if (arr.length === 0) return null

  const counts = arr.reduce<{ [key: number]: number }>((acc, val) => {
    if (val in acc) {
      acc[val] += 1
    } else {
      acc[val] = 1
    }
    return acc
  }, {})

  const sorted = Object.entries(counts).sort((a, b) => b[1] - a[1])

  // Can't call it common if there's only 1, unless it's the only value
  if (sorted.length === 1) return parseInt(sorted[0][0], 10)
  if (sorted[0][1] === 1) return null

  return parseInt(sorted[0][0], 10)
}

/**
 * Find the points on each corner of the grid (top-left, top-right, bottom-left, bottom-right).
 *
 * An important note is that this function assumes HTMLCanvasElement's coordinate system
 * with inverted y-axis.
 *
 * @param points Grid points to search.
 */
export const findAllCorners = (
  points: Vector2Like[],
): {
  topLeft: { point: Vector2Like; index: number }
  topRight: { point: Vector2Like; index: number }
  bottomLeft: { point: Vector2Like; index: number }
  bottomRight: { point: Vector2Like; index: number }
} | null => {
  const rounded = points.map((p) => ({ x: Math.round(p.x), y: Math.round(p.y) }))
  let topLeft: { point: Vector2Like; index: number } = { point: rounded[0], index: 0 }
  let topRight: { point: Vector2Like; index: number } = { point: rounded[0], index: 0 }
  let bottomLeft: { point: Vector2Like; index: number } = { point: rounded[0], index: 0 }
  let bottomRight: { point: Vector2Like; index: number } = { point: rounded[0], index: 0 }

  rounded.forEach((point, index) => {
    if (point.x <= topLeft.point.x && point.y <= topLeft.point.y) {
      topLeft = { point, index }
    }
    if (point.x >= topRight.point.x && point.y <= topRight.point.y) {
      topRight = { point, index }
    }
    if (point.x <= bottomLeft.point.x && point.y >= bottomLeft.point.y) {
      bottomLeft = { point, index }
    }
    if (point.x >= bottomRight.point.x && point.y >= bottomRight.point.y) {
      bottomRight = { point, index }
    }
  })

  return { topLeft, topRight, bottomLeft, bottomRight }
}

/**
 * Find how the grid point's are being laid out on a grid.
 *
 * An important note is that this function assumes HTMLCanvasElement's coordinate system
 * with inverted y-axis.
 *
 * @param corners Corners of the grid.
 * @returns Returns an array of 2 values, each representing the direction of the grid in that specific order.
 *          eg: ['+x', '-y'] means the grid is laid out from left to right and bottom to top.
 */
type GridPointDirection = '+x' | '-x' | '+y' | '-y'
export const findGridDirection = (
  corners: NonNullable<ReturnType<typeof findAllCorners>>,
): [GridPointDirection, GridPointDirection] => {
  const { topLeft, topRight, bottomLeft, bottomRight } = corners
  const isSingleCol = topRight.index === topLeft.index && bottomRight.index === bottomLeft.index

  // top left
  if (topLeft.index === 0) {
    // If the top-right comes before the bottom-left, it's a left-to-right, top-to-bottom grid
    if (topRight.index <= bottomLeft.index && !isSingleCol) {
      return ['+x', '+y']
    }

    // Otherwise, it's a top-to-bottom, left-to-right grid
    return ['+y', '+x']
  }

  // top right
  if (topRight.index === 0) {
    // If the top-left comes before the bottom-right, it's a right-to-left, top-to-bottom grid
    if (topLeft.index <= bottomRight.index && !isSingleCol) {
      return ['-x', '+y']
    }

    // Otherwise, it's a top-to-bottom, right-to-left grid
    return ['+y', '-x']
  }

  // bottom left
  if (bottomLeft.index === 0) {
    // If the bottom-right comes before the top-left, it's a left-to-right, bottom-to-top grid
    if (bottomRight.index <= topLeft.index) {
      return ['+x', '-y']
    }

    // Otherwise, it's a bottom-to-top, left-to-right grid
    return ['-y', '+x']
  }

  // bottom right
  // If the bottom-left comes before the top-right, it's a right-to-left, bottom-to-top grid
  if (bottomLeft.index <= topRight.index) {
    return ['-x', '-y']
  }

  // Otherwise, it's a bottom-to-top, right-to-left grid
  return ['-y', '-x']
}

/**
 * Get the offset of the grid points on a single row.
 *
 * @param points Grid points to search.
 * @returns
 */
export const getSingleRowOffset = (points: Vector2Like[]): number => {
  const common = mostCommonValue(points.map((p) => p.y))
  if (common === null) return points[0].y

  return common
}

/**
 * Get the offset of the grid points on a single column.
 *
 * @param points Grid points to search.
 * @returns
 */
export const getSingleColOffset = (points: Vector2Like[]): number => {
  const common = mostCommonValue(points.map((p) => p.x))
  if (common === null) return points[0].x

  return common
}

export const generateGridLinesFromIntervalGridPoints = (
  intervalGridPoints: number[][],
  orientation: 'vertical' | 'horizontal' | undefined,
  scale: number,
) => {
  const gridPoints = intervalGridPoints.map((point) =>
    new Vector2(point[orientation === 'vertical' ? 1 : 0], point[orientation === 'vertical' ? 0 : 1]).multiplyScalar(
      scale,
    ),
  )

  const rows = new Set<number>()
  const cols = new Set<number>()

  gridPoints.forEach((point) => {
    rows.add(Math.floor(point.y))
    cols.add(Math.floor(point.x))
  })

  return {
    rows: Array.from(rows),
    cols: Array.from(cols),
  }
}

export const calculateColumnDiff = (gridColumnized: GridPointLabel[][]) => {
  const colDiff: number[][] = []
  gridColumnized.forEach((row, rowIdx) => {
    let prev = row[0]
    colDiff.push([])
    row.forEach((point, colIdx) => {
      if (colIdx === 0) return
      colDiff[rowIdx].push(Math.round(point.x - prev.x))
      prev = point
    })
  })

  return colDiff
}
