Spring Security OAuth2實(shí)現(xiàn)簡(jiǎn)單的密鑰輪換及配置資源服務(wù)器JWK緩存

Spring Security OAuth2實(shí)現(xiàn)簡(jiǎn)單的密鑰輪換及配置資源服務(wù)器JWK緩存

概述

在OAuth2協(xié)議中授權(quán)服務(wù)器或者OIDC中身份提供服務(wù)常使用私鑰對(duì)JWT令牌進(jìn)行簽名,第三方服務(wù)客戶端或者資源服務(wù)使用已知URL上發(fā)布的公鑰對(duì)令牌進(jìn)行驗(yàn)證。

這些密鑰構(gòu)成了各方之間安全的基礎(chǔ)。為了維護(hù)安全性,保持私鑰免受任何網(wǎng)絡(luò)攻擊是所必需的。

已確定的用于保護(hù)密鑰不被泄露的最佳做法之一稱為密鑰滾動(dòng)更新密鑰輪換。在此方法中,我們丟棄當(dāng)前密鑰并生成一對(duì)新密鑰,用于對(duì)令牌進(jìn)行簽名和驗(yàn)證。

為什么我們需要密鑰輪換?

為了確保公鑰和私鑰對(duì)的安全性免受黑客的攻擊,建議在一段時(shí)間后輪換密鑰。必須丟棄以前的密鑰,并且必須將新生成的密鑰用于進(jìn)一步的加密操作。根據(jù)NIST指南,密鑰必須至少每?jī)赡贻啌Q一次

如何實(shí)現(xiàn)密鑰輪換?

所有公鑰都由授權(quán)服務(wù)或身份提供服務(wù)在 Web 上發(fā)布的URL提供,URL返回一個(gè)對(duì)象,稱為 JSON Web Keys或 JWKS,其中包含多個(gè)JSON Web Key (是一種 JSON 數(shù)據(jù)結(jié)構(gòu),表示一組公鑰),通常稱為 JWK。在驗(yàn)證由私鑰簽名的 JWT時(shí),將使用與私鑰相對(duì)應(yīng)的JWK。以下是JWKS 示例,其keys包含一個(gè) JWK 數(shù)組。

{
"keys": [
  {
    "alg": "RS256",
    "kty": "RSA",
    "use": "sig",
    "x5c": [
      "MIIC+DCCAeCgAwIBAgIJBIGjYW6hFpn2MA0GCSqGSIb3DQEBBQUAMCMxITAfBgNVBAMTGGN1c3RvbWVyLWRlbW9zLmF1dGgwLmNvbTAeFw0xNjExMjIyMjIyMDVaFw0zMDA4MDEyMjIyMDVaMCMxITAfBgNVBAMTGGN1c3RvbWVyLWRlbW9zLmF1dGgwLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMnjZc5bm/eGIHq09N9HKHahM7Y31P0ul+A2wwP4lSpIwFrWHzxw88/7Dwk9QMc+orGXX95R6av4GF+Es/nG3uK45ooMVMa/hYCh0Mtx3gnSuoTavQEkLzCvSwTqVwzZ+5noukWVqJuMKNwjL77GNcPLY7Xy2/skMCT5bR8UoWaufooQvYq6SyPcRAU4BtdquZRiBT4U5f+4pwNTxSvey7ki50yc1tG49Per/0zA4O6Tlpv8x7Red6m1bCNHt7+Z5nSl3RX/QYyAEUX1a28VcYmR41Osy+o2OUCXYdUAphDaHo4/8rbKTJhlu8jEcc1KoMXAKjgaVZtG/v5ltx6AXY0CAwEAAaMvMC0wDAYDVR0TBAUwAwEB/zAdBgNVHQ4EFgQUQxFG602h1cG+pnyvJoy9pGJJoCswDQYJKoZIhvcNAQEFBQADggEBAGvtCbzGNBUJPLICth3mLsX0Z4z8T8iu4tyoiuAshP/Ry/ZBnFnXmhD8vwgMZ2lTgUWwlrvlgN+fAtYKnwFO2G3BOCFw96Nm8So9sjTda9CCZ3dhoH57F/hVMBB0K6xhklAc0b5ZxUpCIN92v/w+xZoz1XQBHe8ZbRHaP1HpRM4M7DJk2G5cgUCyu3UBvYS41sHvzrxQ3z7vIePRA4WF4bEkfX12gvny0RsPkrbVMXX1Rj9t6V7QXrbPYBAO+43JvDGYawxYVvLhz+BJ45x50GFQmHszfY3BR9TPK8xmMmQwtIvLu1PMttNCs7niCYkSiUv2sc2mlq1i3IashGkkgmo="
    ],
    "n": "yeNlzlub94YgerT030codqEztjfU_S6X4DbDA_iVKkjAWtYfPHDzz_sPCT1Axz6isZdf3lHpq_gYX4Sz-cbe4rjmigxUxr-FgKHQy3HeCdK6hNq9ASQvMK9LBOpXDNn7mei6RZWom4wo3CMvvsY1w8tjtfLb-yQwJPltHxShZq5-ihC9irpLI9xEBTgG12q5lGIFPhTl_7inA1PFK97LuSLnTJzW0bj096v_TMDg7pOWm_zHtF53qbVsI0e3v5nmdKXdFf9BjIARRfVrbxVxiZHjU6zL6jY5QJdh1QCmENoejj_ytspMmGW7yMRxzUqgxcAqOBpVm0b-_mW3HoBdjQ",
    "e": "AQAB",
    "kid": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg",
    "x5t": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg"
  }
]}

