All files / web/src/app/vision-training/train/components DistributionHeatmap.tsx

0% Statements 0/153
0% Branches 0/1
0% Functions 0/1
0% Lines 0/153

Press n or j to go to the next uncovered block, b, p or k for the previous block.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154                                                                                                                                                                                                                                                                                                                   
'use client'

import { useState } from 'react'
import { css } from '../../../../../styled-system/css'

interface DistributionHeatmapProps {
  /** Count of images per digit (0-9) */
  digitCounts: Record<number, number>
  /** Minimum count to be considered "sufficient" */
  minThreshold?: number
  /** Compact mode for mobile */
  compact?: boolean
}

/**
 * Compact heatmap showing distribution of training images across digits 0-9.
 * - Blue intensity = relative count (darker = more)
 * - Red = insufficient samples
 * - Hover/tap to see exact count
 */
export function DistributionHeatmap({
  digitCounts,
  minThreshold = 3,
  compact = false,
}: DistributionHeatmapProps) {
  const [hoveredDigit, setHoveredDigit] = useState<number | null>(null)

  const maxCount = Math.max(...Object.values(digitCounts), 1)
  const digits = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

  // Calculate intensity (0-1) for color
  const getIntensity = (count: number) => {
    if (maxCount === 0) return 0
    return Math.min(count / maxCount, 1)
  }

  // Get background color based on count
  const getCellColor = (count: number, isHovered: boolean) => {
    const isInsufficient = count < minThreshold

    if (isInsufficient) {
      return isHovered ? 'red.400' : 'red.500/70'
    }

    const intensity = getIntensity(count)
    // Map intensity to blue shades
    if (intensity > 0.8) return isHovered ? 'blue.400' : 'blue.500'
    if (intensity > 0.6) return isHovered ? 'blue.500' : 'blue.600'
    if (intensity > 0.4) return isHovered ? 'blue.600' : 'blue.700'
    if (intensity > 0.2) return isHovered ? 'blue.700' : 'blue.800'
    return isHovered ? 'blue.800' : 'blue.900'
  }

  return (
    <div data-component="distribution-heatmap">
      {/* Heatmap grid */}
      <div
        className={css({
          display: 'flex',
          gap: compact ? '2px' : '3px',
        })}
      >
        {digits.map((digit) => {
          const count = digitCounts[digit] || 0
          const isInsufficient = count < minThreshold
          const isHovered = hoveredDigit === digit

          return (
            <div
              key={digit}
              data-digit={digit}
              onMouseEnter={() => setHoveredDigit(digit)}
              onMouseLeave={() => setHoveredDigit(null)}
              onClick={() => setHoveredDigit(hoveredDigit === digit ? null : digit)}
              className={css({
                position: 'relative',
                display: 'flex',
                flexDirection: 'column',
                alignItems: 'center',
                cursor: 'pointer',
                transition: 'transform 0.1s ease',
                _hover: { transform: 'scale(1.1)' },
              })}
            >
              {/* Color cell */}
              <div
                className={css({
                  width: compact ? '20px' : '24px',
                  height: compact ? '16px' : '20px',
                  borderRadius: 'sm',
                  bg: getCellColor(count, isHovered),
                  border: '1px solid',
                  borderColor: isInsufficient
                    ? 'red.400/50'
                    : isHovered
                      ? 'blue.400/50'
                      : 'transparent',
                  transition: 'all 0.15s ease',
                })}
              />
              {/* Digit label */}
              <span
                className={css({
                  fontSize: compact ? '2xs' : 'xs',
                  color: isInsufficient ? 'red.400' : 'gray.500',
                  fontWeight: isInsufficient ? 'bold' : 'normal',
                  mt: '2px',
                  fontFamily: 'mono',
                })}
              >
                {digit}
              </span>

              {/* Tooltip on hover */}
              {isHovered && (
                <div
                  className={css({
                    position: 'absolute',
                    bottom: '100%',
                    left: '50%',
                    transform: 'translateX(-50%)',
                    mb: 1,
                    px: 2,
                    py: 1,
                    bg: 'gray.800',
                    color: 'gray.100',
                    fontSize: 'xs',
                    borderRadius: 'md',
                    whiteSpace: 'nowrap',
                    zIndex: 10,
                    boxShadow: 'lg',
                    border: '1px solid',
                    borderColor: 'gray.700',
                  })}
                >
                  <span className={css({ fontWeight: 'bold' })}>{count}</span>
                  <span className={css({ color: 'gray.400' })}> images</span>
                  {isInsufficient && (
                    <span className={css({ color: 'red.400', ml: 1 })}>
                      (need {minThreshold - count} more)
                    </span>
                  )}
                </div>
              )}
            </div>
          )
        })}
      </div>
    </div>
  )
}

export default DistributionHeatmap