Back to Blog

Building a Custom JWT Token Cache with Spring: Performance and Security

Deep dive into implementing a high-performance JWT token cache using Spring Boot, Redis, and custom validation logic for enterprise applications.

19 min read

Building a Custom JWT Token Cache with Spring: Performance and Security

In high-traffic enterprise applications, JWT token validation can become a performance bottleneck. While JWTs are designed to be stateless, certain scenarios benefit from caching: token blacklisting, enhanced security validation, and reducing cryptographic operations. This guide explores building a sophisticated JWT token cache with Spring Boot.

Why Cache JWT Tokens?

Performance Benefits

  • Reduce CPU-intensive signature verification
  • Minimize database lookups for user permissions
  • Cache expensive claims validation logic
  • Improve response times for protected endpoints

Security Enhancements

  • Implement token blacklisting for immediate revocation
  • Track token usage patterns for anomaly detection
  • Cache user permissions for consistent authorization
  • Enable emergency token invalidation

Architecture Overview

@Configuration
@EnableCaching
public class JwtCacheConfig {
    
    @Bean
    public CacheManager cacheManager() {
        RedisCacheManager.Builder builder = RedisCacheManager
            .RedisCacheManagerBuilder
            .fromConnectionFactory(redisConnectionFactory())
            .cacheDefaults(cacheConfiguration());
        
        return builder.build();
    }
    
    @Bean
    public RedisCacheConfiguration cacheConfiguration() {
        return RedisCacheConfiguration.defaultCacheConfig()
            .entryTtl(Duration.ofMinutes(30))
            .disableCachingNullValues()
            .serializeKeysWith(RedisSerializationContext.SerializationPair
                .fromSerializer(new StringRedisSerializer()))
            .serializeValuesWith(RedisSerializationContext.SerializationPair
                .fromSerializer(new GenericJackson2JsonRedisSerializer()));
    }
}

JWT Token Cache Service

@Service
@Slf4j
public class JwtTokenCacheService {
    
    private static final String TOKEN_CACHE = "jwt_tokens";
    private static final String BLACKLIST_CACHE = "jwt_blacklist";
    private static final String USER_PERMISSIONS_CACHE = "user_permissions";
    
    @Autowired
    private RedisTemplate<String, Object> redisTemplate;
    
    @Autowired
    private JwtTokenProvider jwtTokenProvider;
    
    @Cacheable(value = TOKEN_CACHE, key = "#tokenHash")
    public JwtValidationResult validateAndCache(String token) {
        String tokenHash = generateTokenHash(token);
        
        // Check blacklist first
        if (isTokenBlacklisted(tokenHash)) {
            return JwtValidationResult.blacklisted();
        }
        
        try {
            // Perform expensive validation
            Claims claims = jwtTokenProvider.validateToken(token);
            String userId = claims.getSubject();
            
            // Cache user permissions
            Set<String> permissions = loadUserPermissions(userId);
            cacheUserPermissions(userId, permissions);
            
            // Cache validation result
            JwtValidationResult result = JwtValidationResult.builder()
                .valid(true)
                .userId(userId)
                .permissions(permissions)
                .expirationTime(claims.getExpiration())
                .build();
            
            // Set cache TTL based on token expiration
            long ttlSeconds = calculateTtl(claims.getExpiration());
            cacheValidationResult(tokenHash, result, ttlSeconds);
            
            return result;
            
        } catch (JwtException e) {
            log.warn("JWT validation failed for token hash: {}", tokenHash, e);
            return JwtValidationResult.invalid(e.getMessage());
        }
    }
    
    private String generateTokenHash(String token) {
        return DigestUtils.sha256Hex(token);
    }
    
    private long calculateTtl(Date expiration) {
        long ttl = (expiration.getTime() - System.currentTimeMillis()) / 1000;
        return Math.max(ttl, 0);
    }
}

JWT Validation Result Model

@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class JwtValidationResult {
    
    private boolean valid;
    private String userId;
    private Set<String> permissions;
    private Date expirationTime;
    private String errorMessage;
    private ValidationStatus status;
    
    public static JwtValidationResult blacklisted() {
        return JwtValidationResult.builder()
            .valid(false)
            .status(ValidationStatus.BLACKLISTED)
            .errorMessage("Token has been blacklisted")
            .build();
    }
    
    public static JwtValidationResult invalid(String errorMessage) {
        return JwtValidationResult.builder()
            .valid(false)
            .status(ValidationStatus.INVALID)
            .errorMessage(errorMessage)
            .build();
    }
    
    public enum ValidationStatus {
        VALID, INVALID, BLACKLISTED, EXPIRED
    }
}

Token Blacklisting Service

@Service
public class JwtBlacklistService {
    
    private static final String BLACKLIST_KEY_PREFIX = "blacklist:";
    
    @Autowired
    private RedisTemplate<String, Object> redisTemplate;
    
