diff --git a/ui/src/components/modals/add-new-model/modal-context-provider.jsx b/ui/src/components/modals/add-new-model/modal-context-provider.jsx
index 6762e0c..6aba9bb 100644
--- a/ui/src/components/modals/add-new-model/modal-context-provider.jsx
+++ b/ui/src/components/modals/add-new-model/modal-context-provider.jsx
@@ -1,4 +1,4 @@
-import { DataTypeEnum, ModelTypeEnum } from '@State/models/constants';
+import { DataTypeEnum } from '@State/models/constants';
import useFormbit from '@radicalbit/formbit';
import {
createContext,
@@ -18,7 +18,7 @@ function ModalContextProvider({ children }) {
const [isMaximize, setIsMaximize] = useState(false);
const useFormbitStepOne = useFormbit({
- initialValues: { modelType: ModelTypeEnum.BINARY_CLASSIFICATION, dataType: DataTypeEnum.TABULAR },
+ initialValues: { dataType: DataTypeEnum.TABULAR },
yup: schemaStepOne,
});
diff --git a/ui/src/components/modals/add-new-model/step-four/form-fields.jsx b/ui/src/components/modals/add-new-model/step-four/form-fields.jsx
index f2f65be..7686c05 100644
--- a/ui/src/components/modals/add-new-model/step-four/form-fields.jsx
+++ b/ui/src/components/modals/add-new-model/step-four/form-fields.jsx
@@ -1,9 +1,9 @@
+import { ModelTypeEnum } from '@Src/store/state/models/constants';
import {
FormField,
Select,
Tooltip,
} from '@radicalbit/radicalbit-design-system';
-import { ModelTypeEnum } from '@Src/store/state/models/constants';
import { useModalContext } from '../modal-context-provider';
function Target() {
@@ -196,15 +196,18 @@ function Prediction() {
}
function Probability() {
- const { useFormbit } = useModalContext();
+ const { useFormbit, useFormbitStepOne } = useModalContext();
+
const {
form, error, write, remove,
} = useFormbit;
-
const probabilities = useGetProbabilities();
const predictionName = form?.prediction?.name;
const value = form?.predictionProba?.name;
+ const { form: formStepOne } = useFormbitStepOne;
+ const { modelType } = formStepOne;
+
const handleOnChange = (val) => {
if (val === undefined) {
remove('predictionProba');
@@ -218,6 +221,14 @@ function Probability() {
}
};
+ if (modelType === ModelTypeEnum.REGRESSION) {
+ return (
+
+
+
+ );
+ }
+
return (
{
@@ -295,10 +306,10 @@ const useGetPredictions = () => {
return form.outputs.filter(({ type }) => predictionValidTypes[modelType].includes(type));
};
-const binaryClassificationProbabilityValidTypes = {
+const probabilityValidTypes = {
[ModelTypeEnum.BINARY_CLASSIFICATION]: ['float', 'double'],
[ModelTypeEnum.MULTI_CLASSIFICATION]: ['float', 'double'],
- [ModelTypeEnum.REGRESSION]: ['float', 'double'],
+ [ModelTypeEnum.REGRESSION]: [],
};
const useGetProbabilities = () => {
const { useFormbitStepOne, useFormbit } = useModalContext();
@@ -307,7 +318,7 @@ const useGetProbabilities = () => {
const { form: formStepOne } = useFormbitStepOne;
const { modelType } = formStepOne;
- return form.outputs.filter(({ type }) => binaryClassificationProbabilityValidTypes[modelType].includes(type));
+ return form.outputs.filter(({ type }) => probabilityValidTypes[modelType].includes(type));
};
const timestampValidTypes = {
diff --git a/ui/src/components/modals/add-new-model/step-one/form-fields.jsx b/ui/src/components/modals/add-new-model/step-one/form-fields.jsx
index 5839be5..e17dd8f 100644
--- a/ui/src/components/modals/add-new-model/step-one/form-fields.jsx
+++ b/ui/src/components/modals/add-new-model/step-one/form-fields.jsx
@@ -39,9 +39,22 @@ function Name() {
}
function ModelType() {
+ const { useFormbit } = useModalContext();
+ const { form, write } = useFormbit;
+
+ const handleOnChange = (value) => {
+ write('modelType', value);
+ };
+
return (
- {ModelTypeEnumLabel[ModelTypeEnum.BINARY_CLASSIFICATION]}
+
);
}