diff --git a/package-lock.json b/package-lock.json index 715eb9f9..80ad739d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -8341,9 +8341,9 @@ } }, "node_modules/vite": { - "version": "4.5.3", - "resolved": "https://registry.npmjs.org/vite/-/vite-4.5.3.tgz", - "integrity": "sha512-kQL23kMeX92v3ph7IauVkXkikdDRsYMGTVl5KY2E9OY4ONLvkHf04MDTbnfo6NKxZiDLWzVpP5oTa8hQD8U3dg==", + "version": "4.5.5", + "resolved": "https://registry.npmjs.org/vite/-/vite-4.5.5.tgz", + "integrity": "sha512-ifW3Lb2sMdX+WU91s3R0FyQlAyLxOzCSCP37ujw0+r5POeHPwe6udWVIElKQq8gk3t7b8rkmvqC6IHBpCff4GQ==", "dependencies": { "esbuild": "^0.18.10", "postcss": "^8.4.27", @@ -14557,9 +14557,9 @@ } }, "vite": { - "version": "4.5.3", - "resolved": "https://registry.npmjs.org/vite/-/vite-4.5.3.tgz", - "integrity": "sha512-kQL23kMeX92v3ph7IauVkXkikdDRsYMGTVl5KY2E9OY4ONLvkHf04MDTbnfo6NKxZiDLWzVpP5oTa8hQD8U3dg==", + "version": "4.5.5", + "resolved": "https://registry.npmjs.org/vite/-/vite-4.5.5.tgz", + "integrity": "sha512-ifW3Lb2sMdX+WU91s3R0FyQlAyLxOzCSCP37ujw0+r5POeHPwe6udWVIElKQq8gk3t7b8rkmvqC6IHBpCff4GQ==", "requires": { "esbuild": "^0.18.10", "fsevents": "~2.3.2", diff --git a/src/components/GraphVisualisation/PointPreview.jsx b/src/components/Common/PointPreview.jsx similarity index 92% rename from src/components/GraphVisualisation/PointPreview.jsx rename to src/components/Common/PointPreview.jsx index 0313025d..c4e717fb 100644 --- a/src/components/GraphVisualisation/PointPreview.jsx +++ b/src/components/Common/PointPreview.jsx @@ -11,7 +11,6 @@ const PointPreview = ({ point }) => { const [loading] = React.useState(false); const conditions = []; const payloadSchema = {}; - const onConditionChange = () => {}; if (!point) { return null; @@ -48,7 +47,6 @@ const PointPreview = ({ point }) => { @@ -67,7 +65,7 @@ const PointPreview = ({ point }) => { }} /> - + )} diff --git a/src/components/FilterEditorWindow/config/RequestFromCode.js b/src/components/FilterEditorWindow/config/RequestFromCode.js deleted file mode 100644 index de2354e3..00000000 --- a/src/components/FilterEditorWindow/config/RequestFromCode.js +++ /dev/null @@ -1,165 +0,0 @@ -import { bigIntJSON } from '../../../common/bigIntJSON'; -import { axiosInstance } from '../../../common/axios'; - -function parseDataToRequest(reqBody) { - // Validate color_by - if (reqBody.color_by) { - const colorBy = reqBody.color_by; - - if (typeof colorBy === 'string') { - // Parse into payload variant - reqBody.color_by = { - payload: colorBy, - }; - } else { - // Check we only have one of the options: payload, or discover_score - const options = [colorBy.payload, colorBy.discover_score]; - const optionsCount = options.filter((option) => option).length; - if (optionsCount !== 1) { - return { - reqBody: reqBody, - error: '`color_by`: Only one of `payload`, or `discover_score` can be used', - }; - } - - // Put search arguments in main request body - if (colorBy.discover_score) { - reqBody = { - ...reqBody, - ...colorBy.discover_score, - }; - } - } - } - - // Set with_vector name - if (reqBody.vector_name) { - reqBody.with_vector = [reqBody.vector_name]; - return { - reqBody: reqBody, - error: null, - }; - } else if (!reqBody.vector_name) { - reqBody.with_vector = true; - return { - reqBody: reqBody, - error: null, - }; - } -} -export async function requestFromCode(dataRaw, collectionName) { - const data = parseDataToRequest(dataRaw); - // Sending request - const colorBy = data.reqBody.color_by; - if (colorBy?.payload) { - return await actionFromCode(collectionName, data, 'scroll'); - } - if (colorBy?.discover_score) { - return discoverFromCode(collectionName, data); - } - return await actionFromCode(collectionName, data, 'scroll'); -} - -async function actionFromCode(collectionName, data, action) { - try { - const response = await axiosInstance({ - method: 'POST', - url: `collections/${collectionName}/points/${action || 'scroll'}`, - data: data.reqBody, - }); - response.data.color_by = data.reqBody.color_by; - response.data.vector_name = data.reqBody.vector_name; - response.data.result.points = response.data.result.points.filter((point) => Object.keys(point.vector).length > 0); - return { - data: response.data, - error: null, - }; - } catch (err) { - return { - data: null, - error: err.response?.data?.status ? err.response?.data?.status : err, - }; - } -} - -async function discoverFromCode(collectionName, data) { - // Do 20/80 split. 20% of the points will be returned with the query - // and 80 % will be returned with random sampling - const queryLimit = Math.floor(data.reqBody.limit * 0.2); - const randomLimit = data.reqBody.limit - queryLimit; - data.reqBody.limit = queryLimit; - data.reqBody.with_payload = true; - - const queryResponse = await actionFromCode(collectionName, data, 'discover'); - if (queryResponse.error) { - return { - data: null, - error: queryResponse.error, - }; - } - - // Add tag to know which points were returned by the query - queryResponse.data.result = queryResponse.data.result.map((point) => ({ - ...point, - from_query: true, - })); - - // Get "random" points ids. - // There is no sampling endpoint in Qdrant yet, so for now we just scroll excluding the previous results - const idsToExclude = queryResponse.data.result.map((point) => point.id); - - const originalFilter = data.reqBody.filter; - const mustNotFilter = [{ has_id: idsToExclude }]; - data.reqBody.filter = originalFilter || {}; - data.reqBody.filter.must_not = mustNotFilter.concat(data.reqBody.filter.must_not ?? []); - - data.reqBody.limit = randomLimit; - const randomResponse = await actionFromCode(collectionName, data, 'scroll'); - if (randomResponse.error) { - return { - data: null, - error: randomResponse.error, - }; - } - - // Then score these random points - const idsToInclude = randomResponse.data.result.points.map((point) => point.id); - const mustFilter = [{ has_id: idsToInclude }]; - data.reqBody.filter = originalFilter || {}; - data.reqBody.filter.must = mustFilter.concat(data.reqBody.filter.must || []); - - const scoredRandomResponse = await actionFromCode(collectionName, data, 'discover'); - if (scoredRandomResponse.error) { - return { - data: null, - error: scoredRandomResponse.error, - }; - } - - // Concat both results - const points = queryResponse.data.result.concat(scoredRandomResponse.data.result); - - return { - data: { - ...queryResponse.data, - result: { - points: points, - }, - }, - error: null, - }; -} - -export function codeParse(codeText) { - // Parse JSON - if (codeText) { - try { - return bigIntJSON.parse(codeText); - } catch (e) { - return { - reqBody: codeText, - error: 'Fix the position brackets to run & check the json', - }; - } - } -} diff --git a/src/components/FilterEditorWindow/index.jsx b/src/components/FilterEditorWindow/index.jsx index 89044508..02956258 100644 --- a/src/components/FilterEditorWindow/index.jsx +++ b/src/components/FilterEditorWindow/index.jsx @@ -6,7 +6,7 @@ import { useClient } from '../../context/client-context'; import { useTheme } from '@mui/material/styles'; import { autocomplete } from './config/Autocomplete'; import { useSnackbar } from 'notistack'; -import { codeParse } from './config/RequestFromCode'; +import { bigIntJSON } from '../../common/bigIntJSON'; import './editor.css'; import EditorCommon from '../EditorCommon'; @@ -30,6 +30,20 @@ const CodeEditorWindow = ({ onChange, code, onChangeResult, customRequestSchema, [] ); + function codeParse(codeText) { + // Parse JSON + if (codeText) { + try { + return bigIntJSON.parse(codeText); + } catch (e) { + return { + reqBody: codeText, + error: 'Fix the position brackets to run & check the json', + }; + } + } + } + function onRun(codeText) { const data = codeParse(codeText); if (data.error) { diff --git a/src/components/Points/DataGridList.jsx b/src/components/Points/DataGridList.jsx index c58905f4..847e210e 100644 --- a/src/components/Points/DataGridList.jsx +++ b/src/components/Points/DataGridList.jsx @@ -85,7 +85,9 @@ export const DataGridList = function ({ data = {}, specialCases = {}, onConditio if (conditions.find((c) => c.key === filter.key && c.value === filter.value)) { return; } - onConditionChange([...conditions, filter]); + if (typeof onConditionChange === 'function') { + onConditionChange([...conditions, filter]); + } }} > diff --git a/src/components/Points/PointVectors.jsx b/src/components/Points/PointVectors.jsx index 9bec5f98..4042d12b 100644 --- a/src/components/Points/PointVectors.jsx +++ b/src/components/Points/PointVectors.jsx @@ -34,6 +34,7 @@ const Vectors = memo(function Vectors({ point, onConditionChange }) { return ( + Vectors: {Object.keys(vectors).map((key) => { return ( @@ -81,13 +82,17 @@ const Vectors = memo(function Vectors({ point, onConditionChange }) { - + {typeof onConditionChange !== 'function' ? null : ( + + )} ); @@ -98,7 +103,7 @@ const Vectors = memo(function Vectors({ point, onConditionChange }) { Vectors.propTypes = { point: PropTypes.object.isRequired, - onConditionChange: PropTypes.func.isRequired, + onConditionChange: PropTypes.func, }; export default Vectors; diff --git a/src/components/VisualizeChart/ImageTooltip.jsx b/src/components/VisualizeChart/ImageTooltip.jsx deleted file mode 100644 index 2b0d4e2f..00000000 --- a/src/components/VisualizeChart/ImageTooltip.jsx +++ /dev/null @@ -1,100 +0,0 @@ -import { toFont } from 'chart.js/helpers'; -import React from 'react'; -import { createRoot } from 'react-dom/client'; -import { flushSync } from 'react-dom'; - -const DEFAULT_BORDER_COLOR = '#333333'; - -export function imageTooltip(context) { - // Tooltip Element - let tooltipEl = document.getElementById('chartjs-tooltip'); - - // Create element on first render - if (!tooltipEl) { - tooltipEl = document.createElement('div'); - tooltipEl.id = 'chartjs-tooltip'; - tooltipEl.appendChild(document.createElement('table')); - document.body.appendChild(tooltipEl); - } - - // Hide if no tooltip - const tooltipModel = context.tooltip; - if (tooltipModel.opacity === 0) { - tooltipEl.style.opacity = 0; - return; - } - - // Set caret Position - tooltipEl.classList.remove('above', 'below', 'no-transform'); - if (tooltipModel.yAlign) { - tooltipEl.classList.add(tooltipModel.yAlign); - } else { - tooltipEl.classList.add('no-transform'); - } - - // Set content - if (tooltipModel.body) { - const bodyLines = tooltipModel.body[0].lines; - - const imageSrc = tooltipModel.dataPoints[0].dataset.data[tooltipModel.dataPoints[0].dataIndex].point.payload?.image; - - const borderColor = tooltipModel.labelColors[0]?.borderColor || DEFAULT_BORDER_COLOR; - - const child = ( -
- {imageSrc && ( - - )} -
- {bodyLines.map((line, i) => ( - - {line} - - ))} -
-
- ); - - // Render html to insert in tooltip - const tableRoot = tooltipEl.querySelector('table'); - const root = createRoot(tableRoot); - flushSync(() => { - root.render(child); - }); - } - - const position = context.chart.canvas.getBoundingClientRect(); - const bodyFont = toFont(tooltipModel.options.bodyFont); - - // Display, position, and set styles for font - tooltipEl.style.opacity = 1; - tooltipEl.style.position = 'absolute'; - tooltipEl.style.left = position.left + window.scrollX + tooltipModel.caretX + 'px'; - tooltipEl.style.top = position.top + window.scrollY + tooltipModel.caretY + 'px'; - tooltipEl.style.font = bodyFont.string; - tooltipEl.style.padding = tooltipModel.padding + 'px ' + tooltipModel.padding + 'px'; - tooltipEl.style.pointerEvents = 'none'; -} diff --git a/src/components/VisualizeChart/ViewPointModal.jsx b/src/components/VisualizeChart/ViewPointModal.jsx deleted file mode 100644 index fa2063a0..00000000 --- a/src/components/VisualizeChart/ViewPointModal.jsx +++ /dev/null @@ -1,72 +0,0 @@ -import React from 'react'; -import PropTypes from 'prop-types'; -import { Box, Button, Dialog, DialogContent, DialogActions, DialogTitle, Paper, Typography } from '@mui/material'; -import { alpha } from '@mui/material'; -import { useTheme } from '@mui/material/styles'; -import { DataGridList } from '../Points/DataGridList'; -import { CopyButton } from '../Common/CopyButton'; -import { bigIntJSON } from '../../common/bigIntJSON'; - -const ViewPointModal = (props) => { - const theme = useTheme(); - const { openViewPoints, setOpenViewPoints, viewPoints } = props; - - return ( - <> - setOpenViewPoints(false)} - scroll={'paper'} - maxWidth={'lg'} - fullWidth={true} - > - Selected Points - - {viewPoints.length > 0 ? ( - <> - - {viewPoints.map((point, index) => ( - - - - Point {point.id} - - - - - - - - - ))} - - - ) : ( - - no points selected - - )} - - - - - - - ); -}; - -ViewPointModal.propTypes = { - openViewPoints: PropTypes.bool.isRequired, - setOpenViewPoints: PropTypes.func.isRequired, - viewPoints: PropTypes.array.isRequired, -}; - -export default ViewPointModal; diff --git a/src/components/VisualizeChart/index.jsx b/src/components/VisualizeChart/index.jsx index e88c7e7f..ed52430e 100644 --- a/src/components/VisualizeChart/index.jsx +++ b/src/components/VisualizeChart/index.jsx @@ -1,99 +1,182 @@ import Chart from 'chart.js/auto'; -import chroma from 'chroma-js'; import get from 'lodash/get'; import { useSnackbar } from 'notistack'; import PropTypes from 'prop-types'; -import React, { useEffect, useState } from 'react'; -import ViewPointModal from './ViewPointModal'; -import { imageTooltip } from './ImageTooltip'; -import { bigIntJSON } from '../../common/bigIntJSON'; +import React, { useEffect } from 'react'; +import { generateColorBy, generateSizeBy } from './renderBy'; +import { useTheme } from '@mui/material/styles'; -const SCORE_GRADIENT_COLORS = ['#EB5353', '#F9D923', '#36AE7C']; +// Dark red +const LIGHT_SELECTOR_COLOR = 'rgba(255, 0, 0, 0.5)'; +// White +const DARK_SELECTOR_COLOR = 'rgba(245, 245, 245, 0.8)'; -const VisualizeChart = ({ scrollResult, algorithm = null }) => { +// Transparent color for points +const DEFAULT_BORDER_COLOR = 'rgba(0, 0, 0, 0)'; + +function intoDatasets( + points, // array of original points, which contain payloads + data, // list of compressed coordinates + colors, // list of colors for each point to be displayed + sizes, // list of sizes for each point to be displayed + groupBy = null // payload field to group by +) { + const defaultConfig = { + pointHitRadius: 1, + hoverRadius: 7, + }; + + if (!groupBy) { + // No grouping + return [ + { + label: 'Data', + data, + offsets: Array.from({ length: data.length }, (_, i) => i), + pointBackgroundColor: [...colors], + // Use transparent border color for points + pointBorderColor: Array.from({ length: colors.length }, () => DEFAULT_BORDER_COLOR), + ...defaultConfig, + }, + ]; + } + + const groups = {}; + + points.forEach((point, index) => { + let group = get(point.payload, groupBy) + ''; // Convert to string, even if it's an o + + if (!group) { + // If specified field is not present in the payload, fallback to 'Unknown' + group = 'Unknown'; + } + + if (!groups[group]) { + groups[group] = { + label: group, + data: [], + offsets: [], + pointBackgroundColor: [], + pointBorderColor: [], + pointRadius: [], + ...defaultConfig, + }; + } + + groups[group].data.push(data[index]); + groups[group].offsets.push(index); + groups[group].pointBackgroundColor.push(colors[index]); + groups[group].pointBorderColor.push(DEFAULT_BORDER_COLOR); + groups[group].pointRadius.push(sizes[index]); + }); + + // Convert groups object to array, and sort by label + return Object.values(groups).sort((a, b) => a.label.localeCompare(b.label)); +} + +const VisualizeChart = ({ + requestResult, // Raw output of the request from qdrant client + visualizationParams, // Parameters, as specified by the user in the input editor + activePoint, // currently selected point (with hover) + setActivePoint, // callback to set new active point +}) => { const { enqueueSnackbar } = useSnackbar(); - const [openViewPoints, setOpenViewPoints] = useState(false); - const [viewPoints, setViewPoint] = useState([]); + + // Id of the currently selected point + // Used to prevent multiple updates of the chart on hover + // And for switching colors of the selected point + let selectedPointLocation = null; + + const theme = useTheme(); + + function getSelectionColor() { + return theme.palette.mode === 'light' ? LIGHT_SELECTOR_COLOR : DARK_SELECTOR_COLOR; + } useEffect(() => { - if (!scrollResult.data && !scrollResult.error) { + if (!requestResult.points) { return; } - if (scrollResult.error) { - enqueueSnackbar(`Visualization Unsuccessful, error: ${bigIntJSON.stringify(scrollResult.error)}`, { - variant: 'error', - }); + const points = requestResult.points; + const colorBy = visualizationParams?.color_by; - return; - } else if (!scrollResult.data?.result?.points.length) { - enqueueSnackbar(`Visualization Unsuccessful, error: No data returned`, { - variant: 'error', - }); - return; - } + // Initialize data with random points in range [0, 1] + const data = points.map(() => ({ + x: Math.random(), + y: Math.random(), + })); - const dataset = []; - const colorBy = scrollResult.data.color_by; + // This reference values should be used to rollback the color of the previously selected point + const pointColors = generateColorBy(points, colorBy); + const sizes = generateSizeBy(points); - let labelby = null; - if (colorBy?.payload) { - labelby = colorBy.payload; - // Color and label by payload field - if (get(scrollResult.data.result?.points[0]?.payload, labelby) === undefined) { - enqueueSnackbar(`Visualization Unsuccessful, error: Color by field ${labelby} does not exist`, { - variant: 'error', - }); - return; + const payloadField = typeof colorBy === 'string' ? colorBy : colorBy?.payload; + const useLegend = !!payloadField; + + const datasets = intoDatasets(points, data, pointColors, sizes, payloadField); + + const handlePointHover = (chart) => { + if (!chart.tooltip?._active) return; + if (chart.tooltip?._active.length === 0) return; + + const lastActive = chart.tooltip._active.length - 1; + + const selectedElement = chart.tooltip._active[lastActive]; + + const datasetIndex = selectedElement.datasetIndex; + const pointIndex = selectedElement.index; + + const offsets = chart.data.datasets[datasetIndex].offsets; + const pointOffset = offsets[pointIndex]; + const selectedPoint = points[pointOffset]; + + // Check if the same point is already selected + // To prevent recurrant updates of the chart + if (selectedPoint.id === activePoint?.id) { + selectedPointLocation = { + offset: pointOffset, + datasetIndex, + pointIndex, + }; + return selectedPoint; + } + if (pointOffset === selectedPointLocation?.offset) { + return selectedPoint; } - scrollResult.data.labelByArrayUnique = [ - ...new Set(scrollResult.data.result?.points?.map((point) => get(point.payload, labelby))), - ]; - scrollResult.data.labelByArrayUnique.forEach((label) => { - dataset.push({ - label: label, - data: [], - }); - }); - } else if (colorBy?.discover_score) { - // Color by discover score - const scores = scrollResult.data.result?.points.map((point) => point.score); - const minScore = Math.min(...scores); - const maxScore = Math.max(...scores); - - const colorScale = chroma.scale(SCORE_GRADIENT_COLORS); - const scoreColors = scores.map((score) => { - const normalizedScore = (score - minScore) / (maxScore - minScore); - return colorScale(normalizedScore).hex(); - }); - const pointRadii = scrollResult.data.result?.points.map((point) => { - if (point.from_query) { - return 4; - } else { - return 3; - } - }); + const oldPointLocation = selectedPointLocation; + + selectedPointLocation = { + offset: pointOffset, + datasetIndex, + pointIndex, + }; + + // Reset color of the previously selected point + if (oldPointLocation) { + const targetColor = pointColors[oldPointLocation.offset]; + chart.data.datasets[oldPointLocation.datasetIndex].pointBackgroundColor[oldPointLocation.pointIndex] = + targetColor; + + chart.data.datasets[oldPointLocation.datasetIndex].pointBorderColor[oldPointLocation.pointIndex] = + DEFAULT_BORDER_COLOR; + } + + chart.data.datasets[datasetIndex].pointBackgroundColor[pointIndex] = getSelectionColor(); + chart.data.datasets[datasetIndex].pointBorderColor[pointIndex] = getSelectionColor(); + + setActivePoint(selectedPoint); + + chart.update(); + return selectedPoint; + }; - dataset.push({ - label: 'Discover scores', - pointBackgroundColor: scoreColors, - pointBorderColor: scoreColors, - pointRadius: pointRadii, - data: [], - }); - } else { - // No special coloring - dataset.push({ - label: 'Data', - data: [], - }); - } const ctx = document.getElementById('myChart'); const myChart = new Chart(ctx, { type: 'scatter', data: { - datasets: dataset, + datasets: datasets, }, options: { responsive: true, @@ -109,55 +192,31 @@ const VisualizeChart = ({ scrollResult, algorithm = null }) => { display: false, }, }, + interaction: { + mode: 'nearest', // Show tooltip for the nearest point + intersect: false, // Show even if not directly hovering over a point + }, plugins: { tooltip: { // only use custom tooltip if color by is not discover score - enabled: !colorBy?.discover_score, - external: (colorBy?.discover_score && imageTooltip) || undefined, + enabled: true, usePointStyle: true, + position: 'nearest', + intersect: true, callbacks: { label: (context) => { - const payload = bigIntJSON - .stringify(context.dataset.data[context.dataIndex].point.payload, null, 1) - .split('\n'); - - if (colorBy?.discover_score) { - const id = context.dataset.data[context.dataIndex].point.id; - const score = context.dataset.data[context.dataIndex].point.score; - - return [`id: ${id}`, `score: ${score}`]; - } else { - return payload; - } + const selectedPoint = handlePointHover(context.chart); + if (!selectedPoint) return ''; + const id = selectedPoint.id; + return `Point ${id}`; }, }, }, legend: { - display: !!labelby, + display: useLegend, }, }, }, - plugins: [ - { - id: 'myEventCatcher', - beforeEvent(chart, args) { - const event = args.event; - if (event.type === 'click') { - if (chart.tooltip._active.length > 0) { - const activePoints = chart.tooltip._active.map((point) => { - return { - id: point.element.$context.raw.point.id, - payload: point.element.$context.raw.point.payload, - vector: point.element.$context.raw.point.vector, - }; - }); - setViewPoint(activePoints); - setOpenViewPoints(true); - } - } - }, - }, - ], }); const worker = new Worker(new URL('./worker.js', import.meta.url), { @@ -170,19 +229,24 @@ const VisualizeChart = ({ scrollResult, algorithm = null }) => { variant: 'error', }); } else if (m.data.result && m.data.result.length > 0) { - m.data.result.forEach((dataset, index) => { + const reducedPonts = m.data.result; + + const datasets = intoDatasets(points, reducedPonts, pointColors, sizes, payloadField); + + datasets.forEach((dataset, index) => { myChart.data.datasets[index].data = dataset.data; }); + myChart.update(); } else { enqueueSnackbar(`Visualization Unsuccessful, error: Unexpected Error Occured`, { variant: 'error' }); } }; - if (scrollResult.data.result?.points?.length > 0) { + if (requestResult.points?.length > 0) { worker.postMessage({ - ...scrollResult.data, - algorithm: algorithm, + result: requestResult, + params: visualizationParams, }); } @@ -190,19 +254,20 @@ const VisualizeChart = ({ scrollResult, algorithm = null }) => { myChart.destroy(); worker.terminate(); }; - }, [scrollResult]); + }, [requestResult]); return ( <> - ); }; VisualizeChart.propTypes = { - scrollResult: PropTypes.object.isRequired, - algorithm: PropTypes.string, + requestResult: PropTypes.object.isRequired, + visualizationParams: PropTypes.object.isRequired, + activePoint: PropTypes.object, + setActivePoint: PropTypes.func, }; export default VisualizeChart; diff --git a/src/components/VisualizeChart/renderBy.js b/src/components/VisualizeChart/renderBy.js new file mode 100644 index 00000000..7ab611ae --- /dev/null +++ b/src/components/VisualizeChart/renderBy.js @@ -0,0 +1,97 @@ +import chroma from 'chroma-js'; + +const SCORE_GRADIENT_COLORS = ['#EB5353', '#F9D923', '#36AE7C']; +const BACKGROUND_COLOR = '#36A2EB'; + +const PALLETE = [ + '#3366CC', + '#DC3912', + '#FF9900', + '#109618', + '#990099', + '#3B3EAC', + '#0099C6', + '#DD4477', + '#66AA00', + '#B82E2E', + '#316395', + '#994499', + '#22AA99', + '#AAAA11', + '#6633CC', + '#E67300', + '#8B0707', + '#329262', + '#5574A6', + '#651067', +]; + +// const SELECTED_BORDER_COLOR = '#881177'; + +function colorByPayload(payloadValue, colored) { + if (colored[payloadValue]) { + return colored[payloadValue]; + } + + const nextColorIndex = Object.keys(colored).length % PALLETE.length; + + colored[payloadValue] = PALLETE[nextColorIndex]; + + return PALLETE[nextColorIndex]; +} + +// This function generates an array of colors for each point in the chart. +// There are following options available for colorBy: +// +// - None: all points will have the same color +// - typeof = "string": color points based on the source field +// - {"payload": "field_name"}: color points based on the payload field +// - {"discover_score": { ... } }: color points based on the discover score +// - {"query": { ... }}: color points based on the query score + +export function generateColorBy(points, colorBy = null) { + // Points example: + // [ + // { id: 0, payload: { field_name: 1 }, score: 0.5, vector: [0.1, 0.2, ....] }, + // { id: 1, payload: { field_name: 2 }, score: 0.6, vector: [0.3, 0.4, ....] }, + // ... + // ] + + if (!colorBy) { + return Array.from({ length: points.length }, () => BACKGROUND_COLOR); // Default color + } + + // If `colorBy` is a string, interpret as a field name + if (typeof colorBy === 'string') { + colorBy = { payload: colorBy }; + } + + if (colorBy.payload) { + const valuesToColor = {}; + + return points.map((point) => { + const payloadValue = point.payload[colorBy.payload]; + if (!payloadValue) { + return BACKGROUND_COLOR; + } + return colorByPayload(payloadValue, valuesToColor); + }); + } + + if (colorBy.discover_score || colorBy.query) { + const scores = points.map((point) => point.score); + const minScore = Math.min(...scores); + const maxScore = Math.max(...scores); + + const colorScale = chroma.scale(SCORE_GRADIENT_COLORS); + return scores.map((score) => { + const normalizedScore = (score - minScore) / (maxScore - minScore); + return colorScale(normalizedScore).hex(); + }); + } +} + +export function generateSizeBy(points) { + // ToDo: Intoroduce size differentiation later + return points.map(() => 3); +} diff --git a/src/components/VisualizeChart/requestData.js b/src/components/VisualizeChart/requestData.js new file mode 100644 index 00000000..7deba485 --- /dev/null +++ b/src/components/VisualizeChart/requestData.js @@ -0,0 +1,33 @@ +/* eslint-disable camelcase */ +export function requestData( + qdrantClient, + collectionName, + { limit, filter = null, vector_name = null, color_by = null } +) { + // Based on the input parameters, we need to decide what kind of request we need to send + // By default we should do scroll request + // But if we have color_by field which uses query, it should be used instead + + if (color_by?.query) { + const query = { + query: color_by.query, + limit: limit, + filter: filter, + with_vector: vector_name ? [vector_name] : true, + with_payload: true, + }; + + return qdrantClient.query(collectionName, query); + } + + // It it's not a query, we should do a scroll request + + const scrollQuery = { + limit: limit, + filter: filter, + with_vector: vector_name ? [vector_name] : true, + with_payload: true, + }; + + return qdrantClient.scroll(collectionName, scrollQuery); +} diff --git a/src/components/VisualizeChart/worker.js b/src/components/VisualizeChart/worker.js index 27e3b3cf..8899d76f 100644 --- a/src/components/VisualizeChart/worker.js +++ b/src/components/VisualizeChart/worker.js @@ -4,113 +4,101 @@ import get from 'lodash/get'; const MESSAGE_INTERVAL = 200; +function getVectorType(vector) { + if (Array.isArray(vector)) { + if (Array.isArray(vector[0])) { + return 'multivector'; + } + return 'vector'; + } + if (typeof vector === 'object') { + if (vector.indices) { + return 'sparse'; + } + return 'named'; + } + return 'unknown'; +} + self.onmessage = function (e) { let now = new Date().getTime(); - const algorithm = e?.data?.algorithm || 'TSNE'; + const params = e?.data?.params || {}; + + const algorithm = params.algorithm || 'TSNE'; - const data1 = e.data; const data = []; - if (data1?.result?.points?.length === 0) { + const points = e.data?.result?.points; + const vectorName = params.vector_name; + + if (!points || points.length === 0) { self.postMessage({ data: [], error: 'No data found', }); return; - } else if (data1?.result?.points?.length === 1) { + } + + if (points.length === 1) { self.postMessage({ data: [], error: 'cannot perform tsne on single point', }); return; - } else if (typeof data1?.result?.points[0].vector.length === 'number') { - data1?.result?.points?.forEach((point) => { - data.push(point.vector); - }); - } else if (typeof data1?.result?.points[0].vector === 'object') { - if (data1.vector_name === undefined) { - self.postMessage({ - data: [], - error: 'No vector name found, select a valid vector_name', - }); - return; - } else if (data1?.result?.points[0].vector[data1?.vector_name] === undefined) { - self.postMessage({ - data: [], - error: 'No vector found with name ' + data1?.vector_name, - }); - return; - } else if (data1?.result?.points[0].vector[data1?.vector_name]) { - if (!Array.isArray(data1?.result?.points[0].vector[data1?.vector_name])) { - self.postMessage({ - data: [], - error: 'Vector visualization is not supported for sparse vector', - }); - return; - } - data1?.result?.points?.forEach((point) => { - data.push(point.vector[data1?.vector_name]); - }); + } + + for (let i = 0; i < points.length; i++) { + if (!vectorName) { + // Work with default vector + data.push(points[i]?.vector); } else { + // Work with named vector + data.push(get(points[i]?.vector, vectorName)); + } + } + + // Validate data + + for (let i = 0; i < data.length; i++) { + const vector = data[i]; + const vectorType = getVectorType(vector); + + if (vectorType === 'vector') { + continue; + } + + if (vectorType === 'named') { self.postMessage({ data: [], - error: 'Unexpected Error Occurred', + error: 'Please select a valid vector name, default vector is not defined', }); return; } - } else { + self.postMessage({ data: [], - error: 'Unexpected Error Occurred', + error: 'Vector visualization is not supported for vector type: ' + vectorType, }); return; } + if (data.length) { const D = new druid[algorithm](data, {}); // ex params = { perplexity : 50,epsilon :5} const next = D.generator(); // default = 500 iterations - let i = {}; - for (i of next) { + + let reducedPoints = []; + for (reducedPoints of next) { if (Date.now() - now > MESSAGE_INTERVAL) { now = Date.now(); - self.postMessage({ result: getDataset(data1, i), error: null }); + self.postMessage({ result: getDataset(reducedPoints), error: null }); } } - self.postMessage({ result: getDataset(data1, i), error: null }); + self.postMessage({ result: getDataset(reducedPoints), error: null }); } }; -function getDataset(data, reducedPoint) { - const dataset = []; - const labelby = data.color_by?.payload; - if (labelby) { - data.labelByArrayUnique.forEach((label) => { - dataset.push({ - label: label, - data: [], - }); - }); - - data.result?.points?.forEach((point, index) => { - const label = get(point.payload, labelby); - dataset[data.labelByArrayUnique.indexOf(label)].data.push({ - x: reducedPoint[index][0], - y: reducedPoint[index][1], - point: point, - }); - }); - } else { - dataset.push({ - label: 'data', - data: [], - }); - data.result?.points?.forEach((point, index) => { - dataset[0].data.push({ - x: reducedPoint[index][0], - y: reducedPoint[index][1], - point: point, - }); - }); - } - return dataset; +function getDataset(reducedPoints) { + // Convert [[x1, y1], [x2, y2] ] to [ { x: x1, y: y1 }, { x: x2, y: y2 } ] + return reducedPoints.map((point) => ({ x: point[0], y: point[1] })); } diff --git a/src/pages/Graph.jsx b/src/pages/Graph.jsx index 502cca3e..4ee6b6d5 100644 --- a/src/pages/Graph.jsx +++ b/src/pages/Graph.jsx @@ -6,7 +6,7 @@ import { useTheme } from '@mui/material/styles'; import { Panel, PanelGroup, PanelResizeHandle } from 'react-resizable-panels'; import GraphVisualisation from '../components/GraphVisualisation/GraphVisualisation'; import { useWindowResize } from '../hooks/windowHooks'; -import PointPreview from '../components/GraphVisualisation/PointPreview'; +import PointPreview from '../components/Common/PointPreview'; import CodeEditorWindow from '../components/FilterEditorWindow'; import { useClient } from '../context/client-context'; import { getFirstPoint } from '../lib/graph-visualization-helpers'; diff --git a/src/pages/Visualize.jsx b/src/pages/Visualize.jsx index 1e518a4c..b7b244d1 100644 --- a/src/pages/Visualize.jsx +++ b/src/pages/Visualize.jsx @@ -7,10 +7,19 @@ import { Panel, PanelGroup, PanelResizeHandle } from 'react-resizable-panels'; import FilterEditorWindow from '../components/FilterEditorWindow'; import VisualizeChart from '../components/VisualizeChart'; import { useWindowResize } from '../hooks/windowHooks'; -import { requestFromCode } from '../components/FilterEditorWindow/config/RequestFromCode'; +import PointPreview from '../components/Common/PointPreview'; +import { useClient } from '../context/client-context'; +import { requestData } from '../components/VisualizeChart/requestData'; +import { useSnackbar } from 'notistack'; const query = ` +// Try me! + +{ + "limit": 500 +} + // Specify request parameters to select data for visualization. // // Available parameters: @@ -24,33 +33,14 @@ const query = ` // - 'color_by': specify score or payload field to use for coloring points. // How to use: // -// "color_by": "field_name" -// -// or -// // "color_by": { // "payload": "field_name" // } // -// or -// -// "color_by": { -// "discover_score": { -// "target": 42, -// "context": [{"positive": 1, "negative": 0}] -// } -// } -// // - 'vector_name': specify which vector to use for visualization // if there are multiple. // // - 'algorithm': specify algorithm to use for visualization. Available options: 'TSNE', 'UMAP'. -// -// Minimal example: - -{ - "limit": 500 -} `; @@ -58,27 +48,34 @@ const defaultResult = {}; function Visualize() { const theme = useTheme(); + const { client: qdrantClient } = useClient(); const [code, setCode] = useState(query); + + // Contains the raw output of the request of QdrantClient const [result, setResult] = useState(defaultResult); - const [algorithm, setAlgorithm] = useState('TSNE'); + const [visualizationParams, setVisualizationParams] = useState({}); + const { enqueueSnackbar } = useSnackbar(); // const [errorMessage, setErrorMessage] = useState(null); // todo: use or remove const navigate = useNavigate(); const params = useParams(); const [visualizeChartHeight, setVisualizeChartHeight] = useState(0); const VisualizeChartWrapper = useRef(null); const { height } = useWindowResize(); + const [activePoint, setActivePoint] = useState(null); useEffect(() => { setVisualizeChartHeight(height - VisualizeChartWrapper.current?.offsetTop); }, [height, VisualizeChartWrapper]); const onEditorCodeRun = async (data, collectionName) => { - if (data?.algorithm) { - setAlgorithm(data.algorithm); - } + setVisualizationParams(data); - const result = await requestFromCode(data, collectionName); - setResult(result); + try { + const result = await requestData(qdrantClient, collectionName, data); + setResult(result); + } catch (e) { + enqueueSnackbar(`Request error: ${e.message}`, { variant: 'error' }); + } }; const filterRequestSchema = (vectorNames) => ({ @@ -110,8 +107,33 @@ function Visualize() { }, color_by: { description: 'Color points by this field', - type: 'string', - nullable: true, + anyOf: [ + { + type: 'string', // Name of the field to use for coloring + }, + { + description: 'field name', + type: 'object', + properties: { + payload: { + description: 'Name of the field to use for coloring', + type: 'string', + }, + }, + }, + { + description: 'query', + type: 'object', + properties: { + query: { + $ref: '#/components/schemas/QueryInterface', + }, + }, + }, + { + nullable: true, + }, + ], }, algorithm: { description: 'Algorithm to use for visualization', @@ -159,7 +181,12 @@ function Visualize() { - + @@ -181,12 +208,41 @@ function Visualize() { - + + + + + + + ⋯ + + + + {activePoint && } + +