Skip to content

Commit

Permalink
feat(SIP-85): OAuth2 for databases
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Mar 23, 2024
1 parent f274c47 commit 17e0ee3
Show file tree
Hide file tree
Showing 34 changed files with 1,233 additions and 38 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

import React from 'react';
import * as reduxHooks from 'react-redux';
import { Provider } from 'react-redux';
import { createStore } from 'redux';
import { render, fireEvent, waitFor } from '@testing-library/react';
import '@testing-library/jest-dom';
import { ThemeProvider, supersetTheme } from '@superset-ui/core';
import OAuth2RedirectMessage from 'src/components/ErrorMessage/OAuth2RedirectMessage';
import { ErrorTypeEnum } from 'src/components/ErrorMessage/types';
import { reRunQuery } from 'src/SqlLab/actions/sqlLab';
import { triggerQuery } from 'src/components/Chart/chartAction';
import { onRefresh } from 'src/dashboard/actions/dashboardState';

// Mock the Redux store
const mockStore = createStore(() => ({
sqlLab: {
queries: { 'query-id': { sql: 'SELECT * FROM table' } },
queryEditors: [{ id: 'editor-id', latestQueryId: 'query-id' }],
tabHistory: ['editor-id'],
},
explore: {
slice: { slice_id: 123 },
},
charts: { '1': {}, '2': {} },
dashboardInfo: { id: 'dashboard-id' },
}));

// Mock actions
jest.mock('src/SqlLab/actions/sqlLab', () => ({
reRunQuery: jest.fn(),
}));

jest.mock('src/components/Chart/chartAction', () => ({
triggerQuery: jest.fn(),
}));

jest.mock('src/dashboard/actions/dashboardState', () => ({
onRefresh: jest.fn(),
}));

// Mock useDispatch
const mockDispatch = jest.fn();
jest.spyOn(reduxHooks, 'useDispatch').mockReturnValue(mockDispatch);

// Mock global window functions
const mockOpen = jest.spyOn(window, 'open').mockImplementation(() => null);
const mockAddEventListener = jest.spyOn(window, 'addEventListener');
const mockRemoveEventListener = jest.spyOn(window, 'removeEventListener');

// Mock window.postMessage
const originalPostMessage = window.postMessage;

beforeEach(() => {
window.postMessage = jest.fn();
});

afterEach(() => {
window.postMessage = originalPostMessage;
});

function simulateMessageEvent(data: any, origin: string) {
const messageEvent = new MessageEvent('message', { data, origin });
window.dispatchEvent(messageEvent);
}

const defaultProps = {
error: {
error_type: ErrorTypeEnum.OAUTH2_REDIRECT,
message: "You don't have permission to access the data.",
extra: {
url: 'https://example.com',
tab_id: 'tabId',
redirect_uri: 'https://redirect.example.com',
},
level: 'warning',
},
source: 'sqllab',
};

const setup = (overrides = {}) => (
<ThemeProvider theme={supersetTheme}>
<Provider store={mockStore}>
<OAuth2RedirectMessage {...defaultProps} {...overrides} />;
</Provider>
</ThemeProvider>
);

