001/*******************************************************************************
002 * Copyright 2018 The MIT Internet Trust Consortium
003 *
004 * Portions copyright 2011-2013 The MITRE Corporation
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License");
007 * you may not use this file except in compliance with the License.
008 * You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 *******************************************************************************/
018/**
019 *
020 */
021package org.mitre.openid.connect.filter;
022
023import static org.mitre.openid.connect.request.ConnectRequestParameters.ERROR;
024import static org.mitre.openid.connect.request.ConnectRequestParameters.LOGIN_HINT;
025import static org.mitre.openid.connect.request.ConnectRequestParameters.LOGIN_REQUIRED;
026import static org.mitre.openid.connect.request.ConnectRequestParameters.MAX_AGE;
027import static org.mitre.openid.connect.request.ConnectRequestParameters.PROMPT;
028import static org.mitre.openid.connect.request.ConnectRequestParameters.PROMPT_LOGIN;
029import static org.mitre.openid.connect.request.ConnectRequestParameters.PROMPT_NONE;
030import static org.mitre.openid.connect.request.ConnectRequestParameters.PROMPT_SEPARATOR;
031import static org.mitre.openid.connect.request.ConnectRequestParameters.STATE;
032
033import java.io.IOException;
034import java.net.URISyntaxException;
035import java.util.Date;
036import java.util.HashMap;
037import java.util.List;
038import java.util.Map;
039
040import javax.servlet.FilterChain;
041import javax.servlet.ServletException;
042import javax.servlet.ServletRequest;
043import javax.servlet.ServletResponse;
044import javax.servlet.http.HttpServletRequest;
045import javax.servlet.http.HttpServletResponse;
046import javax.servlet.http.HttpSession;
047
048import org.apache.http.client.utils.URIBuilder;
049import org.mitre.oauth2.model.ClientDetailsEntity;
050import org.mitre.oauth2.service.ClientDetailsEntityService;
051import org.mitre.openid.connect.service.LoginHintExtracter;
052import org.mitre.openid.connect.service.impl.RemoveLoginHintsWithHTTP;
053import org.mitre.openid.connect.web.AuthenticationTimeStamper;
054import org.slf4j.Logger;
055import org.slf4j.LoggerFactory;
056import org.springframework.beans.factory.annotation.Autowired;
057import org.springframework.security.core.Authentication;
058import org.springframework.security.core.context.SecurityContextHolder;
059import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
060import org.springframework.security.oauth2.provider.AuthorizationRequest;
061import org.springframework.security.oauth2.provider.OAuth2RequestFactory;
062import org.springframework.security.oauth2.provider.endpoint.RedirectResolver;
063import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
064import org.springframework.security.web.util.matcher.RequestMatcher;
065import org.springframework.stereotype.Component;
066import org.springframework.web.filter.GenericFilterBean;
067
068import com.google.common.base.Splitter;
069import com.google.common.base.Strings;
070
071/**
072 * @author jricher
073 *
074 */
075@Component("authRequestFilter")
076public class AuthorizationRequestFilter extends GenericFilterBean {
077
078        /**
079         * Logger for this class
080         */
081        private static final Logger logger = LoggerFactory.getLogger(AuthorizationRequestFilter.class);
082
083        public final static String PROMPTED = "PROMPT_FILTER_PROMPTED";
084        public final static String PROMPT_REQUESTED = "PROMPT_FILTER_REQUESTED";
085
086        @Autowired
087        private OAuth2RequestFactory authRequestFactory;
088
089        @Autowired
090        private ClientDetailsEntityService clientService;
091
092        @Autowired
093        private RedirectResolver redirectResolver;
094
095        @Autowired(required = false)
096        private LoginHintExtracter loginHintExtracter = new RemoveLoginHintsWithHTTP();
097
098        private RequestMatcher requestMatcher = new AntPathRequestMatcher("/authorize");
099
100        /**
101         *
102         */
103        @Override
104        public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
105
106                HttpServletRequest request = (HttpServletRequest) req;
107                HttpServletResponse response = (HttpServletResponse) res;
108                HttpSession session = request.getSession();
109
110                // skip everything that's not an authorize URL
111                if (!requestMatcher.matches(request)) {
112                        chain.doFilter(req, res);
113                        return;
114                }
115
116                try {
117                        // we have to create our own auth request in order to get at all the parmeters appropriately
118                        AuthorizationRequest authRequest = null;
119
120                        ClientDetailsEntity client = null;
121
122                        authRequest = authRequestFactory.createAuthorizationRequest(createRequestMap(request.getParameterMap()));
123                        if (!Strings.isNullOrEmpty(authRequest.getClientId())) {
124                                client = clientService.loadClientByClientId(authRequest.getClientId());
125                        }
126
127                        // save the login hint to the session
128                        // but first check to see if the login hint makes any sense
129                        String loginHint = loginHintExtracter.extractHint((String) authRequest.getExtensions().get(LOGIN_HINT));
130                        if (!Strings.isNullOrEmpty(loginHint)) {
131                                session.setAttribute(LOGIN_HINT, loginHint);
132                        } else {
133                                session.removeAttribute(LOGIN_HINT);
134                        }
135
136                        if (authRequest.getExtensions().get(PROMPT) != null) {
137                                // we have a "prompt" parameter
138                                String prompt = (String)authRequest.getExtensions().get(PROMPT);
139                                List<String> prompts = Splitter.on(PROMPT_SEPARATOR).splitToList(Strings.nullToEmpty(prompt));
140
141                                if (prompts.contains(PROMPT_NONE)) {
142                                        // see if the user's logged in
143                                        Authentication auth = SecurityContextHolder.getContext().getAuthentication();
144
145                                        if (auth != null) {
146                                                // user's been logged in already (by session management)
147                                                // we're OK, continue without prompting
148                                                chain.doFilter(req, res);
149                                        } else {
150                                                logger.info("Client requested no prompt");
151                                                // user hasn't been logged in, we need to "return an error"
152                                                if (client != null && authRequest.getRedirectUri() != null) {
153
154                                                        // if we've got a redirect URI then we'll send it
155
156                                                        String url = redirectResolver.resolveRedirect(authRequest.getRedirectUri(), client);
157
158                                                        try {
159                                                                URIBuilder uriBuilder = new URIBuilder(url);
160
161                                                                uriBuilder.addParameter(ERROR, LOGIN_REQUIRED);
162                                                                if (!Strings.isNullOrEmpty(authRequest.getState())) {
163                                                                        uriBuilder.addParameter(STATE, authRequest.getState()); // copy the state parameter if one was given
164                                                                }
165
166                                                                response.sendRedirect(uriBuilder.toString());
167                                                                return;
168
169                                                        } catch (URISyntaxException e) {
170                                                                logger.error("Can't build redirect URI for prompt=none, sending error instead", e);
171                                                                response.sendError(HttpServletResponse.SC_FORBIDDEN, "Access Denied");
172                                                                return;
173                                                        }
174                                                }
175
176                                                response.sendError(HttpServletResponse.SC_FORBIDDEN, "Access Denied");
177                                                return;
178                                        }
179                                } else if (prompts.contains(PROMPT_LOGIN)) {
180
181                                        // first see if the user's already been prompted in this session
182                                        if (session.getAttribute(PROMPTED) == null) {
183                                                // user hasn't been PROMPTED yet, we need to check
184
185                                                session.setAttribute(PROMPT_REQUESTED, Boolean.TRUE);
186
187                                                // see if the user's logged in
188                                                Authentication auth = SecurityContextHolder.getContext().getAuthentication();
189                                                if (auth != null) {
190                                                        // user's been logged in already (by session management)
191                                                        // log them out and continue
192                                                        SecurityContextHolder.getContext().setAuthentication(null);
193                                                        chain.doFilter(req, res);
194                                                } else {
195                                                        // user hasn't been logged in yet, we can keep going since we'll get there
196                                                        chain.doFilter(req, res);
197                                                }
198                                        } else {
199                                                // user has been PROMPTED, we're fine
200
201                                                // but first, undo the prompt tag
202                                                session.removeAttribute(PROMPTED);
203                                                chain.doFilter(req, res);
204                                        }
205                                } else {
206                                        // prompt parameter is a value we don't care about, not our business
207                                        chain.doFilter(req, res);
208                                }
209
210                        } else if (authRequest.getExtensions().get(MAX_AGE) != null ||
211                                        (client != null && client.getDefaultMaxAge() != null)) {
212
213                                // default to the client's stored value, check the string parameter
214                                Integer max = (client != null ? client.getDefaultMaxAge() : null);
215                                String maxAge = (String) authRequest.getExtensions().get(MAX_AGE);
216                                if (maxAge != null) {
217                                        max = Integer.parseInt(maxAge);
218                                }
219
220                                if (max != null) {
221
222                                        Date authTime = (Date) session.getAttribute(AuthenticationTimeStamper.AUTH_TIMESTAMP);
223
224                                        Date now = new Date();
225                                        if (authTime != null) {
226                                                long seconds = (now.getTime() - authTime.getTime()) / 1000;
227                                                if (seconds > max) {
228                                                        // session is too old, log the user out and continue
229                                                        SecurityContextHolder.getContext().setAuthentication(null);
230                                                }
231                                        }
232                                }
233                                chain.doFilter(req, res);
234                        } else {
235                                // no prompt parameter, not our business
236                                chain.doFilter(req, res);
237                        }
238
239                } catch (InvalidClientException e) {
240                        // we couldn't find the client, move on and let the rest of the system catch the error
241                        chain.doFilter(req, res);
242                }
243        }
244
245        /**
246         * @param parameterMap
247         * @return
248         */
249        private Map<String, String> createRequestMap(Map<String, String[]> parameterMap) {
250                Map<String, String> requestMap = new HashMap<>();
251                for (String key : parameterMap.keySet()) {
252                        String[] val = parameterMap.get(key);
253                        if (val != null && val.length > 0) {
254                                requestMap.put(key, val[0]); // add the first value only (which is what Spring seems to do)
255                        }
256                }
257
258                return requestMap;
259        }
260
261        /**
262         * @return the requestMatcher
263         */
264        public RequestMatcher getRequestMatcher() {
265                return requestMatcher;
266        }
267
268        /**
269         * @param requestMatcher the requestMatcher to set
270         */
271        public void setRequestMatcher(RequestMatcher requestMatcher) {
272                this.requestMatcher = requestMatcher;
273        }
274
275}