    public void blacklistToken(String token, String reason) {
        String tokenHash = DigestUtils.sha256Hex(token);
        String blacklistKey = BLACKLIST_KEY_PREFIX + tokenHash;
        
        BlacklistEntry entry = BlacklistEntry.builder()
            .tokenHash(tokenHash)
            .reason(reason)
            .blacklistedAt(Instant.now())
            .build();
        
        // Get token expiration to set appropriate TTL
        long ttl = getTokenTtl(token);
        
        redisTemplate.opsForValue().set(blacklistKey, entry, Duration.ofSeconds(ttl));
        
        log.info("Token blacklisted: hash={}, reason={}", tokenHash, reason);
    }
    
    public boolean isTokenBlacklisted(String tokenHash) {
        String blacklistKey = BLACKLIST_KEY_PREFIX + tokenHash;
        return redisTemplate.hasKey(blacklistKey);
    }
    
    public void blacklistUserTokens(String userId, String reason) {
        // Find all active tokens for user and blacklist them
        Set<String> userTokens = findUserActiveTokens(userId);
        
        userTokens.forEach(token -> blacklistToken(token, reason));
        
        log.info("Blacklisted {} tokens for user: {}, reason: {}", 
                userTokens.size(), userId, reason);
    }
    
    private Set<String> findUserActiveTokens(String userId) {
        // Implementation would depend on your token storage strategy
        // This could query Redis patterns or maintain a user->tokens mapping
        String pattern = "jwt_tokens:*:user:" + userId;
        return redisTemplate.keys(pattern);
    }
}

Cached Permission Service

@Service
public class CachedPermissionService {
    
    private static final String PERMISSIONS_KEY_PREFIX = "permissions:";
    private static final Duration PERMISSIONS_TTL = Duration.ofMinutes(15);
    
    @Autowired
    private RedisTemplate<String, Object> redisTemplate;
    
    @Autowired
    private UserPermissionRepository permissionRepository;
    
    @Cacheable(value = "user_permissions", key = "#userId")
    public Set<String> getUserPermissions(String userId) {
        String cacheKey = PERMISSIONS_KEY_PREFIX + userId;
        
        // Try cache first
        Set<String> cachedPermissions = getCachedPermissions(cacheKey);
        if (cachedPermissions != null) {
            return cachedPermissions;
        }
        
        // Load from database
        Set<String> permissions = permissionRepository.findPermissionsByUserId(userId);
        
        // Cache the result
        redisTemplate.opsForValue().set(cacheKey, permissions, PERMISSIONS_TTL);
        
        return permissions;
    }
    
    @CacheEvict(value = "user_permissions", key = "#userId")
    public void evictUserPermissions(String userId) {
        String cacheKey = PERMISSIONS_KEY_PREFIX + userId;
        redisTemplate.delete(cacheKey);
        log.info("Evicted permissions cache for user: {}", userId);
    }
    
    private Set<String> getCachedPermissions(String cacheKey) {
        Object cached = redisTemplate.opsForValue().get(cacheKey);
        if (cached instanceof Set) {
            return (Set<String>) cached;
        }
        return null;
    }
}

Cache-Aware JWT Filter

@Component
public class CachedJwtAuthenticationFilter extends OncePerRequestFilter {
    
    @Autowired
    private JwtTokenCacheService tokenCacheService;
    
    @Override
    protected void doFilterInternal(HttpServletRequest request, 
                                   HttpServletResponse response, 
                                   FilterChain filterChain) throws ServletException, IOException {
        
        String token = extractTokenFromRequest(request);
        
        if (token != null) {
            JwtValidationResult validationResult = tokenCacheService.validateAndCache(token);
            
            if (validationResult.isValid()) {
                SecurityContext context = SecurityContextHolder.createEmptyContext();
                
                UsernamePasswordAuthenticationToken authentication = 
                    new UsernamePasswordAuthenticationToken(
                        validationResult.getUserId(),
                        null,
                        mapPermissionsToAuthorities(validationResult.getPermissions())
                    );
                
                context.setAuthentication(authentication);
                SecurityContextHolder.setContext(context);
                
                // Add caching headers
                response.setHeader("X-Token-Cache-Status", "HIT");
            } else {
                response.setHeader("X-Token-Cache-Status", "INVALID");
                response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
                return;
            }
        }
        
        filterChain.doFilter(request, response);
    }
    
    private Collection<? extends GrantedAuthority> mapPermissionsToAuthorities(Set<String> permissions) {
        return permissions.stream()
            .map(SimpleGrantedAuthority::new)
            .collect(Collectors.toList());
    }
}

Performance Monitoring

@Component
public class JwtCacheMetrics {
    
    private final MeterRegistry meterRegistry;
    private final Counter cacheHits;
    private final Counter cacheMisses;
    private final Timer validationTimer;
    
