SpringBoot项目接入讯飞星火大模型Api

马肤
这是懒羊羊

SpringBoot项目接入讯飞星火大模型Api

  1. 说明:此代码使用 webSocket 连接 ai 大模型,前端页面使用 websocket 连接后台服务端,

  2. 自行注册讯飞星火大模型平台,申请tokens。

  3. 代码:

    1. maven依赖文件:

              
                  org.springframework.boot
                  spring-boot-starter-web
              
              
                  org.springframework.boot
                  spring-boot-starter-test
                  test
              
              
              
                  com.alibaba
                  fastjson
                  1.2.67
              
              
              
                  com.google.code.gson
                  gson
                  2.8.5
              
              
      
      
      
      
      
              
                  org.springframework.boot
                  spring-boot-starter-websocket
              
              
              
                  com.squareup.okhttp3
                  okhttp
                  4.10.0
              
              
              
                  com.squareup.okio
                  okio
                  2.10.0
              
              
                  org.springframework.boot
                  spring-boot
                  2.7.8
              
              
                  org.projectlombok
                  lombok
              
              
              
                  com.alibaba
                  druid
                  1.1.20
              
              
              
                  cn.hutool
                  hutool-all
                  5.7.22
              
      
    2. java调用 ai 大模型代码:

      import com.alibaba.fastjson.JSON;
      import com.alibaba.fastjson.JSONArray;
      import com.alibaba.fastjson.JSONObject;
      import com.google.gson.Gson;
      import okhttp3.*;
      import javax.crypto.Mac;
      import javax.crypto.spec.SecretKeySpec;
      import java.io.IOException;
      import java.net.URL;
      import java.nio.charset.StandardCharsets;
      import java.text.SimpleDateFormat;
      import java.util.*;
      /**
       * @author hanyiming
       */
      public class BigModelNew extends WebSocketListener {
          // 地址与鉴权信息  https://spark-api.xf-yun.com/v1.1/chat   1.5地址  domain参数为general
          // 地址与鉴权信息  https://spark-api.xf-yun.com/v2.1/chat   2.0地址  domain参数为generalv2
          public static final String hostUrl = "https://spark-api.xf-yun.com/v3.5/chat";
          // 以下参数替换为自己的身份认证信息
          public static final String appid = "appid";
          public static final String apiSecret = "apiSecret";
          public static final String apiKey = "apiKey";
          public static List historyList = new ArrayList(); // 对话历史存储集合
          public static String totalAnswer = ""; // 大模型的答案汇总
          // 环境治理的重要性  环保  人口老龄化  我爱我的祖国
          public static String NewQuestion = "";
          public static final Gson gson = new Gson();
          // 个性化参数
          private String userId;
          private Boolean wsCloseFlag;
          private static Boolean totalFlag = true; // 控制提示用户是否输入
          // 构造函数
          public BigModelNew(String userId, Boolean wsCloseFlag) {
              this.userId = userId;
              this.wsCloseFlag = wsCloseFlag;
          }
          public static boolean canAddHistory() {  // 由于历史记录最大上线1.2W左右,需要判断是能能加入历史
              int history_length = 0;
              for (RoleContent temp : historyList) {
                  history_length = history_length + temp.content.length();
              }
              if (history_length > 12000) {
                  historyList.remove(0);
                  historyList.remove(1);
                  historyList.remove(2);
                  historyList.remove(3);
                  historyList.remove(4);
                  return false;
              } else {
                  return true;
              }
          }
          // 线程来发送音频与参数
          class MyThread extends Thread {
              private WebSocket webSocket;
              public MyThread(WebSocket webSocket) {
                  this.webSocket = webSocket;
              }
              public void run() {
                  try {
                      JSONObject requestJson = new JSONObject();
                      JSONObject header = new JSONObject();  // header参数
                      header.put("app_id", appid);
                      header.put("uid", UUID.randomUUID().toString().substring(0, 10));
                      JSONObject parameter = new JSONObject(); // parameter参数
                      JSONObject chat = new JSONObject();
                      chat.put("domain", "generalv2");
                      chat.put("temperature", 0.5);
                      chat.put("max_tokens", 4096);
                      parameter.put("chat", chat);
                      JSONObject payload = new JSONObject(); // payload参数
                      JSONObject message = new JSONObject();
                      JSONArray text = new JSONArray();
                      // 历史问题获取
                      if (historyList.size() > 0) {
                          for (RoleContent tempRoleContent : historyList) {
                              text.add(JSON.toJSON(tempRoleContent));
                          }
                      }
                      // 最新问题
                      RoleContent roleContent = new RoleContent();
                      roleContent.role = "user";
                      roleContent.content = NewQuestion;
                      text.add(JSON.toJSON(roleContent));
                      historyList.add(roleContent);
                      message.put("text", text);
                      payload.put("message", message);
                      requestJson.put("header", header);
                      requestJson.put("parameter", parameter);
                      requestJson.put("payload", payload);
                      // System.err.println(requestJson); // 可以打印看每次的传参明细
                      webSocket.send(requestJson.toString());
                      // 等待服务端返回完毕后关闭
                      while (true) {
                          // System.err.println(wsCloseFlag + "---");
                          Thread.sleep(200);
                          if (wsCloseFlag) {
                              break;
                          }
                      }
                      webSocket.close(1000, "");
                  } catch (Exception e) {
                      e.printStackTrace();
                  }
              }
          }
          @Override
          public void onOpen(WebSocket webSocket, Response response) {
              super.onOpen(webSocket, response);
              System.out.print("大模型:");
              MyThread myThread = new MyThread(webSocket);
              myThread.start();
          }
          @Override
          public void onMessage(WebSocket webSocket, String text) {
              // System.out.println(userId + "用来区分那个用户的结果" + text);
              JsonParse myJsonParse = gson.fromJson(text, JsonParse.class);
              if (myJsonParse.header.code != 0) {
                  System.out.println("发生错误,错误码为:" + myJsonParse.header.code);
                  System.out.println("本次请求的sid为:" + myJsonParse.header.sid);
                  webSocket.close(1000, "");
              }
              List textList = myJsonParse.payload.choices.text;
              for (Text temp : textList) {
                  // 在此处给前端页面发送回答信息,如有存储问答需求,请在此处存储回答信息
                  WebSocketClient.sendInfo(temp.content);
                  System.out.print(temp.content);
                  totalAnswer = totalAnswer + temp.content;
              }
              if (myJsonParse.header.status == 2) {
                  // 可以关闭连接,释放资源
                  System.out.println();
                  System.out.println("*************************************************************************************");
                  if (canAddHistory()) {
                      RoleContent roleContent = new RoleContent();
                      roleContent.setRole("assistant");
                      roleContent.setContent(totalAnswer);
                      historyList.add(roleContent);
                  } else {
                      historyList.remove(0);
                      RoleContent roleContent = new RoleContent();
                      roleContent.setRole("assistant");
                      roleContent.setContent(totalAnswer);
                      historyList.add(roleContent);
                  }
                  wsCloseFlag = true;
                  totalFlag = true;
              }
          }
          @Override
          public void onFailure(WebSocket webSocket, Throwable t, Response response) {
              super.onFailure(webSocket, t, response);
              try {
                  if (null != response) {
                      int code = response.code();
                      System.out.println("onFailure code:" + code);
                      System.out.println("onFailure body:" + response.body().string());
                      if (101 != code) {
                          System.out.println("connection failed");
                          System.exit(0);
                      }
                  }
              } catch (IOException e) {
                  // TODO Auto-generated catch block
                  e.printStackTrace();
              }
          }
          // 鉴权方法
          public static String getAuthUrl(String hostUrl, String apiKey, String apiSecret) throws Exception {
              URL url = new URL(hostUrl);
              // 时间
              SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
              format.setTimeZone(TimeZone.getTimeZone("GMT"));
              String date = format.format(new Date());
              // 拼接
              String preStr = "host: " + url.getHost() + "\n" + "date: " + date + "\n" + "GET " + url.getPath() + " HTTP/1.1";
              // System.err.println(preStr);
              // SHA256加密
              Mac mac = Mac.getInstance("hmacsha256");
              SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256");
              mac.init(spec);
              byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8));
              // Base64加密
              String sha = Base64.getEncoder().encodeToString(hexDigits);
              // System.err.println(sha);
              // 拼接
              String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha);
              // 拼接地址
              HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath())).newBuilder().//
                      addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8))).//
                      addQueryParameter("date", date).//
                      addQueryParameter("host", url.getHost()).//
                      build();
              // System.err.println(httpUrl.toString());
              return httpUrl.toString();
          }
          //返回的json结果拆解
          class JsonParse {
              Header header;
              Payload payload;
          }
          class Header {
              int code;
              int status;
              String sid;
          }
          class Payload {
              Choices choices;
          }
          class Choices {
              List text;
          }
          class Text {
              String role;
              String content;
          }
          class RoleContent {
              String role;
              String content;
              public String getRole() {
                  return role;
              }
              public void setRole(String role) {
                  this.role = role;
              }
              public String getContent() {
                  return content;
              }
              public void setContent(String content) {
                  this.content = content;
              }
          }
      }
      
    3. 编写给前端页面使用的 websocket 连接接口

      import cn.hutool.core.util.StrUtil;
      import com.alibaba.druid.util.StringUtils;
      import lombok.extern.slf4j.Slf4j;
      import okhttp3.OkHttpClient;
      import okhttp3.Request;
      import okhttp3.WebSocket;
      import org.springframework.stereotype.Component;
      import javax.websocket.*;
      import javax.websocket.server.PathParam;
      import javax.websocket.server.ServerEndpoint;
      import java.io.IOException;
      import java.util.concurrent.ConcurrentHashMap;
      /**
       *
       * @author HanYiMing
       * @date 2024/3/1
       * @description websocket配置类
       */
      @ServerEndpoint(value = "/websocketClient/{userId}")
      @Component
      @Slf4j
      public class WebSocketClient {
          /**
           * 静态变量,用来记录当前在线连接数。应该把它设计成线程安全的
           */
          private static int onlineCount = 0;
          /**
           * concurrent包的线程安全Map,用来存放每个客户端对应的MyWebSocket对象
           */
          private static final ConcurrentHashMap webSocketMap = new ConcurrentHashMap();
          public ConcurrentHashMap getWebSocketMap() {
              return webSocketMap;
          }
          /**
           * 与某个客户端的连接会话,需要通过它来给客户端发送数据
           */
          private Session session;
          /**
           * 用户id 唯一标识
           */
          private String userId;
          /**
           * 连接建立成功调用的方法
           */
          @OnOpen
          public void onOpen(Session session, @PathParam("userId") String userId) {
              this.session = session;
              this.userId = userId;
              //加入map
              webSocketMap.put(userId, this);
              //在线数加1
              addOnlineCount();
              log.info("WebSocket客户端{}连接成功,客户端标识:{},当前在线人数:{}", session.getId(), userId, getOnlineCount());
              sendMessage("用户" + userId + "连接成功!");
          }
          /**
           * 连接关闭调用的方法
           */
          @OnClose
          public void onClose() {
              //从map中删除
              webSocketMap.remove(userId);
              //在线数减1
              subOnlineCount();
              log.info("WebSocket客户端{}连接断开,客户端标识:{},当前在线人数:{}", session.getId(), userId, getOnlineCount());
          }
          /**
           * 收到客户端消息后调用的方法
           *
           * @param message 客户端发送过来的消息
           */
          @OnMessage
          public void onMessage(String message, Session session) throws Exception {
              // 心跳检测响应
              if (StringUtils.equalsIgnoreCase("ping", message)) {
                  sendMessage("pong");
                  log.info("WebSocket服务端已回复客户端{}的心跳检测:pong", session.getId());
                  return;
              }
              BigModelNew.NewQuestion = message;
              // 构建鉴权url
              String authUrl = BigModelNew.getAuthUrl(BigModelNew.hostUrl, BigModelNew.apiKey, BigModelNew.apiSecret);
              OkHttpClient client = new OkHttpClient.Builder().build();
              String url = authUrl.toString().replace("http://", "ws://").replace("https://", "wss://");
              Request request = new Request.Builder().url(url).build();
              for (int i = 0; i  
    4. 使用 postman 进行测试:

      1. 创建一个websocket 连接测试案例

        SpringBoot项目接入讯飞星火大模型Api,image-20240301172900672,词库加载错误:未能找到文件“C:\Users\Administrator\Desktop\火车头9.8破解版\Configuration\Dict_Stopwords.txt”。,服务,li,进行,第1张

      2. 输入端口号,点击connect连接

        SpringBoot项目接入讯飞星火大模型Api,image-20240301172959923,词库加载错误:未能找到文件“C:\Users\Administrator\Desktop\火车头9.8破解版\Configuration\Dict_Stopwords.txt”。,服务,li,进行,第2张

      3. 返回连接成功信息:

        SpringBoot项目接入讯飞星火大模型Api,image-20240301173036059,词库加载错误:未能找到文件“C:\Users\Administrator\Desktop\火车头9.8破解版\Configuration\Dict_Stopwords.txt”。,服务,li,进行,第3张

      4. 发送文字,ai回答,测试成功。

        SpringBoot项目接入讯飞星火大模型Api,image-20240301173937156,词库加载错误:未能找到文件“C:\Users\Administrator\Desktop\火车头9.8破解版\Configuration\Dict_Stopwords.txt”。,服务,li,进行,第4张


文章版权声明:除非注明,否则均为VPS857原创文章,转载或复制请以超链接形式并注明出处。

发表评论

快捷回复:表情:
评论列表 (暂无评论,0人围观)

还没有评论,来说两句吧...

目录[+]

取消
微信二维码
微信二维码
支付宝二维码