diff --git a/pom.xml b/pom.xml
index 6c45f46..28c7775 100644
--- a/pom.xml
+++ b/pom.xml
@@ -126,6 +126,11 @@
yucongming-java-sdk
0.0.3
+
+ com.google.guava
+ guava
+ 29.0-jre
+
diff --git a/src/main/java/top/peng/answerbi/annotation/RateLimiterTag.java b/src/main/java/top/peng/answerbi/annotation/RateLimiterTag.java
new file mode 100644
index 0000000..4ac6d62
--- /dev/null
+++ b/src/main/java/top/peng/answerbi/annotation/RateLimiterTag.java
@@ -0,0 +1,40 @@
+/*
+ * @(#)RateLimiter.java
+ *
+ * Copyright © 2023 YunPeng Corporation.
+ */
+package top.peng.answerbi.annotation;
+
+import java.lang.annotation.Documented;
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * RateLimiter 限流注解
+ *
+ * @author yunpeng
+ * @version 1.0 2023/7/20
+ */
+@Target(ElementType.METHOD)
+@Retention(RetentionPolicy.RUNTIME)
+@Documented
+public @interface RateLimiterTag {
+ int NOT_LIMITED = 0;
+ /**
+ * 用户qps, 每个用户每秒的请求限制
+ */
+ double qps() default NOT_LIMITED;
+
+ /**
+ * 超时时长
+ */
+ int timeout() default 0;
+
+ /**
+ * 超时时间单位
+ */
+ TimeUnit timeUnit() default TimeUnit.MILLISECONDS;
+}
diff --git a/src/main/java/top/peng/answerbi/aop/RateLimiterInterceptor.java b/src/main/java/top/peng/answerbi/aop/RateLimiterInterceptor.java
new file mode 100644
index 0000000..c52e96a
--- /dev/null
+++ b/src/main/java/top/peng/answerbi/aop/RateLimiterInterceptor.java
@@ -0,0 +1,74 @@
+/*
+ * @(#)RateLimiterInterceptor.java
+ *
+ * Copyright © 2023 YunPeng Corporation.
+ */
+package top.peng.answerbi.aop;
+
+import com.google.common.util.concurrent.RateLimiter;
+import java.lang.reflect.Method;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import javax.annotation.Resource;
+import javax.servlet.http.HttpServletRequest;
+import lombok.extern.slf4j.Slf4j;
+import org.aspectj.lang.ProceedingJoinPoint;
+import org.aspectj.lang.annotation.Around;
+import org.aspectj.lang.annotation.Aspect;
+import org.aspectj.lang.reflect.MethodSignature;
+import org.springframework.stereotype.Component;
+import org.springframework.web.context.request.RequestAttributes;
+import org.springframework.web.context.request.RequestContextHolder;
+import org.springframework.web.context.request.ServletRequestAttributes;
+import top.peng.answerbi.annotation.RateLimiterTag;
+import top.peng.answerbi.common.ErrorCode;
+import top.peng.answerbi.exception.ThrowUtils;
+import top.peng.answerbi.model.entity.User;
+import top.peng.answerbi.service.UserService;
+
+/**
+ * RateLimiterInterceptor 限流切面
+ *
+ * @author yunpeng
+ * @version 1.0 2023/7/20
+ */
+@Slf4j
+@Aspect
+@Component
+public class RateLimiterInterceptor {
+ @Resource
+ private UserService userService;
+
+ private static final ConcurrentMap RATE_LIMITER_CACHE = new ConcurrentHashMap<>();
+
+ @Around("@annotation(rateLimiterTag)")
+ public Object doInterceptor(ProceedingJoinPoint point, RateLimiterTag rateLimiterTag) throws Throwable {
+
+ RequestAttributes requestAttributes = RequestContextHolder.currentRequestAttributes();
+ HttpServletRequest request = ((ServletRequestAttributes) requestAttributes).getRequest();
+ // 当前登录用户
+ User loginUser = userService.getLoginUser(request);
+ // 当前请求方法
+ Method method = ((MethodSignature) point.getSignature()).getMethod();
+ String key = loginUser.getId() + "#" + method.getName();
+ if (rateLimiterTag != null && rateLimiterTag.qps() > RateLimiterTag.NOT_LIMITED) {
+ double qps = rateLimiterTag.qps();
+
+ if (RATE_LIMITER_CACHE.get(key) == null) {
+ // 初始化 QPS
+ RATE_LIMITER_CACHE.put(key, RateLimiter.create(qps));
+ }
+
+ log.debug("【{}】每个用户的QPS设置为: {}", method.getName(), RATE_LIMITER_CACHE.get(key).getRate());
+ // 尝试获取令牌
+ if (RATE_LIMITER_CACHE.get(key) != null){
+ RateLimiter limiter = RATE_LIMITER_CACHE.get(key);
+ ThrowUtils.throwIf(
+ !limiter.tryAcquire(rateLimiterTag.timeout(), rateLimiterTag.timeUnit()),
+ ErrorCode.TOO_MANY_REQUEST);
+ }
+ }
+ return point.proceed();
+ }
+
+}
diff --git a/src/main/java/top/peng/answerbi/common/ErrorCode.java b/src/main/java/top/peng/answerbi/common/ErrorCode.java
index 67c1618..8742a60 100644
--- a/src/main/java/top/peng/answerbi/common/ErrorCode.java
+++ b/src/main/java/top/peng/answerbi/common/ErrorCode.java
@@ -13,6 +13,7 @@ public enum ErrorCode {
NOT_LOGIN_ERROR(40100, "未登录"),
NO_AUTH_ERROR(40101, "无权限"),
NOT_FOUND_ERROR(40400, "请求数据不存在"),
+ TOO_MANY_REQUEST(42900,"请求过于频繁"),
FORBIDDEN_ERROR(40300, "禁止访问"),
SYSTEM_ERROR(50000, "系统内部异常"),
OPERATION_ERROR(50001, "操作失败");
diff --git a/src/main/java/top/peng/answerbi/controller/ChartController.java b/src/main/java/top/peng/answerbi/controller/ChartController.java
index bd97e05..d7052e3 100644
--- a/src/main/java/top/peng/answerbi/controller/ChartController.java
+++ b/src/main/java/top/peng/answerbi/controller/ChartController.java
@@ -3,12 +3,15 @@ package top.peng.answerbi.controller;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.google.gson.Gson;
import java.io.File;
+import java.util.Arrays;
+import java.util.List;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.poi.util.StringUtil;
import org.springframework.web.bind.annotation.RequestPart;
import org.springframework.web.multipart.MultipartFile;
import top.peng.answerbi.annotation.AuthCheck;
+import top.peng.answerbi.annotation.RateLimiterTag;
import top.peng.answerbi.common.CommonResponse;
import top.peng.answerbi.common.DeleteRequest;
import top.peng.answerbi.common.ErrorCode;
@@ -41,6 +44,7 @@ import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import top.peng.answerbi.utils.ExcelUtils;
+import top.peng.answerbi.utils.ValidUtils;
/**
* 图表接口
@@ -234,6 +238,7 @@ public class ChartController {
* @return
*/
@PostMapping("/gen")
+ @RateLimiterTag(qps = 1.0, timeout = 100)
public CommonResponse genChartByAi(@RequestPart("file") MultipartFile multipartFile,
GenChartByAiRequest genChartByAiRequest, HttpServletRequest request) {
String chartName = genChartByAiRequest.getChartName();
@@ -246,6 +251,8 @@ public class ChartController {
//如果名称不为空,并且名称长度大于100,就抛出异常,并给出提示
ThrowUtils.throwIf(StringUtils.isNotBlank(chartName) && chartName.length() > 100,ErrorCode.PARAMS_ERROR,"图表名称过长");
+ ValidUtils.validFile(multipartFile, 1, Arrays.asList("jpeg", "jpg", "svg", "png", "webp","xls","xlsx"));
+
//通过request对象拿到用户id(必须登录才能使用)
User loginUser = userService.getLoginUser(request);
@@ -262,8 +269,10 @@ public class ChartController {
//压缩后的数据
String csvData = ExcelUtils.excelToCsv(multipartFile);
userInput.append(csvData).append("\n");
-
- String aiResult = aiManager.doChat(BiConstant.BI_MODEL_ID, userInput.toString());
+ BiResponse biResponse = new BiResponse();
+ biResponse.setGenChart(userInput.toString());
+ return ResultUtils.success(biResponse);
+ /*String aiResult = aiManager.doChat(BiConstant.BI_MODEL_ID, userInput.toString());
BiResponse biResponse = aiManager.aiAnsToBiResp(aiResult);
//插入数据库
@@ -277,7 +286,7 @@ public class ChartController {
boolean saveResult = chartService.save(chart);
ThrowUtils.throwIf(!saveResult, ErrorCode.SYSTEM_ERROR, "图表保存失败");
biResponse.setChartId(chart.getId());
- return ResultUtils.success(biResponse);
+ return ResultUtils.success(biResponse);*/
}
}
diff --git a/src/main/java/top/peng/answerbi/controller/FileController.java b/src/main/java/top/peng/answerbi/controller/FileController.java
index 7ed5e64..655fdee 100644
--- a/src/main/java/top/peng/answerbi/controller/FileController.java
+++ b/src/main/java/top/peng/answerbi/controller/FileController.java
@@ -22,6 +22,7 @@ import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestPart;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
+import top.peng.answerbi.utils.ValidUtils;
/**
* 文件接口
@@ -91,18 +92,8 @@ public class FileController {
* @param fileUploadBizEnum 业务类型
*/
private void validFile(MultipartFile multipartFile, FileUploadBizEnum fileUploadBizEnum) {
- // 文件大小
- long fileSize = multipartFile.getSize();
- // 文件后缀
- String fileSuffix = FileUtil.getSuffix(multipartFile.getOriginalFilename());
- final long ONE_M = 1024 * 1024L;
if (FileUploadBizEnum.USER_AVATAR.equals(fileUploadBizEnum)) {
- if (fileSize > ONE_M) {
- throw new BusinessException(ErrorCode.PARAMS_ERROR, "文件大小不能超过 1M");
- }
- if (!Arrays.asList("jpeg", "jpg", "svg", "png", "webp").contains(fileSuffix)) {
- throw new BusinessException(ErrorCode.PARAMS_ERROR, "文件类型错误");
- }
+ ValidUtils.validFile(multipartFile, 1, Arrays.asList("jpeg", "jpg", "svg", "png", "webp"));
}
}
}
diff --git a/src/main/java/top/peng/answerbi/service/ChartService.java b/src/main/java/top/peng/answerbi/service/ChartService.java
index 5fd845f..bfca9e7 100644
--- a/src/main/java/top/peng/answerbi/service/ChartService.java
+++ b/src/main/java/top/peng/answerbi/service/ChartService.java
@@ -17,7 +17,7 @@ public interface ChartService extends IService {
/**
* 获取查询条件
*
- * @param postQueryRequest
+ * @param chartQueryRequest
* @return
*/
QueryWrapper getQueryWrapper(ChartQueryRequest chartQueryRequest);
diff --git a/src/main/java/top/peng/answerbi/service/impl/ChartServiceImpl.java b/src/main/java/top/peng/answerbi/service/impl/ChartServiceImpl.java
index 063ea70..6432fa5 100644
--- a/src/main/java/top/peng/answerbi/service/impl/ChartServiceImpl.java
+++ b/src/main/java/top/peng/answerbi/service/impl/ChartServiceImpl.java
@@ -41,7 +41,7 @@ public class ChartServiceImpl extends ServiceImpl
String sortOrder = chartQueryRequest.getSortOrder();
queryWrapper.eq(id != null && id > 0, "id", id);
- queryWrapper.eq(StringUtils.isNotBlank(chartName), "chart_name", goal);
+ queryWrapper.like(StringUtils.isNotBlank(chartName), "chart_name", chartName);
queryWrapper.eq(StringUtils.isNotBlank(goal), "goal", goal);
queryWrapper.eq(StringUtils.isNotBlank(chartType), "chart_type", chartType);
queryWrapper.eq(ObjectUtils.isNotEmpty(userId), "user_id", userId);
diff --git a/src/main/java/top/peng/answerbi/utils/ValidUtils.java b/src/main/java/top/peng/answerbi/utils/ValidUtils.java
new file mode 100644
index 0000000..0e80037
--- /dev/null
+++ b/src/main/java/top/peng/answerbi/utils/ValidUtils.java
@@ -0,0 +1,38 @@
+/*
+ * @(#)ValidUtils.java
+ *
+ * Copyright © 2023 YunPeng Corporation.
+ */
+package top.peng.answerbi.utils;
+
+import cn.hutool.core.io.FileUtil;
+import java.util.Arrays;
+import java.util.List;
+import org.springframework.web.multipart.MultipartFile;
+import top.peng.answerbi.common.ErrorCode;
+import top.peng.answerbi.exception.BusinessException;
+import top.peng.answerbi.exception.ThrowUtils;
+import top.peng.answerbi.model.enums.FileUploadBizEnum;
+
+/**
+ * ValidUtils 校验工具类
+ *
+ * @author yunpeng
+ * @version 1.0 2023/7/20
+ */
+public class ValidUtils {
+ public static void validFile(MultipartFile multipartFile, int limitSizeMb, List validFileSuffixList) {
+ // 文件大小
+ long fileSize = multipartFile.getSize();
+
+ final long LIMIT_M = limitSizeMb * 1024 * 1024L;
+
+ ThrowUtils.throwIf(fileSize > LIMIT_M, ErrorCode.PARAMS_ERROR, "文件大小不能超过 "+ limitSizeMb +" M");
+
+ // 文件后缀
+ String fileSuffix = FileUtil.getSuffix(multipartFile.getOriginalFilename());
+
+
+ ThrowUtils.throwIf(!validFileSuffixList.contains(fileSuffix), ErrorCode.PARAMS_ERROR, "文件类型错误");
+ }
+}