diff --git a/cosec-webflux/src/main/kotlin/me/ahoo/cosec/webflux/ReactiveInjectSecurityContextWebFilter.kt b/cosec-webflux/src/main/kotlin/me/ahoo/cosec/webflux/ReactiveInjectSecurityContextWebFilter.kt index 1a5dcce8..a9198365 100644 --- a/cosec-webflux/src/main/kotlin/me/ahoo/cosec/webflux/ReactiveInjectSecurityContextWebFilter.kt +++ b/cosec-webflux/src/main/kotlin/me/ahoo/cosec/webflux/ReactiveInjectSecurityContextWebFilter.kt @@ -20,6 +20,7 @@ import org.springframework.web.server.ServerWebExchange import org.springframework.web.server.WebFilter import org.springframework.web.server.WebFilterChain import reactor.core.publisher.Mono +import reactor.kotlin.core.publisher.toMono import reactor.util.context.Context /** @@ -33,17 +34,19 @@ class ReactiveInjectSecurityContextWebFilter( ) : WebFilter, Ordered { override fun filter(exchange: ServerWebExchange, chain: WebFilterChain): Mono { - return Mono.defer { - try { - val securityContext = securityContextParser.parse(exchange) - exchange.setSecurityContext(securityContext) - return@defer chain.filter(exchange) - .contextWrite { ctx: Context -> ctx.put(SecurityContext.KEY, securityContext) } - } catch (ignored: Throwable) { - // ignored - } - chain.filter(exchange) + try { + val securityContext = securityContextParser.parse(exchange) + exchange.mutate() + .principal(securityContext.principal.toMono()) + .build().let { + exchange.setSecurityContext(securityContext) + return chain.filter(it) + .contextWrite { ctx: Context -> ctx.put(SecurityContext.KEY, securityContext) } + } + } catch (ignored: Throwable) { + // ignored } + return chain.filter(exchange) } override fun getOrder(): Int { diff --git a/cosec-webflux/src/test/kotlin/me/ahoo/cosec/webflux/ReactiveInjectSecurityContextWebFilterTest.kt b/cosec-webflux/src/test/kotlin/me/ahoo/cosec/webflux/ReactiveInjectSecurityContextWebFilterTest.kt index 669b06f0..5f5c9485 100644 --- a/cosec-webflux/src/test/kotlin/me/ahoo/cosec/webflux/ReactiveInjectSecurityContextWebFilterTest.kt +++ b/cosec-webflux/src/test/kotlin/me/ahoo/cosec/webflux/ReactiveInjectSecurityContextWebFilterTest.kt @@ -18,7 +18,10 @@ import io.mockk.just import io.mockk.mockk import io.mockk.runs import io.mockk.verify +import me.ahoo.cosec.authorization.Authorization +import me.ahoo.cosec.authorization.AuthorizeResult import me.ahoo.cosec.context.SecurityContext +import me.ahoo.cosec.context.request.RequestTenantIdParser import me.ahoo.cosec.jwt.Jwts import me.ahoo.cosec.webflux.ServerWebExchanges.setSecurityContext import org.hamcrest.MatcherAssert.assertThat @@ -28,6 +31,7 @@ import org.springframework.core.Ordered import org.springframework.web.server.ServerWebExchange import org.springframework.web.server.WebFilterChain import reactor.core.publisher.Mono +import reactor.kotlin.core.publisher.toMono internal class ReactiveInjectSecurityContextWebFilterTest { @@ -37,12 +41,20 @@ internal class ReactiveInjectSecurityContextWebFilterTest { assertThat(filter.order, equalTo(Ordered.HIGHEST_PRECEDENCE)) val exchange = mockk() { every { request.headers.getFirst(Jwts.AUTHORIZATION_KEY) } returns null + every { request.headers.getFirst(RequestTenantIdParser.TENANT_ID_KEY) } returns "tenantId" + every { request.path.value() } returns "/path" + every { request.methodValue } returns "GET" every { setSecurityContext(any()) } just runs + every { + mutate() + .principal(any()) + .build() + } returns this } - val chain = mockk() { + val filterChain = mockk { every { filter(exchange) } returns Mono.empty() } - filter.filter(exchange, chain).block() + filter.filter(exchange, filterChain).block() verify { exchange.setSecurityContext(SecurityContext.ANONYMOUS) } diff --git a/gradle.properties b/gradle.properties index 1c9aacfd..73a4b296 100644 --- a/gradle.properties +++ b/gradle.properties @@ -11,7 +11,7 @@ # limitations under the License. # group=me.ahoo.cosec -version=1.0.2 +version=1.0.3 description=RBAC-based And Policy-based Multi-Tenant Reactive Security Framework website=https://github.com/Ahoo-Wang/CoSec issues=https://github.com/Ahoo-Wang/CoSec/issues