Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proxy ticket service and proxy ticket validation #105

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ The following CAS features are currently implemented:

The following features are **missing**:
* SAML request/response [CAS 3.0 - optional]
* Proxy ticket service and proxy ticket validation [CAS 2.0]

The following features are out of scope:
* Long-Term Tickets - Remember-Me [CAS 3.0 - optional]
Expand Down
19 changes: 8 additions & 11 deletions src/main/java/org/keycloak/protocol/cas/CASLoginProtocol.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,20 @@
import org.apache.http.HttpEntity;
import org.jboss.logging.Logger;
import org.keycloak.common.util.KeycloakUriBuilder;
import org.keycloak.common.util.Time;
import org.keycloak.events.Details;
import org.keycloak.events.EventBuilder;
import org.keycloak.events.EventType;
import org.keycloak.forms.login.LoginFormsProvider;
import org.keycloak.models.*;
import org.keycloak.protocol.LoginProtocol;
import org.keycloak.protocol.cas.endpoints.AbstractValidateEndpoint;
import org.keycloak.protocol.cas.utils.LogoutHelper;
import org.keycloak.protocol.oidc.utils.OAuth2Code;
import org.keycloak.protocol.oidc.utils.OAuth2CodeParser;
import org.keycloak.services.ErrorPage;
import org.keycloak.services.managers.ResourceAdminManager;
import org.keycloak.sessions.AuthenticationSessionModel;

import java.io.IOException;
import java.net.URI;
import java.util.UUID;

