From 2bee0007afd644ce4f12be5e1d3e443af61bdad8 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 21 Jul 2023 18:33:27 +0800 Subject: [PATCH] =?UTF-8?q?=E5=BC=95=E5=85=A5=E7=BA=BF=E7=A8=8B=E6=B1=A0?= =?UTF-8?q?=E8=BF=9B=E8=A1=8C=E5=BC=82=E6=AD=A5=E5=88=86=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sql/create_table.sql | 2 + .../config/ThreadPoolExecutorConfig.java | 44 +++++ .../answerbi/controller/ChartController.java | 154 +++++++++++++++--- .../top/peng/answerbi/manager/AiManager.java | 3 +- .../top/peng/answerbi/model/entity/Chart.java | 10 ++ .../model/enums/BiTaskStatusEnum.java | 70 ++++++++ 6 files changed, 256 insertions(+), 27 deletions(-) create mode 100644 src/main/java/top/peng/answerbi/config/ThreadPoolExecutorConfig.java create mode 100644 src/main/java/top/peng/answerbi/model/enums/BiTaskStatusEnum.java diff --git a/sql/create_table.sql b/sql/create_table.sql index 5a49ee7..a46200b 100644 --- a/sql/create_table.sql +++ b/sql/create_table.sql @@ -37,5 +37,7 @@ create table if not exists chart chart_type varchar(128) null comment '图表类型', gen_chart text null comment '生成的图表数据', gen_result text null comment '生成的分析结论', + status varchar(128) not null default 'wait' comment '任务状态,取值wait、running、succeed、failed', + exec_message text null comment '执行信息', index idx_userId (user_id) ) comment '图表信息表' collate = utf8mb4_unicode_ci; diff --git a/src/main/java/top/peng/answerbi/config/ThreadPoolExecutorConfig.java b/src/main/java/top/peng/answerbi/config/ThreadPoolExecutorConfig.java new file mode 100644 index 0000000..828c5bb --- /dev/null +++ b/src/main/java/top/peng/answerbi/config/ThreadPoolExecutorConfig.java @@ -0,0 +1,44 @@ +/* + * @(#)ThreeadPoolExecutorConfig.java + * + * Copyright © 2023 YunPeng Corporation. + */ +package top.peng.answerbi.config; + +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import org.jetbrains.annotations.NotNull; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +/** + * ThreadPoolExecutorConfig 线程池配置 + * + * @author yunpeng + * @version 1.0 2023/7/21 + */ +@Configuration +public class ThreadPoolExecutorConfig { + + @Bean + public ThreadPoolExecutor threadPoolExecutor(){ + //创建一个线程工厂 + ThreadFactory threadFactory = new ThreadFactory() { + //初始化线程数为1 + private int count = 1; + @Override + public Thread newThread(@NotNull Runnable r) { + Thread thread = new Thread(r); + thread.setName("线程" + count); + count++; + return thread; + } + }; + //创建一个线程池,核心大小为2,最大线程数为4, 非核心线程空闲100秒即被释放 + //任务队列为阻塞队列,长度为4,使用上方创建的线程工厂 threadFactory 来创建线程 + return new ThreadPoolExecutor(2,4,100, + TimeUnit.SECONDS,new ArrayBlockingQueue<>(4),threadFactory); + } +} diff --git a/src/main/java/top/peng/answerbi/controller/ChartController.java b/src/main/java/top/peng/answerbi/controller/ChartController.java index 8c54dfb..b2354a2 100644 --- a/src/main/java/top/peng/answerbi/controller/ChartController.java +++ b/src/main/java/top/peng/answerbi/controller/ChartController.java @@ -1,13 +1,25 @@ package top.peng.answerbi.controller; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; -import com.google.gson.Gson; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Resource; +import javax.servlet.http.HttpServletRequest; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.BeanUtils; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestPart; +import org.springframework.web.bind.annotation.RestController; import org.springframework.web.multipart.MultipartFile; import top.peng.answerbi.annotation.AuthCheck; -import top.peng.answerbi.annotation.GuavaRateLimiter; import top.peng.answerbi.annotation.RedissonRateLimiter; import top.peng.answerbi.common.CommonResponse; import top.peng.answerbi.common.DeleteRequest; @@ -25,18 +37,10 @@ import top.peng.answerbi.model.dto.chart.ChartUpdateRequest; import top.peng.answerbi.model.dto.chart.GenChartByAiRequest; import top.peng.answerbi.model.entity.Chart; import top.peng.answerbi.model.entity.User; +import top.peng.answerbi.model.enums.BiTaskStatusEnum; import top.peng.answerbi.model.vo.BiResponse; import top.peng.answerbi.service.ChartService; import top.peng.answerbi.service.UserService; -import javax.annotation.Resource; -import javax.servlet.http.HttpServletRequest; -import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.BeanUtils; -import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.PostMapping; -import org.springframework.web.bind.annotation.RequestBody; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; import top.peng.answerbi.utils.ExcelUtils; import top.peng.answerbi.utils.ValidUtils; @@ -60,7 +64,8 @@ public class ChartController { @Resource private AiManager aiManager; - private final static Gson GSON = new Gson(); + @Resource + private ThreadPoolExecutor threadPoolExecutor; // region 增删改查 @@ -224,7 +229,7 @@ public class ChartController { } /** - * 智能分析 + * 智能分析 (同步) * * @param multipartFile * @param genChartByAiRequest @@ -235,6 +240,87 @@ public class ChartController { @RedissonRateLimiter(qps = 1) public CommonResponse genChartByAi(@RequestPart("file") MultipartFile multipartFile, GenChartByAiRequest genChartByAiRequest, HttpServletRequest request) { + + //生成参数 + Map aiMsgAndChartData = genAiMsgAndChartData(genChartByAiRequest, multipartFile, request); + String userInput = (String) aiMsgAndChartData.get("userInput"); + Chart chart = (Chart) aiMsgAndChartData.get("chartObj"); + + //调用AI + String aiResult = aiManager.doChat(BiConstant.BI_MODEL_ID, userInput); + BiResponse biResponse = aiManager.aiAnsToBiResp(aiResult); + + //插入数据库 + BeanUtils.copyProperties(biResponse,chart); + chart.setStatus(BiTaskStatusEnum.SUCCEED.getValue()); + boolean saveResult = chartService.save(chart); + ThrowUtils.throwIf(!saveResult, ErrorCode.SYSTEM_ERROR, "图表保存失败"); + biResponse.setChartId(chart.getId()); + return ResultUtils.success(biResponse); + } + + + /** + * 智能分析 (异步) + * + * @param multipartFile + * @param genChartByAiRequest + * @param request + * @return + */ + @PostMapping("/gen/async") + @RedissonRateLimiter(qps = 1) + public CommonResponse genChartByAiAsync(@RequestPart("file") MultipartFile multipartFile, + GenChartByAiRequest genChartByAiRequest, HttpServletRequest request) { + + //生成参数 + Map aiMsgAndChartData = genAiMsgAndChartData(genChartByAiRequest, multipartFile, request); + String userInput = (String) aiMsgAndChartData.get("userInput"); + Chart chart = (Chart) aiMsgAndChartData.get("chartObj"); + + //先插入数据库, 状态为排队中 + chart.setStatus(BiTaskStatusEnum.WAIT.getValue()); + boolean saveResult = chartService.save(chart); + ThrowUtils.throwIf(!saveResult, ErrorCode.SYSTEM_ERROR, "图表保存失败"); + + AtomicReference ResBiResponse = new AtomicReference<>(new BiResponse()); + //创建线程任务 + CompletableFuture.runAsync(() -> { + //先修改图表任务状态为“执行中”; + handleChartStatusUpdate(chart.getId(),BiTaskStatusEnum.RUNNING.getValue(), ""); + + //调用AI + String aiResult = aiManager.doChat(BiConstant.BI_MODEL_ID, userInput); + BiResponse biResponse; + try { + biResponse = aiManager.aiAnsToBiResp(aiResult); + } catch (BusinessException e) { + //执行失败,状态修改为“失败”,记录任务失败信息 + handleChartStatusUpdate(chart.getId(),BiTaskStatusEnum.FAILED.getValue(), e.getMessage()); + throw e; + } + //执行成功后,修改为“已完成”、保存执行结果 + biResponse.setChartId(chart.getId()); + handleChartSuccessStatusUpdate(biResponse); + ResBiResponse.set(biResponse); + }, threadPoolExecutor); + + return ResultUtils.success(ResBiResponse.get()); + } + + /** + * 根据用户输入构建 要发送给AI的消息 和 要存入数据库的 Chart 对象 + * + * @param genChartByAiRequest + * @param multipartFile + * @param request + * @return + */ + private Map genAiMsgAndChartData(GenChartByAiRequest genChartByAiRequest,MultipartFile multipartFile, HttpServletRequest request){ + + //通过request对象拿到用户id(必须登录才能使用) + User loginUser = userService.getLoginUser(request); + String chartName = genChartByAiRequest.getChartName(); String goal = genChartByAiRequest.getGoal(); String chartType = genChartByAiRequest.getChartType(); @@ -244,11 +330,7 @@ public class ChartController { ThrowUtils.throwIf(StringUtils.isBlank(goal),ErrorCode.PARAMS_ERROR,"分析目标为空"); //如果名称不为空,并且名称长度大于100,就抛出异常,并给出提示 ThrowUtils.throwIf(StringUtils.isNotBlank(chartName) && chartName.length() > 100,ErrorCode.PARAMS_ERROR,"图表名称过长"); - - ValidUtils.validFile(multipartFile, 1, Arrays.asList("jpeg", "jpg", "svg", "png", "webp","xls","xlsx")); - - //通过request对象拿到用户id(必须登录才能使用) - User loginUser = userService.getLoginUser(request); + ValidUtils.validFile(multipartFile, 1, Arrays.asList("xls","xlsx")); //用户输入 StringBuilder userInput = new StringBuilder(); @@ -263,21 +345,41 @@ public class ChartController { //压缩后的数据 String csvData = ExcelUtils.excelToCsv(multipartFile); userInput.append(csvData).append("\n"); - String aiResult = aiManager.doChat(BiConstant.BI_MODEL_ID, userInput.toString()); - BiResponse biResponse = aiManager.aiAnsToBiResp(aiResult); - //插入数据库 Chart chart = new Chart(); - BeanUtils.copyProperties(biResponse,chart); chart.setChartName(chartName); chart.setGoal(goal); chart.setChartType(chartType); chart.setChartData(csvData); chart.setUserId(loginUser.getId()); - boolean saveResult = chartService.save(chart); - ThrowUtils.throwIf(!saveResult, ErrorCode.SYSTEM_ERROR, "图表保存失败"); - biResponse.setChartId(chart.getId()); - return ResultUtils.success(biResponse); + + Map result = new HashMap<>(); + result.put("userInput",userInput.toString()); + result.put("chartObj", chart); + return result; + } + + private void handleChartStatusUpdate(long chartId,String status,String execMessage){ + Chart updateChart = new Chart(); + updateChart.setId(chartId); + updateChart.setStatus(status); + updateChart.setExecMessage(execMessage); + boolean updateResult = chartService.updateById(updateChart); + if (!updateResult){ + log.error("更新图表[{}]状态失败", chartId); + } + } + + private void handleChartSuccessStatusUpdate(BiResponse biResponse){ + Chart updateChart = new Chart(); + updateChart.setId(biResponse.getChartId()); + updateChart.setStatus(BiTaskStatusEnum.SUCCEED.getValue()); + updateChart.setGenChart(biResponse.getGenChart()); + updateChart.setGenResult(biResponse.getGenResult()); + boolean updateResult = chartService.updateById(updateChart); + if (!updateResult){ + log.error("更新图表[{}]结果失败", biResponse.getChartId()); + } } } diff --git a/src/main/java/top/peng/answerbi/manager/AiManager.java b/src/main/java/top/peng/answerbi/manager/AiManager.java index d9bc887..f9eeebc 100644 --- a/src/main/java/top/peng/answerbi/manager/AiManager.java +++ b/src/main/java/top/peng/answerbi/manager/AiManager.java @@ -13,6 +13,7 @@ import javax.annotation.Resource; import org.springframework.stereotype.Service; import top.peng.answerbi.common.ErrorCode; import top.peng.answerbi.constant.BiConstant; +import top.peng.answerbi.exception.BusinessException; import top.peng.answerbi.exception.ThrowUtils; import top.peng.answerbi.model.vo.BiResponse; @@ -52,7 +53,7 @@ public class AiManager { * @param aiAnswer AI 对话 结果 * @return BiResponse对象 */ - public BiResponse aiAnsToBiResp(String aiAnswer){ + public BiResponse aiAnsToBiResp(String aiAnswer) throws BusinessException { String[] aiResultSplit = aiAnswer.split(BiConstant.BI_RESULT_SEPARATOR); ThrowUtils.throwIf(aiResultSplit.length < 3,ErrorCode.SYSTEM_ERROR,"AI 生成错误"); BiResponse biResponse = new BiResponse(); diff --git a/src/main/java/top/peng/answerbi/model/entity/Chart.java b/src/main/java/top/peng/answerbi/model/entity/Chart.java index c37440f..cea5a64 100644 --- a/src/main/java/top/peng/answerbi/model/entity/Chart.java +++ b/src/main/java/top/peng/answerbi/model/entity/Chart.java @@ -73,6 +73,16 @@ public class Chart implements Serializable { */ private String genResult; + /** + * 任务状态 + */ + private String status; + + /** + * 执行信息 + */ + private String execMessage; + @TableField(exist = false) private static final long serialVersionUID = 1L; } \ No newline at end of file diff --git a/src/main/java/top/peng/answerbi/model/enums/BiTaskStatusEnum.java b/src/main/java/top/peng/answerbi/model/enums/BiTaskStatusEnum.java new file mode 100644 index 0000000..d4d537c --- /dev/null +++ b/src/main/java/top/peng/answerbi/model/enums/BiTaskStatusEnum.java @@ -0,0 +1,70 @@ +/* + * @(#)BiTaskStatus.java + * + * Copyright © 2023 YunPeng Corporation. + */ +package top.peng.answerbi.model.enums; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.commons.lang3.ObjectUtils; + +/** + * BiTaskStatus 智能分析任务状态枚举 + * + * @author yunpeng + * @version 1.0 2023/7/21 + */ +public enum BiTaskStatusEnum { + + WAIT("排队中","wait"), + RUNNING("生成中","running"), + SUCCEED("成功","succeed"), + FAILED("失败","failed"), + ; + + private final String text; + + private final String value; + + BiTaskStatusEnum(String text, String value) { + this.text = text; + this.value = value; + } + + /** + * 获取值列表 + * + * @return + */ + public static List getValues() { + return Arrays.stream(values()).map(item -> item.value).collect(Collectors.toList()); + } + + /** + * 根据 value 获取枚举 + * + * @param value + * @return + */ + public static BiTaskStatusEnum getEnumByValue(String value) { + if (ObjectUtils.isEmpty(value)) { + return null; + } + for (BiTaskStatusEnum anEnum : BiTaskStatusEnum.values()) { + if (anEnum.value.equals(value)) { + return anEnum; + } + } + return null; + } + + public String getValue() { + return value; + } + + public String getText() { + return text; + } +}