-
Notifications
You must be signed in to change notification settings - Fork 16.3k
Add token refresh mechanism for Execution API (#59553) #60108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add token refresh mechanism for Execution API (#59553) #60108
Conversation
b183c74 to
9c31417
Compare
c707ddc to
4ef9dfe
Compare
|
As per my understanding this was removed in #55506 to use a middleware that refreshes token. Are you running an instance with execution api only separately with api-server? Could this middleware approach be extended for task-sdk calls too? |
|
Hi @tirkarthi, I took a stab at extending that pattern in #60197, handling expired tokens transparently in JWTBearer + middleware so no client-side changes are needed. Would love your thoughts on it. Totally happy to go with whichever approach the team feels is better! |
Would love to hear @ashb or @amoghrajesh 's opinion on this one |
ashb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't do this approach. It lets any Execution API token be resurrected which fundamentally breaks lots of security assumptions -- it amounts to having tokens not expire. That is bad.
Instead what we should do is generate a new token (i.e. ones with extra/different set of JWT claims) that is only valid for the /run endpoint and valid for longer (say 24hours, make it configurable) and this is what gets sent in the workload.
The run endpoint then would set the header to give the running task a "short lived" token (the one we have right now basically) that is usable on the rest of the Execution API. This approach is safer as the existing controls in the /run endpoint already prevent a task being run one than once, which should also prevent against "resurrecting" an expired token and using it to access things like connections etc. And we should validate that the token used on all endpoints but run is explicitly lacking this new claim.
4ef9dfe to
b32da6b
Compare
14a516a to
5915391
Compare
ashb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much better approach, and on the right track, thanks.
Some changes though:
-
"queue" is not the right thing to use, as these tokens could be used for executing other workloads soon (for instance we have already talked about wanting Dag level callbacks to be executed on the workers, not in the dag processor, which would be done by having a new type from the ExecuteTaskWorkload).
so maybe we have
"scope": "ExecuteTaskWorkload"? -
A little bit of refactoring is needed before we are ready to merge this.
| def generate_queue_token(self, sub: str) -> str: | ||
| """ | ||
| Generate a long-lived queue token for task workloads. | ||
| Queue tokens have a special 'scope' claim that restricts them to the /run endpoint only. | ||
| They are valid for longer (default 24h) to survive queue wait times. | ||
| """ | ||
| from airflow.configuration import conf | ||
|
|
||
| queue_expiry = conf.getint("execution_api", "jwt_queue_token_expiration_time", fallback=86400) | ||
| now = int(datetime.now(tz=timezone.utc).timestamp()) | ||
|
|
||
| claims = { | ||
| "jti": uuid.uuid4().hex, | ||
| "iss": self.issuer, | ||
| "aud": self.audience, | ||
| "nbf": now, | ||
| "exp": now + queue_expiry, | ||
| "iat": now, | ||
| "sub": sub, | ||
| "scope": TOKEN_SCOPE_QUEUE, | ||
| } | ||
|
|
||
| if claims["iss"] is None: | ||
| del claims["iss"] | ||
| if claims["aud"] is None: | ||
| del claims["aud"] | ||
|
|
||
| headers = {"alg": self.algorithm} | ||
| if self._private_key: | ||
| headers["kid"] = self.kid | ||
| return jwt.encode(claims, self.signing_arg, algorithm=self.algorithm, headers=headers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need a new whole function for this -- the existing generate() could work already by doing:
generator.generate({"sub": sub, "exp": now + queue_expiry})If you think it's worth "packaging" that up, then make it call self.generate -- don't essentially duplicate it.
| claims = await validator.avalidated_claims(creds.credentials, validators) | ||
|
|
||
| # Reject queue-scoped tokens - they can only be used on /run endpoint | ||
| # Only check if scope claim is present (allows backwards compatibility with tests) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this back compat is just for tests then we don't need it.
| # Set a dummy JWT secret so the lifespan can create JWT services without failing. | ||
| if not conf.get("api_auth", "jwt_secret", fallback=None): | ||
| conf.set("api_auth", "jwt_secret", "in-process-test-secret-key") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never ever ever do this in production/runtime code. The risk of it being picked up and every install in the world having a signing key of "in-process-test-secret-key" is too large.
I knot this is just the InProcess class, but I'm still worried about doing this, doubly so as config is process global.
| # Create a mock container that provides mock JWT services | ||
| mock_jwt_generator = MagicMock(spec=JWTGenerator) | ||
| mock_jwt_generator.generate.return_value = "mock-execution-token" | ||
|
|
||
| mock_jwt_validator = AsyncMock(spec=JWTValidator) | ||
| mock_jwt_validator.avalidated_claims.return_value = {"sub": "test", "exp": 9999999999} | ||
|
|
||
| class MockContainer: | ||
| """A mock svcs container that returns mock services.""" | ||
|
|
||
| async def aget(self, svc_type): | ||
| if svc_type is JWTGenerator: | ||
| return mock_jwt_generator | ||
| if svc_type is JWTValidator: | ||
| return mock_jwt_validator | ||
| raise ValueError(f"Unknown service type: {svc_type}") | ||
|
|
||
| async def mock_container_dep(): | ||
| return MockContainer() | ||
|
|
||
| self._app.dependency_overrides[DepContainer.dependency] = mock_container_dep |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this doing here? This looks like test-only code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need a mock svcs container? Why not just use a real svcs container?
| # Wait for lifespan to complete before returning the transport | ||
| lifespan_started.wait(timeout=5.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather we wrote the JWTBearer and the JWTBearerQueueScope queue in a layered approach (deps can depend on each other, or subclassing) - that way we only have to writ much of the validation once and only tweak the behaviour.
I think the layer approach would be best, so we have a base JWTBearer dep that does the basic validation, but nothing of the presence/absence of the queue scope, and then two deps that consume that returned TIToken to do the next layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason we needed to change this order?
| HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid payload for the state transition"}, | ||
| }, | ||
| response_model_exclude_unset=True, | ||
| dependencies=[JWTBearerQueueDep], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is the one and only place we use this dep (and it's also the only place we ever want to use this dep) I think it would be better if we moved JWTBearerQueueDep in to this file.
| the /run endpoint, which then issues a short-lived execution token. | ||
| This should be set long enough to cover the maximum expected queue wait time. | ||
| version_added: 3.1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| version_added: 3.1.0 | |
| version_added: 3.1.7 |
e7e3ae1 to
e879863
Compare
Tasks waiting in Celery queue may have their JWT tokens expire before execution starts. This adds a token refresh endpoint that allows the supervisor to refresh expired tokens before task execution.
Changes:
Fixes: #53713
Summary
Fixes #59553 - Tasks waiting in Celery queue fail when JWT tokens expire before execution starts.
Implements a two-token mechanism for task execution to prevent token expiration while tasks wait in executor queues.
Changes
^ Add meaningful description above
Read the Pull Request Guidelines for more information.
In case of fundamental code changes, an Airflow Improvement Proposal (AIP) is needed.
In case of a new dependency, check compliance with the ASF 3rd Party License Policy.
In case of backwards incompatible changes please leave a note in a newsfragment file, named
{pr_number}.significant.rstor{issue_number}.significant.rst, in airflow-core/newsfragments.