public class CASLoginProtocol implements LoginProtocol {
private static final Logger logger = Logger.getLogger(CASLoginProtocol.class);
Expand All @@ -35,11 +32,17 @@ public class CASLoginProtocol implements LoginProtocol {
public static final String GATEWAY_PARAM = "gateway";
public static final String TICKET_PARAM = "ticket";
public static final String FORMAT_PARAM = "format";
public static final String PGTURL_PARAM = "pgtUrl";
public static final String TARGET_SERVICE_PARAM = "targetService";
public static final String PGT_PARAM = "pgt";

public static final String TICKET_RESPONSE_PARAM = "ticket";
public static final String SAMLART_RESPONSE_PARAM = "SAMLart";

public static final String SERVICE_TICKET_PREFIX = "ST-";
public static final String PROXY_GRANTING_TICKET_IOU_PREFIX = "PGTIOU-";
public static final String PROXY_GRANTING_TICKET_PREFIX = "PGT-";
public static final String PROXY_TICKET_PREFIX = "PT-";
public static final String SESSION_SERVICE_TICKET = "service_ticket";

public static final String LOGOUT_REDIRECT_URI = "CAS_LOGOUT_REDIRECT_URI";
Expand Down Expand Up @@ -98,15 +101,9 @@ public Response authenticated(AuthenticationSessionModel authSession, UserSessio
String service = authSession.getRedirectUri();
//TODO validate service

OAuth2Code codeData = new OAuth2Code(UUID.randomUUID().toString(),
Time.currentTime() + userSession.getRealm().getAccessCodeLifespan(),
null, null, authSession.getRedirectUri(), null, null,
userSession.getId());
String code = OAuth2CodeParser.persistCode(session, clientSession, codeData);

KeycloakUriBuilder uriBuilder = KeycloakUriBuilder.fromUri(service);

String loginTicket = SERVICE_TICKET_PREFIX + code;
String loginTicket = AbstractValidateEndpoint.getST(session, clientSession, service);

if (authSession.getClientNotes().containsKey(CASLoginProtocol.TARGET_PARAM)) {
// This was a SAML 1.1 auth request so return the ticket ID as "SAMLart" instead of "ticket"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package org.keycloak.protocol.cas;

import jakarta.ws.rs.Path;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.UriBuilder;

import org.keycloak.events.EventBuilder;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
Expand Down Expand Up @@ -51,13 +51,12 @@ public Object serviceValidate() {

@Path("proxyValidate")
public Object proxyValidate() {
//TODO implement
return serviceValidate();
return new ProxyValidateEndpoint(session, realm, event);
}

@Path("proxy")
public Object proxy() {
return Response.serverError().entity("Not implemented").build();
return new ProxyEndpoint(session, realm, event);
}

@Path("p3/serviceValidate")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,39 @@
import org.keycloak.events.Details;
import org.keycloak.events.Errors;
import org.keycloak.events.EventBuilder;
import org.keycloak.common.util.Time;
import org.keycloak.models.*;
import org.keycloak.protocol.ProtocolMapper;
import org.keycloak.protocol.cas.CASLoginProtocol;
import org.keycloak.protocol.cas.mappers.CASAttributeMapper;
import org.keycloak.protocol.cas.representations.CASErrorCode;
import org.keycloak.protocol.cas.utils.CASValidationException;
import org.keycloak.protocol.oidc.utils.OAuth2CodeParser;
import org.keycloak.protocol.oidc.utils.OAuth2Code;
import org.keycloak.protocol.oidc.utils.RedirectUtils;
import org.keycloak.services.managers.AuthenticationManager;
import org.keycloak.services.managers.UserSessionCrossDCManager;
import org.keycloak.services.util.DefaultClientSessionContext;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.HttpResponse;
import org.apache.http.impl.client.HttpClientBuilder;

public abstract class AbstractValidateEndpoint {
protected final Logger logger = Logger.getLogger(getClass());
private static final Pattern DOT = Pattern.compile("\\.");
protected KeycloakSession session;
protected RealmModel realm;
protected EventBuilder event;
protected ClientModel client;
protected AuthenticatedClientSessionModel clientSession;
protected String pgtIou;

public AbstractValidateEndpoint(KeycloakSession session, RealmModel realm, EventBuilder event) {
this.session = session;
Expand Down Expand Up @@ -74,52 +84,80 @@ protected void checkClient(String service) {
session.getContext().setClient(client);
}

protected void checkTicket(String ticket, boolean requireReauth) {
protected void checkTicket(String ticket, String prefix, boolean requireReauth) {
if (ticket == null) {
event.error(Errors.INVALID_CODE);
throw new CASValidationException(CASErrorCode.INVALID_REQUEST, "Missing parameter: " + CASLoginProtocol.TICKET_PARAM, Response.Status.BAD_REQUEST);
}
if (!ticket.startsWith(CASLoginProtocol.SERVICE_TICKET_PREFIX)) {

if (!ticket.startsWith(prefix)) {
event.error(Errors.INVALID_CODE);
throw new CASValidationException(CASErrorCode.INVALID_TICKET_SPEC, "Malformed service ticket", Response.Status.BAD_REQUEST);
}
alexandrerw marked this conversation as resolved.
Show resolved Hide resolved

String code = ticket.substring(CASLoginProtocol.SERVICE_TICKET_PREFIX.length());
boolean isReusable = ticket.startsWith(CASLoginProtocol.PROXY_GRANTING_TICKET_PREFIX);

OAuth2CodeParser.ParseResult parseResult = OAuth2CodeParser.parseCode(session, code, realm, event);
if (parseResult.isIllegalCode()) {
String[] parsed = DOT.split(ticket.substring(prefix.length()), 3);
if (parsed.length != 3) {
event.error(Errors.INVALID_CODE);
throw new CASValidationException(CASErrorCode.INVALID_TICKET_SPEC, "Invalid format of the code", Response.Status.BAD_REQUEST);
}

String codeUUID = parsed[0];
String userSessionId = parsed[1];
String clientUUID = parsed[2];

event.detail(Details.CODE_ID, userSessionId);
event.session(userSessionId);

// Attempt to use same code twice should invalidate existing clientSession
AuthenticatedClientSessionModel clientSession = parseResult.getClientSession();
if (clientSession != null) {
clientSession.detachFromUserSession();
// Retrieve UserSession
UserSessionModel userSession = new UserSessionCrossDCManager(session).getUserSessionWithClient(realm, userSessionId, clientUUID);
if (userSession == null) {
// Needed to track if code is invalid
userSession = session.sessions().getUserSession(realm, userSessionId);
if (userSession == null) {
event.error(Errors.USER_SESSION_NOT_FOUND);
throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code not valid", Response.Status.BAD_REQUEST);
}
}

clientSession = userSession.getAuthenticatedClientSessionByClient(clientUUID);
if (clientSession == null) {
event.error(Errors.INVALID_CODE);
throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code not valid", Response.Status.BAD_REQUEST);
}

clientSession = parseResult.getClientSession();
SingleUseObjectProvider codeStore = session.singleUseObjects();
Map<String, String> codeDataSerialized = isReusable ? codeStore.get(prefix + codeUUID) : codeStore.remove(prefix + codeUUID);

if (parseResult.isExpiredCode()) {
// Either code not available
if (codeDataSerialized == null) {
event.error(Errors.INVALID_CODE);
throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code not valid", Response.Status.BAD_REQUEST);
}

OAuth2Code codeData = OAuth2Code.deserializeCode(codeDataSerialized);

String persistedUserSessionId = codeData.getUserSessionId();
if (!userSessionId.equals(persistedUserSessionId)) {
event.error(Errors.INVALID_CODE);
throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code not valid", Response.Status.BAD_REQUEST);
}

// Finally doublecheck if code is not expired
int currentTime = Time.currentTime();
if (currentTime > codeData.getExpiration()) {
event.error(Errors.EXPIRED_CODE);
throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Code is expired", Response.Status.BAD_REQUEST);
}

clientSession.setNote(CASLoginProtocol.SESSION_SERVICE_TICKET, ticket);
clientSession.setNote(prefix, ticket);

if (requireReauth && AuthenticationManager.isSSOAuthentication(clientSession)) {
event.error(Errors.SESSION_EXPIRED);
throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Interactive authentication was requested but not performed", Response.Status.BAD_REQUEST);
}

UserSessionModel userSession = clientSession.getUserSession();

if (userSession == null) {
event.error(Errors.USER_SESSION_NOT_FOUND);
throw new CASValidationException(CASErrorCode.INVALID_TICKET, "User session not found", Response.Status.BAD_REQUEST);
}

UserModel user = userSession.getUser();
if (user == null) {
event.error(Errors.USER_NOT_FOUND);
Expand All @@ -133,15 +171,45 @@ protected void checkTicket(String ticket, boolean requireReauth) {
event.user(userSession.getUser());
event.session(userSession.getId());

if (!client.getClientId().equals(clientSession.getClient().getClientId())) {
event.error(Errors.INVALID_CODE);
throw new CASValidationException(CASErrorCode.INVALID_SERVICE, "Auth error", Response.Status.BAD_REQUEST);
if (client == null) {
client = clientSession.getClient();
} else {
if (!client.getClientId().equals(clientSession.getClient().getClientId())) {
event.error(Errors.INVALID_CODE);
throw new CASValidationException(CASErrorCode.INVALID_SERVICE, "Invalid service", Response.Status.BAD_REQUEST);
}
}

if (!AuthenticationManager.isSessionValid(realm, userSession)) {
event.error(Errors.USER_SESSION_NOT_FOUND);
throw new CASValidationException(CASErrorCode.INVALID_TICKET, "Session not active", Response.Status.BAD_REQUEST);
}

}

protected void createProxyGrant(String pgtUrl) {
if ( RedirectUtils.verifyRedirectUri(session, pgtUrl, client) == null ) {
event.error(Errors.INVALID_REQUEST);
throw new CASValidationException(CASErrorCode.INVALID_PROXY_CALLBACK, "Proxy callback is invalid", Response.Status.BAD_REQUEST);
}

String pgtIou = getPGTIOU();
String pgtId = getPGT(session, clientSession, pgtUrl);

try {
HttpResponse response = HttpClientBuilder.create().build().execute(
new HttpGet(new URIBuilder(pgtUrl).setParameter("pgtIou",pgtIou).setParameter("pgtId",pgtId).build())
);

if (response.getStatusLine().getStatusCode() != 200) {
throw new Exception();
}

this.pgtIou = pgtIou;
} catch (Exception e) {
event.error(Errors.INVALID_REQUEST);
throw new CASValidationException(CASErrorCode.PROXY_CALLBACK_ERROR, "Proxy callback returned an error", Response.Status.BAD_REQUEST);
}
}

protected Map<String, Object> getUserAttributes() {
Expand All @@ -160,4 +228,40 @@ protected Map<String, Object> getUserAttributes() {
}
return attributes;
}

protected String getPGTIOU()
{
return CASLoginProtocol.PROXY_GRANTING_TICKET_IOU_PREFIX + UUID.randomUUID().toString();
}

protected String getPGT(KeycloakSession session, AuthenticatedClientSessionModel clientSession, String pgtUrl)
{
return persistedTicket(pgtUrl, CASLoginProtocol.PROXY_GRANTING_TICKET_PREFIX);
}

protected String getPT(KeycloakSession session, AuthenticatedClientSessionModel clientSession, String targetService)
{
return persistedTicket(targetService, CASLoginProtocol.PROXY_TICKET_PREFIX);
}

protected String getST(String redirectUri)
{
return persistedTicket(redirectUri, CASLoginProtocol.SERVICE_TICKET_PREFIX);
}

public static String getST(KeycloakSession session, AuthenticatedClientSessionModel clientSession, String redirectUri)
{
ValidateEndpoint vp = new ValidateEndpoint(session,null,null);
vp.clientSession = clientSession;
return vp.getST(redirectUri);
}

protected String persistedTicket(String redirectUriParam, String prefix)
{
String key = UUID.randomUUID().toString();
UserSessionModel userSession = clientSession.getUserSession();
OAuth2Code codeData = new OAuth2Code(key, Time.currentTime() + userSession.getRealm().getAccessCodeLifespan(), null, null, redirectUriParam, null, null, userSession.getId());
session.singleUseObjects().put(prefix + key, clientSession.getUserSession().getRealm().getAccessCodeLifespan(), codeData.serializeCode());
return prefix + key + "." + clientSession.getUserSession().getId() + "." + clientSession.getClient().getId();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package org.keycloak.protocol.cas.endpoints;

import jakarta.ws.rs.GET;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.MultivaluedMap;
import jakarta.ws.rs.core.Response;
import org.jboss.resteasy.annotations.cache.NoCache;
import org.keycloak.events.EventBuilder;
import org.keycloak.events.EventType;
import org.keycloak.models.*;
import org.keycloak.protocol.cas.CASLoginProtocol;
import org.keycloak.protocol.cas.representations.CASServiceResponse;
import org.keycloak.protocol.cas.utils.CASValidationException;
import org.keycloak.protocol.cas.utils.ContentTypeHelper;
import org.keycloak.protocol.cas.utils.ServiceResponseHelper;

public class ProxyEndpoint extends AbstractValidateEndpoint {

public ProxyEndpoint(KeycloakSession session, RealmModel realm, EventBuilder event) {
super(session, realm, event);
}

@GET
@NoCache
public Response build() {
MultivaluedMap<String, String> params = session.getContext().getUri().getQueryParameters();
String targetService = params.getFirst(CASLoginProtocol.TARGET_SERVICE_PARAM);
String pgt = params.getFirst(CASLoginProtocol.PGT_PARAM);

event.event(EventType.CODE_TO_TOKEN);

try {
checkSsl();
checkRealm();
checkTicket(pgt, CASLoginProtocol.PROXY_GRANTING_TICKET_PREFIX, false);
event.success();
return successResponse(getPT(this.session, clientSession, targetService));
} catch (CASValidationException e) {
return errorResponse(e);
}
}

protected Response successResponse(String pt) {
CASServiceResponse serviceResponse = ServiceResponseHelper.createProxySuccess(pt);
return prepare(Response.Status.OK, serviceResponse);
}

protected Response errorResponse(CASValidationException e) {
CASServiceResponse serviceResponse = ServiceResponseHelper.createProxyFailure(e.getError(), e.getErrorDescription());
return prepare(e.getStatus(), serviceResponse);
}

private Response prepare(Response.Status status, CASServiceResponse serviceResponse) {
MediaType responseMediaType = new ContentTypeHelper(session.getContext().getUri()).selectResponseType();
return ServiceResponseHelper.createResponse(status, responseMediaType, serviceResponse);
}
}
Loading
Loading