Skip to content

Commit

Permalink
allow non numbers in y axis but force aggregate with count
Browse files Browse the repository at this point in the history
  • Loading branch information
vieiralucas committed Feb 3, 2025
1 parent 531374e commit 2d0084c
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 89 deletions.
49 changes: 15 additions & 34 deletions apps/api/src/python/visualizations-v2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,47 +132,30 @@ def _briefer_create_visualization(df, options):
y_group_by = series["groupBy"]["name"] if series["groupBy"] else None
grouping_columns = ["_grouped"] + ([y_group_by] if y_group_by else [])
aggregate_func = series["aggregateFunction"] or "count"
df[series["id"]] = df[series["column"]["name"]]
if pd.api.types.is_datetime64_any_dtype(df[options["xAxis"]["name"]]):
if options["xAxisGroupFunction"] and series["aggregateFunction"]:
if options["xAxisGroupFunction"]:
freq = freqs.get(options["xAxisGroupFunction"], "s")
y_axis_agg_func = series["aggregateFunction"]
datetime_agg_funcs = set(["count", "mean", "median"])
if pd.api.types.is_datetime64_any_dtype(df[series["column"]["name"]]):
if y_axis_agg_func not in datetime_agg_funcs:
y_axis_agg_func = "count"
# Group by the specified frequency and aggregate the values
df["_grouped"] = df[options["xAxis"]["name"]].dt.to_period(freq).dt.start_time
df = df.groupby(grouping_columns, as_index=False).agg({
series["column"]["name"]: y_axis_agg_func
series["id"]: aggregate_func
}).reset_index()
elif options["xAxisGroupFunction"]:
freq = freqs.get(options["xAxisGroupFunction"], "s")
df[options["xAxis"]["name"]] = pd.to_datetime(df[options["xAxis"]["name"]])
# Group by the specified frequency
df[options["xAxis"]["name"]] = df[options["xAxis"]["name"]].dt.to_period(freq).dt.start_time
else:
# just group by values who are the same
df["_grouped"] = df[options["xAxis"]["name"]]
df = df.groupby(grouping_columns, as_index=False).agg({
series["column"]["name"]: "count"
}).reset_index()
elif series["aggregateFunction"]:
y_axis_agg_func = series["aggregateFunction"]
datetime_agg_funcs = set(["count", "mean", "median"])
if pd.api.types.is_datetime64_any_dtype(df[series["column"]["name"]]):
if y_axis_agg_func not in datetime_agg_funcs:
y_axis_agg_func = "count"
df["_grouped"] = df[options["xAxis"]["name"]]
df = df.groupby(grouping_columns, as_index=False).agg({
series["column"]["name"]: y_axis_agg_func
series["id"]: "count"
}).reset_index()
else:
df["_grouped"] = df[options["xAxis"]["name"]]
df = df.groupby(grouping_columns, as_index=False).agg({
series["id"]: aggregate_func
}).reset_index()
return df
Expand Down Expand Up @@ -412,12 +395,14 @@ def _briefer_create_visualization(df, options):
for g_option in series.get("groups") or []:
group_options[g_option["group"]] = g_option
y_name = series["id"]
for group in groups:
color_index += 1
dataset_index = len(data["dataset"])
g_options = group_options.get(group) if group else {}
dimensions = [series["column"]["name"]]
dimensions = [y_name]
if options["xAxis"]:
dimensions.insert(0, options["xAxis"]["name"])
Expand All @@ -430,11 +415,7 @@ def _briefer_create_visualization(df, options):
if group and row[series["groupBy"]["name"]] != group:
continue
y_name = series["column"]["name"]
y_value = row[y_name]
# if y_value is not a number set data_y_axis["type"] to category
if type(y_value) not in [int, float]:
data_y_axis["type"] = "category"
row_data = {}
if options["xAxis"]:
Expand Down Expand Up @@ -467,7 +448,7 @@ def _briefer_create_visualization(df, options):
"z": i,
"encode": {
"x": options["xAxis"]["name"],
"y": series["column"]["name"],
"y": y_name,
}
}
if chart_type == "line":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ function VisualizationControls(props: Props) {
{
axisName: null,
column: null,
aggregateFunction: null,
aggregateFunction: 'sum',
colorBy: null,
chartType: null,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ function VisualizationControlsV2(props: Props) {
{
id: uuidv4(),
column: null,
aggregateFunction: null,
aggregateFunction: 'sum',
groupBy: null,
chartType: null,
name: null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,32 +31,19 @@ interface Props {
onAddYAxis?: () => void
}

const isNumberType = (column: DataFrameColumn | null) =>
NumpyNumberTypes.safeParse(column?.type).success

export function getAggFunction(
defaultChartType: ChartType,
series: SeriesV2,
column: DataFrameColumn | null
): AggregateFunction | null {
const chartType = series.chartType ?? defaultChartType

if (series.aggregateFunction !== null || !column || !isNumberType(column)) {
return series.aggregateFunction
): AggregateFunction {
if (!column) {
return series.aggregateFunction ?? 'sum'
}

if (
chartType === 'groupedColumn' ||
chartType === 'stackedColumn' ||
chartType === 'hundredPercentStackedColumn' ||
chartType === 'line' ||
chartType === 'area' ||
chartType === 'hundredPercentStackedArea'
) {
return 'sum'
if (NumpyNumberTypes.safeParse(column.type).success) {
return series.aggregateFunction ?? 'sum'
}

return null
return 'count'
}

function YAxisPickerV2(props: Props) {
Expand All @@ -70,11 +57,7 @@ function YAxisPickerV2(props: Props) {
? {
...s,
column,
aggregateFunction: getAggFunction(
props.defaultChartType,
s,
column
),
aggregateFunction: getAggFunction(s, column),
}
: s
),
Expand All @@ -86,20 +69,7 @@ function YAxisPickerV2(props: Props) {
)

const onChangeAggregateFunction = useCallback(
(aggregateFunction: string | null, index: number) => {
if (!aggregateFunction) {
props.onChange(
{
...props.yAxis,
series: props.yAxis.series.map((s, i) =>
i === index ? { ...s, aggregateFunction: null } : s
),
},
props.index
)
return
}

(aggregateFunction: string, index: number) => {
const func = AggregateFunction.safeParse(aggregateFunction)
if (func.success) {
props.onChange(
Expand Down Expand Up @@ -160,7 +130,7 @@ function YAxisPickerV2(props: Props) {
{
id: uuidv4(),
column: null,
aggregateFunction: null,
aggregateFunction: 'sum',
groupBy: null,
chartType: null,
name: null,
Expand Down Expand Up @@ -190,8 +160,11 @@ function YAxisPickerV2(props: Props) {

const columns = useMemo(
() =>
(props.dataframe?.columns ?? []).filter(
(c) => NumpyNumberTypes.safeParse(c.type).success
(props.dataframe?.columns ?? []).filter((c) =>
props.defaultChartType === 'trend' ||
props.defaultChartType === 'number'
? NumpyNumberTypes.safeParse(c.type).success
: true
),
[props.dataframe, props.defaultChartType]
)
Expand Down Expand Up @@ -322,20 +295,20 @@ function YAxisPickerV2(props: Props) {
s.column.type
).success
? [
{ name: 'None', value: null },
{ name: 'Sum', value: 'sum' },
{ name: 'Average', value: 'mean' },
{ name: 'Median', value: 'median' },
{ name: 'Min', value: 'min' },
{ name: 'Max', value: 'max' },
{ name: 'Count', value: 'count' },
]
: [
{ name: 'None', value: null },
{ name: 'Count', value: 'count' },
]
: [{ name: 'Count', value: 'count' }]
}
onChange={(agg) => onChangeAggregateFunction(agg, i)}
onChange={(agg) => {
if (agg) {
onChangeAggregateFunction(agg, i)
}
}}
disabled={!props.dataframe || !props.isEditable}
/>
{props.defaultChartType !== 'trend' &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,7 @@ function VisualizationBlockV2(props: Props) {
const groupBy =
df.columns.find((c) => c.name === s.groupBy?.name) ?? null
const aggregateFunction = column
? getAggFunction(
s.chartType ?? attrs.input.chartType,
s,
column
)
? getAggFunction(s, column)
: null
return {
...s,
Expand Down
2 changes: 1 addition & 1 deletion packages/editor/src/blocks/visualization-v2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ function getYAxes(input: VisualizationV2BlockInput): YAxisV2[] {
{
id: uuidv4(),
column: null,
aggregateFunction: null,
aggregateFunction: 'sum',
groupBy: null,
chartType: null,
name: null,
Expand Down

0 comments on commit 2d0084c

Please sign in to comment.