001/******************************************************************************* 002 * Copyright 2018 The MIT Internet Trust Consortium 003 * 004 * Portions copyright 2011-2013 The MITRE Corporation 005 * 006 * Licensed under the Apache License, Version 2.0 (the "License"); 007 * you may not use this file except in compliance with the License. 008 * You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, software 013 * distributed under the License is distributed on an "AS IS" BASIS, 014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 015 * See the License for the specific language governing permissions and 016 * limitations under the License. 017 *******************************************************************************/ 018package org.mitre.oauth2.repository.impl; 019 020import java.text.ParseException; 021import java.util.ArrayList; 022import java.util.Date; 023import java.util.HashSet; 024import java.util.LinkedHashSet; 025import java.util.List; 026import java.util.Set; 027 028import javax.persistence.EntityManager; 029import javax.persistence.PersistenceContext; 030import javax.persistence.Query; 031import javax.persistence.TypedQuery; 032import javax.persistence.criteria.CriteriaBuilder; 033import javax.persistence.criteria.CriteriaDelete; 034import javax.persistence.criteria.Root; 035 036import org.mitre.data.DefaultPageCriteria; 037import org.mitre.data.PageCriteria; 038import org.mitre.oauth2.model.ClientDetailsEntity; 039import org.mitre.oauth2.model.OAuth2AccessTokenEntity; 040import org.mitre.oauth2.model.OAuth2RefreshTokenEntity; 041import org.mitre.oauth2.repository.OAuth2TokenRepository; 042import org.mitre.openid.connect.model.ApprovedSite; 043import org.mitre.uma.model.ResourceSet; 044import org.mitre.util.jpa.JpaUtil; 045import org.slf4j.Logger; 046import org.slf4j.LoggerFactory; 047import org.springframework.stereotype.Repository; 048import org.springframework.transaction.annotation.Transactional; 049 050import com.nimbusds.jwt.JWT; 051import com.nimbusds.jwt.JWTParser; 052 053@Repository 054public class JpaOAuth2TokenRepository implements OAuth2TokenRepository { 055 056 private static final int MAXEXPIREDRESULTS = 1000; 057 058 private static final Logger logger = LoggerFactory.getLogger(JpaOAuth2TokenRepository.class); 059 060 @PersistenceContext(unitName="defaultPersistenceUnit") 061 private EntityManager manager; 062 063 @Override 064 public Set<OAuth2AccessTokenEntity> getAllAccessTokens() { 065 TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_ALL, OAuth2AccessTokenEntity.class); 066 return new LinkedHashSet<>(query.getResultList()); 067 } 068 069 @Override 070 public Set<OAuth2RefreshTokenEntity> getAllRefreshTokens() { 071 TypedQuery<OAuth2RefreshTokenEntity> query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_ALL, OAuth2RefreshTokenEntity.class); 072 return new LinkedHashSet<>(query.getResultList()); 073 } 074 075 076 @Override 077 public OAuth2AccessTokenEntity getAccessTokenByValue(String accessTokenValue) { 078 try { 079 JWT jwt = JWTParser.parse(accessTokenValue); 080 TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_TOKEN_VALUE, OAuth2AccessTokenEntity.class); 081 query.setParameter(OAuth2AccessTokenEntity.PARAM_TOKEN_VALUE, jwt); 082 return JpaUtil.getSingleResult(query.getResultList()); 083 } catch (ParseException e) { 084 return null; 085 } 086 } 087 088 @Override 089 public OAuth2AccessTokenEntity getAccessTokenById(Long id) { 090 return manager.find(OAuth2AccessTokenEntity.class, id); 091 } 092 093 @Override 094 @Transactional(value="defaultTransactionManager") 095 public OAuth2AccessTokenEntity saveAccessToken(OAuth2AccessTokenEntity token) { 096 return JpaUtil.saveOrUpdate(token.getId(), manager, token); 097 } 098 099 @Override 100 @Transactional(value="defaultTransactionManager") 101 public void removeAccessToken(OAuth2AccessTokenEntity accessToken) { 102 OAuth2AccessTokenEntity found = getAccessTokenById(accessToken.getId()); 103 if (found != null) { 104 manager.remove(found); 105 } else { 106 throw new IllegalArgumentException("Access token not found: " + accessToken); 107 } 108 } 109 110 @Override 111 @Transactional(value="defaultTransactionManager") 112 public void clearAccessTokensForRefreshToken(OAuth2RefreshTokenEntity refreshToken) { 113 TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_REFRESH_TOKEN, OAuth2AccessTokenEntity.class); 114 query.setParameter(OAuth2AccessTokenEntity.PARAM_REFERSH_TOKEN, refreshToken); 115 List<OAuth2AccessTokenEntity> accessTokens = query.getResultList(); 116 for (OAuth2AccessTokenEntity accessToken : accessTokens) { 117 removeAccessToken(accessToken); 118 } 119 } 120 121 @Override 122 public OAuth2RefreshTokenEntity getRefreshTokenByValue(String refreshTokenValue) { 123 try { 124 JWT jwt = JWTParser.parse(refreshTokenValue); 125 TypedQuery<OAuth2RefreshTokenEntity> query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_TOKEN_VALUE, OAuth2RefreshTokenEntity.class); 126 query.setParameter(OAuth2RefreshTokenEntity.PARAM_TOKEN_VALUE, jwt); 127 return JpaUtil.getSingleResult(query.getResultList()); 128 } catch (ParseException e) { 129 return null; 130 } 131 } 132 133 @Override 134 public OAuth2RefreshTokenEntity getRefreshTokenById(Long id) { 135 return manager.find(OAuth2RefreshTokenEntity.class, id); 136 } 137 138 @Override 139 @Transactional(value="defaultTransactionManager") 140 public OAuth2RefreshTokenEntity saveRefreshToken(OAuth2RefreshTokenEntity refreshToken) { 141 return JpaUtil.saveOrUpdate(refreshToken.getId(), manager, refreshToken); 142 } 143 144 @Override 145 @Transactional(value="defaultTransactionManager") 146 public void removeRefreshToken(OAuth2RefreshTokenEntity refreshToken) { 147 OAuth2RefreshTokenEntity found = getRefreshTokenById(refreshToken.getId()); 148 if (found != null) { 149 manager.remove(found); 150 } else { 151 throw new IllegalArgumentException("Refresh token not found: " + refreshToken); 152 } 153 } 154 155 @Override 156 @Transactional(value="defaultTransactionManager") 157 public void clearTokensForClient(ClientDetailsEntity client) { 158 TypedQuery<OAuth2AccessTokenEntity> queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_CLIENT, OAuth2AccessTokenEntity.class); 159 queryA.setParameter(OAuth2AccessTokenEntity.PARAM_CLIENT, client); 160 List<OAuth2AccessTokenEntity> accessTokens = queryA.getResultList(); 161 for (OAuth2AccessTokenEntity accessToken : accessTokens) { 162 removeAccessToken(accessToken); 163 } 164 TypedQuery<OAuth2RefreshTokenEntity> queryR = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, OAuth2RefreshTokenEntity.class); 165 queryR.setParameter(OAuth2RefreshTokenEntity.PARAM_CLIENT, client); 166 List<OAuth2RefreshTokenEntity> refreshTokens = queryR.getResultList(); 167 for (OAuth2RefreshTokenEntity refreshToken : refreshTokens) { 168 removeRefreshToken(refreshToken); 169 } 170 } 171 172 @Override 173 public List<OAuth2AccessTokenEntity> getAccessTokensForClient(ClientDetailsEntity client) { 174 TypedQuery<OAuth2AccessTokenEntity> queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_CLIENT, OAuth2AccessTokenEntity.class); 175 queryA.setParameter(OAuth2AccessTokenEntity.PARAM_CLIENT, client); 176 List<OAuth2AccessTokenEntity> accessTokens = queryA.getResultList(); 177 return accessTokens; 178 } 179 180 @Override 181 public List<OAuth2RefreshTokenEntity> getRefreshTokensForClient(ClientDetailsEntity client) { 182 TypedQuery<OAuth2RefreshTokenEntity> queryR = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, OAuth2RefreshTokenEntity.class); 183 queryR.setParameter(OAuth2RefreshTokenEntity.PARAM_CLIENT, client); 184 List<OAuth2RefreshTokenEntity> refreshTokens = queryR.getResultList(); 185 return refreshTokens; 186 } 187 188 @Override 189 public Set<OAuth2AccessTokenEntity> getAccessTokensByUserName(String name) { 190 TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_NAME, OAuth2AccessTokenEntity.class); 191 query.setParameter(OAuth2AccessTokenEntity.PARAM_NAME, name); 192 List<OAuth2AccessTokenEntity> results = query.getResultList(); 193 return results != null ? new HashSet<>(results) : new HashSet<>(); 194 } 195 196 @Override 197 public Set<OAuth2RefreshTokenEntity> getRefreshTokensByUserName(String name) { 198 TypedQuery<OAuth2RefreshTokenEntity> query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_NAME, OAuth2RefreshTokenEntity.class); 199 query.setParameter(OAuth2RefreshTokenEntity.PARAM_NAME, name); 200 List<OAuth2RefreshTokenEntity> results = query.getResultList(); 201 return results != null ? new HashSet<>(results) : new HashSet<>(); 202 } 203 204 @Override 205 public Set<OAuth2AccessTokenEntity> getAllExpiredAccessTokens() { 206 DefaultPageCriteria pageCriteria = new DefaultPageCriteria(0, MAXEXPIREDRESULTS); 207 return getAllExpiredAccessTokens(pageCriteria); 208 } 209 210 @Override 211 public Set<OAuth2AccessTokenEntity> getAllExpiredAccessTokens(PageCriteria pageCriteria) { 212 TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_EXPIRED_BY_DATE, OAuth2AccessTokenEntity.class); 213 query.setParameter(OAuth2AccessTokenEntity.PARAM_DATE, new Date()); 214 return new LinkedHashSet<>(JpaUtil.getResultPage(query, pageCriteria)); 215 } 216 217 @Override 218 public Set<OAuth2RefreshTokenEntity> getAllExpiredRefreshTokens() { 219 DefaultPageCriteria pageCriteria = new DefaultPageCriteria(0, MAXEXPIREDRESULTS); 220 return getAllExpiredRefreshTokens(pageCriteria); 221 } 222 223 @Override 224 public Set<OAuth2RefreshTokenEntity> getAllExpiredRefreshTokens(PageCriteria pageCriteria) { 225 TypedQuery<OAuth2RefreshTokenEntity> query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_EXPIRED_BY_DATE, OAuth2RefreshTokenEntity.class); 226 query.setParameter(OAuth2AccessTokenEntity.PARAM_DATE, new Date()); 227 return new LinkedHashSet<>(JpaUtil.getResultPage(query,pageCriteria)); 228 } 229 230 @Override 231 public Set<OAuth2AccessTokenEntity> getAccessTokensForResourceSet(ResourceSet rs) { 232 TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_RESOURCE_SET, OAuth2AccessTokenEntity.class); 233 query.setParameter(OAuth2AccessTokenEntity.PARAM_RESOURCE_SET_ID, rs.getId()); 234 return new LinkedHashSet<>(query.getResultList()); 235 } 236 237 @Override 238 @Transactional(value="defaultTransactionManager") 239 public void clearDuplicateAccessTokens() { 240 Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2AccessTokenEntity a GROUP BY a.jwt HAVING count(1) > 1"); 241 @SuppressWarnings("unchecked") 242 List<Object[]> resultList = query.getResultList(); 243 List<JWT> values = new ArrayList<>(); 244 for (Object[] r : resultList) { 245 logger.warn("Found duplicate access tokens: {}, {}", ((JWT)r[0]).serialize(), r[1]); 246 values.add((JWT) r[0]); 247 } 248 if (values.size() > 0) { 249 CriteriaBuilder cb = manager.getCriteriaBuilder(); 250 CriteriaDelete<OAuth2AccessTokenEntity> criteriaDelete = cb.createCriteriaDelete(OAuth2AccessTokenEntity.class); 251 Root<OAuth2AccessTokenEntity> root = criteriaDelete.from(OAuth2AccessTokenEntity.class); 252 criteriaDelete.where(root.get("jwt").in(values)); 253 int result = manager.createQuery(criteriaDelete).executeUpdate(); 254 logger.warn("Deleted {} duplicate access tokens", result); 255 } 256 } 257 258 @Override 259 @Transactional(value="defaultTransactionManager") 260 public void clearDuplicateRefreshTokens() { 261 Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2RefreshTokenEntity a GROUP BY a.jwt HAVING count(1) > 1"); 262 @SuppressWarnings("unchecked") 263 List<Object[]> resultList = query.getResultList(); 264 List<JWT> values = new ArrayList<>(); 265 for (Object[] r : resultList) { 266 logger.warn("Found duplicate refresh tokens: {}, {}", ((JWT)r[0]).serialize(), r[1]); 267 values.add((JWT) r[0]); 268 } 269 if (values.size() > 0) { 270 CriteriaBuilder cb = manager.getCriteriaBuilder(); 271 CriteriaDelete<OAuth2RefreshTokenEntity> criteriaDelete = cb.createCriteriaDelete(OAuth2RefreshTokenEntity.class); 272 Root<OAuth2RefreshTokenEntity> root = criteriaDelete.from(OAuth2RefreshTokenEntity.class); 273 criteriaDelete.where(root.get("jwt").in(values)); 274 int result = manager.createQuery(criteriaDelete).executeUpdate(); 275 logger.warn("Deleted {} duplicate refresh tokens", result); 276 } 277 278 } 279 280 @Override 281 public List<OAuth2AccessTokenEntity> getAccessTokensForApprovedSite(ApprovedSite approvedSite) { 282 TypedQuery<OAuth2AccessTokenEntity> queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_APPROVED_SITE, OAuth2AccessTokenEntity.class); 283 queryA.setParameter(OAuth2AccessTokenEntity.PARAM_APPROVED_SITE, approvedSite); 284 List<OAuth2AccessTokenEntity> accessTokens = queryA.getResultList(); 285 return accessTokens; 286 } 287 288}