001/**
002 * Copyright 2022 Emmanuel Bourg
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package net.jsign.jca;
018
019import java.io.IOException;
020import java.net.HttpURLConnection;
021import java.net.URL;
022import java.security.GeneralSecurityException;
023import java.security.InvalidAlgorithmParameterException;
024import java.security.KeyStoreException;
025import java.security.MessageDigest;
026import java.security.UnrecoverableKeyException;
027import java.security.cert.Certificate;
028import java.text.DateFormat;
029import java.text.SimpleDateFormat;
030import java.util.ArrayList;
031import java.util.Base64;
032import java.util.Collections;
033import java.util.Date;
034import java.util.HashMap;
035import java.util.List;
036import java.util.Map;
037import java.util.TimeZone;
038import java.util.TreeMap;
039import java.util.function.Function;
040import java.util.regex.Matcher;
041import java.util.regex.Pattern;
042import java.util.stream.Collectors;
043import javax.crypto.Mac;
044import javax.crypto.spec.SecretKeySpec;
045
046import com.cedarsoftware.util.io.JsonWriter;
047import org.apache.commons.codec.binary.Hex;
048
049import net.jsign.DigestAlgorithm;
050
051import static java.nio.charset.StandardCharsets.*;
052
053/**
054 * Signing service using the AWS API.
055 *
056 * @since 4.3
057 * @see <a href="https://docs.aws.amazon.com/kms/latest/APIReference/">AWS Key Management Service API Reference</a>
058 * @see <a href="https://docs.aws.amazon.com/general/latest/gr/signing_aws_api_requests.html">Signing AWS API Requests</a>
059 */
060public class AmazonSigningService implements SigningService {
061
062    /** Source for the certificates */
063    private final Function<String, Certificate[]> certificateStore;
064
065    /** Cache of private keys indexed by id */
066    private final Map<String, SigningServicePrivateKey> keys = new HashMap<>();
067
068    private final RESTClient client;
069
070    /** Mapping between Java and AWS signing algorithms */
071    private final Map<String, String> algorithmMapping = new HashMap<>();
072    {
073        algorithmMapping.put("SHA256withRSA", "RSASSA_PKCS1_V1_5_SHA_256");
074        algorithmMapping.put("SHA384withRSA", "RSASSA_PKCS1_V1_5_SHA_384");
075        algorithmMapping.put("SHA512withRSA", "RSASSA_PKCS1_V1_5_SHA_512");
076        algorithmMapping.put("SHA256withECDSA", "ECDSA_SHA_256");
077        algorithmMapping.put("SHA384withECDSA", "ECDSA_SHA_384");
078        algorithmMapping.put("SHA512withECDSA", "ECDSA_SHA_512");
079        algorithmMapping.put("SHA256withRSA/PSS", "RSASSA_PSS_SHA_256");
080        algorithmMapping.put("SHA384withRSA/PSS", "RSASSA_PSS_SHA_384");
081        algorithmMapping.put("SHA512withRSA/PSS", "RSASSA_PSS_SHA_512");
082    }
083
084    /**
085     * Creates a new AWS signing service.
086     *
087     * @param region           the AWS region holding the keys (for example <tt>eu-west-3</tt>)
088     * @param credentials      the AWS credentials: <tt>accessKey|secretKey|sessionToken</tt> (the session token is optional)
089     * @param certificateStore provides the certificate chain for the keys
090     */
091    public AmazonSigningService(String region, String credentials, Function<String, Certificate[]> certificateStore) {
092        this.certificateStore = certificateStore;
093
094        // parse the credentials
095        String[] elements = credentials.split("\\|", 3);
096        if (elements.length < 2) {
097            throw new IllegalArgumentException("Invalid AWS credentials: " + credentials);
098        }
099        String accessKey = elements[0];
100        String secretKey = elements[1];
101        String sessionToken = elements.length > 2 ? elements[2] : null;
102
103        this.client = new RESTClient("https://kms." + region + ".amazonaws.com", (conn, data) -> sign(conn, accessKey, secretKey, sessionToken, data, null));
104    }
105
106    @Override
107    public String getName() {
108        return "AWS";
109    }
110
111    @Override
112    public List<String> aliases() throws KeyStoreException {
113        List<String> aliases = new ArrayList<>();
114
115        try {
116            // kms:ListKeys (https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html)
117            Map<String, ?> response = query("TrentService.ListKeys", "{}");
118            Object[] keys = (Object[]) response.get("Keys");
119            for (Object key : keys) {
120                aliases.add((String) ((Map) key).get("KeyId"));
121            }
122        } catch (IOException e) {
123            throw new KeyStoreException(e);
124        }
125
126        return aliases;
127    }
128
129    @Override
130    public Certificate[] getCertificateChain(String alias) throws KeyStoreException {
131        return certificateStore.apply(alias);
132    }
133
134    @Override
135    public SigningServicePrivateKey getPrivateKey(String alias, char[] password) throws UnrecoverableKeyException {
136        if (keys.containsKey(alias)) {
137            return keys.get(alias);
138        }
139
140        String algorithm;
141
142        try {
143            // kms:DescribeKey (https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html)
144            Map<String, ?> response = query("TrentService.DescribeKey", "{\"KeyId\":\"" + normalizeKeyId(alias) + "\"}");
145            Map<String, ?> keyMetadata = (Map<String, ?>) response.get("KeyMetadata");
146
147            String keyUsage = (String) keyMetadata.get("KeyUsage");
148            if (!"SIGN_VERIFY".equals(keyUsage)) {
149                throw new UnrecoverableKeyException("The key '" + alias + "' is not a signing key");
150            }
151
152            String keyState = (String) keyMetadata.get("KeyState");
153            if (!"Enabled".equals(keyState)) {
154                throw new UnrecoverableKeyException("The key '" + alias + "' is not enabled (" + keyState + ")");
155            }
156
157            String keySpec = (String) keyMetadata.get("KeySpec");
158            algorithm = keySpec.substring(0, keySpec.indexOf('_'));
159        } catch (IOException e) {
160            throw (UnrecoverableKeyException) new UnrecoverableKeyException("Unable to fetch AWS key '" + alias + "'").initCause(e);
161        }
162
163        SigningServicePrivateKey key = new SigningServicePrivateKey(alias, algorithm);
164        keys.put(alias, key);
165        return key;
166    }
167
168    @Override
169    public byte[] sign(SigningServicePrivateKey privateKey, String algorithm, byte[] data) throws GeneralSecurityException {
170        String alg = algorithmMapping.get(algorithm);
171        if (alg == null) {
172            throw new InvalidAlgorithmParameterException("Unsupported signing algorithm: " + algorithm);
173        }
174
175        DigestAlgorithm digestAlgorithm = DigestAlgorithm.of(algorithm.substring(0, algorithm.toLowerCase().indexOf("with")));
176        data = digestAlgorithm.getMessageDigest().digest(data);
177
178        // kms:Sign (https://docs.aws.amazon.com/kms/latest/APIReference/API_Sign.html)
179        Map<String, String> request = new HashMap<>();
180        request.put("KeyId", normalizeKeyId(privateKey.getId()));
181        request.put("MessageType", "DIGEST");
182        request.put("Message", Base64.getEncoder().encodeToString(data));
183        request.put("SigningAlgorithm", alg);
184        request.put(JsonWriter.TYPE, "false");
185
186        try {
187            Map<String, ?> response = query("TrentService.Sign", JsonWriter.objectToJson(request));
188            String signature = (String) response.get("Signature");
189            return Base64.getDecoder().decode(signature);
190        } catch (IOException e) {
191            throw new GeneralSecurityException(e);
192        }
193    }
194
195    /**
196     * Sends a request to the AWS API.
197     */
198    private Map<String, ?> query(String target, String body) throws IOException {
199        Map<String, String> headers = new HashMap<>();
200        headers.put("X-Amz-Target", target);
201        headers.put("Content-Type", "application/x-amz-json-1.1");
202        return client.post("/", body, headers);
203    }
204
205    /**
206     * Prefixes the key id with <tt>alias/</tt> if necessary.
207     */
208    private String normalizeKeyId(String keyId) {
209        if (keyId.startsWith("arn:") || keyId.startsWith("alias/")) {
210            return keyId;
211        }
212
213        if (!keyId.matches("^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$")) {
214            return "alias/" + keyId;
215        } else {
216            return keyId;
217        }
218    }
219
220    /**
221     * Signs the request
222     *
223     * @see <a href="https://docs.aws.amazon.com/general/latest/gr/signature-version-4.html">Signature Version 4 signing process</a>
224     */
225    void sign(HttpURLConnection conn, String accessKey, String secretKey, String sessionToken, byte[] content, Date date) {
226        DateFormat dateFormat = new SimpleDateFormat("yyyyMMdd");
227        dateFormat.setTimeZone(TimeZone.getTimeZone("UTC"));
228        DateFormat dateTimeFormat = new SimpleDateFormat("yyyyMMdd'T'HHmmss'Z'");
229        dateTimeFormat.setTimeZone(TimeZone.getTimeZone("UTC"));
230        if (date == null) {
231            date = new Date();
232        }
233
234        // Extract the service name and the region from the endpoint
235        URL endpoint = conn.getURL();
236        Pattern hostnamePattern = Pattern.compile("^([^.]+)\\.([^.]+)\\.amazonaws\\.com$");
237        String host = endpoint.getHost();
238        Matcher matcher = hostnamePattern.matcher(host);
239        String regionName = matcher.matches() ? matcher.group(2) : "us-east-1";
240        String serviceName = matcher.matches() ? matcher.group(1) : host.substring(0, host.indexOf('.'));
241
242        String credentialScope = dateFormat.format(date) + "/" + regionName + "/" + serviceName + "/" + "aws4_request";
243
244        conn.addRequestProperty("X-Amz-Date", dateTimeFormat.format(date));
245
246        // Create the canonical request (https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html)
247        Map<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
248        headers.putAll(conn.getRequestProperties());
249        headers.put("Host", Collections.singletonList(host));
250
251        String canonicalRequest = conn.getRequestMethod() + "\n"
252                + endpoint.getPath() + (endpoint.getPath().endsWith("/") ? "" : "/") + "\n"
253                + /* canonical query string, not used for kms operations */ "\n"
254                + canonicalHeaders(headers) + "\n"
255                + signedHeaders(headers) + "\n"
256                + sha256(content);
257
258        // Create the string to sign (https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html)
259        String stringToSign = "AWS4-HMAC-SHA256" + "\n"
260                + dateTimeFormat.format(date) + "\n"
261                + credentialScope + "\n"
262                + sha256(canonicalRequest.getBytes(UTF_8));
263
264        // Derive the signing key (https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html)
265        byte[] key = ("AWS4" + secretKey).getBytes(UTF_8);
266        byte[] signingKey = hmac("aws4_request", hmac(serviceName, hmac(regionName, hmac(dateFormat.format(date), key))));
267
268        // Compute the signature (https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html)
269        byte[] signature = hmac(stringToSign, signingKey);
270
271        conn.setRequestProperty("Authorization",
272                "AWS4-HMAC-SHA256 Credential=" + accessKey + "/" + credentialScope
273                + ", SignedHeaders=" + signedHeaders(headers)
274                + ", Signature=" + Hex.encodeHexString(signature).toLowerCase());
275
276        if (sessionToken != null) {
277            conn.setRequestProperty("X-Amz-Security-Token", sessionToken);
278        }
279    }
280
281    private String canonicalHeaders(Map<String, List<String>> headers) {
282        return headers.entrySet().stream()
283                .map(entry -> entry.getKey().toLowerCase() + ":" + String.join(",", entry.getValue()).replaceAll("\\s+", " "))
284                .collect(Collectors.joining("\n")) + "\n";
285    }
286
287    private String signedHeaders(Map<String, List<String>> headers) {
288        return headers.keySet().stream()
289                .map(String::toLowerCase)
290                .collect(Collectors.joining(";"));
291    }
292
293    private byte[] hmac(String data, byte[] key) {
294        return hmac(data.getBytes(UTF_8), key);
295    }
296
297    private byte[] hmac(byte[] data, byte[] key) {
298        try {
299            Mac mac = Mac.getInstance("HmacSHA256");
300            mac.init(new SecretKeySpec(key, mac.getAlgorithm()));
301            return mac.doFinal(data);
302        } catch (Exception e) {
303            throw new RuntimeException(e);
304        }
305    }
306
307    private String sha256(byte[] data) {
308        MessageDigest digest =  DigestAlgorithm.SHA256.getMessageDigest();
309        digest.update(data);
310        return Hex.encodeHexString(digest.digest()).toLowerCase();
311    }
312}