引入线程池进行异步分析

This commit is contained in:
brian 2023-07-21 18:33:27 +08:00
parent a13ec103c5
commit 2bee0007af
6 changed files with 256 additions and 27 deletions

View File

@ -37,5 +37,7 @@ create table if not exists chart
chart_type varchar(128) null comment '图表类型',
gen_chart text null comment '生成的图表数据',
gen_result text null comment '生成的分析结论',
status varchar(128) not null default 'wait' comment '任务状态,取值wait、running、succeed、failed',
exec_message text null comment '执行信息',
index idx_userId (user_id)
) comment '图表信息表' collate = utf8mb4_unicode_ci;

View File

@ -0,0 +1,44 @@
/*
* @(#)ThreeadPoolExecutorConfig.java
*
* Copyright © 2023 YunPeng Corporation.
*/
package top.peng.answerbi.config;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.jetbrains.annotations.NotNull;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
/**
* ThreadPoolExecutorConfig 线程池配置
*
* @author yunpeng
* @version 1.0 2023/7/21
*/
@Configuration
public class ThreadPoolExecutorConfig {
@Bean
public ThreadPoolExecutor threadPoolExecutor(){
//创建一个线程工厂
ThreadFactory threadFactory = new ThreadFactory() {
//初始化线程数为1
private int count = 1;
@Override
public Thread newThread(@NotNull Runnable r) {
Thread thread = new Thread(r);
thread.setName("线程" + count);
count++;
return thread;
}
};
//创建一个线程池核心大小为2最大线程数为4, 非核心线程空闲100秒即被释放
//任务队列为阻塞队列长度为4使用上方创建的线程工厂 threadFactory 来创建线程
return new ThreadPoolExecutor(2,4,100,
TimeUnit.SECONDS,new ArrayBlockingQueue<>(4),threadFactory);
}
}

View File