典型的密鑰輪換策略可避免客戶端發(fā)送使用以前頒發(fā)的密鑰簽名的 JWT 的驗(yàn)證失敗潛在問(wèn)題,因此,在令牌完全過(guò)期之前,我們需要在一段時(shí)間內(nèi)保持兩個(gè)密鑰(先前和當(dāng)前)有效 - 剛好足以為客戶端提供更新其本地緩存的空間。

先決條件

  • java8+
  • Redis
  • JWT


在閱讀文章前,首先說(shuō)明下文密鑰JWK表示相同含義。雖然JWK從規(guī)范定義層面表示一組公鑰,但是在代碼層面JWK所指定的是一組密鑰。例如RSAKey,ECKey等等。

授權(quán)服務(wù)器實(shí)現(xiàn)密鑰輪換

本節(jié)中我們將使用Spring Authorization Server 搭建一個(gè)簡(jiǎn)單的授權(quán)服務(wù)器,并實(shí)現(xiàn)JWKSource自定義密鑰輪換邏輯,密鑰緩存策略提供本地內(nèi)存,caffeine,redis三種實(shí)現(xiàn)方式。

Maven依賴

<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-security</artifactId>
  <version>2.6.7</version>
</dependency>

<dependency>
  <groupId>org.springframework.security</groupId>
  <artifactId>spring-security-oauth2-authorization-server</artifactId>
  <version>0.3.1</version>
</dependency>

<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-web</artifactId>
  <version>2.6.7</version>
</dependency>

<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-cache</artifactId>
  <version>2.6.7</version>
</dependency>

<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-data-redis</artifactId>
  <version>2.6.7</version>
</dependency>

<dependency>
  <groupId>com.github.ben-manes.caffeine</groupId>
  <artifactId>caffeine</artifactId>
  <version>2.9.3</version>
</dependency>

配置

首先我們從application.yml配置開(kāi)始,這里我們指定授權(quán)服務(wù)器端口為8080,并添加redis連接配置信息:

server:
  port: 8080

spring:
  redis:
    host: localhost
    database: 0
    port: 6379
    password: 123456
    timeout: 1800
    lettuce:
      pool:
        max-active: 20
        max-wait: 60
        max-idle: 5
        min-idle: 0
      shutdown-timeout: 100


接下來(lái)我們將創(chuàng)建AuthorizationServerConfig配置類,用于配置OAuth2及OIDC所需Bean,首先我們將新增OAuth2客戶端信息:

    @Bean
    public RegisteredClientRepository registeredClientRepository() {
        RegisteredClient registeredClient = RegisteredClient.withId(UUID.randomUUID().toString())
                .clientId("relive-client")
                .clientSecret("{noop}relive-client")
                .clientAuthenticationMethods(s -> {
                    s.add(ClientAuthenticationMethod.CLIENT_SECRET_POST);
                    s.add(ClientAuthenticationMethod.CLIENT_SECRET_BASIC);
                })
                .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
                .authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
                .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
                .authorizationGrantType(AuthorizationGrantType.PASSWORD)
                .redirectUri("http://127.0.0.1:8070/login/oauth2/code/messaging-client-authorization-code")
                .scope("message.read")
                .clientSettings(ClientSettings.builder()
                        .requireAuthorizationConsent(true)
                        .requireProofKey(false)
                        .build())
                .tokenSettings(TokenSettings.builder()
                        .accessTokenFormat(OAuth2TokenFormat.SELF_CONTAINED) 
                        .idTokenSignatureAlgorithm(SignatureAlgorithm.RS256)
                        .accessTokenTimeToLive(Duration.ofSeconds(30 * 60))
                        .refreshTokenTimeToLive(Duration.ofSeconds(60 * 60))
                        .reuseRefreshTokens(true)
                        .build())
                .build();


        return new InMemoryRegisteredClientRepository(registeredClient);
    }

和以往文章一樣,指定OAuth2客戶端信息,并將OAuth2客戶端信息存儲(chǔ)在內(nèi)存中,如果你需要配置數(shù)據(jù)庫(kù)存儲(chǔ),請(qǐng)參考文章將JWT與Spring Security OAuth2結(jié)合使用


簡(jiǎn)化其他高級(jí)配置,使用OAuth2授權(quán)服務(wù)默認(rèn)配置,并將未認(rèn)證的授權(quán)請(qǐng)求重定向到登錄頁(yè)面:

    @Bean
    @Order(Ordered.HIGHEST_PRECEDENCE)
    public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception {
        OAuth2AuthorizationServerConfiguration.applyDefaultSecurity(http);
        return http.exceptionHandling(exceptions -> exceptions.
                        authenticationEntryPoint(new LoginUrlAuthenticationEntryPoint("/login"))).build();
    }


自定義JWKSource實(shí)現(xiàn)密鑰輪換

之前文章中授權(quán)服務(wù)啟動(dòng)時(shí)隨機(jī)生成一個(gè)2048字節(jié)的RSA密鑰,用于令牌的簽名密鑰。本示例中我們將自定義JWKSource并實(shí)現(xiàn)密鑰輪換策略:

public final class RotateJwkSource<C extends SecurityContext> implements JWKSource<C> {
    private final JWKSource<C> failoverJWKSource;
    private final JWKSetCache jwkSetCache;
    private final JWKGenerator<? extends JWK> jwkGenerator;
    private KeyIDStrategy keyIDStrategy;

