import { first } from "lodash";
import { useEffect } from "react";
import {
  useQueryNetworkDetailsForDatasetQuery,
  DenkjobType,
  DenkjobState,
} from "../../../../ts-clients/query";

import { formatDate } from "../../../../utils/helpers";
import { DenkBoxNetworkState } from "./types";

const isRunning = (d: DenkjobState | null) =>
  d !== null &&
  [DenkjobState.Stopping, DenkjobState.Starting, DenkjobState.Running].includes(
    d
  );

export default function useNetworkDetails(
  datasetId?: string,
  networkId?: string,
  pollInterval?: number,
  skip?: boolean
) {
  const { data, refetch } = useQueryNetworkDetailsForDatasetQuery({
    variables: {
      filter: datasetId
        ? {
            id: { eq: datasetId },
          }
        : undefined,
      networkFilter: networkId ? { id: { eq: networkId } } : undefined,
    },
    fetchPolicy: "cache-and-network",
    pollInterval,
    skip,
  });

  useEffect(() => {
    if (pollInterval) {
      refetch();
    }
  }, [pollInterval, refetch]);

  const datasets = data?.queryDataset ?? [];

  const networkModules = (
    datasets?.map(
      (d) =>
        d?.pipeline?.modules?.flatMap((d) => {
          if (
            d &&
            (d.__typename === "PipelineModuleNetworkImageSegmentation" ||
              d.__typename === "PipelineModuleNetworkImageClassification" ||
              d.__typename === "PipelineModuleNetworkImageObjectDetection")
          ) {
            return d;
          }
          return [];
        }) ?? []
    ) ?? []
  ).flat(1);

  const networkModule = networkModules.length > 0 ? networkModules[0] : null;

  const networkType = networkModule?.__typename;

  const denkjobs = networkModule?.denkjobs || [];

  const metrics = networkModule?.snapshots.length
    ? [...networkModule?.snapshots[0].metrics].sort((a, b) => {
        if (!a.label && !!b.label) return -1;
        if (a.label && !b.label) return 1;
        if (!!a.label && !!b.label) {
          return a.label.idx - b.label.idx;
        }
        return 0;
      })
    : [];

  const trainingCount = networkModule?.trainingMarkups?.count ?? 0;
  const validationCount = networkModule?.validationMarkups?.count ?? 0;

  const networkExists = metrics && metrics.length > 0;

  const networkState: DenkBoxNetworkState = {
    testing: null,
    prediction: null,
    training: null,
    runningSince: null,
  };

  denkjobs.forEach((denkjob) => {
    if (denkjob?.jobType === DenkjobType.DenKpredictor) {
      networkState.prediction = denkjob.state ?? null;
      networkState.runningSince = formatDate(denkjob.createdAt).timeAgo;
    }
    if (denkjob?.jobType === DenkjobType.DenKtester) {
      networkState.testing = denkjob.state ?? null;
      networkState.runningSince = formatDate(denkjob.createdAt).timeAgo;
    }
    if (denkjob?.jobType === DenkjobType.DenKtrainer) {
      networkState.training = denkjob.state ?? null;
      networkState.runningSince = formatDate(denkjob.createdAt).timeAgo;
    }
  });

  const testingRunning =
    networkState !== undefined && isRunning(networkState.testing);

  const trainingRunning =
    networkState !== undefined && isRunning(networkState.training);

  const predictionRunning =
    networkState !== undefined && isRunning(networkState.prediction);

  const hasSufficientData = trainingCount >= 10 && validationCount >= 2;

  const classLabels = networkModule?.classLabels || [];

  const stats = {
    trainingCount,
    validationCount,
    trainingRunning,
    testingRunning,
    predictionRunning,
    hasSufficientData,
    networkExists,
  };

  const lastUpdated = first(networkModule?.snapshots)?.updatedAt;

  return {
    classLabels,
    metrics,
    networkState,
    stats,
    downloadURL: first(networkModule?.snapshots)?.onnxBinary.url,
    networkType,
    id: networkModule?.id,
    name: networkModule?.name,
    description: networkModule?.description,
    trainingCount: networkModule?.trainingMarkups?.count || 0,
    validationCount: networkModule?.validationMarkups?.count || 0,
    snapshotCount: networkModule?.snapshotsAggregate?.count || 0,
    lastSnapshotTime: formatDate(lastUpdated),
  };
}
export type NetworkDetailsType = ReturnType<typeof useNetworkDetails>;
export type NetworkDetailsMetricType = NetworkDetailsType["metrics"];
