001/*******************************************************************************
002 * Copyright 2018 The MIT Internet Trust Consortium
003 *
004 * Portions copyright 2011-2013 The MITRE Corporation
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License");
007 * you may not use this file except in compliance with the License.
008 * You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 *******************************************************************************/
018/**
019 *
020 */
021package org.mitre.oauth2.service.impl;
022
023import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE;
024import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE_METHOD;
025import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_VERIFIER;
026
027import java.nio.charset.StandardCharsets;
028import java.security.MessageDigest;
029import java.security.NoSuchAlgorithmException;
030import java.util.Collection;
031import java.util.Date;
032import java.util.HashSet;
033import java.util.List;
034import java.util.Set;
035import java.util.UUID;
036
037import org.mitre.data.AbstractPageOperationTemplate;
038import org.mitre.data.DefaultPageCriteria;
039import org.mitre.oauth2.model.AuthenticationHolderEntity;
040import org.mitre.oauth2.model.ClientDetailsEntity;
041import org.mitre.oauth2.model.OAuth2AccessTokenEntity;
042import org.mitre.oauth2.model.OAuth2RefreshTokenEntity;
043import org.mitre.oauth2.model.PKCEAlgorithm;
044import org.mitre.oauth2.model.SystemScope;
045import org.mitre.oauth2.repository.AuthenticationHolderRepository;
046import org.mitre.oauth2.repository.OAuth2TokenRepository;
047import org.mitre.oauth2.service.ClientDetailsEntityService;
048import org.mitre.oauth2.service.OAuth2TokenEntityService;
049import org.mitre.oauth2.service.SystemScopeService;
050import org.mitre.openid.connect.model.ApprovedSite;
051import org.mitre.openid.connect.service.ApprovedSiteService;
052import org.slf4j.Logger;
053import org.slf4j.LoggerFactory;
054import org.springframework.beans.factory.annotation.Autowired;
055import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
056import org.springframework.security.core.AuthenticationException;
057import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
058import org.springframework.security.oauth2.common.exceptions.InvalidRequestException;
059import org.springframework.security.oauth2.common.exceptions.InvalidScopeException;
060import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
061import org.springframework.security.oauth2.provider.OAuth2Authentication;
062import org.springframework.security.oauth2.provider.OAuth2Request;
063import org.springframework.security.oauth2.provider.TokenRequest;
064import org.springframework.security.oauth2.provider.token.TokenEnhancer;
065import org.springframework.stereotype.Service;
066import org.springframework.transaction.annotation.Transactional;
067
068import com.google.common.base.Strings;
069import com.nimbusds.jose.util.Base64URL;
070import com.nimbusds.jwt.JWTClaimsSet;
071import com.nimbusds.jwt.PlainJWT;
072
073
074/**
075 * @author jricher
076 *
077 */
078@Service("defaultOAuth2ProviderTokenService")
079public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityService {
080
081        /**
082         * Logger for this class
083         */
084        private static final Logger logger = LoggerFactory.getLogger(DefaultOAuth2ProviderTokenService.class);
085
086        @Autowired
087        private OAuth2TokenRepository tokenRepository;
088
089        @Autowired
090        private AuthenticationHolderRepository authenticationHolderRepository;
091
092        @Autowired
093        private ClientDetailsEntityService clientDetailsService;
094
095        @Autowired
096        private TokenEnhancer tokenEnhancer;
097
098        @Autowired
099        private SystemScopeService scopeService;
100
101        @Autowired
102        private ApprovedSiteService approvedSiteService;
103
104        @Override
105        public Set<OAuth2AccessTokenEntity> getAllAccessTokensForUser(String userName) {
106                return tokenRepository.getAccessTokensByUserName(userName);
107        }
108
109        @Override
110        public Set<OAuth2RefreshTokenEntity> getAllRefreshTokensForUser(String userName) {
111                return tokenRepository.getRefreshTokensByUserName(userName);
112        }
113
114        @Override
115        public OAuth2AccessTokenEntity getAccessTokenById(Long id) {
116                return clearExpiredAccessToken(tokenRepository.getAccessTokenById(id));
117        }
118
119        @Override
120        public OAuth2RefreshTokenEntity getRefreshTokenById(Long id) {
121                return clearExpiredRefreshToken(tokenRepository.getRefreshTokenById(id));
122        }
123
124        /**
125         * Utility function to delete an access token that's expired before returning it.
126         * @param token the token to check
127         * @return null if the token is null or expired, the input token (unchanged) if it hasn't
128         */
129        private OAuth2AccessTokenEntity clearExpiredAccessToken(OAuth2AccessTokenEntity token) {
130                if (token == null) {
131                        return null;
132                } else if (token.isExpired()) {
133                        // immediately revoke expired token
134                        logger.debug("Clearing expired access token: " + token.getValue());
135                        revokeAccessToken(token);
136                        return null;
137                } else {
138                        return token;
139                }
140        }
141
142        /**
143         * Utility function to delete a refresh token that's expired before returning it.
144         * @param token the token to check
145         * @return null if the token is null or expired, the input token (unchanged) if it hasn't
146         */
147        private OAuth2RefreshTokenEntity clearExpiredRefreshToken(OAuth2RefreshTokenEntity token) {
148                if (token == null) {
149                        return null;
150                } else if (token.isExpired()) {
151                        // immediately revoke expired token
152                        logger.debug("Clearing expired refresh token: " + token.getValue());
153                        revokeRefreshToken(token);
154                        return null;
155                } else {
156                        return token;
157                }
158        }
159
160        @Override
161        @Transactional(value="defaultTransactionManager")
162        public OAuth2AccessTokenEntity createAccessToken(OAuth2Authentication authentication) throws AuthenticationException, InvalidClientException {
163                if (authentication != null && authentication.getOAuth2Request() != null) {
164                        // look up our client
165                        OAuth2Request request = authentication.getOAuth2Request();
166
167                        ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
168
169                        if (client == null) {
170                                throw new InvalidClientException("Client not found: " + request.getClientId());
171                        }
172
173                        // handle the PKCE code challenge if present
174                        if (request.getExtensions().containsKey(CODE_CHALLENGE)) {
175                                String challenge = (String) request.getExtensions().get(CODE_CHALLENGE);
176                                PKCEAlgorithm alg = PKCEAlgorithm.parse((String) request.getExtensions().get(CODE_CHALLENGE_METHOD));
177
178                                String verifier = request.getRequestParameters().get(CODE_VERIFIER);
179
180                                if (alg.equals(PKCEAlgorithm.plain)) {
181                                        // do a direct string comparison
182                                        if (!challenge.equals(verifier)) {
183                                                throw new InvalidRequestException("Code challenge and verifier do not match");
184                                        }
185                                } else if (alg.equals(PKCEAlgorithm.S256)) {
186                                        // hash the verifier
187                                        try {
188                                                MessageDigest digest = MessageDigest.getInstance("SHA-256");
189                                                String hash = Base64URL.encode(digest.digest(verifier.getBytes(StandardCharsets.US_ASCII))).toString();
190                                                if (!challenge.equals(hash)) {
191                                                        throw new InvalidRequestException("Code challenge and verifier do not match");
192                                                }
193                                        } catch (NoSuchAlgorithmException e) {
194                                                logger.error("Unknown algorithm for PKCE digest", e);
195                                        }
196                                }
197
198                        }
199
200                        OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity();//accessTokenFactory.createNewAccessToken();
201
202                        // attach the client
203                        token.setClient(client);
204
205                        // inherit the scope from the auth, but make a new set so it is
206                        //not unmodifiable. Unmodifiables don't play nicely with Eclipselink, which
207                        //wants to use the clone operation.
208                        Set<SystemScope> scopes = scopeService.fromStrings(request.getScope());
209
210                        // remove any of the special system scopes
211                        scopes = scopeService.removeReservedScopes(scopes);
212
213                        token.setScope(scopeService.toStrings(scopes));
214
215                        // make it expire if necessary
216                        if (client.getAccessTokenValiditySeconds() != null && client.getAccessTokenValiditySeconds() > 0) {
217                                Date expiration = new Date(System.currentTimeMillis() + (client.getAccessTokenValiditySeconds() * 1000L));
218                                token.setExpiration(expiration);
219                        }
220
221                        // attach the authorization so that we can look it up later
222                        AuthenticationHolderEntity authHolder = new AuthenticationHolderEntity();
223                        authHolder.setAuthentication(authentication);
224                        authHolder = authenticationHolderRepository.save(authHolder);
225
226                        token.setAuthenticationHolder(authHolder);
227
228                        // attach a refresh token, if this client is allowed to request them and the user gets the offline scope
229                        if (client.isAllowRefresh() && token.getScope().contains(SystemScopeService.OFFLINE_ACCESS)) {
230                                OAuth2RefreshTokenEntity savedRefreshToken = createRefreshToken(client, authHolder);
231
232                                token.setRefreshToken(savedRefreshToken);
233                        }
234
235                        //Add approved site reference, if any
236                        OAuth2Request originalAuthRequest = authHolder.getAuthentication().getOAuth2Request();
237
238                        if (originalAuthRequest.getExtensions() != null && originalAuthRequest.getExtensions().containsKey("approved_site")) {
239
240                                Long apId = Long.parseLong((String) originalAuthRequest.getExtensions().get("approved_site"));
241                                ApprovedSite ap = approvedSiteService.getById(apId);
242
243                                token.setApprovedSite(ap);
244                        }
245
246                        OAuth2AccessTokenEntity enhancedToken = (OAuth2AccessTokenEntity) tokenEnhancer.enhance(token, authentication);
247
248                        OAuth2AccessTokenEntity savedToken = saveAccessToken(enhancedToken);
249
250                        if (savedToken.getRefreshToken() != null) {
251                                tokenRepository.saveRefreshToken(savedToken.getRefreshToken()); // make sure we save any changes that might have been enhanced
252                        }
253
254                        return savedToken;
255                }
256
257                throw new AuthenticationCredentialsNotFoundException("No authentication credentials found");
258        }
259
260
261        private OAuth2RefreshTokenEntity createRefreshToken(ClientDetailsEntity client, AuthenticationHolderEntity authHolder) {
262                OAuth2RefreshTokenEntity refreshToken = new OAuth2RefreshTokenEntity(); //refreshTokenFactory.createNewRefreshToken();
263                JWTClaimsSet.Builder refreshClaims = new JWTClaimsSet.Builder();
264
265
266                // make it expire if necessary
267                if (client.getRefreshTokenValiditySeconds() != null) {
268                        Date expiration = new Date(System.currentTimeMillis() + (client.getRefreshTokenValiditySeconds() * 1000L));
269                        refreshToken.setExpiration(expiration);
270                        refreshClaims.expirationTime(expiration);
271                }
272
273                // set a random identifier
274                refreshClaims.jwtID(UUID.randomUUID().toString());
275
276                // TODO: add issuer fields, signature to JWT
277
278                PlainJWT refreshJwt = new PlainJWT(refreshClaims.build());
279                refreshToken.setJwt(refreshJwt);
280
281                //Add the authentication
282                refreshToken.setAuthenticationHolder(authHolder);
283                refreshToken.setClient(client);
284
285                // save the token first so that we can set it to a member of the access token (NOTE: is this step necessary?)
286                OAuth2RefreshTokenEntity savedRefreshToken = tokenRepository.saveRefreshToken(refreshToken);
287                return savedRefreshToken;
288        }
289
290        @Override
291        @Transactional(value="defaultTransactionManager")
292        public OAuth2AccessTokenEntity refreshAccessToken(String refreshTokenValue, TokenRequest authRequest) throws AuthenticationException {
293                
294                if (Strings.isNullOrEmpty(refreshTokenValue)) {
295                        // throw an invalid token exception if there's no refresh token value at all
296                        throw new InvalidTokenException("Invalid refresh token: " + refreshTokenValue);
297                }
298
299                OAuth2RefreshTokenEntity refreshToken = clearExpiredRefreshToken(tokenRepository.getRefreshTokenByValue(refreshTokenValue));
300
301                if (refreshToken == null) {
302                        // throw an invalid token exception if we couldn't find the token
303                        throw new InvalidTokenException("Invalid refresh token: " + refreshTokenValue);
304                }
305
306                ClientDetailsEntity client = refreshToken.getClient();
307
308                AuthenticationHolderEntity authHolder = refreshToken.getAuthenticationHolder();
309
310                // make sure that the client requesting the token is the one who owns the refresh token
311                ClientDetailsEntity requestingClient = clientDetailsService.loadClientByClientId(authRequest.getClientId());
312                if (!client.getClientId().equals(requestingClient.getClientId())) {
313                        tokenRepository.removeRefreshToken(refreshToken);
314                        throw new InvalidClientException("Client does not own the presented refresh token");
315                }
316
317                //Make sure this client allows access token refreshing
318                if (!client.isAllowRefresh()) {
319                        throw new InvalidClientException("Client does not allow refreshing access token!");
320                }
321
322                // clear out any access tokens
323                if (client.isClearAccessTokensOnRefresh()) {
324                        tokenRepository.clearAccessTokensForRefreshToken(refreshToken);
325                }
326
327                if (refreshToken.isExpired()) {
328                        tokenRepository.removeRefreshToken(refreshToken);
329                        throw new InvalidTokenException("Expired refresh token: " + refreshTokenValue);
330                }
331
332                OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity();
333
334                // get the stored scopes from the authentication holder's authorization request; these are the scopes associated with the refresh token
335                Set<String> refreshScopesRequested = new HashSet<>(refreshToken.getAuthenticationHolder().getAuthentication().getOAuth2Request().getScope());
336                Set<SystemScope> refreshScopes = scopeService.fromStrings(refreshScopesRequested);
337                // remove any of the special system scopes
338                refreshScopes = scopeService.removeReservedScopes(refreshScopes);
339
340                Set<String> scopeRequested = authRequest.getScope() == null ? new HashSet<String>() : new HashSet<>(authRequest.getScope());
341                Set<SystemScope> scope = scopeService.fromStrings(scopeRequested);
342
343                // remove any of the special system scopes
344                scope = scopeService.removeReservedScopes(scope);
345
346                if (scope != null && !scope.isEmpty()) {
347                        // ensure a proper subset of scopes
348                        if (refreshScopes != null && refreshScopes.containsAll(scope)) {
349                                // set the scope of the new access token if requested
350                                token.setScope(scopeService.toStrings(scope));
351                        } else {
352                                String errorMsg = "Up-scoping is not allowed.";
353                                logger.error(errorMsg);
354                                throw new InvalidScopeException(errorMsg);
355                        }
356                } else {
357                        // otherwise inherit the scope of the refresh token (if it's there -- this can return a null scope set)
358                        token.setScope(scopeService.toStrings(refreshScopes));
359                }
360
361                token.setClient(client);
362
363                if (client.getAccessTokenValiditySeconds() != null) {
364                        Date expiration = new Date(System.currentTimeMillis() + (client.getAccessTokenValiditySeconds() * 1000L));
365                        token.setExpiration(expiration);
366                }
367
368                if (client.isReuseRefreshToken()) {
369                        // if the client re-uses refresh tokens, do that
370                        token.setRefreshToken(refreshToken);
371                } else {
372                        // otherwise, make a new refresh token
373                        OAuth2RefreshTokenEntity newRefresh = createRefreshToken(client, authHolder);
374                        token.setRefreshToken(newRefresh);
375
376                        // clean up the old refresh token
377                        tokenRepository.removeRefreshToken(refreshToken);
378                }
379
380                token.setAuthenticationHolder(authHolder);
381
382                tokenEnhancer.enhance(token, authHolder.getAuthentication());
383
384                tokenRepository.saveAccessToken(token);
385
386                return token;
387        }
388
389        @Override
390        public OAuth2Authentication loadAuthentication(String accessTokenValue) throws AuthenticationException {
391                OAuth2AccessTokenEntity accessToken = clearExpiredAccessToken(tokenRepository.getAccessTokenByValue(accessTokenValue));
392
393                if (accessToken == null) {
394                        throw new InvalidTokenException("Invalid access token: " + accessTokenValue);
395                } else {
396                        return accessToken.getAuthenticationHolder().getAuthentication();
397                }
398        }
399
400
401        /**
402         * Get an access token from its token value.
403         */
404        @Override
405        public OAuth2AccessTokenEntity readAccessToken(String accessTokenValue) throws AuthenticationException {
406                OAuth2AccessTokenEntity accessToken = clearExpiredAccessToken(tokenRepository.getAccessTokenByValue(accessTokenValue));
407                if (accessToken == null) {
408                        throw new InvalidTokenException("Access token for value " + accessTokenValue + " was not found");
409                } else {
410                        return accessToken;
411                }
412        }
413
414        /**
415         * Get an access token by its authentication object.
416         */
417        @Override
418        public OAuth2AccessTokenEntity getAccessToken(OAuth2Authentication authentication) {
419                // TODO: implement this against the new service (#825)
420                throw new UnsupportedOperationException("Unable to look up access token from authentication object.");
421        }
422
423        /**
424         * Get a refresh token by its token value.
425         */
426        @Override
427        public OAuth2RefreshTokenEntity getRefreshToken(String refreshTokenValue) throws AuthenticationException {
428                OAuth2RefreshTokenEntity refreshToken = tokenRepository.getRefreshTokenByValue(refreshTokenValue);
429                if (refreshToken == null) {
430                        throw new InvalidTokenException("Refresh token for value " + refreshTokenValue + " was not found");
431                }
432                else {
433                        return refreshToken;
434                }
435        }
436
437        /**
438         * Revoke a refresh token and all access tokens issued to it.
439         */
440        @Override
441        @Transactional(value="defaultTransactionManager")
442        public void revokeRefreshToken(OAuth2RefreshTokenEntity refreshToken) {
443                tokenRepository.clearAccessTokensForRefreshToken(refreshToken);
444                tokenRepository.removeRefreshToken(refreshToken);
445        }
446
447        /**
448         * Revoke an access token.
449         */
450        @Override
451        @Transactional(value="defaultTransactionManager")
452        public void revokeAccessToken(OAuth2AccessTokenEntity accessToken) {
453                tokenRepository.removeAccessToken(accessToken);
454        }
455
456        @Override
457        public List<OAuth2AccessTokenEntity> getAccessTokensForClient(ClientDetailsEntity client) {
458                return tokenRepository.getAccessTokensForClient(client);
459        }
460
461        @Override
462        public List<OAuth2RefreshTokenEntity> getRefreshTokensForClient(ClientDetailsEntity client) {
463                return tokenRepository.getRefreshTokensForClient(client);
464        }
465
466        /**
467         * Clears out expired tokens and any abandoned authentication objects
468         */
469        @Override
470        public void clearExpiredTokens() {
471                logger.debug("Cleaning out all expired tokens");
472
473                new AbstractPageOperationTemplate<OAuth2AccessTokenEntity>("clearExpiredAccessTokens") {
474                        @Override
475                        public Collection<OAuth2AccessTokenEntity> fetchPage() {
476                                return tokenRepository.getAllExpiredAccessTokens(new DefaultPageCriteria());
477                        }
478
479                        @Override
480                        public void doOperation(OAuth2AccessTokenEntity item) {
481                                revokeAccessToken(item);
482                        }
483                }.execute();
484
485                new AbstractPageOperationTemplate<OAuth2RefreshTokenEntity>("clearExpiredRefreshTokens") {
486                        @Override
487                        public Collection<OAuth2RefreshTokenEntity> fetchPage() {
488                                return tokenRepository.getAllExpiredRefreshTokens(new DefaultPageCriteria());
489                        }
490
491                        @Override
492                        public void doOperation(OAuth2RefreshTokenEntity item) {
493                                revokeRefreshToken(item);
494                        }
495                }.execute();
496
497                new AbstractPageOperationTemplate<AuthenticationHolderEntity>("clearExpiredAuthenticationHolders") {
498                        @Override
499                        public Collection<AuthenticationHolderEntity> fetchPage() {
500                                return authenticationHolderRepository.getOrphanedAuthenticationHolders(new DefaultPageCriteria());
501                        }
502
503                        @Override
504                        public void doOperation(AuthenticationHolderEntity item) {
505                                authenticationHolderRepository.remove(item);
506                        }
507                }.execute();
508        }
509
510        /* (non-Javadoc)
511         * @see org.mitre.oauth2.service.OAuth2TokenEntityService#saveAccessToken(org.mitre.oauth2.model.OAuth2AccessTokenEntity)
512         */
513        @Override
514        @Transactional(value="defaultTransactionManager")
515        public OAuth2AccessTokenEntity saveAccessToken(OAuth2AccessTokenEntity accessToken) {
516                OAuth2AccessTokenEntity newToken = tokenRepository.saveAccessToken(accessToken);
517
518                // if the old token has any additional information for the return from the token endpoint, carry it through here after save
519                if (accessToken.getAdditionalInformation() != null && !accessToken.getAdditionalInformation().isEmpty()) {
520                        newToken.getAdditionalInformation().putAll(accessToken.getAdditionalInformation());
521                }
522
523                return newToken;
524        }
525
526        /* (non-Javadoc)
527         * @see org.mitre.oauth2.service.OAuth2TokenEntityService#saveRefreshToken(org.mitre.oauth2.model.OAuth2RefreshTokenEntity)
528         */
529        @Override
530        @Transactional(value="defaultTransactionManager")
531        public OAuth2RefreshTokenEntity saveRefreshToken(OAuth2RefreshTokenEntity refreshToken) {
532                return tokenRepository.saveRefreshToken(refreshToken);
533        }
534
535        /**
536         * @return the tokenEnhancer
537         */
538        public TokenEnhancer getTokenEnhancer() {
539                return tokenEnhancer;
540        }
541
542        /**
543         * @param tokenEnhancer the tokenEnhancer to set
544         */
545        public void setTokenEnhancer(TokenEnhancer tokenEnhancer) {
546                this.tokenEnhancer = tokenEnhancer;
547        }
548
549        @Override
550        public OAuth2AccessTokenEntity getRegistrationAccessTokenForClient(ClientDetailsEntity client) {
551                List<OAuth2AccessTokenEntity> allTokens = getAccessTokensForClient(client);
552
553                for (OAuth2AccessTokenEntity token : allTokens) {
554                        if ((token.getScope().contains(SystemScopeService.REGISTRATION_TOKEN_SCOPE) || token.getScope().contains(SystemScopeService.RESOURCE_TOKEN_SCOPE))
555                                        && token.getScope().size() == 1) {
556                                // if it only has the registration scope, then it's a registration token
557                                return token;
558                        }
559                }
560
561                return null;
562        }
563}