    public RotateJwkSource() {
        this(new InMemoryJWKSetCache(), null, null, null);
    }

    public RotateJwkSource(JWKSetCache jwkSetCache) {
        this(jwkSetCache, null, null, null);
    }

    public RotateJwkSource(JWKSetCache jwkSetCache, JWKSource<C> failoverJWKSource) {
        this(jwkSetCache, failoverJWKSource, null, null);
    }

    public RotateJwkSource(JWKSetCache jwkSetCache, JWKGenerator<? extends JWK> jwkGenerator) {
        this(jwkSetCache, null, jwkGenerator, null);
    }

    public RotateJwkSource(JWKSetCache jwkSetCache, JWKSource<C> failoverJWKSource, JWKGenerator<? extends JWK> jwkGenerator, KeyIDStrategy keyIDStrategy) {
        Assert.notNull(jwkSetCache, "jwkSetCache cannot be null");
        this.jwkSetCache = jwkSetCache;
        this.failoverJWKSource = failoverJWKSource;
        if (jwkGenerator == null) {
            this.jwkGenerator = new RSAKeyGenerator(RSAKeyGenerator.MIN_KEY_SIZE_BITS);
        } else {
            this.jwkGenerator = jwkGenerator;
        }
        if (keyIDStrategy == null) {
            this.keyIDStrategy = new TimestampKeyIDStrategy();
        } else {
            this.keyIDStrategy = keyIDStrategy;
        }

    }

    @Override
    public List<JWK> get(JWKSelector jwkSelector, C context) throws RotateKeySourceException {
        JWKSet jwkSet = this.jwkSetCache.get();
        if (this.jwkSetCache.requiresRefresh() || jwkSet == null) {
            try {
                synchronized (this) {
                    jwkSet = this.jwkSetCache.get();
                    if (this.jwkSetCache.requiresRefresh() || jwkSet == null) {
                        jwkSet = this.updateJWKSet(jwkSet);
                    }
                }
            } catch (Exception e) {
                List<JWK> failoverMatches = this.failover(e, jwkSelector, context);
                if (failoverMatches != null) {
                    return failoverMatches;
                }

                if (jwkSet == null) {
                    throw e;
                }
            }
        }
        List<JWK> jwks = jwkSelector.select(jwkSet);
        if (!jwks.isEmpty()) {
            return jwks;
        } else {
            return Collections.emptyList();
        }
    }

    private JWKSet updateJWKSet(JWKSet jwkSet) throws RotateKeySourceException {
        JWK jwk;
        try {
            jwkGenerator.keyID(keyIDStrategy.generateKeyID());
            jwk = jwkGenerator.generate();
        } catch (JOSEException e) {
            throw new RotateKeySourceException("Couldn't generate JWK:" + e.getMessage(), e);
        }
        JWKSet updateJWKSet = new JWKSet(jwk);
        this.jwkSetCache.put(updateJWKSet);
        if (jwkSet != null) {
            List<JWK> keys = jwkSet.getKeys();
            List<JWK> updateJwks = new ArrayList<>(keys);
            updateJwks.add(jwk);
            updateJWKSet = new JWKSet(updateJwks);
        }
        return updateJWKSet;
    }

    private List<JWK> failover(Exception exception, JWKSelector jwkSelector, C context) throws RotateKeySourceException {
        if (this.getFailoverJWKSource() == null) {
            return null;
        } else {
            try {
                return this.getFailoverJWKSource().get(jwkSelector, context);
            } catch (KeySourceException e) {
                throw new RotateKeySourceException(exception.getMessage() + "; Failover JWK source retrieval failed with: " + e.getMessage(), e);
            }
        }
    }

    public void setKeyIDStrategy(KeyIDStrategy keyIDStrategy) {
        this.keyIDStrategy = keyIDStrategy;
    }
}

RotateJwkSource為包含密鑰輪換的JWKSource的實(shí)現(xiàn)類,遵循以下步驟:

  • 首先從JWKSetCache緩存中獲取JWKSet(JWKSet僅包含未過(guò)期JWK)。本示例中自定義JWKSetCache實(shí)現(xiàn)類有InMemoryJWKSetCacheCaffeineJWKSetCacheRedisJWKSetCache

  • 如果JWKSet不為空或不需要刷新密鑰,則通過(guò)JWKSelector從指定的 JWK 集中選擇與配置的條件匹配的JWK。

  • 否則,執(zhí)行updateJWKSet(JWKSet jwkSet)生成新的密鑰對(duì)添加進(jìn)緩存,并返回新的JWKSet(JWKSet僅包含未過(guò)期JWK)。

JWKSetCache 定義密鑰刷新周期及密鑰過(guò)期時(shí)間。

RotateJwkSource屬性介紹:

  • failoverJWKSource:故障轉(zhuǎn)移 JWKSource。
  • jwkSetCache:JWKSet緩存接口類,定義密鑰刷新周期,密鑰過(guò)期時(shí)間。本示例中提供三種實(shí)現(xiàn)類,InMemoryJWKSetCacheCaffeineJWKSetCacheRedisJWKSetCache
  • jwkGenerator:密鑰生成器,RotateJwkSource默認(rèn)使用RSAKeyGenerator
  • KeyIDStrategykid生成策略,本示例中使用時(shí)間戳表示kid


基于本地內(nèi)存,caffeine,redis的JWKSetCache