@ -1,13 +1,25 @@
package top.peng.answerbi.controller;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.google.gson.Gson;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestPart;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
import top.peng.answerbi.annotation.AuthCheck;
import top.peng.answerbi.annotation.GuavaRateLimiter;
import top.peng.answerbi.annotation.RedissonRateLimiter;
import top.peng.answerbi.common.CommonResponse;
import top.peng.answerbi.common.DeleteRequest;
@ -25,18 +37,10 @@ import top.peng.answerbi.model.dto.chart.ChartUpdateRequest;
import top.peng.answerbi.model.dto.chart.GenChartByAiRequest;
import top.peng.answerbi.model.entity.Chart;
import top.peng.answerbi.model.entity.User;
import top.peng.answerbi.model.enums.BiTaskStatusEnum;
import top.peng.answerbi.model.vo.BiResponse;
import top.peng.answerbi.service.ChartService;
import top.peng.answerbi.service.UserService;
import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import top.peng.answerbi.utils.ExcelUtils;
import top.peng.answerbi.utils.ValidUtils;
@ -60,7 +64,8 @@ public class ChartController {
@Resource
private AiManager aiManager;
private final static Gson GSON = new Gson();
@Resource
private ThreadPoolExecutor threadPoolExecutor;
// region 增删改查
@ -224,7 +229,7 @@ public class ChartController {
}
/**
* 智能分析
* 智能分析 (同步)
*
* @param multipartFile
* @param genChartByAiRequest
@ -235,6 +240,87 @@ public class ChartController {
@RedissonRateLimiter(qps = 1)
public CommonResponse<BiResponse> genChartByAi(@RequestPart("file") MultipartFile multipartFile,
GenChartByAiRequest genChartByAiRequest, HttpServletRequest request) {
//生成参数
Map<String, Object> aiMsgAndChartData = genAiMsgAndChartData(genChartByAiRequest, multipartFile, request);
String userInput = (String) aiMsgAndChartData.get("userInput");
Chart chart = (Chart) aiMsgAndChartData.get("chartObj");
//调用AI
String aiResult = aiManager.doChat(BiConstant.BI_MODEL_ID, userInput);
BiResponse biResponse = aiManager.aiAnsToBiResp(aiResult);
//插入数据库
BeanUtils.copyProperties(biResponse,chart);
chart.setStatus(BiTaskStatusEnum.SUCCEED.getValue());
boolean saveResult = chartService.save(chart);
ThrowUtils.throwIf(!saveResult, ErrorCode.SYSTEM_ERROR, "图表保存失败");
biResponse.setChartId(chart.getId());
return ResultUtils.success(biResponse);
}
/**
* 智能分析 (异步)
*
* @param multipartFile
* @param genChartByAiRequest
* @param request
* @return
*/
@PostMapping("/gen/async")
@RedissonRateLimiter(qps = 1)
public CommonResponse<BiResponse> genChartByAiAsync(@RequestPart("file") MultipartFile multipartFile,
GenChartByAiRequest genChartByAiRequest, HttpServletRequest request) {
//生成参数
Map<String, Object> aiMsgAndChartData = genAiMsgAndChartData(genChartByAiRequest, multipartFile, request);
String userInput = (String) aiMsgAndChartData.get("userInput");
Chart chart = (Chart) aiMsgAndChartData.get("chartObj");
//先插入数据库, 状态为排队中
chart.setStatus(BiTaskStatusEnum.WAIT.getValue());
boolean saveResult = chartService.save(chart);
ThrowUtils.throwIf(!saveResult, ErrorCode.SYSTEM_ERROR, "图表保存失败");
AtomicReference<BiResponse> ResBiResponse = new AtomicReference<>(new BiResponse());
//创建线程任务
CompletableFuture.runAsync(() -> {
//先修改图表任务状态为执行中;
handleChartStatusUpdate(chart.getId(),BiTaskStatusEnum.RUNNING.getValue(), "");
//调用AI
String aiResult = aiManager.doChat(BiConstant.BI_MODEL_ID, userInput);
BiResponse biResponse;
try {
biResponse = aiManager.aiAnsToBiResp(aiResult);
} catch (BusinessException e) {
//执行失败状态修改为失败,记录任务失败信息
handleChartStatusUpdate(chart.getId(),BiTaskStatusEnum.FAILED.getValue(), e.getMessage());
throw e;
}
//执行成功后修改为已完成保存执行结果
biResponse.setChartId(chart.getId());
handleChartSuccessStatusUpdate(biResponse);
ResBiResponse.set(biResponse);
}, threadPoolExecutor);
return ResultUtils.success(ResBiResponse.get());
}
/**
* 根据用户输入构建 要发送给AI的消息 要存入数据库的 Chart 对象
*
* @param genChartByAiRequest
* @param multipartFile
* @param request
* @return
*/
private Map<String,Object> genAiMsgAndChartData(GenChartByAiRequest genChartByAiRequest,MultipartFile multipartFile, HttpServletRequest request){
//通过request对象拿到用户id(必须登录才能使用)
User loginUser = userService.getLoginUser(request);
String chartName = genChartByAiRequest.getChartName();
String goal = genChartByAiRequest.getGoal();
String chartType = genChartByAiRequest.getChartType();
@ -244,11 +330,7 @@ public class ChartController {
ThrowUtils.throwIf(StringUtils.isBlank(goal),ErrorCode.PARAMS_ERROR,"分析目标为空");
//如果名称不为空并且名称长度大于100就抛出异常并给出提示
ThrowUtils.throwIf(StringUtils.isNotBlank(chartName) && chartName.length() > 100,ErrorCode.PARAMS_ERROR,"图表名称过长");
ValidUtils.validFile(multipartFile, 1, Arrays.asList("jpeg", "jpg", "svg", "png", "webp","xls","xlsx"));
//通过request对象拿到用户id(必须登录才能使用)
User loginUser = userService.getLoginUser(request);
ValidUtils.validFile(multipartFile, 1, Arrays.asList("xls","xlsx"));
//用户输入
StringBuilder userInput = new StringBuilder();
@ -263,21 +345,41 @@ public class ChartController {
//压缩后的数据
String csvData = ExcelUtils.excelToCsv(multipartFile);
userInput.append(csvData).append("\n");
String aiResult = aiManager.doChat(BiConstant.BI_MODEL_ID, userInput.toString());
BiResponse biResponse = aiManager.aiAnsToBiResp(aiResult);
//插入数据库
Chart chart = new Chart();
BeanUtils.copyProperties(biResponse,chart);
chart.setChartName(chartName);
chart.setGoal(goal);
chart.setChartType(chartType);
chart.setChartData(csvData);
chart.setUserId(loginUser.getId());
boolean saveResult = chartService.save(chart);
ThrowUtils.throwIf(!saveResult, ErrorCode.SYSTEM_ERROR, "图表保存失败");
biResponse.setChartId(chart.getId());
return ResultUtils.success(biResponse);
Map<String, Object> result = new HashMap<>();
result.put("userInput",userInput.toString());
result.put("chartObj", chart);
return result;
}
private void handleChartStatusUpdate(long chartId,String status,String execMessage){
Chart updateChart = new Chart();
updateChart.setId(chartId);
updateChart.setStatus(status);
updateChart.setExecMessage(execMessage);
boolean updateResult = chartService.updateById(updateChart);
if (!updateResult){
log.error("更新图表[{}]状态失败", chartId);
}
}
private void handleChartSuccessStatusUpdate(BiResponse biResponse){
Chart updateChart = new Chart();
updateChart.setId(biResponse.getChartId());
updateChart.setStatus(BiTaskStatusEnum.SUCCEED.getValue());
updateChart.setGenChart(biResponse.getGenChart());
updateChart.setGenResult(biResponse.getGenResult());
boolean updateResult = chartService.updateById(updateChart);
if (!updateResult){
log.error("更新图表[{}]结果失败", biResponse.getChartId());
}
}
}

View File

@ -13,6 +13,7 @@ import javax.annotation.Resource;
import org.springframework.stereotype.Service;
import top.peng.answerbi.common.ErrorCode;
import top.peng.answerbi.constant.BiConstant;
import top.peng.answerbi.exception.BusinessException;
import top.peng.answerbi.exception.ThrowUtils;
import top.peng.answerbi.model.vo.BiResponse;
@ -52,7 +53,7 @@ public class AiManager {
* @param aiAnswer AI 对话 结果
* @return BiResponse对象
*/
public BiResponse aiAnsToBiResp(String aiAnswer){
public BiResponse aiAnsToBiResp(String aiAnswer) throws BusinessException {
String[] aiResultSplit = aiAnswer.split(BiConstant.BI_RESULT_SEPARATOR);
ThrowUtils.throwIf(aiResultSplit.length < 3,ErrorCode.SYSTEM_ERROR,"AI 生成错误");
BiResponse biResponse = new BiResponse();

View File

@ -73,6 +73,16 @@ public class Chart implements Serializable {
*/
private String genResult;
/**
* 任务状态
*/
private String status;
/**
* 执行信息
*/
private String execMessage;
@TableField(exist = false)
private static final long serialVersionUID = 1L;
}

View File

@ -0,0 +1,70 @@
/*
* @(#)BiTaskStatus.java
*
* Copyright © 2023 YunPeng Corporation.
*/
package top.peng.answerbi.model.enums;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ObjectUtils;
/**
* BiTaskStatus 智能分析任务状态枚举
*
* @author yunpeng
* @version 1.0 2023/7/21
*/
public enum BiTaskStatusEnum {
WAIT("排队中","wait"),
RUNNING("生成中","running"),
SUCCEED("成功","succeed"),
FAILED("失败","failed"),
;
private final String text;
private final String value;
BiTaskStatusEnum(String text, String value) {
this.text = text;
this.value = value;
}
/**
* 获取值列表
*
* @return
*/
public static List<String> getValues() {
return Arrays.stream(values()).map(item -> item.value).collect(Collectors.toList());
}
/**
* 根据 value 获取枚举
*
* @param value
* @return
*/
public static BiTaskStatusEnum getEnumByValue(String value) {
if (ObjectUtils.isEmpty(value)) {
return null;
}
for (BiTaskStatusEnum anEnum : BiTaskStatusEnum.values()) {
if (anEnum.value.equals(value)) {
return anEnum;
}
}
return null;
}
public String getValue() {
return value;
}
public String getText() {
return text;
}
}