    public JwtCacheMetrics(MeterRegistry meterRegistry) {
        this.meterRegistry = meterRegistry;
        this.cacheHits = Counter.builder("jwt.cache.hits")
            .description("JWT cache hits")
            .register(meterRegistry);
        this.cacheMisses = Counter.builder("jwt.cache.misses")
            .description("JWT cache misses")
            .register(meterRegistry);
        this.validationTimer = Timer.builder("jwt.validation.time")
            .description("JWT validation time")
            .register(meterRegistry);
    }
    
    public void recordCacheHit() {
        cacheHits.increment();
    }
    
    public void recordCacheMiss() {
        cacheMisses.increment();
    }
    
    public Timer.Sample startValidationTimer() {
        return Timer.start(meterRegistry);
    }
    
    @Scheduled(fixedRate = 60000)
    public void recordCacheStatistics() {
        // Record cache size, hit ratio, etc.
        double hitRatio = calculateHitRatio();
        meterRegistry.gauge("jwt.cache.hit_ratio", hitRatio);
    }
}

Cache Warming Strategy

@Service
public class JwtCacheWarmingService {
    
    @Autowired
    private JwtTokenCacheService tokenCacheService;
    
    @Autowired
    private ActiveSessionRepository sessionRepository;
    
    @EventListener(ApplicationReadyEvent.class)
    public void warmCache() {
        log.info("Starting JWT cache warming...");
        
        List<String> activeTokens = sessionRepository.findActiveTokens();
        
        activeTokens.parallelStream()
            .forEach(token -> {
                try {
                    tokenCacheService.validateAndCache(token);
                } catch (Exception e) {
                    log.warn("Failed to warm cache for token", e);
                }
            });
        
        log.info("JWT cache warming completed. Warmed {} tokens", activeTokens.size());
    }
    
    @Scheduled(fixedRate = 300000) // Every 5 minutes
    public void refreshExpiringEntries() {
        // Refresh cache entries that are about to expire
        refreshSoonToExpireEntries();
    }
}

Testing the JWT Cache

@SpringBootTest
@TestPropertySource(properties = {
    "spring.redis.host=localhost",
    "spring.redis.port=6379"
})
class JwtTokenCacheServiceTest {
    
    @Autowired
    private JwtTokenCacheService tokenCacheService;
    
    @Autowired
    private RedisTemplate<String, Object> redisTemplate;
    
    @MockBean
    private JwtTokenProvider jwtTokenProvider;
    
    @Test
    void shouldCacheValidationResult() {
        String token = "valid.jwt.token";
        Claims claims = createMockClaims();
        
        when(jwtTokenProvider.validateToken(token)).thenReturn(claims);
        
        // First call - should hit the database
        JwtValidationResult result1 = tokenCacheService.validateAndCache(token);
        
        // Second call - should hit the cache
        JwtValidationResult result2 = tokenCacheService.validateAndCache(token);
        
        assertThat(result1.isValid()).isTrue();
        assertThat(result2.isValid()).isTrue();
        
        // Verify provider was called only once
        verify(jwtTokenProvider, times(1)).validateToken(token);
    }
    
    @Test
    void shouldRespectBlacklist() {
        String token = "blacklisted.jwt.token";
        String tokenHash = DigestUtils.sha256Hex(token);
        
        // Blacklist the token
        redisTemplate.opsForValue().set("blacklist:" + tokenHash, 
            new BlacklistEntry(), Duration.ofMinutes(30));
        
        JwtValidationResult result = tokenCacheService.validateAndCache(token);
        
        assertThat(result.isValid()).isFalse();
        assertThat(result.getStatus()).isEqualTo(ValidationStatus.BLACKLISTED);
    }
}

Configuration Properties

# application.yml
spring:
  redis:
    host: localhost
    port: 6379
    password: ${REDIS_PASSWORD:}
    timeout: 2000ms
    jedis:
      pool:
        max-active: 20
        max-idle: 10
        min-idle: 2

jwt:
  cache:
    default-ttl: PT30M
    max-entries: 10000
    blacklist-ttl: PT24H
    permissions-ttl: PT15M
  
logging:
  level:
    com.company.security: DEBUG
    org.springframework.cache: DEBUG

Best Practices

  1. TTL Management: Set cache TTL to token expiration time
  2. Memory Usage: Monitor Redis memory consumption
  3. Cache Eviction: Implement proper cleanup strategies
  4. Security: Hash tokens before using as cache keys
  5. Monitoring: Track hit ratios and performance metrics
  6. Fallback: Always handle cache failures gracefully
  7. Blacklist Cleanup: Remove expired blacklist entries

Conclusion

A well-implemented JWT token cache can significantly improve application performance while enhancing security capabilities. The combination of Redis caching, proper TTL management, and comprehensive monitoring creates a robust foundation for enterprise authentication systems.

Remember to balance performance gains with security requirements, and always have fallback mechanisms when the cache is unavailable. The patterns shown here provide a solid foundation for scaling JWT-based authentication in high-traffic applications.