本示例用于測(cè)試需要,密鑰刷新周期定為5分鐘,密鑰過(guò)期時(shí)間定為15分鐘,實(shí)際應(yīng)用中請(qǐng)根據(jù)需要修改。

InMemoryJWKSetCache實(shí)現(xiàn)方式相對(duì)簡(jiǎn)單。由JWKWithTimestamp存儲(chǔ)密鑰對(duì),lifespan為密鑰過(guò)期時(shí)間,refreshTime為密鑰刷新周期。為確保密鑰輪換正常使用,建議 lifespan >= refreshTime + exp

public class InMemoryJWKSetCache implements JWKSetCache {
    private final long lifespan;
    private final long refreshTime;
    private final TimeUnit timeUnit;
    private volatile Set<JWKWithTimestamp> jwkWithTimestamps;


    public InMemoryJWKSetCache() {
        this(15L, 5L, TimeUnit.MINUTES);
    }

    public InMemoryJWKSetCache(long lifespan, long refreshTime, TimeUnit timeUnit) {
        this.lifespan = lifespan;
        this.refreshTime = refreshTime;
        if ((lifespan > -1L || refreshTime > -1L) && timeUnit == null) {
            throw new IllegalArgumentException("A time unit must be specified for non-negative lifespans or refresh times");
        } else {
            this.timeUnit = timeUnit;
        }
        this.jwkWithTimestamps = new LinkedHashSet<>();
    }

    @Override
    public void put(JWKSet jwkSet) {
        if (jwkSet != null) {
            if (!CollectionUtils.isEmpty(jwkSet.getKeys())) {
                List<JWKWithTimestamp> updateJWKWithTs = jwkSet.getKeys().stream().map(JWKWithTimestamp::new)
                        .collect(Collectors.toList());
                this.jwkWithTimestamps.addAll(updateJWKWithTs);
            }
        }
    }

    @Override
    public JWKSet get() {
        return !CollectionUtils.isEmpty(this.jwkWithTimestamps) && !this.isExpired() ? new JWKSet(this.jwkWithTimestamps.stream()
                .filter(t -> t.getDate().getTime() + TimeUnit.MILLISECONDS.convert(this.lifespan, this.timeUnit) > (new Date()).getTime())
                .map(JWKWithTimestamp::getJwk).collect(Collectors.toList())) : null;
    }

    @Override
    public boolean requiresRefresh() {
        return !CollectionUtils.isEmpty(this.jwkWithTimestamps) && this.refreshTime > -1L && this.jwkWithTimestamps.stream().map(jwkWithTimestamp -> jwkWithTimestamp.getDate().getTime())
                .max(Long::compareTo)
                .filter(time -> (new Date()).getTime() > time + TimeUnit.MILLISECONDS.convert(this.refreshTime, this.timeUnit))
                .isPresent();
    }

    public boolean isExpired() {
        return !CollectionUtils.isEmpty(this.jwkWithTimestamps) && this.lifespan > -1L && this.jwkWithTimestamps.stream().map(jwkWithTimestamp -> jwkWithTimestamp.getDate().getTime())
                .max(Long::compareTo)
                .filter(time -> (new Date()).getTime() > time + TimeUnit.MILLISECONDS.convert(this.lifespan, this.timeUnit))
                .isPresent();
    }

    public long getLifespan(TimeUnit timeUnit) {
        return this.lifespan < 0L ? this.lifespan : timeUnit.convert(this.lifespan, this.timeUnit);
    }

    public long getRefreshTime(TimeUnit timeUnit) {
        return this.refreshTime < 0L ? this.refreshTime : timeUnit.convert(this.refreshTime, this.timeUnit);
    }
}

InMemoryJWKSetCache中put方法將密鑰對(duì)及當(dāng)前時(shí)間封裝為JWKWithTimestamp并添加到LinkedHashSet。get方法從LinkedHashSet過(guò)濾獲取未過(guò)期JWK并返回。

此方式中使用ScheduledFuture開(kāi)啟單獨(dú)任務(wù)清除過(guò)期密鑰。


Caffeine — 一個(gè)用于 Java 的高性能緩存庫(kù)CaffeineJWKSetCache基于Caffeine實(shí)現(xiàn)密鑰存儲(chǔ)。lifespan為密鑰過(guò)期時(shí)間,refreshTime為密鑰刷新周期。建議 lifespan >= refreshTime + exp

Caffeine三種緩存填充策略:手動(dòng)、同步加載和異步加載。其中我們選用手動(dòng)填充將密鑰放入緩存中,并在get()中檢索它們。

public class CaffeineJWKSetCache implements JWKSetCache {
    private final long lifespan;
    private final long refreshTime;
    private final TimeUnit timeUnit;
    private final Cache<Long, JWK> cache;

    public CaffeineJWKSetCache() {
        this(15L, 5L, TimeUnit.MINUTES);
    }

    public CaffeineJWKSetCache(long lifespan, long refreshTime, TimeUnit timeUnit) {
        this.lifespan = lifespan;
        this.refreshTime = refreshTime;
        if ((lifespan > -1L || refreshTime > -1L) && timeUnit == null) {
            throw new IllegalArgumentException("A time unit must be specified for non-negative lifespans or refresh times");
        } else {
            this.timeUnit = timeUnit;
        }
        Caffeine<Object, Object> caffeine = Caffeine.newBuilder().maximumSize(10);
        if (lifespan > -1L) {
            caffeine.expireAfterWrite(this.lifespan, this.timeUnit);
        }
        this.cache = caffeine.build();

    }

