import {
  Cell,
  flexRender,
  HeaderGroup,
  Row,
  RowData,
  type Table as TableType,
} from '@tanstack/react-table'
import useStyles from './Table.styles'
import React, { useCallback, useRef } from 'react'
import { useVirtualizer } from '@tanstack/react-virtual'

export type TableProps<T extends RowData> = {
  table: TableType<T>
  onDragStart?: (row: Row<T>) => void
  onDragEnd?: (row: Row<T>) => void
  onRowClick?: (row: Row<T>) => void
  selectable?: boolean
  draggable?: boolean
  className?: string
  totalsRow?: React.ReactElement
  totalsRowColumnsOffset?: number
  virtualized?: boolean
  getRowMarked?: (row: Row<T>) => boolean
  enableRowDragging?: (row: Row<T>) => boolean
}

const Table = <T extends RowData>({
  table,
  onDragStart,
  onDragEnd,
  onRowClick,
  selectable = false,
  draggable = false,
  className,
  totalsRow,
  totalsRowColumnsOffset,
  virtualized = false,
  getRowMarked,
  enableRowDragging,
}: TableProps<T>) => {
  const { styles, cx } = useStyles()
  const tableContainerRef = useRef<HTMLDivElement>(null)
  const { rows } = table.getRowModel()
  const visibleColumns = table.getVisibleLeafColumns()
  const headerGroups = table.getHeaderGroups()

  const columnVirtualizer = useVirtualizer({
    count: visibleColumns.length,
    estimateSize: (index) => visibleColumns[index].getSize(),
    getScrollElement: () => tableContainerRef.current,
    horizontal: true,
    overscan: 3,
  })
  const virtualColumns = columnVirtualizer.getVirtualItems()

  let virtualPaddingLeft: number | undefined
  let virtualPaddingRight: number | undefined

  if (columnVirtualizer && virtualColumns?.length) {
    virtualPaddingLeft = virtualColumns[0]?.start ?? 0
    virtualPaddingRight =
      columnVirtualizer.getTotalSize() -
      (virtualColumns[virtualColumns.length - 1]?.end ?? 0)
  }

  const handleRowClick = useCallback(
    (row: Row<T>) => {
      if (onRowClick) {
        onRowClick(row)
      } else if (selectable) {
        row.getToggleSelectedHandler()
      }
    },
    [onRowClick, selectable],
  )

  const virtualLeftColumn = virtualPaddingLeft ? (
    <th style={{ display: 'flex', width: virtualPaddingLeft }} />
  ) : null
  const virtualRightColumn = virtualPaddingRight ? (
    <th style={{ display: 'flex', width: virtualPaddingRight }} />
  ) : null

  const renderHeaderCells = (headerGroup: HeaderGroup<T>, isLeaf: boolean) => {
    if (virtualized) {
      if (isLeaf) {
        return (
          <>
            {virtualLeftColumn}
            {virtualColumns.map((vc) => {
              const header = headerGroup.headers[vc.index]
              return (
                <th
                  key={header.id}
                  id={'head-' + header.id}
                  colSpan={header.colSpan}
                  style={{
                    ...header.column.columnDef.meta?.style,
                    width:
                      header.column.columnDef.meta?.width || header.getSize(),
                  }}
                >
                  {header.isPlaceholder
                    ? null
                    : flexRender(
                        header.column.columnDef.header,
                        header.getContext(),
                      )}
                </th>
              )
            })}
            {virtualRightColumn}
          </>
        )
      }
      return headerGroup.headers.map((header) => (
        <th
          key={header.id}
          id={'head-' + header.id}
          colSpan={header.colSpan}
          style={{ width: header.getSize() }}
        >
          {header.isPlaceholder
            ? null
            : flexRender(header.column.columnDef.header, header.getContext())}
        </th>
      ))
    }
    return headerGroup.headers.map((header) => {
      return (
        <th
          key={header.id}
          id={'head-' + header.id}
          colSpan={header.colSpan}
          style={{
            ...header.column.columnDef.meta?.style,
            width: header.column.columnDef.meta?.width,
          }}
        >
          {header.isPlaceholder
            ? null
            : flexRender(header.column.columnDef.header, header.getContext())}
        </th>
      )
    })
  }

  const renderRowCells = (visibleCells: Cell<T, unknown>[]) => {
    if (virtualized) {
      return (
        <>
          {virtualLeftColumn}
          {virtualColumns.map((vc) => {
            const cell = visibleCells[vc.index]
            return (
              <td
                key={cell.id}
                id={`id-${cell.id}`}
                style={{
                  ...cell.column.columnDef.meta?.style,
                  width: cell.column.getSize(),
                }}
              >
                {flexRender(cell.column.columnDef.cell, cell.getContext())}
              </td>
            )
          })}
          {virtualRightColumn}
        </>
      )
    }

    return visibleCells.map((cell) => {
      return (
        <td
          key={cell.id}
          id={`id-${cell.id}`}
          style={cell.column.columnDef.meta?.style}
        >
          {flexRender(cell.column.columnDef.cell, cell.getContext())}
        </td>
      )
    })
  }

  return (
    <div
      className={cx(styles.container)}
      ref={tableContainerRef}
      id="table-container"
    >
      <table
        className={cx(
          styles.table,
          virtualized && styles.virtualizedTable,
          className,
        )}
      >
        <thead>
          {headerGroups.map((headerGroup) => {
            const isLeaf = headerGroup.depth === headerGroups.length - 1
            const totalRowColumnsOffsetSize =
              isLeaf && totalsRowColumnsOffset
                ? headerGroup.headers[totalsRowColumnsOffset].getStart()
                : null
            return (
              <React.Fragment key={headerGroup.id}>
                <tr id={headerGroup.id}>
                  {renderHeaderCells(headerGroup, isLeaf)}
                </tr>
                {isLeaf &&
                  totalsRow &&
                  React.cloneElement(totalsRow, {
                    labelOffset: totalRowColumnsOffsetSize,
                  })}
              </React.Fragment>
            )
          })}
        </thead>
        <tbody>
          {rows.map((row) => {
            const visibleCells = row.getVisibleCells()
            const isMarked = getRowMarked?.(row)
            const enableDragging = enableRowDragging
              ? enableRowDragging(row)
              : true

            return (
              <tr
                key={row.id}
                id={`rowid-${row.id}`}
                draggable={draggable && enableDragging}
                onDragStart={() => onDragStart?.(row)}
                onDragEnd={() => onDragEnd?.(row)}
                onClick={() => handleRowClick(row)}
                className={cx(
                  row.getIsSelected() ? styles.selected : '',
                  isMarked ? styles.marked : '',
                  onRowClick ? styles.clickable : '',
                )}
              >
                {renderRowCells(visibleCells)}
              </tr>
            )
          })}
        </tbody>
      </table>
    </div>
  )
}

export default Table
