import { MarkupLabel } from "MarkupTypes";
import { useMemo } from "react";
import { useQueryImageClassificationMarkupsQuery } from "../../../../../../ts-clients/query";
import { VectorMarkup, VectorMarkupType } from "../types";

const getMode = (
  predicted: boolean,
  userAnnotated: boolean
): VectorMarkupType => {
  if (userAnnotated) {
    return "userAnnotated";
  }
  if (predicted) {
    return "predicted";
  }
  return "unmarked";
};

export default function useLoadVectorData(
  imageId: string,
  networkId: string,
  thresholdForActiveLabel: number,
  labelList: MarkupLabel[]
) {
  const { data: markupData } = useQueryImageClassificationMarkupsQuery({
    variables: {
      imageIds: [imageId],
      networkId,
    },
    fetchPolicy: "no-cache",
  });

  const lastPredictionTime =
    (markupData?.queryImage ? markupData?.queryImage[0] : null)?.predictions
      .map((p) => p.createdAt)
      .reduce((prev, cur) => (cur > prev ? cur : prev), "") ?? "";

  const predictions = useMemo(
    () =>
      (markupData?.queryImage || [])[0]?.predictions.flatMap((m) =>
        m.__typename === "ClassificationPrediction" ? m.predictions : []
      ) ?? [],
    [markupData?.queryImage]
  );

  const markup = useMemo(
    () =>
      (markupData?.queryImage || [])[0]?.markups.flatMap((m) =>
        m.__typename === "ClassificationMarkup" ? m : []
      )[0],
    [markupData?.queryImage]
  );

  const probabilityVector = useMemo(
    () =>
      predictions
        ? predictions.map((d) => {
            return {
              labelId: d.label.id,
              probability: d.probability,
              name: d.label.name,
              color: d.label.color,
            };
          })
        : [],
    [predictions]
  );

  const markupId = markup?.id ?? "";

  const defaultLabelIds: string[] = useMemo(() => {
    if ((predictions?.length ?? 0) > 0) {
      return predictions.map((m) =>
        m.probability > thresholdForActiveLabel / 100 ? m.label.id : ""
      );
    }
    return (
      markup?.annotations
        .map((s) => (s.value === 1 ? s.label.id : ""))
        .filter((x) => x !== "") ?? []
    );
  }, [predictions, markup?.annotations, thresholdForActiveLabel]);

  const loadedVectorList: VectorMarkup[] = useMemo(
    () =>
      labelList.map(
        (label): VectorMarkup => ({
          ...label,
          active: defaultLabelIds.includes(label.id),
          mode: getMode(
            probabilityVector.find((p) => p.labelId === label.id) !== undefined,
            markup?.annotations.find((a) => a.label.id === label.id) !==
              undefined
          ),
          probability:
            (probabilityVector.find((p) => p.labelId === label.id)
              ?.probability ?? 0) * 100,
          created: 0,
        })
      ),
    [defaultLabelIds, labelList, markup?.annotations, probabilityVector]
  );

  return {
    loadedVectorList,
    markupId,
    lastPredictionTime,
    subset: markup?.subset ?? "unknown-subset",
  };
}
