Skip to content

Commit

Permalink
Optimize modal sample tagging (#5417) (#5513)
Browse files Browse the repository at this point in the history
* tag counts with respect to modal sample/group

* group issues

* simplify, some cleanup

* revert slice view change

* fix e2e
  • Loading branch information
benjaminpkane authored Feb 25, 2025
1 parent e2a660c commit 5273935
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 38 deletions.
11 changes: 8 additions & 3 deletions app/packages/core/src/components/Actions/Tagger.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,11 @@ const useTagCallback = (
fos.isOrderedDynamicGroup
);

const slices = await snapshot.getPromise(fos.currentSlices(modal));
const mode = await snapshot.getPromise(groupStatistics(modal));
const currentSlices = await snapshot.getPromise(
fos.currentSlices(modal)
);
const slices = await snapshot.getPromise(fos.groupSlices);
const { samples } = await getFetchFunction()("POST", "/tag", {
...tagParameters({
activeFields: await snapshot.getPromise(
Expand All @@ -363,9 +367,10 @@ const useTagCallback = (
isGroup && !isNonNestedDynamicGroup
? {
id: modal ? await snapshot.getPromise(groupId) : null,
slices,
mode,
currentSlices,
slice: await snapshot.getPromise(fos.groupSlice),
mode: await snapshot.getPromise(groupStatistics(modal)),
slices,
}
: null,
modal,
Expand Down
30 changes: 13 additions & 17 deletions app/packages/core/src/components/Actions/utils.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ export const tagStatistics = selectorFamily<
get(isGroup) && get(fos.groupField)
? {
id: modal ? get(groupId) : null,
slices: get(fos.currentSlices(modal)),
slice: get(fos.currentSlice(modal)),
currentSlices: get(fos.currentSlices(modal)),
mode: get(groupStatistics(modal)),
slice: get(fos.currentSlice(modal)),
slices: get(fos.groupSlices),
}
: null,
hiddenLabels: get(fos.hiddenLabelsArray),
Expand Down Expand Up @@ -131,18 +132,7 @@ export const tagStats = selectorFamily<
get:
({ modal, labels }) =>
({ get }) => {
const data = Object.keys(
get(
labels
? fos.labelTagCounts({ modal: false, extended: false })
: fos.sampleTagCounts({ modal: false, extended: false })
)
).map((t) => [t, 0]);

return {
...Object.fromEntries(data),
...get(tagStatistics({ modal, labels })).tags,
};
return get(tagStatistics({ modal, labels })).tags;
},
});

Expand All @@ -166,6 +156,7 @@ export const tagParameters = ({
activeFields: string[];
groupData: {
id: string | null;
currentSlices: string[] | null;
slice: string | null;
slices: string[] | null;
mode: "group" | "slice";
Expand All @@ -174,8 +165,11 @@ export const tagParameters = ({
sampleId: string | null;
}) => {
const shouldShowCurrentSample =
params.modal && selectedSamples.size == 0 && hiddenLabels.length == 0;
params.modal && selectedSamples.size === 0 && hiddenLabels.length === 0;
const groups = groupData?.mode === "group";
if (groupData && !groups) {
groupData.slices = groupData.currentSlices;
}

const getSampleIds = () => {
if (shouldShowCurrentSample && !groups) {
Expand All @@ -186,17 +180,19 @@ export const tagParameters = ({
return [...new Set(selectedLabels.map((l) => l.sampleId))];
}
return [sampleId];
} else if (selectedSamples.size) {
}
if (selectedSamples.size) {
return [...selectedSamples];
}

return null;
};

return {
...params,
label_fields: activeFields,
target_labels: targetLabels,
slices: !groups ? groupData?.slices : null,
slices: groupData?.slices,
slice: groupData?.slice,
group_id: params.modal ? groupData?.id : null,
sample_ids: getSampleIds(),
Expand Down
28 changes: 18 additions & 10 deletions app/packages/state/src/recoil/aggregations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ import { graphQLSelectorFamily } from "recoil-relay";
import type { ResponseFrom } from "../utils";
import { refresher } from "./atoms";
import * as filterAtoms from "./filters";
import { currentSlices, groupId, groupSlice, groupStatistics } from "./groups";
import {
currentSlices,
groupId,
groupSlice,
groupSlices,
groupStatistics,
} from "./groups";
import { sidebarSampleId } from "./modal";
import { RelayEnvironmentKey } from "./relay";
import * as schemaAtoms from "./schema";
Expand Down Expand Up @@ -76,7 +82,7 @@ export const aggregationQuery = graphQLSelectorFamily<
paths,
mixed,
sampleIds,
slices: mixed ? null : get(currentSlices(modal)), // when mixed, slice is not needed
slices: mixed ? get(groupSlices) : get(currentSlices(modal)),
slice: get(groupSlice),
view: customView ? customView : !root ? get(viewAtoms.view) : [],
};
Expand Down Expand Up @@ -150,27 +156,29 @@ export const modalAggregationPaths = selectorFamily({
const isFramesPath = frames.some((p) => params.path.startsWith(p));
let paths = isFramesPath
? frames
: get(schemaAtoms.labelFields({ space: State.SPACE.SAMPLE })).map(
(path) => get(schemaAtoms.expandPath(path))
);
: [
...get(schemaAtoms.labelFields({ space: State.SPACE.SAMPLE })).map(
(path) => get(schemaAtoms.expandPath(path))
),
];

paths = paths
.sort()
.flatMap((p) => get(schemaAtoms.modalFilterFields(p)));

const numeric = get(schemaAtoms.isNumericField(params.path));
if (!isFramesPath && !numeric) {
// the modal currently requires a 'tags' aggregation
paths = ["tags", ...paths];
}

if (params.mixed || get(groupId)) {
paths = [
...paths.filter((p) => {
const n = get(schemaAtoms.isNumericField(p));
return numeric ? n : !n;
}),
];

if (!numeric && !isFramesPath) {
// the modal currently requires a 'tags' aggregation
paths = ["tags", ...paths];
}
}

return paths;
Expand Down
4 changes: 2 additions & 2 deletions app/packages/state/src/recoil/modal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export const modalLooker = atom<Lookers | null>({
dangerouslyAllowMutability: true,
});

export const sidebarSampleId = selector({
export const sidebarSampleId = selector<null | string>({
key: "sidebarSampleId",
get: ({ get }) => {
if (get(shouldRenderImaVidLooker(true))) {
Expand All @@ -41,7 +41,7 @@ export const sidebarSampleId = selector({

if (!isPlaying && !isSeeking && thisFrameNumber && sample) {
// is the type incorrect? fix me
const id = sample?.id || sample?._id || sample?.sample?._id;
const id = sample?.id || sample?._id || (sample?.sample?._id as string);
if (id) {
return id;
}
Expand Down
18 changes: 16 additions & 2 deletions app/packages/state/src/recoil/pathData/tags.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { selectorFamily } from "recoil";
import { aggregation } from "../aggregations";
import { groupStatistics } from "../groups";
import * as schemaAtoms from "../schema";

export const labelTagCounts = selectorFamily<
Expand All @@ -11,7 +12,14 @@ export const labelTagCounts = selectorFamily<
({ modal, extended }) =>
({ get }) => {
const data = get(schemaAtoms.labelPaths({})).map((path) =>
get(aggregation({ extended, modal, path: `${path}.tags`, mixed: true }))
get(
aggregation({
extended,
modal,
path: `${path}.tags`,
mixed: get(groupStatistics(modal)) === "group",
})
)
);

const result = {};
Expand Down Expand Up @@ -45,7 +53,13 @@ export const sampleTagCounts = selectorFamily<
get:
(params) =>
({ get }) => {
const data = get(aggregation({ ...params, path: "tags" }));
const data = get(
aggregation({
...params,
path: "tags",
mixed: get(groupStatistics(params.modal)) === "group",
})
);
if (data.__typename !== "StringAggregation") {
throw new Error("unexpected");
}
Expand Down
4 changes: 0 additions & 4 deletions fiftyone/server/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,6 @@ async def aggregate_resolver(
if form.sample_ids:
view = fov.make_optimized_select_view(view, form.sample_ids)

if form.mixed and view.media_type == fom.GROUP and view.group_slices:
view = view.select_group_slices(_force_mixed=True)
view = fosv.get_extended_view(view, form.filters)

if form.hidden_labels:
view = view.exclude_labels(
[
Expand Down

0 comments on commit 5273935

Please sign in to comment.