import {
  PipelineModuleNetworkImageClassificationProps as Props,
  BaseParams,
} from "ModuleTypes";
import { useState } from "react";
import {
  useAddPipelineModuleNetworkImageClassificationMutation,
  AddPipelineModuleNetworkImageClassificationInput,
} from "../../ts-clients/command";
import { pipelineModuleNetworkImageClassificationDefaults } from "./defaults";
import { blankNetworkModuleMagicId } from "../../features/networks3/AddNetwork/RealAddNetwork/blankNetwork";

export default function useCreatePipelineModuleCCChecker() {
  const [moduleData, setModuleData] = useState<Props>(
    pipelineModuleNetworkImageClassificationDefaults()
  );
  const [m] = useAddPipelineModuleNetworkImageClassificationMutation();

  const create = async (datasetId: string | null) => {
    const filteredModuleData = moduleData;
    filteredModuleData.classLabels = filteredModuleData.classLabels.filter(
      (c) => c.name !== ""
    );

    if (
      filteredModuleData.pretrainedNetworkModuleId === blankNetworkModuleMagicId
    ) {
      filteredModuleData.pretrainedNetworkModuleId = null;
    }

    const input: AddPipelineModuleNetworkImageClassificationInput = !datasetId
      ? {
          ...filteredModuleData,
          datasetId: undefined,
          selectedPipelineModuleOutputIDs:
            filteredModuleData.selectedPipelineModuleOutputIDs.length > 0
              ? filteredModuleData.selectedPipelineModuleOutputIDs
              : undefined,
        }
      : {
          ...filteredModuleData,
          datasetId,
        };

    const result = await m({
      variables: {
        input,
      },
    });

    return (
      result.data?.addPipelineModuleNetworkImageClassification?.datasetId ?? ""
    );
  };

  const setBaseParams = ({ moduleDescription, moduleName }: BaseParams) => {
    setModuleData((md) => ({ ...md, moduleName, moduleDescription }));
  };

  return { create, moduleData, setModuleData, setBaseParams };
}
