/* * Copyright 2006-2008 Sxip Identity Corporation */ package org.openid4java.message; import org.openid4java.message.ax.AxMessage; import org.openid4java.message.sreg.SRegMessage; import org.openid4java.message.sreg.SReg11ExtensionFactory; import org.openid4java.message.pape.PapeMessage; import java.io.UnsupportedEncodingException; import java.util.*; import java.net.URLEncoder; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; /** * @author Marius Scurtescu, Johnny Bufu */ public class Message { private static Log _log = LogFactory.getLog(Message.class); private static final boolean DEBUG = _log.isDebugEnabled(); // message constants public static final String MODE_IDRES = "id_res"; public static final String MODE_CANCEL = "cancel"; public static final String MODE_SETUP_NEEDED = "setup_needed"; public static final String OPENID2_NS = "http://specs.openid.net/auth/2.0"; private ParameterList _params; private int _extCounter; // extention type URI -> extension alias : extension present in the message private Map _extAliases; // extension type URI -> MessageExtensions : extracted extension objects private Map _extesion; // the URL where this message should be sent, where applicable // should remain null for received messages (created from param lists) protected String _destinationUrl; // type URI -> message extension factory : supported extensions private static Map _extensionFactories = new HashMap(); static { _extensionFactories.put(AxMessage.OPENID_NS_AX, AxMessage.class); _extensionFactories.put(SRegMessage.OPENID_NS_SREG, SRegMessage.class); _extensionFactories.put(SRegMessage.OPENID_NS_SREG11, SReg11ExtensionFactory.class); _extensionFactories.put(PapeMessage.OPENID_NS_PAPE, PapeMessage.class); } protected Message() { _params = new ParameterList(); _extCounter = 0; _extAliases = new HashMap(); _extesion = new HashMap(); } protected Message (ParameterList params) { this(); this._params = params; //build the extension list when creating a message from a param list Iterator iter = _params.getParameters().iterator(); // simple registration is a special case; we support only: // SREG1.0 (no namespace, "sreg" alias hardcoded) in : // - OpenID1 messages // - OpenID2 messages (against the 2.0 spec), // to accomodate Yahoo's non-2.0-compliant implementation // SREG1.1 (namespace, any possible alias) in OpenID2 messages boolean hasOpenidDotSreg = false; while (iter.hasNext()) { String key = ((Parameter) iter.next()).getKey(); if (key.startsWith("openid.ns.") && key.length() > 10) _extAliases.put(_params.getParameter(key).getValue(), key.substring(10)); if (key.startsWith("openid.sreg.")) hasOpenidDotSreg = true; } // only do the workaround for OpenID1 messages if ( hasOpenidDotSreg && ! _extAliases.values().contains("sreg") /*! todo: revert this: hasParameter("openid.ns")*/ ) _extAliases.put(SRegMessage.OPENID_NS_SREG, "sreg"); _extCounter = _extAliases.size(); } public static Message createMessage() throws MessageException { Message message = new Message(); message.validate(); if (DEBUG) _log.debug("Created message:\n" + message.keyValueFormEncoding()); return message; } public static Message createMessage(ParameterList params) throws MessageException { Message message = new Message(params); message.validate(); if (DEBUG) _log.debug("Created message from parameter list:\n" + message.keyValueFormEncoding()); return message; } protected Parameter getParameter(String name) { return _params.getParameter(name); } public String getParameterValue(String name) { return _params.getParameterValue(name); } public boolean hasParameter(String name) { return _params.hasParameter(name); } protected void set(String name, String value) { _params.set(new Parameter(name, value)); } protected List getParameters() { return _params.getParameters(); } public Map getParameterMap() { Map params = new LinkedHashMap(); Iterator iter = _params.getParameters().iterator(); while (iter.hasNext()) { Parameter p = (Parameter) iter.next(); params.put( p.getKey(), p.getValue() ); } return params; } /** * Checks that all required parameters are present */ public void validate() throws MessageException { List requiredFields = getRequiredFields(); Iterator paramIter = _params.getParameters().iterator(); while (paramIter.hasNext()) { Parameter param = (Parameter) paramIter.next(); if (!param.isValid()) throw new MessageException("Invalid parameter: " + param); } if (requiredFields == null) return; Iterator reqIter = requiredFields.iterator(); while(reqIter.hasNext()) { String required = (String) reqIter.next(); if (! hasParameter(required)) throw new MessageException( "Required parameter missing: " + required); } } public List getRequiredFields() { return null; } public String keyValueFormEncoding() { return _params.toString(); } public String wwwFormEncoding() { StringBuffer allParams = new StringBuffer(""); List parameters = _params.getParameters(); Iterator iterator = parameters.iterator(); while (iterator.hasNext()) { Parameter parameter = (Parameter) iterator.next(); // All of the keys in the request message MUST be prefixed with "openid." if ( ! parameter.getKey().startsWith("openid.")) allParams.append("openid."); try { allParams.append(URLEncoder.encode(parameter.getKey(), "UTF-8")); allParams.append('='); allParams.append(URLEncoder.encode(parameter.getValue(), "UTF-8")); allParams.append('&'); } catch (UnsupportedEncodingException e) { return null; } } // remove the trailing '&' if (allParams.length() > 0) allParams.deleteCharAt(allParams.length() -1); return allParams.toString(); } /** * Gets the URL where the message should be sent, where applicable. * Null for received messages. * * @param httpGet If true, the wwwFormEncoding() is appended to the * destination URL; the return value should be used * with a GET-redirect. * If false, the verbatim destination URL is returned, * which should be used with a FORM POST redirect. * * @see #wwwFormEncoding() */ public String getDestinationUrl(boolean httpGet) { if (_destinationUrl == null) throw new IllegalStateException("Destination URL not set; " + "is this a received message?"); if (httpGet) // append wwwFormEncoding to the destination URL { boolean hasQuery = _destinationUrl.indexOf("?") > 0; String initialChar = hasQuery ? "&" : "?"; return _destinationUrl + initialChar + wwwFormEncoding(); } else // should send the keyValueFormEncoding in POST data return _destinationUrl; } // ------------ extensions implementation ------------ /** * Adds a new extension factory. * * @param clazz The implementation class for the extension factory, * must implement {@link MessageExtensionFactory}. */ public static void addExtensionFactory(Class clazz) throws MessageException { try { MessageExtensionFactory extensionFactory = (MessageExtensionFactory) clazz.newInstance(); if (DEBUG) _log.debug("Adding extension factory for " + extensionFactory.getTypeUri()); _extensionFactories.put(extensionFactory.getTypeUri(), clazz); } catch (Exception e) { throw new MessageException( "Cannot instantiante message extension factory class: " + clazz.getName()); } } /** * Returns true if there is an extension factory available for extension * identified by the specified Type URI, or false otherwise. * * @param typeUri The Type URI that identifies an extension. */ public static boolean hasExtensionFactory(String typeUri) { return _extensionFactories.containsKey(typeUri); } /** * Gets a MessageExtensionFactory for the specified Type URI * if an implementation is available, or null otherwise. * * @param typeUri The Type URI that identifies a extension. * @see MessageExtensionFactory Message */ public static MessageExtensionFactory getExtensionFactory(String typeUri) { if (! hasExtensionFactory(typeUri)) return null; MessageExtensionFactory extensionFactory; try { Class extensionClass = (Class) _extensionFactories.get(typeUri); extensionFactory = (MessageExtensionFactory) extensionClass.newInstance(); } catch (Exception e) { _log.error("Error getting extension factory for " + typeUri); return null; } return extensionFactory; } /** * Returns true if the message has parameters for the specified * extension type URI. * * @param typeUri The URI that identifies the extension. */ public boolean hasExtension(String typeUri) { return _extAliases.containsKey(typeUri); } /** * Gets a set of extension Type URIs that are present in the message. */ public Set getExtensions() { return _extAliases.keySet(); } /** * Retrieves the extension alias that will be used for the extension * identified by the supplied extension type URI. *
* If the message contains no parameters for the specified extension, * null will be returned. * * @param extensionTypeUri The URI that identifies the extension * @return The extension alias associated with the * extension specifid by the Type URI */ public String getExtensionAlias(String extensionTypeUri) { return (_extAliases.get(extensionTypeUri) != null) ? (String) _extAliases.get(extensionTypeUri) : null; } /** * Adds a set of extension-specific parameters to a message. *
* The parameter names must NOT contain the "openid.
* The "openid.ns.
* The returned object will contain the parameters from the message
* belonging to the specified extension.
*
* @param typeUri The Type URI that identifies a extension.
*/
public MessageExtension getExtension(String typeUri) throws MessageException
{
if (!_extesion.containsKey(typeUri))
{
if (hasExtensionFactory(typeUri))
{
MessageExtensionFactory extensionFactory = getExtensionFactory(typeUri);
String mode = getParameterValue("openid.mode");
MessageExtension extension = extensionFactory.getExtension(
getExtensionParams(typeUri), mode.startsWith("checkid_"));
if (this instanceof AuthSuccess && extension.signRequired())
{
List signedParams = Arrays.asList(
((AuthSuccess)this).getSignList().split(",") );
String alias = getExtensionAlias(typeUri);
Iterator iter = extension.getParameters().getParameters().iterator();
while (iter.hasNext())
{
Parameter param = (Parameter) iter.next();
if (! signedParams.contains(alias + "." + param.getKey()))
{
throw new MessageException(
"Extension " + typeUri + " MUST be signed; " +
"field " + param.getKey() + " is NOT signed.");
}
}
}
_extesion.put(typeUri, extension);
}
else
throw new MessageException("Cannot instantiate extension: " + typeUri);
}
if (DEBUG) _log.debug("Extracting " + typeUri +" extension from message...");
return (MessageExtension) _extesion.get(typeUri);
}
}