依赖
<dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-aop</artifactId> </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-redis</artifactId> <version>1.4.2.RELEASE</version> </dependency> <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> <version>23.0</version> </dependency>
自定义注解
@Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) public @interface Limit { // 资源名称,用于描述接口功能 String name() default ""; // 资源 key String key() default ""; // key prefix String prefix() default ""; // 时间的,单位秒 int period(); // 限制访问次数 int count(); // 限制类型 LimitType limitType() default LimitType.CUSTOMER; }
枚举类
public enum LimitType { // 传统类型 CUSTOMER, // 根据 IP 限制 IP; }
异常类
public class LimitAccessException extends Exception { private static final long serialVersionUID = -3608667856397125671L; public LimitAccessException(String message) { super(message); } }
获取ip地址工具类
public class IPUtil { private static final String UNKNOWN = "unknown"; protected IPUtil(){ } /** * 获取 IP地址 * 使用 Nginx等反向代理软件, 则不能通过 request.getRemoteAddr()获取 IP地址 * 如果使用了多级反向代理的话,X-Forwarded-For的值并不止一个,而是一串IP地址, * X-Forwarded-For中第一个非 unknown的有效IP字符串,则为真实IP地址 */ public static String getIpAddr(HttpServletRequest request) { String ip = request.getHeader("x-forwarded-for"); if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) { ip = request.getHeader("Proxy-Client-IP"); } if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) { ip = request.getHeader("WL-Proxy-Client-IP"); } if (ip == null || ip.length() == 0 || UNKNOWN.equalsIgnoreCase(ip)) { ip = request.getRemoteAddr(); } return "0:0:0:0:0:0:0:1".equals(ip) ? "127.0.0.1" : ip; } }
aop实现接口限流,redis执行lua脚本说明如下
tonumber把字符串转换为数字 tostring则把数字转换为字符串。 Lua认为false和nil为假,true 和非nil为真 Redis Incr 命令将 key 中储存的数字值增一。 如果 key 不存在,那么 key 的值会先被初始化为 0 Redis Expire 命令用于设置 key 的过期时间,key 过期后将不再可用。单位以秒计。 主要使用redistempalte封装的execute来实现调用 其中第二个参数为lua脚本中的KEYS[1],后面为可变参数 可变参数即传入我们需要的值,分别对应ARGV[1],ARGV[2]
/** * 接口限流 */ @Slf4j @Aspect @Component public class LimitAspect { private final RedisTemplate<String, Serializable> limitRedisTemplate; @Autowired public LimitAspect(RedisTemplate<String, Serializable> limitRedisTemplate) { this.limitRedisTemplate = limitRedisTemplate; } @Pointcut("@annotation(com.bruce.kls.common.annotation.Limit)") public void pointcut() { // do nothing } @Around("pointcut()") public Object around(ProceedingJoinPoint point) throws Throwable { HttpServletRequest request = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest(); MethodSignature signature = (MethodSignature) point.getSignature(); Method method = signature.getMethod(); Limit limitAnnotation = method.getAnnotation(Limit.class); LimitType limitType = limitAnnotation.limitType(); String name = limitAnnotation.name(); String key; String ip = IPUtil.getIpAddr(request); int limitPeriod = limitAnnotation.period(); int limitCount = limitAnnotation.count(); switch (limitType) { case IP: key = ip; break; case CUSTOMER: key = limitAnnotation.key(); break; default: key = StringUtils.upperCase(method.getName()); } ImmutableList<String> keys = ImmutableList.of(StringUtils.join(limitAnnotation.prefix() + "_", key, ip)); String luaScript = buildLuaScript(); RedisScript<Number> redisScript = new DefaultRedisScript<>(luaScript, Number.class); Number count = limitRedisTemplate.execute(redisScript, keys, limitCount, limitPeriod); log.info("IP:{} 第 {} 次访问key为 {},描述为 [{}] 的接口", ip, count, keys, name); if (count != null && count.intValue() <= limitCount) { return point.proceed(); } else { throw new LimitAccessException("接口访问超出频率限制"); } } /** * 限流脚本 * 调用的时候不超过阈值,则执行计算器自加。 * @return lua脚本 */ private String buildLuaScript() { return "local c" + "\nc = redis.call('get',KEYS[1])" + "\nif c and tonumber(c) > tonumber(ARGV[1]) then" + "\nreturn c;" + "\nend" + "\nc = redis.call('incr',KEYS[1])" + "\nif tonumber(c) == 1 then" + "\nredis.call('expire',KEYS[1],ARGV[2])" + "\nend" + "\nreturn c;"; } }
测试使用
@GetMapping("/hello") @Limit(key = "hello", period = 10, count = 3, name = "测试接口", prefix = "limit") public Response<?> hello(){ return Response.success("ok"); }