    @Override
    public void put(JWKSet jwkSet) {
        if (jwkSet != null) {
            if (!CollectionUtils.isEmpty(jwkSet.getKeys())) {
                jwkSet.getKeys().forEach(jwk -> cache.put(new Date().getTime(), jwk));
            }
        }
    }

    @Override
    public JWKSet get() {
        List<@NonNull JWK> jwks = new ArrayList<>(cache.asMap().values());
        return CollectionUtils.isEmpty(jwks) ? null : new JWKSet(jwks);
    }

    @Override
    public boolean requiresRefresh() {
        return this.refreshTime > -1L && cache.asMap().keySet().stream()
                .max(Long::compareTo)
                .filter(time -> (new Date()).getTime() > time + TimeUnit.MILLISECONDS.convert(this.refreshTime, this.timeUnit))
                .isPresent();
    }
}


Redis — 流行的內(nèi)存數(shù)據(jù)結(jié)構(gòu)存儲(chǔ)。RedisJWKSetCache使用Redis有序集合(sorted set)存儲(chǔ)密鑰,scope為密鑰放進(jìn)緩存的時(shí)間。

Redis有序集合不支持對(duì)單個(gè)元素設(shè)置過(guò)期時(shí)間,所以我們將通過(guò)使用scope存儲(chǔ)密鑰緩存時(shí)間,并在每次更新緩存時(shí)計(jì)算已過(guò)期密鑰,使用zRemRangeByScore命令移除已過(guò)期密鑰。建議 lifespan >= refreshTime + exp

public class RedisJWKSetCache implements JWKSetCache {
    private static final boolean springDataRedis_2_0 = ClassUtils.isPresent("org.springframework.data.redis.connection.RedisStandaloneConfiguration", RedisJWKSetCache.class.getClassLoader());
    private final RedisConnectionFactory connectionFactory;
    private final String JWK_KEY = "jwks";
    private String prefix = "";
    private RedisSerializer<String> redisSerializeKey = new StringRedisSerializer();
    private RedisSerializer<String> redisSerializerValue = new Jackson2JsonRedisSerializer<>(String.class);
    private Method redisConnectionSet_2_0;
    private final long lifespan;
    private final long refreshTime;
    private final TimeUnit timeUnit;

    public RedisJWKSetCache(RedisConnectionFactory connectionFactory) {
        this(15L, 5L, TimeUnit.MINUTES, connectionFactory);
    }

    public RedisJWKSetCache(long lifespan, long refreshTime, TimeUnit timeUnit, RedisConnectionFactory connectionFactory) {
        this.lifespan = lifespan;
        this.refreshTime = refreshTime;
        if ((lifespan > -1L || refreshTime > -1L) && timeUnit == null) {
            throw new IllegalArgumentException("A time unit must be specified for non-negative lifespans or refresh times");
        } else {
            this.timeUnit = timeUnit;
        }
        Assert.notNull(connectionFactory, "redisConnectionFactory cannot be null");
        this.connectionFactory = connectionFactory;
        if (springDataRedis_2_0) {
            this.loadRedisConnectionMethods_2_0();
        }

    }


    @Override
    public void put(JWKSet jwkSet) {
        if (jwkSet != null) {
            if (!CollectionUtils.isEmpty(jwkSet.getKeys())) {
                RedisConnection connection = this.getConnection();
                byte[] key = this.serializeKey(JWK_KEY);

                connection.openPipeline();

                if (this.lifespan > -1) {
                    long max = new Date().getTime() - TimeUnit.MILLISECONDS.convert(this.lifespan, this.timeUnit);
                    connection.zRemRangeByScore(key, Range.range().lte(max));
                }

                List<JWK> keys = jwkSet.getKeys();
                try {
                    for (JWK jwk : keys) {
                        byte[] value = this.serialize(jwk.toJSONString());

                        if (springDataRedis_2_0) {
                            try {
                                this.redisConnectionSet_2_0.invoke(connection, key, new Date().getTime(), value);
                            } catch (Exception e) {
                                throw new RuntimeException(e);
                            }
                        } else {
                            connection.zAdd(key, new Date().getTime(), value);
                        }
                    }
                    connection.closePipeline();
                } finally {
                    connection.close();
                }
            }
        }
    }

    @Override
    public JWKSet get() {
        RedisConnection connection = this.getConnection();
        byte[] key = this.serializeKey(JWK_KEY);
        try {
            Long efficientCount = Optional.ofNullable(connection.zCard(key)).orElse(0L);
            if (efficientCount > 0) {
                Set<byte[]> jwkBytes = connection.zRevRangeByScore(key, Range.range());
                List<JWK> jwks = jwkBytes.stream().map(this::deserialize).map(this::parse).collect(Collectors.toList());
                return new JWKSet(jwks);
            }

            return null;
        } finally {
            connection.close();
        }
    }

