DefaultJWTSigningAndValidationService.java

/*******************************************************************************
 * Copyright 2018 The MIT Internet Trust Consortium
 *
 * Portions copyright 2011-2013 The MITRE Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
package org.mitre.jwt.signer.service.impl;

import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.UUID;

import org.mitre.jose.keystore.JWKSetKeyStore;
import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Strings;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.JWSVerifier;
import com.nimbusds.jose.crypto.ECDSASigner;
import com.nimbusds.jose.crypto.ECDSAVerifier;
import com.nimbusds.jose.crypto.MACSigner;
import com.nimbusds.jose.crypto.MACVerifier;
import com.nimbusds.jose.crypto.RSASSASigner;
import com.nimbusds.jose.crypto.RSASSAVerifier;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jwt.SignedJWT;

public class DefaultJWTSigningAndValidationService implements JWTSigningAndValidationService {

	// map of identifier to signer
	private Map<String, JWSSigner> signers = new HashMap<>();

	// map of identifier to verifier
	private Map<String, JWSVerifier> verifiers = new HashMap<>();

	/**
	 * Logger for this class
	 */
	private static final Logger logger = LoggerFactory.getLogger(DefaultJWTSigningAndValidationService.class);

	private String defaultSignerKeyId;

	private JWSAlgorithm defaultAlgorithm;

	// map of identifier to key
	private Map<String, JWK> keys = new HashMap<>();

	/**
	 * Build this service based on the keys given. All public keys will be used
	 * to make verifiers, all private keys will be used to make signers.
	 *
	 * @param keys
	 *            A map of key identifier to key
	 *
	 * @throws InvalidKeySpecException
	 *             If the keys in the JWKs are not valid
	 * @throws NoSuchAlgorithmException
	 *             If there is no appropriate algorithm to tie the keys to.
	 */
	public DefaultJWTSigningAndValidationService(Map<String, JWK> keys) throws NoSuchAlgorithmException, InvalidKeySpecException {
		this.keys = keys;
		buildSignersAndVerifiers();
	}

	/**
	 * Build this service based on the given keystore. All keys must have a key
	 * id ({@code kid}) field in order to be used.
	 *
	 * @param keyStore
	 *            the keystore to load all keys from
	 *
	 * @throws InvalidKeySpecException
	 *             If the keys in the JWKs are not valid
	 * @throws NoSuchAlgorithmException
	 *             If there is no appropriate algorithm to tie the keys to.
	 */
	public DefaultJWTSigningAndValidationService(JWKSetKeyStore keyStore) throws NoSuchAlgorithmException, InvalidKeySpecException {
		// convert all keys in the keystore to a map based on key id
		if (keyStore!= null && keyStore.getJwkSet() != null) {
			for (JWK key : keyStore.getKeys()) {
				if (!Strings.isNullOrEmpty(key.getKeyID())) {
					// use the key ID that's built into the key itself
					this.keys.put(key.getKeyID(), key);
				} else {
					// create a random key id
					String fakeKid = UUID.randomUUID().toString();
					this.keys.put(fakeKid, key);
				}
			}
		}
		buildSignersAndVerifiers();
	}


	/**
	 * @return the defaultSignerKeyId
	 */
	@Override
	public String getDefaultSignerKeyId() {
		return defaultSignerKeyId;
	}

	/**
	 * @param defaultSignerKeyId the defaultSignerKeyId to set
	 */
	public void setDefaultSignerKeyId(String defaultSignerId) {
		this.defaultSignerKeyId = defaultSignerId;
	}

	/**
	 * @return
	 */
	@Override
	public JWSAlgorithm getDefaultSigningAlgorithm() {
		return defaultAlgorithm;
	}

	public void setDefaultSigningAlgorithmName(String algName) {
		defaultAlgorithm = JWSAlgorithm.parse(algName);
	}

	public String getDefaultSigningAlgorithmName() {
		if (defaultAlgorithm != null) {
			return defaultAlgorithm.getName();
		} else {
			return null;
		}
	}

	/**
	 * Build all of the signers and verifiers for this based on the key map.
	 * @throws InvalidKeySpecException If the keys in the JWKs are not valid
	 * @throws NoSuchAlgorithmException If there is no appropriate algorithm to tie the keys to.
	 */
	private void buildSignersAndVerifiers() throws NoSuchAlgorithmException, InvalidKeySpecException {
		for (Map.Entry<String, JWK> jwkEntry : keys.entrySet()) {

			String id = jwkEntry.getKey();
			JWK jwk = jwkEntry.getValue();

			try {
				if (jwk instanceof RSAKey) {
					// build RSA signers & verifiers

					if (jwk.isPrivate()) { // only add the signer if there's a private key
						RSASSASigner signer = new RSASSASigner((RSAKey) jwk);
						signers.put(id, signer);
					}

					RSASSAVerifier verifier = new RSASSAVerifier((RSAKey) jwk);
					verifiers.put(id, verifier);

				} else if (jwk instanceof ECKey) {
					// build EC signers & verifiers

					if (jwk.isPrivate()) {
						ECDSASigner signer = new ECDSASigner((ECKey) jwk);
						signers.put(id, signer);
					}

					ECDSAVerifier verifier = new ECDSAVerifier((ECKey) jwk);
					verifiers.put(id, verifier);

				} else if (jwk instanceof OctetSequenceKey) {
					// build HMAC signers & verifiers

					if (jwk.isPrivate()) { // technically redundant check because all HMAC keys are private
						MACSigner signer = new MACSigner((OctetSequenceKey) jwk);
						signers.put(id, signer);
					}

					MACVerifier verifier = new MACVerifier((OctetSequenceKey) jwk);
					verifiers.put(id, verifier);

				} else {
					logger.warn("Unknown key type: " + jwk);
				}
			} catch (JOSEException e) {
				logger.warn("Exception loading signer/verifier", e);
			}
		}

		if (defaultSignerKeyId == null && keys.size() == 1) {
			// if there's only one key, it's the default
			setDefaultSignerKeyId(keys.keySet().iterator().next());
		}
	}

	/**
	 * Sign a jwt in place using the configured default signer.
	 */
	@Override
	public void signJwt(SignedJWT jwt) {
		if (getDefaultSignerKeyId() == null) {
			throw new IllegalStateException("Tried to call default signing with no default signer ID set");
		}

		JWSSigner signer = signers.get(getDefaultSignerKeyId());

		try {
			jwt.sign(signer);
		} catch (JOSEException e) {

			logger.error("Failed to sign JWT, error was: ", e);
		}

	}

	@Override
	public void signJwt(SignedJWT jwt, JWSAlgorithm alg) {

		JWSSigner signer = null;

		for (JWSSigner s : signers.values()) {
			if (s.supportedJWSAlgorithms().contains(alg)) {
				signer = s;
				break;
			}
		}

		if (signer == null) {
			//If we can't find an algorithm that matches, we can't sign
			logger.error("No matching algirthm found for alg=" + alg);

		}

		try {
			jwt.sign(signer);
		} catch (JOSEException e) {

			logger.error("Failed to sign JWT, error was: ", e);
		}

	}

	@Override
	public boolean validateSignature(SignedJWT jwt) {

		for (JWSVerifier verifier : verifiers.values()) {
			try {
				if (jwt.verify(verifier)) {
					return true;
				}
			} catch (JOSEException e) {

				logger.error("Failed to validate signature with " + verifier + " error message: " + e.getMessage());
			}
		}
		return false;
	}

	@Override
	public Map<String, JWK> getAllPublicKeys() {
		Map<String, JWK> pubKeys = new HashMap<>();

		// pull all keys out of the verifiers if we know how
		for (String keyId : keys.keySet()) {
			JWK key = keys.get(keyId);
			JWK pub = key.toPublicJWK();
			if (pub != null) {
				pubKeys.put(keyId, pub);
			}
		}

		return pubKeys;
	}

	/* (non-Javadoc)
	 * @see org.mitre.jwt.signer.service.JwtSigningAndValidationService#getAllSigningAlgsSupported()
	 */
	@Override
	public Collection<JWSAlgorithm> getAllSigningAlgsSupported() {

		Set<JWSAlgorithm> algs = new HashSet<>();

		for (JWSSigner signer : signers.values()) {
			algs.addAll(signer.supportedJWSAlgorithms());
		}

		for (JWSVerifier verifier : verifiers.values()) {
			algs.addAll(verifier.supportedJWSAlgorithms());
		}

		return algs;

	}

}