import Konva from "konva";
import { ImageSize, Vector2d } from "MarkupTypes";
import { useCallback, useMemo, useState } from "react";

export default function useZoom(
  stage: Konva.Stage | null,
  canvasSize: ImageSize,
  imageSize: ImageSize
) {
  const [scale, setScale] = useState(1);

  const initialZoom = Math.min(
    canvasSize.width / imageSize.width,
    canvasSize.height / imageSize.height
  );

  const initialOffSet = useMemo(
    () => ({
      x: Math.ceil(
        Math.abs(canvasSize.width - imageSize.width * initialZoom) / 2
      ),
      y: Math.ceil(
        Math.abs(canvasSize.height - imageSize.height * initialZoom) / 2
      ),
    }),
    [
      canvasSize.height,
      canvasSize.width,
      imageSize.height,
      imageSize.width,
      initialZoom,
    ]
  );

  const applyToStage = useCallback(
    (zoom: number, offset: Vector2d) => {
      if (stage === null) {
        return;
      }
      setScale(zoom);
      stage.scale({ x: zoom, y: zoom });
      stage.position(offset);
      stage.batchDraw();
    },
    [stage]
  );

  const resetZoom = useCallback(
    () => applyToStage(initialZoom, initialOffSet),
    [applyToStage, initialOffSet, initialZoom]
  );

  const scaleStage = useCallback(
    (scaleBy: number): Vector2d => {
      if (stage === null) {
        return { x: 1, y: 1 };
      }

      const cursor = stage.getPointerPosition();
      const oldScale = stage.scaleX();

      if (cursor === null) {
        return { x: oldScale, y: oldScale };
      }

      const transformation = stage.getTransform();

      const relativeCursor = {
        x: (cursor.x - transformation.m[4]) / transformation.m[0],
        y: (cursor.y - transformation.m[5]) / transformation.m[3],
      };

      const newScale = oldScale * scaleBy;

      const stagePosition = stage.position();
      const scale2d = { x: newScale, y: newScale };

      const position = {
        x:
          stagePosition.x -
          (relativeCursor.x - relativeCursor.x / scaleBy) * newScale,
        y:
          stagePosition.y -
          (relativeCursor.y - relativeCursor.y / scaleBy) * newScale,
      };

      applyToStage(newScale, position);
      return scale2d;
    },
    [applyToStage, stage]
  );

  const result = useMemo(
    () => ({ resetZoom, scaleStage, scale }),
    [resetZoom, scale, scaleStage]
  );

  return result;
}