describe('OAuth2RedirectMessage Component', () => {
it('renders without crashing and displays the correct initial UI elements', () => {
const { getByText } = render(setup());

expect(getByText(/Authorization needed/i)).toBeInTheDocument();
expect(getByText(/provide authorization/i)).toBeInTheDocument();
});

it('opens a new window with the correct URL when the link is clicked', () => {
const { getByText } = render(setup());

const linkElement = getByText(/provide authorization/i);
fireEvent.click(linkElement);

expect(mockOpen).toHaveBeenCalledWith('https://example.com', '_blank');
});

it('cleans up the message event listener on unmount', () => {
const { unmount } = render(setup());

expect(mockAddEventListener).toHaveBeenCalled();
unmount();
expect(mockRemoveEventListener).toHaveBeenCalled();
});

it('dispatches reRunQuery action when a message with correct tab ID is received for SQL Lab', async () => {
render(setup());

simulateMessageEvent({ tabId: 'tabId' }, 'https://redirect.example.com');

await waitFor(() => {
expect(reRunQuery).toHaveBeenCalledWith({ sql: 'SELECT * FROM table' });
});
});

it('dispatches triggerQuery action for explore source upon receiving a correct message', async () => {
render(setup({ source: 'explore' }));

simulateMessageEvent({ tabId: 'tabId' }, 'https://redirect.example.com');

await waitFor(() => {
expect(triggerQuery).toHaveBeenCalledWith(true, 123);
});
});

it('dispatches onRefresh action for dashboard source upon receiving a correct message', async () => {
render(setup({ source: 'dashboard' }));

simulateMessageEvent({ tabId: 'tabId' }, 'https://redirect.example.com');

await waitFor(() => {
expect(onRefresh).toHaveBeenCalledWith(
['1', '2'],
true,
0,
'dashboard-id',
);
});
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import React, { useEffect, useRef } from 'react';
import { useDispatch, useSelector } from 'react-redux';
import { QueryEditor, SqlLabRootState } from 'src/SqlLab/types';
import { ExplorePageState } from 'src/explore/types';
import { RootState } from 'src/dashboard/types';
import { reRunQuery } from 'src/SqlLab/actions/sqlLab';
import { triggerQuery } from 'src/components/Chart/chartAction';
import { onRefresh } from 'src/dashboard/actions/dashboardState';
import { QueryResponse, t } from '@superset-ui/core';

import { ErrorMessageComponentProps } from './types';
import ErrorAlert from './ErrorAlert';

interface OAuth2RedirectExtra {
url: string;
tab_id: string;
redirect_uri: string;
}

/*
* Component for starting OAuth2 dance.
*
* When a user without credentials tries to access a database that supports OAuth2, the
* backend will raise an exception with the custom error `OAUTH2_REDIRECT`. This will
* cause the frontend to display this component, which informs the user that they need
* to authenticate in order to access the data.
*
* The component has a URL that is used to start the OAuth2 dance for the given
* database. When the user clicks the link a new browser tab will open, where they can
* authorize Superset to access the data. Once authorization is successfull the user will
* be redirected back to Superset, and their personal access token is stored, so it can
* be used in subsequent connections. If a refresh token is also present in the response,
* it will also be stored.
*
* After the token has been stored, the opened tab will send a message to the original
* tab and close itself. This component, running on the original tab, will listen for
* message events, and once it receives the success message from the opened tab it will
* re-run the query for the user, be it in SQL Lab, Explore, or a dashboard. In order to
* communicate securely, both tabs share a "tab ID", which is a UUID that is generated
* by the backend and sent from the opened tab to the original tab. For extra security,
* we also check that the source of the message is the opened tab via a ref.
*/
function OAuth2RedirectMessage({
error,
source,
}: ErrorMessageComponentProps<OAuth2RedirectExtra>) {
const oAuthTab = useRef<Window | null>(null);
const { extra, level } = error;

// store a reference to the OAuth2 browser tab, so we can check that the success
// message is coming from it
const handleOAuthClick = (event: React.MouseEvent<HTMLAnchorElement>) => {
event.preventDefault();
oAuthTab.current = window.open(extra.url, '_blank');
};

// state needed for re-running the SQL Lab query
const queries = useSelector<
SqlLabRootState,
Record<string, QueryResponse & { inLocalStorage?: boolean }>
>(state => state.sqlLab.queries);
const queryEditors = useSelector<SqlLabRootState, QueryEditor[]>(
state => state.sqlLab.queryEditors,
);
const tabHistory = useSelector<SqlLabRootState, string[]>(
state => state.sqlLab.tabHistory,
);
const qe = queryEditors.find(
qe => qe.id === tabHistory[tabHistory.length - 1],
);
const query = qe?.latestQueryId ? queries[qe.latestQueryId] : null;

// state needed for triggering the chart in Explore
const chartId = useSelector<ExplorePageState, number | undefined>(
state => state.explore?.slice?.slice_id,
);

// state needed for refreshing dashboard
const chartList = useSelector<RootState, string[]>(state =>
Object.keys(state.charts),
);
const dashboardId = useSelector<RootState, number | undefined>(
state => state.dashboardInfo?.id,
);

const dispatch = useDispatch();

useEffect(() => {
/* Listen for messages from the OAuth2 tab.
*
* After OAuth2 is successfull the opened tab will send a message before
* closing itself. Once we receive the message we can retrigger the
* original query in SQL Lab, explore, or in a dashboard.
*/
const redirectUrl = new URL(extra.redirect_uri);
const handleMessage = (event: MessageEvent) => {
if (
event.origin === redirectUrl.origin &&
event.data.tabId === extra.tab_id &&
event.source === oAuthTab.current
) {
if (source === 'sqllab' && query) {
dispatch(reRunQuery(query));
} else if (source === 'explore' && chartId) {
dispatch(triggerQuery(true, chartId));
} else if (source === 'dashboard') {
dispatch(onRefresh(chartList, true, 0, dashboardId));
}
}
};
window.addEventListener('message', handleMessage);

return () => {
window.removeEventListener('message', handleMessage);
};
}, [
source,
extra.redirect_uri,
extra.tab_id,
dispatch,
query,
chartId,
chartList,
dashboardId,
]);

const body = (
<p>
This database uses OAuth2 for authentication. Please click the link above
to grant Apache Superset permission to access the data. Your personal
access token will be stored encrypted and used only for queries run by
you.
</p>
);
const subtitle = (
<>
You need to{' '}
<a
href={extra.url}
onClick={handleOAuthClick}
target="_blank"
rel="noreferrer"
>
provide authorization
</a>{' '}
in order to run this query.
</>
);

return (
<ErrorAlert
title={t('Authorization needed')}
subtitle={subtitle}
level={level}
source={source}
body={body}
/>
);
}

export default OAuth2RedirectMessage;
2 changes: 2 additions & 0 deletions superset-frontend/src/components/ErrorMessage/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ export const ErrorTypeEnum = {
QUERY_SECURITY_ACCESS_ERROR: 'QUERY_SECURITY_ACCESS_ERROR',
MISSING_OWNERSHIP_ERROR: 'MISSING_OWNERSHIP_ERROR',
DASHBOARD_SECURITY_ACCESS_ERROR: 'DASHBOARD_SECURITY_ACCESS_ERROR',
OAUTH2_REDIRECT: 'OAUTH2_REDIRECT',
OAUTH2_REDIRECT_ERROR: 'OAUTH2_REDIRECT_ERROR',

// Other errors
BACKEND_TIMEOUT_ERROR: 'BACKEND_TIMEOUT_ERROR',
Expand Down
5 changes: 5 additions & 0 deletions superset-frontend/src/setup/setupErrorMessages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import DatabaseErrorMessage from 'src/components/ErrorMessage/DatabaseErrorMessa
import MarshmallowErrorMessage from 'src/components/ErrorMessage/MarshmallowErrorMessage';
import ParameterErrorMessage from 'src/components/ErrorMessage/ParameterErrorMessage';
import DatasetNotFoundErrorMessage from 'src/components/ErrorMessage/DatasetNotFoundErrorMessage';
import OAuth2RedirectMessage from 'src/components/ErrorMessage/OAuth2RedirectMessage';

import setupErrorMessagesExtra from './setupErrorMessagesExtra';

Expand Down Expand Up @@ -149,5 +150,9 @@ export default function setupErrorMessages() {
ErrorTypeEnum.MARSHMALLOW_ERROR,
MarshmallowErrorMessage,
);
errorMessageComponentRegistry.registerValue(
ErrorTypeEnum.OAUTH2_REDIRECT,
OAuth2RedirectMessage,
);
setupErrorMessagesExtra();
}
1 change: 0 additions & 1 deletion superset/commands/chart/data/get_data_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def run(self, **kwargs: Any) -> dict[str, Any]:
except CacheLoadError as ex:
raise ChartDataCacheLoadError(ex.message) from ex

# TODO: QueryContext should support SIP-40 style errors
for query in payload["queries"]:
if query.get("error"):
raise ChartDataQueryFailedError(
Expand Down
Loading

0 comments on commit 17e0ee3

Please sign in to comment.