001/******************************************************************************* 002 * Copyright 2018 The MIT Internet Trust Consortium 003 * 004 * Portions copyright 2011-2013 The MITRE Corporation 005 * 006 * Licensed under the Apache License, Version 2.0 (the "License"); 007 * you may not use this file except in compliance with the License. 008 * You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, software 013 * distributed under the License is distributed on an "AS IS" BASIS, 014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 015 * See the License for the specific language governing permissions and 016 * limitations under the License. 017 *******************************************************************************/ 018/** 019 * 020 */ 021package org.mitre.openid.connect.assertion; 022 023import java.io.IOException; 024import java.text.ParseException; 025 026import javax.servlet.FilterChain; 027import javax.servlet.ServletException; 028import javax.servlet.http.HttpServletRequest; 029import javax.servlet.http.HttpServletResponse; 030 031import org.springframework.security.authentication.BadCredentialsException; 032import org.springframework.security.core.Authentication; 033import org.springframework.security.core.AuthenticationException; 034import org.springframework.security.oauth2.common.exceptions.BadClientCredentialsException; 035import org.springframework.security.oauth2.provider.error.OAuth2AuthenticationEntryPoint; 036import org.springframework.security.web.AuthenticationEntryPoint; 037import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; 038import org.springframework.security.web.authentication.AuthenticationFailureHandler; 039import org.springframework.security.web.authentication.AuthenticationSuccessHandler; 040import org.springframework.security.web.util.matcher.RequestMatcher; 041 042import com.google.common.base.Strings; 043import com.nimbusds.jwt.JWT; 044import com.nimbusds.jwt.JWTParser; 045 046/** 047 * Filter to check client authentication via JWT Bearer assertions. 048 * 049 * @author jricher 050 * 051 */ 052public class JWTBearerClientAssertionTokenEndpointFilter extends AbstractAuthenticationProcessingFilter { 053 054 private AuthenticationEntryPoint authenticationEntryPoint = new OAuth2AuthenticationEntryPoint(); 055 056 public JWTBearerClientAssertionTokenEndpointFilter(RequestMatcher additionalMatcher) { 057 super(new ClientAssertionRequestMatcher(additionalMatcher)); 058 // If authentication fails the type is "Form" 059 ((OAuth2AuthenticationEntryPoint) authenticationEntryPoint).setTypeName("Form"); 060 } 061 062 @Override 063 public void afterPropertiesSet() { 064 super.afterPropertiesSet(); 065 setAuthenticationFailureHandler(new AuthenticationFailureHandler() { 066 @Override 067 public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, 068 AuthenticationException exception) throws IOException, ServletException { 069 if (exception instanceof BadCredentialsException) { 070 exception = new BadCredentialsException(exception.getMessage(), new BadClientCredentialsException()); 071 } 072 authenticationEntryPoint.commence(request, response, exception); 073 } 074 }); 075 setAuthenticationSuccessHandler(new AuthenticationSuccessHandler() { 076 @Override 077 public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, 078 Authentication authentication) throws IOException, ServletException { 079 // no-op - just allow filter chain to continue to token endpoint 080 } 081 }); 082 } 083 084 /** 085 * Pull the assertion out of the request and send it up to the auth manager for processing. 086 */ 087 @Override 088 public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response) throws AuthenticationException, IOException, ServletException { 089 090 // check for appropriate parameters 091 String assertionType = request.getParameter("client_assertion_type"); 092 String assertion = request.getParameter("client_assertion"); 093 094 try { 095 JWT jwt = JWTParser.parse(assertion); 096 097 String clientId = jwt.getJWTClaimsSet().getSubject(); 098 099 Authentication authRequest = new JWTBearerAssertionAuthenticationToken(jwt); 100 101 return this.getAuthenticationManager().authenticate(authRequest); 102 } catch (ParseException e) { 103 throw new BadCredentialsException("Invalid JWT credential: " + assertion); 104 } 105 } 106 107 @Override 108 protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, 109 FilterChain chain, Authentication authResult) throws IOException, ServletException { 110 super.successfulAuthentication(request, response, chain, authResult); 111 chain.doFilter(request, response); 112 } 113 114 private static class ClientAssertionRequestMatcher implements RequestMatcher { 115 116 private RequestMatcher additionalMatcher; 117 118 public ClientAssertionRequestMatcher(RequestMatcher additionalMatcher) { 119 this.additionalMatcher = additionalMatcher; 120 } 121 122 @Override 123 public boolean matches(HttpServletRequest request) { 124 // check for appropriate parameters 125 String assertionType = request.getParameter("client_assertion_type"); 126 String assertion = request.getParameter("client_assertion"); 127 128 if (Strings.isNullOrEmpty(assertionType) || Strings.isNullOrEmpty(assertion)) { 129 return false; 130 } else if (!assertionType.equals("urn:ietf:params:oauth:client-assertion-type:jwt-bearer")) { 131 return false; 132 } 133 134 return additionalMatcher.matches(request); 135 } 136 137 } 138 139 140 141}