    private JWK parse(String jwkJsonString) {
        try {
            return JWK.parse(jwkJsonString);
        } catch (ParseException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public boolean requiresRefresh() {
        RedisConnection connection = this.getConnection();
        byte[] key = this.serializeKey("jwks");
        try {
            Long efficientCount = Optional.ofNullable(connection.zCard(key)).orElse(0L);
            Set<Tuple> maximumScoreTuple = connection.zRevRangeByScoreWithScores(key, Range.range(), Limit.limit().count(1));

            long lastRefreshTime = 0L;
            if (!CollectionUtils.isEmpty(maximumScoreTuple)) {
                lastRefreshTime = maximumScoreTuple.stream().findFirst().orElse(new DefaultTuple(null, 0.0)).getScore().longValue();
            }
            return efficientCount > 0 && this.refreshTime > -1L && (new Date()).getTime() > lastRefreshTime + TimeUnit.MILLISECONDS.convert(this.refreshTime, this.timeUnit);
        } finally {
            connection.close();
        }

    }

    private byte[] serializeKey(String key) {
        return this.redisSerializeKey.serialize(this.prefix + key);
    }

    private byte[] serialize(String value) {
        return this.redisSerializerValue.serialize(value);
    }

    private String deserialize(byte[] bytes) {
        return this.redisSerializerValue.deserialize(bytes);
    }

    private void loadRedisConnectionMethods_2_0() {
        this.redisConnectionSet_2_0 = ReflectionUtils.findMethod(RedisConnection.class, "zAdd", new Class[]{byte[].class, double.class, byte[].class});
    }

    private RedisConnection getConnection() {
        return this.connectionFactory.getConnection();
    }

    public void setPrefix(String prefix) {
        this.prefix = prefix;
    }

    public void setRedisSerializerKey(RedisSerializer<String> redisSerializer) {
        this.redisSerializeKey = redisSerializer;
    }

    public void setRedisSerializerValue(RedisSerializer<String> redisSerializer) {
        this.redisSerializerValue = redisSerializer;
    }
}



介紹完本示例中密鑰輪換實(shí)現(xiàn)邏輯,接下來(lái)讓我們配置RotateJwkSource應(yīng)用于授權(quán)服務(wù):

    @Bean
    public JWKSource<SecurityContext> jwkSource(RedisConnectionFactory connectionFactory) {
        RedisJWKSetCache redisJWKSetCache = new RedisJWKSetCache(connectionFactory);
        redisJWKSetCache.setPrefix("auth-server");

        return new RotateJwkSource<>(redisJWKSetCache);
    }


是否記得前面提到的避免客戶端發(fā)送使用以前頒發(fā)的密鑰簽名的 JWT 的驗(yàn)證失敗潛在問(wèn)題,在令牌完全過(guò)期之前,我們需要在一段時(shí)間內(nèi)保持兩個(gè)密鑰。所以授權(quán)服務(wù)在簽發(fā)JWT令牌時(shí),由于某一段時(shí)間存在多個(gè)密鑰,因此在JwtEncoder生成JWT時(shí)將提示以下錯(cuò)誤信息:

org.springframework.security.oauth2.jwt.JwtEncodingException: An error occurred while attempting to encode the Jwt: Found multiple JWK signing keys for algorithm 'RS256'

所以我們需要生成JWT前指定kid屬性, JWKSelector將從指定的 JWKS 中選擇與kid相對(duì)應(yīng)的JWK用于生成JWT。Spring Authorization Server中OAuth2TokenCustomizer提供了自定義屬性的能力,根據(jù)密鑰輪換策略,我們需要使用最新密鑰生成JWT,RotateJwkSource中kid生成策略由時(shí)間戳定義,所以JWKS中最新的密鑰將會(huì)是最大值kid對(duì)應(yīng)的密鑰,我們將獲取最大值kid放入JWT的Header中。

    @Bean
    public OAuth2TokenCustomizer<JwtEncodingContext> tokenCustomizer(JWKSource<SecurityContext> jwkSource) {
        return (context) -> {
            if (OAuth2TokenType.ACCESS_TOKEN.equals(context.getTokenType()) ||
                    OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) {

                JWKSelector jwkSelector = new JWKSelector(new JWKMatcher.Builder().build());
                List<JWK> jwks;
                try {
                    jwks = jwkSource.get(jwkSelector, null);
                } catch (KeySourceException e) {
                    throw new IllegalStateException("Failed to select the JWK(s) -> " + e.getMessage(), e);
                }
                String kid = jwks.stream().map(JWK::getKeyID)
                        .max(String::compareTo)
                        .orElseThrow(() -> new IllegalArgumentException("kid not found"));
                context.getHeaders().keyId(kid);
            }
        };
    }

本示例中kid由時(shí)間戳定義,所以確保密鑰輪換后使用最新密鑰,我們將獲取最大值kid所對(duì)應(yīng)的密鑰進(jìn)行簽名。但是kid若不使用類似于時(shí)間戳的遞增值,將建議按照FIFO(先進(jìn)先出)結(jié)構(gòu),其格式是將新生成的密鑰推送到末尾。


最后讓我們配置Form表單認(rèn)證方式,并設(shè)置用戶名和密碼:

    @Bean
    SecurityFilterChain defaultSecurityFilterChain(HttpSecurity http) throws Exception {
        http.authorizeHttpRequests((authorize) -> authorize.anyRequest().authenticated())
                .formLogin(Customizer.withDefaults());

        return http.build();
    }

    @Bean
    UserDetailsService userDetailsService() {
        UserDetails userDetails = User.withUsername("admin")
                .password("{noop}password")
                .roles("ADMIN")
                .build();
        return new InMemoryUserDetailsManager(userDetails);
    }

配置資源服務(wù)

本節(jié)中我們將使用Spring Security搭建OAuth2資源服務(wù),并且我們將為JwtDecoder配置redis緩存。

Maven依賴

<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-web</artifactId>
  <version>2.6.7</version>
</dependency>
<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-security</artifactId>
  <version>2.6.7</version>
</dependency>
<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-oauth2-resource-server</artifactId>
  <version>2.6.7</version>
</dependency>

<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-cache</artifactId>
  <version>2.6.7</version>
</dependency>

<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-data-redis</artifactId>
  <version>2.6.7</version>
</dependency>

配置

首先我們從application.yml文件配置開(kāi)始,指定端口8090,并添加redis配置和OAuth2配置信息:

server:
  port: 8090

spring:
  redis:
    host: localhost
    database: 0
    port: 6379
    password: 123456
    timeout: 1800
    lettuce:
      pool:
        max-active: 20
        max-wait: 60
        max-idle: 5
        min-idle: 0
      shutdown-timeout: 100
  security:
    oauth2:
      resourceserver:
        jwt:
          jwk-set-uri: http://127.0.0.1:8080/oauth2/jwks


Spring Boot 自動(dòng)配置一個(gè)具有默認(rèn)緩存配置的RedisCacheManager 。但是,我們可以在緩存管理器初始化之前修改此配置,將緩存過(guò)期時(shí)間設(shè)置為5分鐘

    @Bean
    public CacheManager cacheManager(RedisConnectionFactory factory) {
        RedisSerializer<String> redisSerializer = new StringRedisSerializer();
        Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer(Object.class);
        ObjectMapper om = new ObjectMapper();
        om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        om.activateDefaultTyping(LaissezFaireSubTypeValidator.instance, ObjectMapper.DefaultTyping.NON_FINAL);
        jackson2JsonRedisSerializer.setObjectMapper(om);
        // 配置序列化(解決亂碼的問(wèn)題),過(guò)期時(shí)間5分鐘
        RedisCacheConfiguration config = RedisCacheConfiguration.defaultCacheConfig()
                .entryTtl(Duration.ofSeconds(5*60))
                .serializeKeysWith(RedisSerializationContext.SerializationPair.fromSerializer(redisSerializer))
                .serializeValuesWith(RedisSerializationContext.SerializationPair.fromSerializer(jackson2JsonRedisSerializer))
                .disableCachingNullValues();
        RedisCacheManager cacheManager = RedisCacheManager.builder(factory)
                .cacheDefaults(config)
                .build();
        return cacheManager;
    }


閱讀到這里,你是否有疑問(wèn)授權(quán)服務(wù)輪換密鑰后資源服務(wù)如何獲取最新密鑰驗(yàn)證JWT。

在此之前讓我們先了解下JwtDecoder工作原理,以下是簡(jiǎn)單聲明JwtDecoder的示例:

@Bean
public JwtDecoder jwtDecoder() {
    return NimbusJwtDecoder.withJwkSetUri(jwkSetUri).build();
}

為了清楚起見(jiàn),部分源碼細(xì)節(jié)已被省略。

當(dāng)我們查看源碼NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder構(gòu)建器中,可以看到內(nèi)部創(chuàng)建了JWKSource的實(shí)現(xiàn)類RemoteJWKSet,注意我們沒(méi)有配置Cache,所以最終執(zhí)行return new RemoteJWKSet(toURL(this.jwkSetUri), jwkSetRetriever);。這里ResourceRetriever實(shí)現(xiàn)類為RestOperationsResourceRetriever

        JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever) {
            if (this.cache == null) {
                return new RemoteJWKSet(toURL(this.jwkSetUri), jwkSetRetriever);
            } else {
                ResourceRetriever cachingJwkSetRetriever = new NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder.CachingResourceRetriever(this.cache, jwkSetRetriever);
                return new RemoteJWKSet(toURL(this.jwkSetUri), cachingJwkSetRetriever, new NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder.NoOpJwkSetCache());
            }
        }

JwtDecoder驗(yàn)證JWT需要通過(guò)RemoteJWKSet獲取JWK, RemoteJWKSet由 JWKS URL 指定的遠(yuǎn)程 JSON Web KEY (JWK) 端點(diǎn)。檢索到的 JWKS將被緩存以最小化網(wǎng)絡(luò)調(diào)用。每當(dāng)JWKSelector嘗試獲取具有未知 kid時(shí),都會(huì)更新緩存。以下為RemoteJWKSet核心方法:

    public List<JWK> get(JWKSelector jwkSelector, C context) throws RemoteKeySourceException {
        JWKSet jwkSet = this.jwkSetCache.get();
        if (this.jwkSetCache.requiresRefresh() || jwkSet == null) {
            try {
                jwkSet = this.updateJWKSetFromURL();
            } catch (Exception var6) {
                if (jwkSet == null) {
                    throw var6;
                }
            }
        }

        List<JWK> matches = jwkSelector.select(jwkSet);
        if (!matches.isEmpty()) {
            return matches;
        } else {
            String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher());
            if (soughtKeyID == null) {
                return Collections.emptyList();
            } else if (jwkSet.getKeyByKeyId(soughtKeyID) != null) {
                return Collections.emptyList();
            } else {
                jwkSet = this.updateJWKSetFromURL();
                return jwkSet == null ? Collections.emptyList() : jwkSelector.select(jwkSet);
            }
        }
    }

RemoteJWKSet遵循以下步驟:

  • JWKSetCache獲取JWKSet,RemoteJWKSet中默認(rèn)實(shí)現(xiàn)為DefaultJWKSetCache,默認(rèn)情況DefaultJWKSetCache將授權(quán)服務(wù)器的 JWKS 緩存 5 分鐘。
  • JWKSetCache中JWKSet為空或者需要刷新JWK更新緩存時(shí),RestOperationsResourceRetriever將發(fā)起HTTP請(qǐng)求向授權(quán)服務(wù)獲取JWKS。
  • JWKSelector從指定的 JWKS中選擇與配置的條件匹配的JWK。若匹配為空則將重新通過(guò)RestOperationsResourceRetrieve向授權(quán)服務(wù)請(qǐng)求獲取JWKS,再次匹配結(jié)果為空則返回空值。


