diff --git a/package-lock.json b/package-lock.json
index 715eb9f9..80ad739d 100644
--- a/package-lock.json
+++ b/package-lock.json
@@ -8341,9 +8341,9 @@
}
},
"node_modules/vite": {
- "version": "4.5.3",
- "resolved": "https://registry.npmjs.org/vite/-/vite-4.5.3.tgz",
- "integrity": "sha512-kQL23kMeX92v3ph7IauVkXkikdDRsYMGTVl5KY2E9OY4ONLvkHf04MDTbnfo6NKxZiDLWzVpP5oTa8hQD8U3dg==",
+ "version": "4.5.5",
+ "resolved": "https://registry.npmjs.org/vite/-/vite-4.5.5.tgz",
+ "integrity": "sha512-ifW3Lb2sMdX+WU91s3R0FyQlAyLxOzCSCP37ujw0+r5POeHPwe6udWVIElKQq8gk3t7b8rkmvqC6IHBpCff4GQ==",
"dependencies": {
"esbuild": "^0.18.10",
"postcss": "^8.4.27",
@@ -14557,9 +14557,9 @@
}
},
"vite": {
- "version": "4.5.3",
- "resolved": "https://registry.npmjs.org/vite/-/vite-4.5.3.tgz",
- "integrity": "sha512-kQL23kMeX92v3ph7IauVkXkikdDRsYMGTVl5KY2E9OY4ONLvkHf04MDTbnfo6NKxZiDLWzVpP5oTa8hQD8U3dg==",
+ "version": "4.5.5",
+ "resolved": "https://registry.npmjs.org/vite/-/vite-4.5.5.tgz",
+ "integrity": "sha512-ifW3Lb2sMdX+WU91s3R0FyQlAyLxOzCSCP37ujw0+r5POeHPwe6udWVIElKQq8gk3t7b8rkmvqC6IHBpCff4GQ==",
"requires": {
"esbuild": "^0.18.10",
"fsevents": "~2.3.2",
diff --git a/src/components/GraphVisualisation/PointPreview.jsx b/src/components/Common/PointPreview.jsx
similarity index 92%
rename from src/components/GraphVisualisation/PointPreview.jsx
rename to src/components/Common/PointPreview.jsx
index 0313025d..c4e717fb 100644
--- a/src/components/GraphVisualisation/PointPreview.jsx
+++ b/src/components/Common/PointPreview.jsx
@@ -11,7 +11,6 @@ const PointPreview = ({ point }) => {
const [loading] = React.useState(false);
const conditions = [];
const payloadSchema = {};
- const onConditionChange = () => {};
if (!point) {
return null;
@@ -48,7 +47,6 @@ const PointPreview = ({ point }) => {
@@ -67,7 +65,7 @@ const PointPreview = ({ point }) => {
}}
/>
-
+
>
)}
diff --git a/src/components/FilterEditorWindow/config/RequestFromCode.js b/src/components/FilterEditorWindow/config/RequestFromCode.js
deleted file mode 100644
index de2354e3..00000000
--- a/src/components/FilterEditorWindow/config/RequestFromCode.js
+++ /dev/null
@@ -1,165 +0,0 @@
-import { bigIntJSON } from '../../../common/bigIntJSON';
-import { axiosInstance } from '../../../common/axios';
-
-function parseDataToRequest(reqBody) {
- // Validate color_by
- if (reqBody.color_by) {
- const colorBy = reqBody.color_by;
-
- if (typeof colorBy === 'string') {
- // Parse into payload variant
- reqBody.color_by = {
- payload: colorBy,
- };
- } else {
- // Check we only have one of the options: payload, or discover_score
- const options = [colorBy.payload, colorBy.discover_score];
- const optionsCount = options.filter((option) => option).length;
- if (optionsCount !== 1) {
- return {
- reqBody: reqBody,
- error: '`color_by`: Only one of `payload`, or `discover_score` can be used',
- };
- }
-
- // Put search arguments in main request body
- if (colorBy.discover_score) {
- reqBody = {
- ...reqBody,
- ...colorBy.discover_score,
- };
- }
- }
- }
-
- // Set with_vector name
- if (reqBody.vector_name) {
- reqBody.with_vector = [reqBody.vector_name];
- return {
- reqBody: reqBody,
- error: null,
- };
- } else if (!reqBody.vector_name) {
- reqBody.with_vector = true;
- return {
- reqBody: reqBody,
- error: null,
- };
- }
-}
-export async function requestFromCode(dataRaw, collectionName) {
- const data = parseDataToRequest(dataRaw);
- // Sending request
- const colorBy = data.reqBody.color_by;
- if (colorBy?.payload) {
- return await actionFromCode(collectionName, data, 'scroll');
- }
- if (colorBy?.discover_score) {
- return discoverFromCode(collectionName, data);
- }
- return await actionFromCode(collectionName, data, 'scroll');
-}
-
-async function actionFromCode(collectionName, data, action) {
- try {
- const response = await axiosInstance({
- method: 'POST',
- url: `collections/${collectionName}/points/${action || 'scroll'}`,
- data: data.reqBody,
- });
- response.data.color_by = data.reqBody.color_by;
- response.data.vector_name = data.reqBody.vector_name;
- response.data.result.points = response.data.result.points.filter((point) => Object.keys(point.vector).length > 0);
- return {
- data: response.data,
- error: null,
- };
- } catch (err) {
- return {
- data: null,
- error: err.response?.data?.status ? err.response?.data?.status : err,
- };
- }
-}
-
-async function discoverFromCode(collectionName, data) {
- // Do 20/80 split. 20% of the points will be returned with the query
- // and 80 % will be returned with random sampling
- const queryLimit = Math.floor(data.reqBody.limit * 0.2);
- const randomLimit = data.reqBody.limit - queryLimit;
- data.reqBody.limit = queryLimit;
- data.reqBody.with_payload = true;
-
- const queryResponse = await actionFromCode(collectionName, data, 'discover');
- if (queryResponse.error) {
- return {
- data: null,
- error: queryResponse.error,
- };
- }
-
- // Add tag to know which points were returned by the query
- queryResponse.data.result = queryResponse.data.result.map((point) => ({
- ...point,
- from_query: true,
- }));
-
- // Get "random" points ids.
- // There is no sampling endpoint in Qdrant yet, so for now we just scroll excluding the previous results
- const idsToExclude = queryResponse.data.result.map((point) => point.id);
-
- const originalFilter = data.reqBody.filter;
- const mustNotFilter = [{ has_id: idsToExclude }];
- data.reqBody.filter = originalFilter || {};
- data.reqBody.filter.must_not = mustNotFilter.concat(data.reqBody.filter.must_not ?? []);
-
- data.reqBody.limit = randomLimit;
- const randomResponse = await actionFromCode(collectionName, data, 'scroll');
- if (randomResponse.error) {
- return {
- data: null,
- error: randomResponse.error,
- };
- }
-
- // Then score these random points
- const idsToInclude = randomResponse.data.result.points.map((point) => point.id);
- const mustFilter = [{ has_id: idsToInclude }];
- data.reqBody.filter = originalFilter || {};
- data.reqBody.filter.must = mustFilter.concat(data.reqBody.filter.must || []);
-
- const scoredRandomResponse = await actionFromCode(collectionName, data, 'discover');
- if (scoredRandomResponse.error) {
- return {
- data: null,
- error: scoredRandomResponse.error,
- };
- }
-
- // Concat both results
- const points = queryResponse.data.result.concat(scoredRandomResponse.data.result);
-
- return {
- data: {
- ...queryResponse.data,
- result: {
- points: points,
- },
- },
- error: null,
- };
-}
-
-export function codeParse(codeText) {
- // Parse JSON
- if (codeText) {
- try {
- return bigIntJSON.parse(codeText);
- } catch (e) {
- return {
- reqBody: codeText,
- error: 'Fix the position brackets to run & check the json',
- };
- }
- }
-}
diff --git a/src/components/FilterEditorWindow/index.jsx b/src/components/FilterEditorWindow/index.jsx
index 89044508..02956258 100644
--- a/src/components/FilterEditorWindow/index.jsx
+++ b/src/components/FilterEditorWindow/index.jsx
@@ -6,7 +6,7 @@ import { useClient } from '../../context/client-context';
import { useTheme } from '@mui/material/styles';
import { autocomplete } from './config/Autocomplete';
import { useSnackbar } from 'notistack';
-import { codeParse } from './config/RequestFromCode';
+import { bigIntJSON } from '../../common/bigIntJSON';
import './editor.css';
import EditorCommon from '../EditorCommon';
@@ -30,6 +30,20 @@ const CodeEditorWindow = ({ onChange, code, onChangeResult, customRequestSchema,
[]
);
+ function codeParse(codeText) {
+ // Parse JSON
+ if (codeText) {
+ try {
+ return bigIntJSON.parse(codeText);
+ } catch (e) {
+ return {
+ reqBody: codeText,
+ error: 'Fix the position brackets to run & check the json',
+ };
+ }
+ }
+ }
+
function onRun(codeText) {
const data = codeParse(codeText);
if (data.error) {
diff --git a/src/components/Points/DataGridList.jsx b/src/components/Points/DataGridList.jsx
index c58905f4..847e210e 100644
--- a/src/components/Points/DataGridList.jsx
+++ b/src/components/Points/DataGridList.jsx
@@ -85,7 +85,9 @@ export const DataGridList = function ({ data = {}, specialCases = {}, onConditio
if (conditions.find((c) => c.key === filter.key && c.value === filter.value)) {
return;
}
- onConditionChange([...conditions, filter]);
+ if (typeof onConditionChange === 'function') {
+ onConditionChange([...conditions, filter]);
+ }
}}
>
diff --git a/src/components/Points/PointVectors.jsx b/src/components/Points/PointVectors.jsx
index 9bec5f98..4042d12b 100644
--- a/src/components/Points/PointVectors.jsx
+++ b/src/components/Points/PointVectors.jsx
@@ -34,6 +34,7 @@ const Vectors = memo(function Vectors({ point, onConditionChange }) {
return (
+ Vectors:
{Object.keys(vectors).map((key) => {
return (
@@ -81,13 +82,17 @@ const Vectors = memo(function Vectors({ point, onConditionChange }) {
-
+ {typeof onConditionChange !== 'function' ? null : (
+
+ )}
);
@@ -98,7 +103,7 @@ const Vectors = memo(function Vectors({ point, onConditionChange }) {
Vectors.propTypes = {
point: PropTypes.object.isRequired,
- onConditionChange: PropTypes.func.isRequired,
+ onConditionChange: PropTypes.func,
};
export default Vectors;
diff --git a/src/components/VisualizeChart/ImageTooltip.jsx b/src/components/VisualizeChart/ImageTooltip.jsx
deleted file mode 100644
index 2b0d4e2f..00000000
--- a/src/components/VisualizeChart/ImageTooltip.jsx
+++ /dev/null
@@ -1,100 +0,0 @@
-import { toFont } from 'chart.js/helpers';
-import React from 'react';
-import { createRoot } from 'react-dom/client';
-import { flushSync } from 'react-dom';
-
-const DEFAULT_BORDER_COLOR = '#333333';
-
-export function imageTooltip(context) {
- // Tooltip Element
- let tooltipEl = document.getElementById('chartjs-tooltip');
-
- // Create element on first render
- if (!tooltipEl) {
- tooltipEl = document.createElement('div');
- tooltipEl.id = 'chartjs-tooltip';
- tooltipEl.appendChild(document.createElement('table'));
- document.body.appendChild(tooltipEl);
- }
-
- // Hide if no tooltip
- const tooltipModel = context.tooltip;
- if (tooltipModel.opacity === 0) {
- tooltipEl.style.opacity = 0;
- return;
- }
-
- // Set caret Position
- tooltipEl.classList.remove('above', 'below', 'no-transform');
- if (tooltipModel.yAlign) {
- tooltipEl.classList.add(tooltipModel.yAlign);
- } else {
- tooltipEl.classList.add('no-transform');
- }
-
- // Set content
- if (tooltipModel.body) {
- const bodyLines = tooltipModel.body[0].lines;
-
- const imageSrc = tooltipModel.dataPoints[0].dataset.data[tooltipModel.dataPoints[0].dataIndex].point.payload?.image;
-
- const borderColor = tooltipModel.labelColors[0]?.borderColor || DEFAULT_BORDER_COLOR;
-
- const child = (
-
- {imageSrc && (
-
- )}
-
- {bodyLines.map((line, i) => (
-
- {line}
-
- ))}
-
-
- );
-
- // Render html to insert in tooltip
- const tableRoot = tooltipEl.querySelector('table');
- const root = createRoot(tableRoot);
- flushSync(() => {
- root.render(child);
- });
- }
-
- const position = context.chart.canvas.getBoundingClientRect();
- const bodyFont = toFont(tooltipModel.options.bodyFont);
-
- // Display, position, and set styles for font
- tooltipEl.style.opacity = 1;
- tooltipEl.style.position = 'absolute';
- tooltipEl.style.left = position.left + window.scrollX + tooltipModel.caretX + 'px';
- tooltipEl.style.top = position.top + window.scrollY + tooltipModel.caretY + 'px';
- tooltipEl.style.font = bodyFont.string;
- tooltipEl.style.padding = tooltipModel.padding + 'px ' + tooltipModel.padding + 'px';
- tooltipEl.style.pointerEvents = 'none';
-}
diff --git a/src/components/VisualizeChart/ViewPointModal.jsx b/src/components/VisualizeChart/ViewPointModal.jsx
deleted file mode 100644
index fa2063a0..00000000
--- a/src/components/VisualizeChart/ViewPointModal.jsx
+++ /dev/null
@@ -1,72 +0,0 @@
-import React from 'react';
-import PropTypes from 'prop-types';
-import { Box, Button, Dialog, DialogContent, DialogActions, DialogTitle, Paper, Typography } from '@mui/material';
-import { alpha } from '@mui/material';
-import { useTheme } from '@mui/material/styles';
-import { DataGridList } from '../Points/DataGridList';
-import { CopyButton } from '../Common/CopyButton';
-import { bigIntJSON } from '../../common/bigIntJSON';
-
-const ViewPointModal = (props) => {
- const theme = useTheme();
- const { openViewPoints, setOpenViewPoints, viewPoints } = props;
-
- return (
- <>
-
- >
- );
-};
-
-ViewPointModal.propTypes = {
- openViewPoints: PropTypes.bool.isRequired,
- setOpenViewPoints: PropTypes.func.isRequired,
- viewPoints: PropTypes.array.isRequired,
-};
-
-export default ViewPointModal;
diff --git a/src/components/VisualizeChart/index.jsx b/src/components/VisualizeChart/index.jsx
index e88c7e7f..ed52430e 100644
--- a/src/components/VisualizeChart/index.jsx
+++ b/src/components/VisualizeChart/index.jsx
@@ -1,99 +1,182 @@
import Chart from 'chart.js/auto';
-import chroma from 'chroma-js';
import get from 'lodash/get';
import { useSnackbar } from 'notistack';
import PropTypes from 'prop-types';
-import React, { useEffect, useState } from 'react';
-import ViewPointModal from './ViewPointModal';
-import { imageTooltip } from './ImageTooltip';
-import { bigIntJSON } from '../../common/bigIntJSON';
+import React, { useEffect } from 'react';
+import { generateColorBy, generateSizeBy } from './renderBy';
+import { useTheme } from '@mui/material/styles';
-const SCORE_GRADIENT_COLORS = ['#EB5353', '#F9D923', '#36AE7C'];
+// Dark red
+const LIGHT_SELECTOR_COLOR = 'rgba(255, 0, 0, 0.5)';
+// White
+const DARK_SELECTOR_COLOR = 'rgba(245, 245, 245, 0.8)';
-const VisualizeChart = ({ scrollResult, algorithm = null }) => {
+// Transparent color for points
+const DEFAULT_BORDER_COLOR = 'rgba(0, 0, 0, 0)';
+
+function intoDatasets(
+ points, // array of original points, which contain payloads
+ data, // list of compressed coordinates
+ colors, // list of colors for each point to be displayed
+ sizes, // list of sizes for each point to be displayed
+ groupBy = null // payload field to group by
+) {
+ const defaultConfig = {
+ pointHitRadius: 1,
+ hoverRadius: 7,
+ };
+
+ if (!groupBy) {
+ // No grouping
+ return [
+ {
+ label: 'Data',
+ data,
+ offsets: Array.from({ length: data.length }, (_, i) => i),
+ pointBackgroundColor: [...colors],
+ // Use transparent border color for points
+ pointBorderColor: Array.from({ length: colors.length }, () => DEFAULT_BORDER_COLOR),
+ ...defaultConfig,
+ },
+ ];
+ }
+
+ const groups = {};
+
+ points.forEach((point, index) => {
+ let group = get(point.payload, groupBy) + ''; // Convert to string, even if it's an o
+
+ if (!group) {
+ // If specified field is not present in the payload, fallback to 'Unknown'
+ group = 'Unknown';
+ }
+
+ if (!groups[group]) {
+ groups[group] = {
+ label: group,
+ data: [],
+ offsets: [],
+ pointBackgroundColor: [],
+ pointBorderColor: [],
+ pointRadius: [],
+ ...defaultConfig,
+ };
+ }
+
+ groups[group].data.push(data[index]);
+ groups[group].offsets.push(index);
+ groups[group].pointBackgroundColor.push(colors[index]);
+ groups[group].pointBorderColor.push(DEFAULT_BORDER_COLOR);
+ groups[group].pointRadius.push(sizes[index]);
+ });
+
+ // Convert groups object to array, and sort by label
+ return Object.values(groups).sort((a, b) => a.label.localeCompare(b.label));
+}
+
+const VisualizeChart = ({
+ requestResult, // Raw output of the request from qdrant client
+ visualizationParams, // Parameters, as specified by the user in the input editor
+ activePoint, // currently selected point (with hover)
+ setActivePoint, // callback to set new active point
+}) => {
const { enqueueSnackbar } = useSnackbar();
- const [openViewPoints, setOpenViewPoints] = useState(false);
- const [viewPoints, setViewPoint] = useState([]);
+
+ // Id of the currently selected point
+ // Used to prevent multiple updates of the chart on hover
+ // And for switching colors of the selected point
+ let selectedPointLocation = null;
+
+ const theme = useTheme();
+
+ function getSelectionColor() {
+ return theme.palette.mode === 'light' ? LIGHT_SELECTOR_COLOR : DARK_SELECTOR_COLOR;
+ }
useEffect(() => {
- if (!scrollResult.data && !scrollResult.error) {
+ if (!requestResult.points) {
return;
}
- if (scrollResult.error) {
- enqueueSnackbar(`Visualization Unsuccessful, error: ${bigIntJSON.stringify(scrollResult.error)}`, {
- variant: 'error',
- });
+ const points = requestResult.points;
+ const colorBy = visualizationParams?.color_by;
- return;
- } else if (!scrollResult.data?.result?.points.length) {
- enqueueSnackbar(`Visualization Unsuccessful, error: No data returned`, {
- variant: 'error',
- });
- return;
- }
+ // Initialize data with random points in range [0, 1]
+ const data = points.map(() => ({
+ x: Math.random(),
+ y: Math.random(),
+ }));
- const dataset = [];
- const colorBy = scrollResult.data.color_by;
+ // This reference values should be used to rollback the color of the previously selected point
+ const pointColors = generateColorBy(points, colorBy);
+ const sizes = generateSizeBy(points);
- let labelby = null;
- if (colorBy?.payload) {
- labelby = colorBy.payload;
- // Color and label by payload field
- if (get(scrollResult.data.result?.points[0]?.payload, labelby) === undefined) {
- enqueueSnackbar(`Visualization Unsuccessful, error: Color by field ${labelby} does not exist`, {
- variant: 'error',
- });
- return;
+ const payloadField = typeof colorBy === 'string' ? colorBy : colorBy?.payload;
+ const useLegend = !!payloadField;
+
+ const datasets = intoDatasets(points, data, pointColors, sizes, payloadField);
+
+ const handlePointHover = (chart) => {
+ if (!chart.tooltip?._active) return;
+ if (chart.tooltip?._active.length === 0) return;
+
+ const lastActive = chart.tooltip._active.length - 1;
+
+ const selectedElement = chart.tooltip._active[lastActive];
+
+ const datasetIndex = selectedElement.datasetIndex;
+ const pointIndex = selectedElement.index;
+
+ const offsets = chart.data.datasets[datasetIndex].offsets;
+ const pointOffset = offsets[pointIndex];
+ const selectedPoint = points[pointOffset];
+
+ // Check if the same point is already selected
+ // To prevent recurrant updates of the chart
+ if (selectedPoint.id === activePoint?.id) {
+ selectedPointLocation = {
+ offset: pointOffset,
+ datasetIndex,
+ pointIndex,
+ };
+ return selectedPoint;
+ }
+ if (pointOffset === selectedPointLocation?.offset) {
+ return selectedPoint;
}
- scrollResult.data.labelByArrayUnique = [
- ...new Set(scrollResult.data.result?.points?.map((point) => get(point.payload, labelby))),
- ];
- scrollResult.data.labelByArrayUnique.forEach((label) => {
- dataset.push({
- label: label,
- data: [],
- });
- });
- } else if (colorBy?.discover_score) {
- // Color by discover score
- const scores = scrollResult.data.result?.points.map((point) => point.score);
- const minScore = Math.min(...scores);
- const maxScore = Math.max(...scores);
-
- const colorScale = chroma.scale(SCORE_GRADIENT_COLORS);
- const scoreColors = scores.map((score) => {
- const normalizedScore = (score - minScore) / (maxScore - minScore);
- return colorScale(normalizedScore).hex();
- });
- const pointRadii = scrollResult.data.result?.points.map((point) => {
- if (point.from_query) {
- return 4;
- } else {
- return 3;
- }
- });
+ const oldPointLocation = selectedPointLocation;
+
+ selectedPointLocation = {
+ offset: pointOffset,
+ datasetIndex,
+ pointIndex,
+ };
+
+ // Reset color of the previously selected point
+ if (oldPointLocation) {
+ const targetColor = pointColors[oldPointLocation.offset];
+ chart.data.datasets[oldPointLocation.datasetIndex].pointBackgroundColor[oldPointLocation.pointIndex] =
+ targetColor;
+
+ chart.data.datasets[oldPointLocation.datasetIndex].pointBorderColor[oldPointLocation.pointIndex] =
+ DEFAULT_BORDER_COLOR;
+ }
+
+ chart.data.datasets[datasetIndex].pointBackgroundColor[pointIndex] = getSelectionColor();
+ chart.data.datasets[datasetIndex].pointBorderColor[pointIndex] = getSelectionColor();
+
+ setActivePoint(selectedPoint);
+
+ chart.update();
+ return selectedPoint;
+ };
- dataset.push({
- label: 'Discover scores',
- pointBackgroundColor: scoreColors,
- pointBorderColor: scoreColors,
- pointRadius: pointRadii,
- data: [],
- });
- } else {
- // No special coloring
- dataset.push({
- label: 'Data',
- data: [],
- });
- }
const ctx = document.getElementById('myChart');
const myChart = new Chart(ctx, {
type: 'scatter',
data: {
- datasets: dataset,
+ datasets: datasets,
},
options: {
responsive: true,
@@ -109,55 +192,31 @@ const VisualizeChart = ({ scrollResult, algorithm = null }) => {
display: false,
},
},
+ interaction: {
+ mode: 'nearest', // Show tooltip for the nearest point
+ intersect: false, // Show even if not directly hovering over a point
+ },
plugins: {
tooltip: {
// only use custom tooltip if color by is not discover score
- enabled: !colorBy?.discover_score,
- external: (colorBy?.discover_score && imageTooltip) || undefined,
+ enabled: true,
usePointStyle: true,
+ position: 'nearest',
+ intersect: true,
callbacks: {
label: (context) => {
- const payload = bigIntJSON
- .stringify(context.dataset.data[context.dataIndex].point.payload, null, 1)
- .split('\n');
-
- if (colorBy?.discover_score) {
- const id = context.dataset.data[context.dataIndex].point.id;
- const score = context.dataset.data[context.dataIndex].point.score;
-
- return [`id: ${id}`, `score: ${score}`];
- } else {
- return payload;
- }
+ const selectedPoint = handlePointHover(context.chart);
+ if (!selectedPoint) return '';
+ const id = selectedPoint.id;
+ return `Point ${id}`;
},
},
},
legend: {
- display: !!labelby,
+ display: useLegend,
},
},
},
- plugins: [
- {
- id: 'myEventCatcher',
- beforeEvent(chart, args) {
- const event = args.event;
- if (event.type === 'click') {
- if (chart.tooltip._active.length > 0) {
- const activePoints = chart.tooltip._active.map((point) => {
- return {
- id: point.element.$context.raw.point.id,
- payload: point.element.$context.raw.point.payload,
- vector: point.element.$context.raw.point.vector,
- };
- });
- setViewPoint(activePoints);
- setOpenViewPoints(true);
- }
- }
- },
- },
- ],
});
const worker = new Worker(new URL('./worker.js', import.meta.url), {
@@ -170,19 +229,24 @@ const VisualizeChart = ({ scrollResult, algorithm = null }) => {
variant: 'error',
});
} else if (m.data.result && m.data.result.length > 0) {
- m.data.result.forEach((dataset, index) => {
+ const reducedPonts = m.data.result;
+
+ const datasets = intoDatasets(points, reducedPonts, pointColors, sizes, payloadField);
+
+ datasets.forEach((dataset, index) => {
myChart.data.datasets[index].data = dataset.data;
});
+
myChart.update();
} else {
enqueueSnackbar(`Visualization Unsuccessful, error: Unexpected Error Occured`, { variant: 'error' });
}
};
- if (scrollResult.data.result?.points?.length > 0) {
+ if (requestResult.points?.length > 0) {
worker.postMessage({
- ...scrollResult.data,
- algorithm: algorithm,
+ result: requestResult,
+ params: visualizationParams,
});
}
@@ -190,19 +254,20 @@ const VisualizeChart = ({ scrollResult, algorithm = null }) => {
myChart.destroy();
worker.terminate();
};
- }, [scrollResult]);
+ }, [requestResult]);
return (
<>
-
>
);
};
VisualizeChart.propTypes = {
- scrollResult: PropTypes.object.isRequired,
- algorithm: PropTypes.string,
+ requestResult: PropTypes.object.isRequired,
+ visualizationParams: PropTypes.object.isRequired,
+ activePoint: PropTypes.object,
+ setActivePoint: PropTypes.func,
};
export default VisualizeChart;
diff --git a/src/components/VisualizeChart/renderBy.js b/src/components/VisualizeChart/renderBy.js
new file mode 100644
index 00000000..7ab611ae
--- /dev/null
+++ b/src/components/VisualizeChart/renderBy.js
@@ -0,0 +1,97 @@
+import chroma from 'chroma-js';
+
+const SCORE_GRADIENT_COLORS = ['#EB5353', '#F9D923', '#36AE7C'];
+const BACKGROUND_COLOR = '#36A2EB';
+
+const PALLETE = [
+ '#3366CC',
+ '#DC3912',
+ '#FF9900',
+ '#109618',
+ '#990099',
+ '#3B3EAC',
+ '#0099C6',
+ '#DD4477',
+ '#66AA00',
+ '#B82E2E',
+ '#316395',
+ '#994499',
+ '#22AA99',
+ '#AAAA11',
+ '#6633CC',
+ '#E67300',
+ '#8B0707',
+ '#329262',
+ '#5574A6',
+ '#651067',
+];
+
+// const SELECTED_BORDER_COLOR = '#881177';
+
+function colorByPayload(payloadValue, colored) {
+ if (colored[payloadValue]) {
+ return colored[payloadValue];
+ }
+
+ const nextColorIndex = Object.keys(colored).length % PALLETE.length;
+
+ colored[payloadValue] = PALLETE[nextColorIndex];
+
+ return PALLETE[nextColorIndex];
+}
+
+// This function generates an array of colors for each point in the chart.
+// There are following options available for colorBy:
+//
+// - None: all points will have the same color
+// - typeof = "string": color points based on the source field
+// - {"payload": "field_name"}: color points based on the payload field
+// - {"discover_score": { ... } }: color points based on the discover score
+// - {"query": { ... }}: color points based on the query score
+
+export function generateColorBy(points, colorBy = null) {
+ // Points example:
+ // [
+ // { id: 0, payload: { field_name: 1 }, score: 0.5, vector: [0.1, 0.2, ....] },
+ // { id: 1, payload: { field_name: 2 }, score: 0.6, vector: [0.3, 0.4, ....] },
+ // ...
+ // ]
+
+ if (!colorBy) {
+ return Array.from({ length: points.length }, () => BACKGROUND_COLOR); // Default color
+ }
+
+ // If `colorBy` is a string, interpret as a field name
+ if (typeof colorBy === 'string') {
+ colorBy = { payload: colorBy };
+ }
+
+ if (colorBy.payload) {
+ const valuesToColor = {};
+
+ return points.map((point) => {
+ const payloadValue = point.payload[colorBy.payload];
+ if (!payloadValue) {
+ return BACKGROUND_COLOR;
+ }
+ return colorByPayload(payloadValue, valuesToColor);
+ });
+ }
+
+ if (colorBy.discover_score || colorBy.query) {
+ const scores = points.map((point) => point.score);
+ const minScore = Math.min(...scores);
+ const maxScore = Math.max(...scores);
+
+ const colorScale = chroma.scale(SCORE_GRADIENT_COLORS);
+ return scores.map((score) => {
+ const normalizedScore = (score - minScore) / (maxScore - minScore);
+ return colorScale(normalizedScore).hex();
+ });
+ }
+}
+
+export function generateSizeBy(points) {
+ // ToDo: Intoroduce size differentiation later
+ return points.map(() => 3);
+}
diff --git a/src/components/VisualizeChart/requestData.js b/src/components/VisualizeChart/requestData.js
new file mode 100644
index 00000000..7deba485
--- /dev/null
+++ b/src/components/VisualizeChart/requestData.js
@@ -0,0 +1,33 @@
+/* eslint-disable camelcase */
+export function requestData(
+ qdrantClient,
+ collectionName,
+ { limit, filter = null, vector_name = null, color_by = null }
+) {
+ // Based on the input parameters, we need to decide what kind of request we need to send
+ // By default we should do scroll request
+ // But if we have color_by field which uses query, it should be used instead
+
+ if (color_by?.query) {
+ const query = {
+ query: color_by.query,
+ limit: limit,
+ filter: filter,
+ with_vector: vector_name ? [vector_name] : true,
+ with_payload: true,
+ };
+
+ return qdrantClient.query(collectionName, query);
+ }
+
+ // It it's not a query, we should do a scroll request
+
+ const scrollQuery = {
+ limit: limit,
+ filter: filter,
+ with_vector: vector_name ? [vector_name] : true,
+ with_payload: true,
+ };
+
+ return qdrantClient.scroll(collectionName, scrollQuery);
+}
diff --git a/src/components/VisualizeChart/worker.js b/src/components/VisualizeChart/worker.js
index 27e3b3cf..8899d76f 100644
--- a/src/components/VisualizeChart/worker.js
+++ b/src/components/VisualizeChart/worker.js
@@ -4,113 +4,101 @@ import get from 'lodash/get';
const MESSAGE_INTERVAL = 200;
+function getVectorType(vector) {
+ if (Array.isArray(vector)) {
+ if (Array.isArray(vector[0])) {
+ return 'multivector';
+ }
+ return 'vector';
+ }
+ if (typeof vector === 'object') {
+ if (vector.indices) {
+ return 'sparse';
+ }
+ return 'named';
+ }
+ return 'unknown';
+}
+
self.onmessage = function (e) {
let now = new Date().getTime();
- const algorithm = e?.data?.algorithm || 'TSNE';
+ const params = e?.data?.params || {};
+
+ const algorithm = params.algorithm || 'TSNE';
- const data1 = e.data;
const data = [];
- if (data1?.result?.points?.length === 0) {
+ const points = e.data?.result?.points;
+ const vectorName = params.vector_name;
+
+ if (!points || points.length === 0) {
self.postMessage({
data: [],
error: 'No data found',
});
return;
- } else if (data1?.result?.points?.length === 1) {
+ }
+
+ if (points.length === 1) {
self.postMessage({
data: [],
error: 'cannot perform tsne on single point',
});
return;
- } else if (typeof data1?.result?.points[0].vector.length === 'number') {
- data1?.result?.points?.forEach((point) => {
- data.push(point.vector);
- });
- } else if (typeof data1?.result?.points[0].vector === 'object') {
- if (data1.vector_name === undefined) {
- self.postMessage({
- data: [],
- error: 'No vector name found, select a valid vector_name',
- });
- return;
- } else if (data1?.result?.points[0].vector[data1?.vector_name] === undefined) {
- self.postMessage({
- data: [],
- error: 'No vector found with name ' + data1?.vector_name,
- });
- return;
- } else if (data1?.result?.points[0].vector[data1?.vector_name]) {
- if (!Array.isArray(data1?.result?.points[0].vector[data1?.vector_name])) {
- self.postMessage({
- data: [],
- error: 'Vector visualization is not supported for sparse vector',
- });
- return;
- }
- data1?.result?.points?.forEach((point) => {
- data.push(point.vector[data1?.vector_name]);
- });
+ }
+
+ for (let i = 0; i < points.length; i++) {
+ if (!vectorName) {
+ // Work with default vector
+ data.push(points[i]?.vector);
} else {
+ // Work with named vector
+ data.push(get(points[i]?.vector, vectorName));
+ }
+ }
+
+ // Validate data
+
+ for (let i = 0; i < data.length; i++) {
+ const vector = data[i];
+ const vectorType = getVectorType(vector);
+
+ if (vectorType === 'vector') {
+ continue;
+ }
+
+ if (vectorType === 'named') {
self.postMessage({
data: [],
- error: 'Unexpected Error Occurred',
+ error: 'Please select a valid vector name, default vector is not defined',
});
return;
}
- } else {
+
self.postMessage({
data: [],
- error: 'Unexpected Error Occurred',
+ error: 'Vector visualization is not supported for vector type: ' + vectorType,
});
return;
}
+
if (data.length) {
const D = new druid[algorithm](data, {}); // ex params = { perplexity : 50,epsilon :5}
const next = D.generator(); // default = 500 iterations
- let i = {};
- for (i of next) {
+
+ let reducedPoints = [];
+ for (reducedPoints of next) {
if (Date.now() - now > MESSAGE_INTERVAL) {
now = Date.now();
- self.postMessage({ result: getDataset(data1, i), error: null });
+ self.postMessage({ result: getDataset(reducedPoints), error: null });
}
}
- self.postMessage({ result: getDataset(data1, i), error: null });
+ self.postMessage({ result: getDataset(reducedPoints), error: null });
}
};
-function getDataset(data, reducedPoint) {
- const dataset = [];
- const labelby = data.color_by?.payload;
- if (labelby) {
- data.labelByArrayUnique.forEach((label) => {
- dataset.push({
- label: label,
- data: [],
- });
- });
-
- data.result?.points?.forEach((point, index) => {
- const label = get(point.payload, labelby);
- dataset[data.labelByArrayUnique.indexOf(label)].data.push({
- x: reducedPoint[index][0],
- y: reducedPoint[index][1],
- point: point,
- });
- });
- } else {
- dataset.push({
- label: 'data',
- data: [],
- });
- data.result?.points?.forEach((point, index) => {
- dataset[0].data.push({
- x: reducedPoint[index][0],
- y: reducedPoint[index][1],
- point: point,
- });
- });
- }
- return dataset;
+function getDataset(reducedPoints) {
+ // Convert [[x1, y1], [x2, y2] ] to [ { x: x1, y: y1 }, { x: x2, y: y2 } ]
+ return reducedPoints.map((point) => ({ x: point[0], y: point[1] }));
}
diff --git a/src/pages/Graph.jsx b/src/pages/Graph.jsx
index 502cca3e..4ee6b6d5 100644
--- a/src/pages/Graph.jsx
+++ b/src/pages/Graph.jsx
@@ -6,7 +6,7 @@ import { useTheme } from '@mui/material/styles';
import { Panel, PanelGroup, PanelResizeHandle } from 'react-resizable-panels';
import GraphVisualisation from '../components/GraphVisualisation/GraphVisualisation';
import { useWindowResize } from '../hooks/windowHooks';
-import PointPreview from '../components/GraphVisualisation/PointPreview';
+import PointPreview from '../components/Common/PointPreview';
import CodeEditorWindow from '../components/FilterEditorWindow';
import { useClient } from '../context/client-context';
import { getFirstPoint } from '../lib/graph-visualization-helpers';
diff --git a/src/pages/Visualize.jsx b/src/pages/Visualize.jsx
index 1e518a4c..b7b244d1 100644
--- a/src/pages/Visualize.jsx
+++ b/src/pages/Visualize.jsx
@@ -7,10 +7,19 @@ import { Panel, PanelGroup, PanelResizeHandle } from 'react-resizable-panels';
import FilterEditorWindow from '../components/FilterEditorWindow';
import VisualizeChart from '../components/VisualizeChart';
import { useWindowResize } from '../hooks/windowHooks';
-import { requestFromCode } from '../components/FilterEditorWindow/config/RequestFromCode';
+import PointPreview from '../components/Common/PointPreview';
+import { useClient } from '../context/client-context';
+import { requestData } from '../components/VisualizeChart/requestData';
+import { useSnackbar } from 'notistack';
const query = `
+// Try me!
+
+{
+ "limit": 500
+}
+
// Specify request parameters to select data for visualization.
//
// Available parameters:
@@ -24,33 +33,14 @@ const query = `
// - 'color_by': specify score or payload field to use for coloring points.
// How to use:
//
-// "color_by": "field_name"
-//
-// or
-//
// "color_by": {
// "payload": "field_name"
// }
//
-// or
-//
-// "color_by": {
-// "discover_score": {
-// "target": 42,
-// "context": [{"positive": 1, "negative": 0}]
-// }
-// }
-//
// - 'vector_name': specify which vector to use for visualization
// if there are multiple.
//
// - 'algorithm': specify algorithm to use for visualization. Available options: 'TSNE', 'UMAP'.
-//
-// Minimal example:
-
-{
- "limit": 500
-}
`;
@@ -58,27 +48,34 @@ const defaultResult = {};
function Visualize() {
const theme = useTheme();
+ const { client: qdrantClient } = useClient();
const [code, setCode] = useState(query);
+
+ // Contains the raw output of the request of QdrantClient
const [result, setResult] = useState(defaultResult);
- const [algorithm, setAlgorithm] = useState('TSNE');
+ const [visualizationParams, setVisualizationParams] = useState({});
+ const { enqueueSnackbar } = useSnackbar();
// const [errorMessage, setErrorMessage] = useState(null); // todo: use or remove
const navigate = useNavigate();
const params = useParams();
const [visualizeChartHeight, setVisualizeChartHeight] = useState(0);
const VisualizeChartWrapper = useRef(null);
const { height } = useWindowResize();
+ const [activePoint, setActivePoint] = useState(null);
useEffect(() => {
setVisualizeChartHeight(height - VisualizeChartWrapper.current?.offsetTop);
}, [height, VisualizeChartWrapper]);
const onEditorCodeRun = async (data, collectionName) => {
- if (data?.algorithm) {
- setAlgorithm(data.algorithm);
- }
+ setVisualizationParams(data);
- const result = await requestFromCode(data, collectionName);
- setResult(result);
+ try {
+ const result = await requestData(qdrantClient, collectionName, data);
+ setResult(result);
+ } catch (e) {
+ enqueueSnackbar(`Request error: ${e.message}`, { variant: 'error' });
+ }
};
const filterRequestSchema = (vectorNames) => ({
@@ -110,8 +107,33 @@ function Visualize() {
},
color_by: {
description: 'Color points by this field',
- type: 'string',
- nullable: true,
+ anyOf: [
+ {
+ type: 'string', // Name of the field to use for coloring
+ },
+ {
+ description: 'field name',
+ type: 'object',
+ properties: {
+ payload: {
+ description: 'Name of the field to use for coloring',
+ type: 'string',
+ },
+ },
+ },
+ {
+ description: 'query',
+ type: 'object',
+ properties: {
+ query: {
+ $ref: '#/components/schemas/QueryInterface',
+ },
+ },
+ },
+ {
+ nullable: true,
+ },
+ ],
},
algorithm: {
description: 'Algorithm to use for visualization',
@@ -159,7 +181,12 @@ function Visualize() {
-
+
@@ -181,12 +208,41 @@ function Visualize() {
-
+
+
+
+
+
+
+ ⋯
+
+
+
+ {activePoint && }
+
+