001/*******************************************************************************
002 * Copyright 2018 The MIT Internet Trust Consortium
003 *
004 * Portions copyright 2011-2013 The MITRE Corporation
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License");
007 * you may not use this file except in compliance with the License.
008 * You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 *******************************************************************************/
018package org.mitre.oauth2.repository.impl;
019
020import java.text.ParseException;
021import java.util.ArrayList;
022import java.util.Date;
023import java.util.HashSet;
024import java.util.LinkedHashSet;
025import java.util.List;
026import java.util.Set;
027
028import javax.persistence.EntityManager;
029import javax.persistence.PersistenceContext;
030import javax.persistence.Query;
031import javax.persistence.TypedQuery;
032import javax.persistence.criteria.CriteriaBuilder;
033import javax.persistence.criteria.CriteriaDelete;
034import javax.persistence.criteria.Root;
035
036import org.mitre.data.DefaultPageCriteria;
037import org.mitre.data.PageCriteria;
038import org.mitre.oauth2.model.ClientDetailsEntity;
039import org.mitre.oauth2.model.OAuth2AccessTokenEntity;
040import org.mitre.oauth2.model.OAuth2RefreshTokenEntity;
041import org.mitre.oauth2.repository.OAuth2TokenRepository;
042import org.mitre.openid.connect.model.ApprovedSite;
043import org.mitre.uma.model.ResourceSet;
044import org.mitre.util.jpa.JpaUtil;
045import org.slf4j.Logger;
046import org.slf4j.LoggerFactory;
047import org.springframework.stereotype.Repository;
048import org.springframework.transaction.annotation.Transactional;
049
050import com.nimbusds.jwt.JWT;
051import com.nimbusds.jwt.JWTParser;
052
053@Repository
054public class JpaOAuth2TokenRepository implements OAuth2TokenRepository {
055
056        private static final int MAXEXPIREDRESULTS = 1000;
057
058        private static final Logger logger = LoggerFactory.getLogger(JpaOAuth2TokenRepository.class);
059
060        @PersistenceContext(unitName="defaultPersistenceUnit")
061        private EntityManager manager;
062
063        @Override
064        public Set<OAuth2AccessTokenEntity> getAllAccessTokens() {
065                TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_ALL, OAuth2AccessTokenEntity.class);
066                return new LinkedHashSet<>(query.getResultList());
067        }
068
069        @Override
070        public Set<OAuth2RefreshTokenEntity> getAllRefreshTokens() {
071                TypedQuery<OAuth2RefreshTokenEntity> query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_ALL, OAuth2RefreshTokenEntity.class);
072                return new LinkedHashSet<>(query.getResultList());
073        }
074
075
076        @Override
077        public OAuth2AccessTokenEntity getAccessTokenByValue(String accessTokenValue) {
078                try {
079                        JWT jwt = JWTParser.parse(accessTokenValue);
080                        TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_TOKEN_VALUE, OAuth2AccessTokenEntity.class);
081                        query.setParameter(OAuth2AccessTokenEntity.PARAM_TOKEN_VALUE, jwt);
082                        return JpaUtil.getSingleResult(query.getResultList());
083                } catch (ParseException e) {
084                        return null;
085                }
086        }
087
088        @Override
089        public OAuth2AccessTokenEntity getAccessTokenById(Long id) {
090                return manager.find(OAuth2AccessTokenEntity.class, id);
091        }
092
093        @Override
094        @Transactional(value="defaultTransactionManager")
095        public OAuth2AccessTokenEntity saveAccessToken(OAuth2AccessTokenEntity token) {
096                return JpaUtil.saveOrUpdate(token.getId(), manager, token);
097        }
098
099        @Override
100        @Transactional(value="defaultTransactionManager")
101        public void removeAccessToken(OAuth2AccessTokenEntity accessToken) {
102                OAuth2AccessTokenEntity found = getAccessTokenById(accessToken.getId());
103                if (found != null) {
104                        manager.remove(found);
105                } else {
106                        throw new IllegalArgumentException("Access token not found: " + accessToken);
107                }
108        }
109
110        @Override
111        @Transactional(value="defaultTransactionManager")
112        public void clearAccessTokensForRefreshToken(OAuth2RefreshTokenEntity refreshToken) {
113                TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_REFRESH_TOKEN, OAuth2AccessTokenEntity.class);
114                query.setParameter(OAuth2AccessTokenEntity.PARAM_REFERSH_TOKEN, refreshToken);
115                List<OAuth2AccessTokenEntity> accessTokens = query.getResultList();
116                for (OAuth2AccessTokenEntity accessToken : accessTokens) {
117                        removeAccessToken(accessToken);
118                }
119        }
120
121        @Override
122        public OAuth2RefreshTokenEntity getRefreshTokenByValue(String refreshTokenValue) {
123                try {
124                        JWT jwt = JWTParser.parse(refreshTokenValue);
125                        TypedQuery<OAuth2RefreshTokenEntity> query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_TOKEN_VALUE, OAuth2RefreshTokenEntity.class);
126                        query.setParameter(OAuth2RefreshTokenEntity.PARAM_TOKEN_VALUE, jwt);
127                        return JpaUtil.getSingleResult(query.getResultList());
128                } catch (ParseException e) {
129                        return null;
130                }
131        }
132
133        @Override
134        public OAuth2RefreshTokenEntity getRefreshTokenById(Long id) {
135                return manager.find(OAuth2RefreshTokenEntity.class, id);
136        }
137
138        @Override
139        @Transactional(value="defaultTransactionManager")
140        public OAuth2RefreshTokenEntity saveRefreshToken(OAuth2RefreshTokenEntity refreshToken) {
141                return JpaUtil.saveOrUpdate(refreshToken.getId(), manager, refreshToken);
142        }
143
144        @Override
145        @Transactional(value="defaultTransactionManager")
146        public void removeRefreshToken(OAuth2RefreshTokenEntity refreshToken) {
147                OAuth2RefreshTokenEntity found = getRefreshTokenById(refreshToken.getId());
148                if (found != null) {
149                        manager.remove(found);
150                } else {
151                        throw new IllegalArgumentException("Refresh token not found: " + refreshToken);
152                }
153        }
154
155        @Override
156        @Transactional(value="defaultTransactionManager")
157        public void clearTokensForClient(ClientDetailsEntity client) {
158                TypedQuery<OAuth2AccessTokenEntity> queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_CLIENT, OAuth2AccessTokenEntity.class);
159                queryA.setParameter(OAuth2AccessTokenEntity.PARAM_CLIENT, client);
160                List<OAuth2AccessTokenEntity> accessTokens = queryA.getResultList();
161                for (OAuth2AccessTokenEntity accessToken : accessTokens) {
162                        removeAccessToken(accessToken);
163                }
164                TypedQuery<OAuth2RefreshTokenEntity> queryR = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, OAuth2RefreshTokenEntity.class);
165                queryR.setParameter(OAuth2RefreshTokenEntity.PARAM_CLIENT, client);
166                List<OAuth2RefreshTokenEntity> refreshTokens = queryR.getResultList();
167                for (OAuth2RefreshTokenEntity refreshToken : refreshTokens) {
168                        removeRefreshToken(refreshToken);
169                }
170        }
171
172        @Override
173        public List<OAuth2AccessTokenEntity> getAccessTokensForClient(ClientDetailsEntity client) {
174                TypedQuery<OAuth2AccessTokenEntity> queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_CLIENT, OAuth2AccessTokenEntity.class);
175                queryA.setParameter(OAuth2AccessTokenEntity.PARAM_CLIENT, client);
176                List<OAuth2AccessTokenEntity> accessTokens = queryA.getResultList();
177                return accessTokens;
178        }
179
180        @Override
181        public List<OAuth2RefreshTokenEntity> getRefreshTokensForClient(ClientDetailsEntity client) {
182                TypedQuery<OAuth2RefreshTokenEntity> queryR = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, OAuth2RefreshTokenEntity.class);
183                queryR.setParameter(OAuth2RefreshTokenEntity.PARAM_CLIENT, client);
184                List<OAuth2RefreshTokenEntity> refreshTokens = queryR.getResultList();
185                return refreshTokens;
186        }
187        
188        @Override
189        public Set<OAuth2AccessTokenEntity> getAccessTokensByUserName(String name) {
190                TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_NAME, OAuth2AccessTokenEntity.class);
191            query.setParameter(OAuth2AccessTokenEntity.PARAM_NAME, name);
192            List<OAuth2AccessTokenEntity> results = query.getResultList();
193            return results != null ? new HashSet<>(results) : new HashSet<>();
194        }
195        
196        @Override
197        public Set<OAuth2RefreshTokenEntity> getRefreshTokensByUserName(String name) {
198                TypedQuery<OAuth2RefreshTokenEntity> query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_NAME, OAuth2RefreshTokenEntity.class);
199            query.setParameter(OAuth2RefreshTokenEntity.PARAM_NAME, name);
200            List<OAuth2RefreshTokenEntity> results = query.getResultList();
201            return results != null ? new HashSet<>(results) : new HashSet<>();
202        }
203
204        @Override
205        public Set<OAuth2AccessTokenEntity> getAllExpiredAccessTokens() {
206                DefaultPageCriteria pageCriteria = new DefaultPageCriteria(0, MAXEXPIREDRESULTS);
207                return getAllExpiredAccessTokens(pageCriteria);
208        }
209
210        @Override
211        public Set<OAuth2AccessTokenEntity> getAllExpiredAccessTokens(PageCriteria pageCriteria) {
212                TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_EXPIRED_BY_DATE, OAuth2AccessTokenEntity.class);
213                query.setParameter(OAuth2AccessTokenEntity.PARAM_DATE, new Date());
214                return new LinkedHashSet<>(JpaUtil.getResultPage(query, pageCriteria));
215        }
216
217        @Override
218        public Set<OAuth2RefreshTokenEntity> getAllExpiredRefreshTokens() {
219                DefaultPageCriteria pageCriteria = new DefaultPageCriteria(0, MAXEXPIREDRESULTS);
220                return getAllExpiredRefreshTokens(pageCriteria);
221        }
222
223        @Override
224        public Set<OAuth2RefreshTokenEntity> getAllExpiredRefreshTokens(PageCriteria pageCriteria) {
225                TypedQuery<OAuth2RefreshTokenEntity> query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_EXPIRED_BY_DATE, OAuth2RefreshTokenEntity.class);
226                query.setParameter(OAuth2AccessTokenEntity.PARAM_DATE, new Date());
227                return new LinkedHashSet<>(JpaUtil.getResultPage(query,pageCriteria));
228        }
229
230        @Override
231        public Set<OAuth2AccessTokenEntity> getAccessTokensForResourceSet(ResourceSet rs) {
232                TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_RESOURCE_SET, OAuth2AccessTokenEntity.class);
233                query.setParameter(OAuth2AccessTokenEntity.PARAM_RESOURCE_SET_ID, rs.getId());
234                return new LinkedHashSet<>(query.getResultList());
235        }
236
237        @Override
238        @Transactional(value="defaultTransactionManager")
239        public void clearDuplicateAccessTokens() {
240                Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2AccessTokenEntity a GROUP BY a.jwt HAVING count(1) > 1");
241                @SuppressWarnings("unchecked")
242                List<Object[]> resultList = query.getResultList();
243                List<JWT> values = new ArrayList<>();
244                for (Object[] r : resultList) {
245                        logger.warn("Found duplicate access tokens: {}, {}", ((JWT)r[0]).serialize(), r[1]);
246                        values.add((JWT) r[0]);
247                }
248                if (values.size() > 0) {
249                        CriteriaBuilder cb = manager.getCriteriaBuilder();
250                        CriteriaDelete<OAuth2AccessTokenEntity> criteriaDelete = cb.createCriteriaDelete(OAuth2AccessTokenEntity.class);
251                        Root<OAuth2AccessTokenEntity> root = criteriaDelete.from(OAuth2AccessTokenEntity.class);
252                        criteriaDelete.where(root.get("jwt").in(values));
253                        int result = manager.createQuery(criteriaDelete).executeUpdate();
254                        logger.warn("Deleted {} duplicate access tokens", result);
255                }
256        }
257
258        @Override
259        @Transactional(value="defaultTransactionManager")
260        public void clearDuplicateRefreshTokens() {
261                Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2RefreshTokenEntity a GROUP BY a.jwt HAVING count(1) > 1");
262                @SuppressWarnings("unchecked")
263                List<Object[]> resultList = query.getResultList();
264                List<JWT> values = new ArrayList<>();
265                for (Object[] r : resultList) {
266                        logger.warn("Found duplicate refresh tokens: {}, {}", ((JWT)r[0]).serialize(), r[1]);
267                        values.add((JWT) r[0]);
268                }
269                if (values.size() > 0) {
270                        CriteriaBuilder cb = manager.getCriteriaBuilder();
271                        CriteriaDelete<OAuth2RefreshTokenEntity> criteriaDelete = cb.createCriteriaDelete(OAuth2RefreshTokenEntity.class);
272                        Root<OAuth2RefreshTokenEntity> root = criteriaDelete.from(OAuth2RefreshTokenEntity.class);
273                        criteriaDelete.where(root.get("jwt").in(values));
274                        int result = manager.createQuery(criteriaDelete).executeUpdate();
275                        logger.warn("Deleted {} duplicate refresh tokens", result);
276                }
277
278        }
279
280        @Override
281        public List<OAuth2AccessTokenEntity> getAccessTokensForApprovedSite(ApprovedSite approvedSite) {
282                TypedQuery<OAuth2AccessTokenEntity> queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_APPROVED_SITE, OAuth2AccessTokenEntity.class);
283                queryA.setParameter(OAuth2AccessTokenEntity.PARAM_APPROVED_SITE, approvedSite);
284                List<OAuth2AccessTokenEntity> accessTokens = queryA.getResultList();
285                return accessTokens;
286        }
287
288}