Skip to content

Commit

Permalink
Add set_access_token method and the related tests
Browse files Browse the repository at this point in the history
  • Loading branch information
slvrtrn committed Dec 23, 2024
1 parent 52a566c commit 9a7c871
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 3 deletions.
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ The supported method of passing ClickHouse server settings is to prefix such arg
## Unreleased
### Improvement
- Added support for JWT authentication (ClickHouse Cloud feature).
It can be set via the `access_token` client configuration option for both sync and async clients.
NB: do not mix access token and username/password credentials in the configuration; the client will throw an error if both are set.
It can be set via the `access_token` client configuration option for both sync and async clients.
The token can also be updated via the `set_access_token` method in the existing client instance.
NB: do not mix access token and username/password credentials in the configuration;
the client will throw an error if both are set.

## 0.8.11, 2024-12-21
### Improvement
Expand Down
8 changes: 7 additions & 1 deletion clickhouse_connect/driver/asyncclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(self, *, client: Client, executor_threads: int = 0):
executor_threads = min(32, (os.cpu_count() or 1) + 4) # Mimic the default behavior
self.executor = ThreadPoolExecutor(max_workers=executor_threads)


def set_client_setting(self, key, value):
"""
Set a clickhouse setting for the client after initialization. If a setting is not recognized by ClickHouse,
Expand All @@ -48,6 +47,13 @@ def get_client_setting(self, key) -> Optional[str]:
"""
return self.client.get_client_setting(key=key)

def set_access_token(self, access_token: str):
"""
Set the ClickHouse access token for the client
:param access_token: Access token string
"""
self.client.set_access_token(access_token)

def min_version(self, version_str: str) -> bool:
"""
Determine whether the connected server is at least the submitted version
Expand Down
7 changes: 7 additions & 0 deletions clickhouse_connect/driver/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,13 @@ def get_client_setting(self, key) -> Optional[str]:
:return: The string value of the setting, if it exists, or None
"""

@abstractmethod
def set_access_token(self, access_token: str):
"""
Set the ClickHouse access token for the client
:param access_token: Access token string
"""

# pylint: disable=unused-argument,too-many-locals
def query(self,
query: Optional[str] = None,
Expand Down
6 changes: 6 additions & 0 deletions clickhouse_connect/driver/httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,12 @@ def set_client_setting(self, key, value):
def get_client_setting(self, key) -> Optional[str]:
return self.params.get(key)

def set_access_token(self, access_token: str):
auth_header = self.headers.get('Authorization')
if auth_header and not auth_header.startswith('Bearer'):
raise ProgrammingError('Cannot set access token when a different auth type is used')
self.headers['Authorization'] = f'Bearer {access_token}'

def _prep_query(self, context: QueryContext):
final_query = super()._prep_query(context)
if context.is_insert:
Expand Down
73 changes: 73 additions & 0 deletions tests/integration_tests/test_jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,25 @@ def test_jwt_auth_sync_client(test_config: TestConfig):
assert result == [(True,)]


def test_jwt_auth_sync_client_set_access_token(test_config: TestConfig):
if not test_config.cloud:
pytest.skip('Skipping JWT test in non-Cloud mode')

access_token = make_access_token()
client = create_client(
host=test_config.host,
port=test_config.port,
access_token=access_token,
)

# Should still work after the override
access_token = make_access_token()
client.set_access_token(access_token)

result = client.query(query=CHECK_CLOUD_MODE_QUERY).result_set
assert result == [(True,)]


def test_jwt_auth_sync_client_config_errors():
with pytest.raises(ProgrammingError):
create_client(
Expand All @@ -41,6 +60,23 @@ def test_jwt_auth_sync_client_config_errors():
)


def test_jwt_auth_sync_client_set_access_token_errors(test_config: TestConfig):
if not test_config.cloud:
pytest.skip('Skipping JWT test in non-Cloud mode')

client = create_client(
host=test_config.host,
port=test_config.port,
username=test_config.username,
password=test_config.password,
)

# Can't use JWT with username/password
access_token = make_access_token()
with pytest.raises(ProgrammingError):
client.set_access_token(access_token)


@pytest.mark.asyncio
async def test_jwt_auth_async_client(test_config: TestConfig):
if not test_config.cloud:
Expand All @@ -56,6 +92,25 @@ async def test_jwt_auth_async_client(test_config: TestConfig):
assert result == [(True,)]


@pytest.mark.asyncio
async def test_jwt_auth_async_client_set_access_token(test_config: TestConfig):
if not test_config.cloud:
pytest.skip('Skipping JWT test in non-Cloud mode')

access_token = make_access_token()
client = await create_async_client(
host=test_config.host,
port=test_config.port,
access_token=access_token,
)

access_token = make_access_token()
client.set_access_token(access_token)

result = (await client.query(query=CHECK_CLOUD_MODE_QUERY)).result_set
assert result == [(True,)]


@pytest.mark.asyncio
async def test_jwt_auth_async_client_config_errors():
with pytest.raises(ProgrammingError):
Expand All @@ -76,6 +131,24 @@ async def test_jwt_auth_async_client_config_errors():
)


@pytest.mark.asyncio
async def test_jwt_auth_async_client_set_access_token_errors(test_config: TestConfig):
if not test_config.cloud:
pytest.skip('Skipping JWT test in non-Cloud mode')

client = await create_async_client(
host=test_config.host,
port=test_config.port,
username=test_config.username,
password=test_config.password,
)

# Can't use JWT with username/password
access_token = make_access_token()
with pytest.raises(ProgrammingError):
client.set_access_token(access_token)


CHECK_CLOUD_MODE_QUERY = "SELECT value='1' FROM system.settings WHERE name='cloud_mode'"
JWT_SECRET_ENV_KEY = 'CLICKHOUSE_CONNECT_TEST_JWT_SECRET'

Expand Down

0 comments on commit 9a7c871

Please sign in to comment.