Building a Custom JWT Token Cache with Spring: Performance and Security
Deep dive into implementing a high-performance JWT token cache using Spring Boot, Redis, and custom validation logic for enterprise applications.
Building a Custom JWT Token Cache with Spring: Performance and Security
In high-traffic enterprise applications, JWT token validation can become a performance bottleneck. While JWTs are designed to be stateless, certain scenarios benefit from caching: token blacklisting, enhanced security validation, and reducing cryptographic operations. This guide explores building a sophisticated JWT token cache with Spring Boot.
Why Cache JWT Tokens?
Performance Benefits
- Reduce CPU-intensive signature verification
- Minimize database lookups for user permissions
- Cache expensive claims validation logic
- Improve response times for protected endpoints
Security Enhancements
- Implement token blacklisting for immediate revocation
- Track token usage patterns for anomaly detection
- Cache user permissions for consistent authorization
- Enable emergency token invalidation
Architecture Overview
@Configuration
@EnableCaching
public class JwtCacheConfig {
@Bean
public CacheManager cacheManager() {
RedisCacheManager.Builder builder = RedisCacheManager
.RedisCacheManagerBuilder
.fromConnectionFactory(redisConnectionFactory())
.cacheDefaults(cacheConfiguration());
return builder.build();
}
@Bean
public RedisCacheConfiguration cacheConfiguration() {
return RedisCacheConfiguration.defaultCacheConfig()
.entryTtl(Duration.ofMinutes(30))
.disableCachingNullValues()
.serializeKeysWith(RedisSerializationContext.SerializationPair
.fromSerializer(new StringRedisSerializer()))
.serializeValuesWith(RedisSerializationContext.SerializationPair
.fromSerializer(new GenericJackson2JsonRedisSerializer()));
}
}
JWT Token Cache Service
@Service
@Slf4j
public class JwtTokenCacheService {
private static final String TOKEN_CACHE = "jwt_tokens";
private static final String BLACKLIST_CACHE = "jwt_blacklist";
private static final String USER_PERMISSIONS_CACHE = "user_permissions";
@Autowired
private RedisTemplate<String, Object> redisTemplate;
@Autowired
private JwtTokenProvider jwtTokenProvider;
@Cacheable(value = TOKEN_CACHE, key = "#tokenHash")
public JwtValidationResult validateAndCache(String token) {
String tokenHash = generateTokenHash(token);
// Check blacklist first
if (isTokenBlacklisted(tokenHash)) {
return JwtValidationResult.blacklisted();
}
try {
// Perform expensive validation
Claims claims = jwtTokenProvider.validateToken(token);
String userId = claims.getSubject();
// Cache user permissions
Set<String> permissions = loadUserPermissions(userId);
cacheUserPermissions(userId, permissions);
// Cache validation result
JwtValidationResult result = JwtValidationResult.builder()
.valid(true)
.userId(userId)
.permissions(permissions)
.expirationTime(claims.getExpiration())
.build();
// Set cache TTL based on token expiration
long ttlSeconds = calculateTtl(claims.getExpiration());
cacheValidationResult(tokenHash, result, ttlSeconds);
return result;
} catch (JwtException e) {
log.warn("JWT validation failed for token hash: {}", tokenHash, e);
return JwtValidationResult.invalid(e.getMessage());
}
}
private String generateTokenHash(String token) {
return DigestUtils.sha256Hex(token);
}
private long calculateTtl(Date expiration) {
long ttl = (expiration.getTime() - System.currentTimeMillis()) / 1000;
return Math.max(ttl, 0);
}
}
JWT Validation Result Model
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class JwtValidationResult {
private boolean valid;
private String userId;
private Set<String> permissions;
private Date expirationTime;
private String errorMessage;
private ValidationStatus status;
public static JwtValidationResult blacklisted() {
return JwtValidationResult.builder()
.valid(false)
.status(ValidationStatus.BLACKLISTED)
.errorMessage("Token has been blacklisted")
.build();
}
public static JwtValidationResult invalid(String errorMessage) {
return JwtValidationResult.builder()
.valid(false)
.status(ValidationStatus.INVALID)
.errorMessage(errorMessage)
.build();
}
public enum ValidationStatus {
VALID, INVALID, BLACKLISTED, EXPIRED
}
}
Token Blacklisting Service
@Service
public class JwtBlacklistService {
private static final String BLACKLIST_KEY_PREFIX = "blacklist:";
@Autowired
private RedisTemplate<String, Object> redisTemplate;
public void blacklistToken(String token, String reason) {
String tokenHash = DigestUtils.sha256Hex(token);
String blacklistKey = BLACKLIST_KEY_PREFIX + tokenHash;
BlacklistEntry entry = BlacklistEntry.builder()
.tokenHash(tokenHash)
.reason(reason)
.blacklistedAt(Instant.now())
.build();
// Get token expiration to set appropriate TTL
long ttl = getTokenTtl(token);
redisTemplate.opsForValue().set(blacklistKey, entry, Duration.ofSeconds(ttl));
log.info("Token blacklisted: hash={}, reason={}", tokenHash, reason);
}
public boolean isTokenBlacklisted(String tokenHash) {
String blacklistKey = BLACKLIST_KEY_PREFIX + tokenHash;
return redisTemplate.hasKey(blacklistKey);
}
public void blacklistUserTokens(String userId, String reason) {
// Find all active tokens for user and blacklist them
Set<String> userTokens = findUserActiveTokens(userId);
userTokens.forEach(token -> blacklistToken(token, reason));
log.info("Blacklisted {} tokens for user: {}, reason: {}",
userTokens.size(), userId, reason);
}
private Set<String> findUserActiveTokens(String userId) {
// Implementation would depend on your token storage strategy
// This could query Redis patterns or maintain a user->tokens mapping
String pattern = "jwt_tokens:*:user:" + userId;
return redisTemplate.keys(pattern);
}
}
Cached Permission Service
@Service
public class CachedPermissionService {
private static final String PERMISSIONS_KEY_PREFIX = "permissions:";
private static final Duration PERMISSIONS_TTL = Duration.ofMinutes(15);
@Autowired
private RedisTemplate<String, Object> redisTemplate;
@Autowired
private UserPermissionRepository permissionRepository;
@Cacheable(value = "user_permissions", key = "#userId")
public Set<String> getUserPermissions(String userId) {
String cacheKey = PERMISSIONS_KEY_PREFIX + userId;
// Try cache first
Set<String> cachedPermissions = getCachedPermissions(cacheKey);
if (cachedPermissions != null) {
return cachedPermissions;
}
// Load from database
Set<String> permissions = permissionRepository.findPermissionsByUserId(userId);
// Cache the result
redisTemplate.opsForValue().set(cacheKey, permissions, PERMISSIONS_TTL);
return permissions;
}
@CacheEvict(value = "user_permissions", key = "#userId")
public void evictUserPermissions(String userId) {
String cacheKey = PERMISSIONS_KEY_PREFIX + userId;
redisTemplate.delete(cacheKey);
log.info("Evicted permissions cache for user: {}", userId);
}
private Set<String> getCachedPermissions(String cacheKey) {
Object cached = redisTemplate.opsForValue().get(cacheKey);
if (cached instanceof Set) {
return (Set<String>) cached;
}
return null;
}
}
Cache-Aware JWT Filter
@Component
public class CachedJwtAuthenticationFilter extends OncePerRequestFilter {
@Autowired
private JwtTokenCacheService tokenCacheService;
@Override
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
String token = extractTokenFromRequest(request);
if (token != null) {
JwtValidationResult validationResult = tokenCacheService.validateAndCache(token);
if (validationResult.isValid()) {
SecurityContext context = SecurityContextHolder.createEmptyContext();
UsernamePasswordAuthenticationToken authentication =
new UsernamePasswordAuthenticationToken(
validationResult.getUserId(),
null,
mapPermissionsToAuthorities(validationResult.getPermissions())
);
context.setAuthentication(authentication);
SecurityContextHolder.setContext(context);
// Add caching headers
response.setHeader("X-Token-Cache-Status", "HIT");
} else {
response.setHeader("X-Token-Cache-Status", "INVALID");
response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
return;
}
}
filterChain.doFilter(request, response);
}
private Collection<? extends GrantedAuthority> mapPermissionsToAuthorities(Set<String> permissions) {
return permissions.stream()
.map(SimpleGrantedAuthority::new)
.collect(Collectors.toList());
}
}
Performance Monitoring
@Component
public class JwtCacheMetrics {
private final MeterRegistry meterRegistry;
private final Counter cacheHits;
private final Counter cacheMisses;
private final Timer validationTimer;
public JwtCacheMetrics(MeterRegistry meterRegistry) {
this.meterRegistry = meterRegistry;
this.cacheHits = Counter.builder("jwt.cache.hits")
.description("JWT cache hits")
.register(meterRegistry);
this.cacheMisses = Counter.builder("jwt.cache.misses")
.description("JWT cache misses")
.register(meterRegistry);
this.validationTimer = Timer.builder("jwt.validation.time")
.description("JWT validation time")
.register(meterRegistry);
}
public void recordCacheHit() {
cacheHits.increment();
}
public void recordCacheMiss() {
cacheMisses.increment();
}
public Timer.Sample startValidationTimer() {
return Timer.start(meterRegistry);
}
@Scheduled(fixedRate = 60000)
public void recordCacheStatistics() {
// Record cache size, hit ratio, etc.
double hitRatio = calculateHitRatio();
meterRegistry.gauge("jwt.cache.hit_ratio", hitRatio);
}
}
Cache Warming Strategy
@Service
public class JwtCacheWarmingService {
@Autowired
private JwtTokenCacheService tokenCacheService;
@Autowired
private ActiveSessionRepository sessionRepository;
@EventListener(ApplicationReadyEvent.class)
public void warmCache() {
log.info("Starting JWT cache warming...");
List<String> activeTokens = sessionRepository.findActiveTokens();
activeTokens.parallelStream()
.forEach(token -> {
try {
tokenCacheService.validateAndCache(token);
} catch (Exception e) {
log.warn("Failed to warm cache for token", e);
}
});
log.info("JWT cache warming completed. Warmed {} tokens", activeTokens.size());
}
@Scheduled(fixedRate = 300000) // Every 5 minutes
public void refreshExpiringEntries() {
// Refresh cache entries that are about to expire
refreshSoonToExpireEntries();
}
}
Testing the JWT Cache
@SpringBootTest
@TestPropertySource(properties = {
"spring.redis.host=localhost",
"spring.redis.port=6379"
})
class JwtTokenCacheServiceTest {
@Autowired
private JwtTokenCacheService tokenCacheService;
@Autowired
private RedisTemplate<String, Object> redisTemplate;
@MockBean
private JwtTokenProvider jwtTokenProvider;
@Test
void shouldCacheValidationResult() {
String token = "valid.jwt.token";
Claims claims = createMockClaims();
when(jwtTokenProvider.validateToken(token)).thenReturn(claims);
// First call - should hit the database
JwtValidationResult result1 = tokenCacheService.validateAndCache(token);
// Second call - should hit the cache
JwtValidationResult result2 = tokenCacheService.validateAndCache(token);
assertThat(result1.isValid()).isTrue();
assertThat(result2.isValid()).isTrue();
// Verify provider was called only once
verify(jwtTokenProvider, times(1)).validateToken(token);
}
@Test
void shouldRespectBlacklist() {
String token = "blacklisted.jwt.token";
String tokenHash = DigestUtils.sha256Hex(token);
// Blacklist the token
redisTemplate.opsForValue().set("blacklist:" + tokenHash,
new BlacklistEntry(), Duration.ofMinutes(30));
JwtValidationResult result = tokenCacheService.validateAndCache(token);
assertThat(result.isValid()).isFalse();
assertThat(result.getStatus()).isEqualTo(ValidationStatus.BLACKLISTED);
}
}
Configuration Properties
# application.yml
spring:
redis:
host: localhost
port: 6379
password: ${REDIS_PASSWORD:}
timeout: 2000ms
jedis:
pool:
max-active: 20
max-idle: 10
min-idle: 2
jwt:
cache:
default-ttl: PT30M
max-entries: 10000
blacklist-ttl: PT24H
permissions-ttl: PT15M
logging:
level:
com.company.security: DEBUG
org.springframework.cache: DEBUG
Best Practices
- TTL Management: Set cache TTL to token expiration time
- Memory Usage: Monitor Redis memory consumption
- Cache Eviction: Implement proper cleanup strategies
- Security: Hash tokens before using as cache keys
- Monitoring: Track hit ratios and performance metrics
- Fallback: Always handle cache failures gracefully
- Blacklist Cleanup: Remove expired blacklist entries
Conclusion
A well-implemented JWT token cache can significantly improve application performance while enhancing security capabilities. The combination of Redis caching, proper TTL management, and comprehensive monitoring creates a robust foundation for enterprise authentication systems.
Remember to balance performance gains with security requirements, and always have fallback mechanisms when the cache is unavailable. The patterns shown here provide a solid foundation for scaling JWT-based authentication in high-traffic applications.