diff --git a/pom.xml b/pom.xml index 6c45f46..28c7775 100644 --- a/pom.xml +++ b/pom.xml @@ -126,6 +126,11 @@ yucongming-java-sdk 0.0.3 + + com.google.guava + guava + 29.0-jre + diff --git a/src/main/java/top/peng/answerbi/annotation/RateLimiterTag.java b/src/main/java/top/peng/answerbi/annotation/RateLimiterTag.java new file mode 100644 index 0000000..4ac6d62 --- /dev/null +++ b/src/main/java/top/peng/answerbi/annotation/RateLimiterTag.java @@ -0,0 +1,40 @@ +/* + * @(#)RateLimiter.java + * + * Copyright © 2023 YunPeng Corporation. + */ +package top.peng.answerbi.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.util.concurrent.TimeUnit; + +/** + * RateLimiter 限流注解 + * + * @author yunpeng + * @version 1.0 2023/7/20 + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface RateLimiterTag { + int NOT_LIMITED = 0; + /** + * 用户qps, 每个用户每秒的请求限制 + */ + double qps() default NOT_LIMITED; + + /** + * 超时时长 + */ + int timeout() default 0; + + /** + * 超时时间单位 + */ + TimeUnit timeUnit() default TimeUnit.MILLISECONDS; +} diff --git a/src/main/java/top/peng/answerbi/aop/RateLimiterInterceptor.java b/src/main/java/top/peng/answerbi/aop/RateLimiterInterceptor.java new file mode 100644 index 0000000..c52e96a --- /dev/null +++ b/src/main/java/top/peng/answerbi/aop/RateLimiterInterceptor.java @@ -0,0 +1,74 @@ +/* + * @(#)RateLimiterInterceptor.java + * + * Copyright © 2023 YunPeng Corporation. + */ +package top.peng.answerbi.aop; + +import com.google.common.util.concurrent.RateLimiter; +import java.lang.reflect.Method; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import javax.annotation.Resource; +import javax.servlet.http.HttpServletRequest; +import lombok.extern.slf4j.Slf4j; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Around; +import org.aspectj.lang.annotation.Aspect; +import org.aspectj.lang.reflect.MethodSignature; +import org.springframework.stereotype.Component; +import org.springframework.web.context.request.RequestAttributes; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; +import top.peng.answerbi.annotation.RateLimiterTag; +import top.peng.answerbi.common.ErrorCode; +import top.peng.answerbi.exception.ThrowUtils; +import top.peng.answerbi.model.entity.User; +import top.peng.answerbi.service.UserService; + +/** + * RateLimiterInterceptor 限流切面 + * + * @author yunpeng + * @version 1.0 2023/7/20 + */ +@Slf4j +@Aspect +@Component +public class RateLimiterInterceptor { + @Resource + private UserService userService; + + private static final ConcurrentMap RATE_LIMITER_CACHE = new ConcurrentHashMap<>(); + + @Around("@annotation(rateLimiterTag)") + public Object doInterceptor(ProceedingJoinPoint point, RateLimiterTag rateLimiterTag) throws Throwable { + + RequestAttributes requestAttributes = RequestContextHolder.currentRequestAttributes(); + HttpServletRequest request = ((ServletRequestAttributes) requestAttributes).getRequest(); + // 当前登录用户 + User loginUser = userService.getLoginUser(request); + // 当前请求方法 + Method method = ((MethodSignature) point.getSignature()).getMethod(); + String key = loginUser.getId() + "#" + method.getName(); + if (rateLimiterTag != null && rateLimiterTag.qps() > RateLimiterTag.NOT_LIMITED) { + double qps = rateLimiterTag.qps(); + + if (RATE_LIMITER_CACHE.get(key) == null) { + // 初始化 QPS + RATE_LIMITER_CACHE.put(key, RateLimiter.create(qps)); + } + + log.debug("【{}】每个用户的QPS设置为: {}", method.getName(), RATE_LIMITER_CACHE.get(key).getRate()); + // 尝试获取令牌 + if (RATE_LIMITER_CACHE.get(key) != null){ + RateLimiter limiter = RATE_LIMITER_CACHE.get(key); + ThrowUtils.throwIf( + !limiter.tryAcquire(rateLimiterTag.timeout(), rateLimiterTag.timeUnit()), + ErrorCode.TOO_MANY_REQUEST); + } + } + return point.proceed(); + } + +} diff --git a/src/main/java/top/peng/answerbi/common/ErrorCode.java b/src/main/java/top/peng/answerbi/common/ErrorCode.java index 67c1618..8742a60 100644 --- a/src/main/java/top/peng/answerbi/common/ErrorCode.java +++ b/src/main/java/top/peng/answerbi/common/ErrorCode.java @@ -13,6 +13,7 @@ public enum ErrorCode { NOT_LOGIN_ERROR(40100, "未登录"), NO_AUTH_ERROR(40101, "无权限"), NOT_FOUND_ERROR(40400, "请求数据不存在"), + TOO_MANY_REQUEST(42900,"请求过于频繁"), FORBIDDEN_ERROR(40300, "禁止访问"), SYSTEM_ERROR(50000, "系统内部异常"), OPERATION_ERROR(50001, "操作失败"); diff --git a/src/main/java/top/peng/answerbi/controller/ChartController.java b/src/main/java/top/peng/answerbi/controller/ChartController.java index bd97e05..d7052e3 100644 --- a/src/main/java/top/peng/answerbi/controller/ChartController.java +++ b/src/main/java/top/peng/answerbi/controller/ChartController.java @@ -3,12 +3,15 @@ package top.peng.answerbi.controller; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.google.gson.Gson; import java.io.File; +import java.util.Arrays; +import java.util.List; import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.StringUtils; import org.apache.poi.util.StringUtil; import org.springframework.web.bind.annotation.RequestPart; import org.springframework.web.multipart.MultipartFile; import top.peng.answerbi.annotation.AuthCheck; +import top.peng.answerbi.annotation.RateLimiterTag; import top.peng.answerbi.common.CommonResponse; import top.peng.answerbi.common.DeleteRequest; import top.peng.answerbi.common.ErrorCode; @@ -41,6 +44,7 @@ import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import top.peng.answerbi.utils.ExcelUtils; +import top.peng.answerbi.utils.ValidUtils; /** * 图表接口 @@ -234,6 +238,7 @@ public class ChartController { * @return */ @PostMapping("/gen") + @RateLimiterTag(qps = 1.0, timeout = 100) public CommonResponse genChartByAi(@RequestPart("file") MultipartFile multipartFile, GenChartByAiRequest genChartByAiRequest, HttpServletRequest request) { String chartName = genChartByAiRequest.getChartName(); @@ -246,6 +251,8 @@ public class ChartController { //如果名称不为空,并且名称长度大于100,就抛出异常,并给出提示 ThrowUtils.throwIf(StringUtils.isNotBlank(chartName) && chartName.length() > 100,ErrorCode.PARAMS_ERROR,"图表名称过长"); + ValidUtils.validFile(multipartFile, 1, Arrays.asList("jpeg", "jpg", "svg", "png", "webp","xls","xlsx")); + //通过request对象拿到用户id(必须登录才能使用) User loginUser = userService.getLoginUser(request); @@ -262,8 +269,10 @@ public class ChartController { //压缩后的数据 String csvData = ExcelUtils.excelToCsv(multipartFile); userInput.append(csvData).append("\n"); - - String aiResult = aiManager.doChat(BiConstant.BI_MODEL_ID, userInput.toString()); + BiResponse biResponse = new BiResponse(); + biResponse.setGenChart(userInput.toString()); + return ResultUtils.success(biResponse); + /*String aiResult = aiManager.doChat(BiConstant.BI_MODEL_ID, userInput.toString()); BiResponse biResponse = aiManager.aiAnsToBiResp(aiResult); //插入数据库 @@ -277,7 +286,7 @@ public class ChartController { boolean saveResult = chartService.save(chart); ThrowUtils.throwIf(!saveResult, ErrorCode.SYSTEM_ERROR, "图表保存失败"); biResponse.setChartId(chart.getId()); - return ResultUtils.success(biResponse); + return ResultUtils.success(biResponse);*/ } } diff --git a/src/main/java/top/peng/answerbi/controller/FileController.java b/src/main/java/top/peng/answerbi/controller/FileController.java index 7ed5e64..655fdee 100644 --- a/src/main/java/top/peng/answerbi/controller/FileController.java +++ b/src/main/java/top/peng/answerbi/controller/FileController.java @@ -22,6 +22,7 @@ import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestPart; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.multipart.MultipartFile; +import top.peng.answerbi.utils.ValidUtils; /** * 文件接口 @@ -91,18 +92,8 @@ public class FileController { * @param fileUploadBizEnum 业务类型 */ private void validFile(MultipartFile multipartFile, FileUploadBizEnum fileUploadBizEnum) { - // 文件大小 - long fileSize = multipartFile.getSize(); - // 文件后缀 - String fileSuffix = FileUtil.getSuffix(multipartFile.getOriginalFilename()); - final long ONE_M = 1024 * 1024L; if (FileUploadBizEnum.USER_AVATAR.equals(fileUploadBizEnum)) { - if (fileSize > ONE_M) { - throw new BusinessException(ErrorCode.PARAMS_ERROR, "文件大小不能超过 1M"); - } - if (!Arrays.asList("jpeg", "jpg", "svg", "png", "webp").contains(fileSuffix)) { - throw new BusinessException(ErrorCode.PARAMS_ERROR, "文件类型错误"); - } + ValidUtils.validFile(multipartFile, 1, Arrays.asList("jpeg", "jpg", "svg", "png", "webp")); } } } diff --git a/src/main/java/top/peng/answerbi/service/ChartService.java b/src/main/java/top/peng/answerbi/service/ChartService.java index 5fd845f..bfca9e7 100644 --- a/src/main/java/top/peng/answerbi/service/ChartService.java +++ b/src/main/java/top/peng/answerbi/service/ChartService.java @@ -17,7 +17,7 @@ public interface ChartService extends IService { /** * 获取查询条件 * - * @param postQueryRequest + * @param chartQueryRequest * @return */ QueryWrapper getQueryWrapper(ChartQueryRequest chartQueryRequest); diff --git a/src/main/java/top/peng/answerbi/service/impl/ChartServiceImpl.java b/src/main/java/top/peng/answerbi/service/impl/ChartServiceImpl.java index 063ea70..6432fa5 100644 --- a/src/main/java/top/peng/answerbi/service/impl/ChartServiceImpl.java +++ b/src/main/java/top/peng/answerbi/service/impl/ChartServiceImpl.java @@ -41,7 +41,7 @@ public class ChartServiceImpl extends ServiceImpl String sortOrder = chartQueryRequest.getSortOrder(); queryWrapper.eq(id != null && id > 0, "id", id); - queryWrapper.eq(StringUtils.isNotBlank(chartName), "chart_name", goal); + queryWrapper.like(StringUtils.isNotBlank(chartName), "chart_name", chartName); queryWrapper.eq(StringUtils.isNotBlank(goal), "goal", goal); queryWrapper.eq(StringUtils.isNotBlank(chartType), "chart_type", chartType); queryWrapper.eq(ObjectUtils.isNotEmpty(userId), "user_id", userId); diff --git a/src/main/java/top/peng/answerbi/utils/ValidUtils.java b/src/main/java/top/peng/answerbi/utils/ValidUtils.java new file mode 100644 index 0000000..0e80037 --- /dev/null +++ b/src/main/java/top/peng/answerbi/utils/ValidUtils.java @@ -0,0 +1,38 @@ +/* + * @(#)ValidUtils.java + * + * Copyright © 2023 YunPeng Corporation. + */ +package top.peng.answerbi.utils; + +import cn.hutool.core.io.FileUtil; +import java.util.Arrays; +import java.util.List; +import org.springframework.web.multipart.MultipartFile; +import top.peng.answerbi.common.ErrorCode; +import top.peng.answerbi.exception.BusinessException; +import top.peng.answerbi.exception.ThrowUtils; +import top.peng.answerbi.model.enums.FileUploadBizEnum; + +/** + * ValidUtils 校验工具类 + * + * @author yunpeng + * @version 1.0 2023/7/20 + */ +public class ValidUtils { + public static void validFile(MultipartFile multipartFile, int limitSizeMb, List validFileSuffixList) { + // 文件大小 + long fileSize = multipartFile.getSize(); + + final long LIMIT_M = limitSizeMb * 1024 * 1024L; + + ThrowUtils.throwIf(fileSize > LIMIT_M, ErrorCode.PARAMS_ERROR, "文件大小不能超过 "+ limitSizeMb +" M"); + + // 文件后缀 + String fileSuffix = FileUtil.getSuffix(multipartFile.getOriginalFilename()); + + + ThrowUtils.throwIf(!validFileSuffixList.contains(fileSuffix), ErrorCode.PARAMS_ERROR, "文件类型错误"); + } +}