package cc.mrbird.febs.common.aspect; 
 | 
  
 | 
import cc.mrbird.febs.common.annotation.Limit; 
 | 
import cc.mrbird.febs.common.entity.LimitType; 
 | 
import cc.mrbird.febs.common.exception.LimitAccessException; 
 | 
import cc.mrbird.febs.common.utils.HttpContextUtil; 
 | 
import cc.mrbird.febs.common.utils.IpUtil; 
 | 
import com.google.common.collect.ImmutableList; 
 | 
import lombok.RequiredArgsConstructor; 
 | 
import lombok.extern.slf4j.Slf4j; 
 | 
import org.apache.commons.lang3.StringUtils; 
 | 
import org.aspectj.lang.ProceedingJoinPoint; 
 | 
import org.aspectj.lang.annotation.Around; 
 | 
import org.aspectj.lang.annotation.Aspect; 
 | 
import org.aspectj.lang.annotation.Pointcut; 
 | 
import org.springframework.data.redis.core.RedisTemplate; 
 | 
import org.springframework.data.redis.core.script.DefaultRedisScript; 
 | 
import org.springframework.data.redis.core.script.RedisScript; 
 | 
import org.springframework.stereotype.Component; 
 | 
  
 | 
import javax.servlet.http.HttpServletRequest; 
 | 
import java.lang.reflect.Method; 
 | 
  
 | 
  
 | 
/** 
 | 
 * 接口限流 
 | 
 * 
 | 
 * @author MrBird 
 | 
 */ 
 | 
@Slf4j 
 | 
@Aspect 
 | 
@Component 
 | 
@RequiredArgsConstructor 
 | 
public class LimitAspect extends BaseAspectSupport { 
 | 
  
 | 
    private final RedisTemplate<String, Object> redisTemplate; 
 | 
  
 | 
    @Pointcut("@annotation(cc.mrbird.febs.common.annotation.Limit)") 
 | 
    public void pointcut() { 
 | 
    } 
 | 
  
 | 
    @Around("pointcut()") 
 | 
    public Object around(ProceedingJoinPoint point) throws Throwable { 
 | 
        HttpServletRequest request = HttpContextUtil.getHttpServletRequest(); 
 | 
        Method method = resolveMethod(point); 
 | 
        Limit limitAnnotation = method.getAnnotation(Limit.class); 
 | 
        LimitType limitType = limitAnnotation.limitType(); 
 | 
        String name = limitAnnotation.name(); 
 | 
        String key; 
 | 
        String ip = IpUtil.getIpAddr(request); 
 | 
        int limitPeriod = limitAnnotation.period(); 
 | 
        int limitCount = limitAnnotation.count(); 
 | 
        switch (limitType) { 
 | 
            case IP: 
 | 
                key = ip; 
 | 
                break; 
 | 
            case CUSTOMER: 
 | 
                key = limitAnnotation.key(); 
 | 
                break; 
 | 
            default: 
 | 
                key = StringUtils.upperCase(method.getName()); 
 | 
        } 
 | 
        ImmutableList<String> keys = ImmutableList.of(StringUtils.join(limitAnnotation.prefix() + "_", key, ip)); 
 | 
        String luaScript = buildLuaScript(); 
 | 
        RedisScript<Number> redisScript = new DefaultRedisScript<>(luaScript, Number.class); 
 | 
        Number count = redisTemplate.execute(redisScript, keys, limitCount, limitPeriod); 
 | 
        log.info("IP:{} 第 {} 次访问key为 {},描述为 [{}] 的接口", ip, count, keys, name); 
 | 
        if (count != null && count.intValue() <= limitCount) { 
 | 
            return point.proceed(); 
 | 
        } else { 
 | 
            throw new LimitAccessException("接口访问超出频率限制"); 
 | 
        } 
 | 
    } 
 | 
  
 | 
    /** 
 | 
     * 限流脚本 
 | 
     * 调用的时候不超过阈值,则直接返回并执行计算器自加。 
 | 
     * 
 | 
     * @return lua脚本 
 | 
     */ 
 | 
    private String buildLuaScript() { 
 | 
        return "local c" + 
 | 
                "\nc = redis.call('get',KEYS[1])" + 
 | 
                "\nif c and tonumber(c) > tonumber(ARGV[1]) then" + 
 | 
                "\nreturn c;" + 
 | 
                "\nend" + 
 | 
                "\nc = redis.call('incr',KEYS[1])" + 
 | 
                "\nif tonumber(c) == 1 then" + 
 | 
                "\nredis.call('expire',KEYS[1],ARGV[2])" + 
 | 
                "\nend" + 
 | 
                "\nreturn c;"; 
 | 
    } 
 | 
} 
 |