通過(guò)簡(jiǎn)要了解RemoteJWKSet執(zhí)行過(guò)程,我相信對(duì)于之前授權(quán)服務(wù)器輪換密鑰后資源服務(wù)如何獲取最新密鑰已經(jīng)有了答案。

在授權(quán)服務(wù)密鑰輪換后生成JWT的Header中kid使用的是當(dāng)前最新密鑰所對(duì)應(yīng)的kid,此時(shí)資源服務(wù)收到JWT,通過(guò)RemoteJWKSet獲取JWK用于驗(yàn)證JWT時(shí),JWKSelectorJWKSetCache返回的JWKS中并沒(méi)有匹配到條件相符的JWK,所以將會(huì)使用RestOperationsResourceRetrieve重新向授權(quán)服務(wù)獲取最新JWKS,JWKSelector將再次選擇與條件相符的JWK。

但是分布式系統(tǒng)中協(xié)作服務(wù)器數(shù)量的增加,授權(quán)服務(wù)密鑰輪換后,涉及資源服務(wù)都要重新請(qǐng)求授權(quán)服務(wù)獲取最新JWK,當(dāng)然這并不會(huì)對(duì)授權(quán)服務(wù)造成太大壓力。但是為了最小化網(wǎng)絡(luò)調(diào)用,本示例使用共享緩存解決此問(wèn)題。


接下來(lái)我們將為JwtDecoder配置Redis緩存,Redis將使用 JWKS Uri 作為鍵,并使用 JWKS JSON 作為值:

    @Bean
    JwtDecoder jwtDecoder(OAuth2ResourceServerProperties properties, RestOperations restOperations, CacheManager cacheManager) {
        NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(properties.getJwt().getJwkSetUri())
                .restOperations(restOperations)
                .cache(cacheManager.getCache("jwks"))
                .jwsAlgorithms(algorithms -> {
                    algorithms.add(RS256);
                }).build();

        //自定義時(shí)間戳驗(yàn)證
        OAuth2TokenValidator<Jwt> withClockSkew = new DelegatingOAuth2TokenValidator<>(
                new JwtTimestampValidator(Duration.ofSeconds(60)));

        jwtDecoder.setJwtValidator(withClockSkew);

        return jwtDecoder;
    }

此時(shí)RemoteJWKSet的ResourceRetriever屬性實(shí)際賦值為CachingResourceRetriever, 我們使用的是Redis緩存,CachingResourceRetriever中更新JWKS會(huì)先從Redis緩存中獲取,若Redis緩存為空則將請(qǐng)求授權(quán)服務(wù),部分源碼如下:

public Resource retrieveResource(URL url) throws IOException {
    String jwkSet = (String)this.cache.get(url.toString(), () -> {
      return this.resourceRetriever.retrieveResource(url).getContent();
    });
    return new Resource(jwkSet, "UTF-8");
}

這里引出一個(gè)新的問(wèn)題,Redis緩存為空才會(huì)重新請(qǐng)求授權(quán)服務(wù)JWKS Uri,如果某個(gè)時(shí)刻授權(quán)服務(wù)密鑰輪換后,資源服務(wù)Redis緩存此時(shí)存在值,則不會(huì)重新向授權(quán)服務(wù)發(fā)起請(qǐng)求來(lái)更新資源服務(wù)JWKS緩存,此時(shí)資源服務(wù)驗(yàn)證輪換后的密鑰生成的JWT將會(huì)失敗。

解決此問(wèn)題我們可以在授權(quán)服務(wù)密鑰輪換后清除Redis中資源服務(wù)JWKS緩存信息。


最后我們將使用Spring Security保護(hù)資源服務(wù)端點(diǎn),指定受保護(hù)端點(diǎn)訪問(wèn)權(quán)限:

    @Bean
    SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
        http.authorizeHttpRequests((authorize) -> authorize
                .antMatchers("/resource/article").hasAuthority("SCOPE_message.read")
                .anyRequest().authenticated())
                .oauth2ResourceServer(OAuth2ResourceServerConfigurer::jwt);

        return http.build();
    }
@RestController
public class ArticleController {


    @GetMapping("/resource/article")
    public Map<String, Object> getArticle(@AuthenticationPrincipal Jwt jwt) {
        Map<String, Object> result = new HashMap<>();
        result.put("principal", jwt.getClaims());
        result.put("article", Arrays.asList("article1", "article2", "article3"));
        return result;
    }
}

測(cè)試我們的應(yīng)用程序

本文中并沒(méi)有說(shuō)明客戶端服務(wù)創(chuàng)建,客戶端服務(wù)并不是本文介紹重點(diǎn),若有疑問(wèn)可以參考之前文章或者通過(guò)文末鏈接獲取源碼。

接下來(lái)我們將啟動(dòng)所有服務(wù)并訪問(wèn)http://127.0.0.1:8070/client/article 。在等待授權(quán)服務(wù)輪換密鑰后,訪問(wèn)依舊正常。

結(jié)論

與往常一樣,本文中使用的源代碼可在 GitHub 上獲得。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

推薦閱讀更多精彩內(nèi)容