首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >Agent Development Kit(ADK)通过java自定义封装DeepSeek模型

Agent Development Kit(ADK)通过java自定义封装DeepSeek模型

原创
作者头像
礼兴
修改2025-09-24 11:00:36
修改2025-09-24 11:00:36
3020
举报
文章被收录于专栏:个人总结系列个人总结系列

一、ADK版本

代码语言:xml
复制
<!-- The ADK Core dependency -->
<dependency>
    <groupId>com.google.adk</groupId>
    <artifactId>google-adk</artifactId>
    <version>0.2.0</version>
</dependency>
<!-- The ADK Dev Web UI to debug your agent (Optional) -->
<dependency>
    <groupId>com.google.adk</groupId>
    <artifactId>google-adk-dev</artifactId>
    <version>0.2.0</version>
</dependency>   

二、自定义封装DeepSeek的Model

由于ADK自带的LlmAgent只是支持gemini 与 claude,所以类似deepseek的使用需要基于BaseLlm自定义封装

代码语言:java
复制
import com.google.adk.models.BaseLlm;
import com.google.adk.models.BaseLlmConnection;
import com.google.adk.models.LlmRequest;
import com.google.adk.models.LlmResponse;
import com.google.common.collect.ImmutableList;
import com.google.genai.types.Blob;
import com.google.genai.types.Content;
import com.google.genai.types.GenerateContentConfig;
import com.google.genai.types.Part;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.Completable;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.messages.AssistantMessage;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

@Slf4j
public class DeepSeekModel extends BaseLlm {
    private final static String API_URL = "https://xxx";
    private final static String API_KEY = "sk_xxx";
    private final static String MODEL_NAME = "deepseek/deepseek-v3.1";

    private OpenAiChatModel chatModel;

    private DeepSeekModel(String modelName, OpenAiChatModel chatModel) {
        super(modelName);
        this.chatModel = chatModel;
    }

    /**
     * 创建 DeepSeek 模型实例
     */
    public static DeepSeekModel create() {
        return create(MODEL_NAME);
    }

    /**
     * 创建指定模型名称的 DeepSeek 实例
     */
    public static DeepSeekModel create(String modelName) {
        // 创建自定义的 OpenAI API 配置,指向 DeepSeek 的 API
        OpenAiApi openAiApi =OpenAiApi.builder().apiKey(API_KEY).baseUrl(API_URL).build();

        // 创建 OpenAI Chat 模型配置
        OpenAiChatOptions chatOptions = OpenAiChatOptions.builder()
                .model(modelName)
                .maxTokens(8192)
                .temperature(0.7)
                .build();

        // 创建 OpenAI Chat 模型实例
        OpenAiChatModel chatModel =OpenAiChatModel.builder().defaultOptions(chatOptions).openAiApi(openAiApi).build();

        return new DeepSeekModel(modelName, chatModel);
    }

    @Override
    public Flowable<LlmResponse> generateContent(LlmRequest llmRequest, boolean stream) {
        try {
            log.debug("DeepSeek generateContent called with stream: {}", stream);

            // 构建消息列表
            List<Message> messages = new ArrayList<>();

            // 处理系统指令
            String systemText = extractSystemInstruction(llmRequest);
            if (!systemText.isEmpty()) {
                messages.add(new SystemMessage(systemText));
            }

            // 转换内容为 Spring AI 消息格式
            for (Content content : llmRequest.contents()) {
                Message message = contentToSpringAiMessage(content);
                if (message != null) {
                    messages.add(message);
                }
            }

            // 创建 Prompt
            Prompt prompt = new Prompt(messages);
            log.debug("创建 Prompt,消息数量: {}", messages.size());

            // 调用 OpenAI Chat 模型
            ChatResponse response = chatModel.call(prompt);

            if (response == null) {
                log.error("ChatModel 返回 null 响应");
                return Flowable.error(new RuntimeException("ChatModel 返回 null 响应"));
            }

            log.debug("DeepSeek response received: {}", response.getClass().getSimpleName());

            // 转换响应为 LlmResponse
            LlmResponse llmResponse = convertSpringAiResponseToLlmResponse(response);

            if (llmResponse == null) {
                log.error("转换后的 LlmResponse 为 null");
                return Flowable.error(new RuntimeException("转换响应失败"));
            }

            return Flowable.just(llmResponse);

        } catch (Exception e) {
            log.error("Error in DeepSeek generateContent: ", e);
            return Flowable.error(e);
        }
    }

    /**
     * 提取系统指令
     */
    private String extractSystemInstruction(LlmRequest llmRequest) {
        Optional<GenerateContentConfig> configOpt = llmRequest.config();
        if (configOpt.isPresent()) {
            Optional<Content> systemInstructionOpt = configOpt.get().systemInstruction();
            if (systemInstructionOpt.isPresent()) {
                return systemInstructionOpt.get().parts().orElse(ImmutableList.of()).stream()
                        .filter(p -> p.text().isPresent())
                        .map(p -> p.text().get())
                        .collect(Collectors.joining("\n"));
            }
        }
        return "";
    }

    /**
     * 转换 ADK Content 为 Spring AI Message
     */
    private Message contentToSpringAiMessage(Content content) {
        String role = content.role().orElse("user");
        String text = content.parts().orElse(ImmutableList.of()).stream()
                .filter(p -> p.text().isPresent())
                .map(p -> p.text().get())
                .collect(Collectors.joining("\n"));

        if (text.isEmpty()) {
            return null;
        }

        switch (role.toLowerCase()) {
            case "user":
                return new UserMessage(text);
            case "model":
            case "assistant":
                return new AssistantMessage(text);
            case "system":
                return new SystemMessage(text);
            default:
                return new UserMessage(text);
        }
    }

    /**
     * 转换 Spring AI ChatResponse 为 ADK LlmResponse
     */
    private LlmResponse convertSpringAiResponseToLlmResponse(ChatResponse chatResponse) {
        LlmResponse.Builder responseBuilder = LlmResponse.builder();

        try {
            // 安全地提取响应内容
            String content = extractContentSafely(chatResponse);

            if (content != null && !content.trim().isEmpty()) {
                Part part = Part.builder()
                        .text(content)
                        .build();

                Content responseContent = Content.builder()
                        .role("model")
                        .parts(ImmutableList.of(part))
                        .build();

                responseBuilder.content(responseContent);
                log.debug("成功转换响应内容: {}", content.substring(0, Math.min(content.length(), 100)));
            } else {
                log.warn("响应内容为空或无效");
                // 创建一个默认的错误响应
                Part errorPart = Part.builder()
                        .text("抱歉,没有收到有效的响应内容。")
                        .build();

                Content errorContent = Content.builder()
                        .role("model")
                        .parts(ImmutableList.of(errorPart))
                        .build();

                responseBuilder.content(errorContent);
            }
        } catch (Exception e) {
            log.error("转换响应时出错: {}", e.getMessage(), e);

            // 创建错误响应
            Part errorPart = Part.builder()
                    .text("抱歉,处理响应时出现错误:" + e.getMessage())
                    .build();

            Content errorContent = Content.builder()
                    .role("model")
                    .parts(ImmutableList.of(errorPart))
                    .build();

            responseBuilder.content(errorContent);
        }

        return responseBuilder.build();
    }

    /**
     * 安全地从 ChatResponse 中提取内容
     */
    private String extractContentSafely(ChatResponse chatResponse) {
        if (chatResponse == null) {
            log.warn("ChatResponse 为 null");
            return null;
        }

        if (chatResponse.getResult() == null) {
            log.warn("ChatResponse.getResult() 为 null");
            return null;
        }

        if (chatResponse.getResult().getOutput() == null) {
            log.warn("ChatResponse.getResult().getOutput() 为 null");
            return null;
        }

        String content = chatResponse.getResult().getOutput().getText();
        if (content == null) {
            log.warn("ChatResponse 内容为 null");
            return null;
        }

        return content;
    }

    @Override
    public BaseLlmConnection connect(LlmRequest llmRequest) {
        log.info("创建 DeepSeek 模拟连接对象");

        return new BaseLlmConnection() {
            private boolean connected = true;
            private final List<Content> conversationHistory = new ArrayList<>();

            @Override
            public Completable sendHistory(List<Content> history) {
                return Completable.fromAction(() -> {
                    log.debug("DeepSeek 连接接收历史记录,数量: {}", history.size());
                    conversationHistory.clear();
                    conversationHistory.addAll(history);
                });
            }

            @Override
            public Completable sendContent(Content content) {
                return Completable.fromAction(() -> {
                    log.debug("DeepSeek 连接发送内容: {}", content);
                    conversationHistory.add(content);
                });
            }

            @Override
            public Completable sendRealtime(Blob blob) {
                return Completable.fromAction(() -> {
                    log.warn("DeepSeek 模型不支持实时 Blob 发送,忽略请求");
                    // DeepSeek 模型不支持实时 Blob,这里只是记录日志
                });
            }

            @Override
            public Flowable<LlmResponse> receive() {
                return Flowable.defer(() -> {
                    if (!connected) {
                        return Flowable.error(new IllegalStateException("连接已关闭"));
                    }

                    if (conversationHistory.isEmpty()) {
                        log.warn("没有内容可处理,返回空响应");
                        return Flowable.empty();
                    }

                    // 创建包含完整对话历史的 LlmRequest
                    LlmRequest request = LlmRequest.builder()
                            .contents(new ArrayList<>(conversationHistory))
                            .build();

                    // 委托给主要的 generateContent 方法
                    return DeepSeekModel.this.generateContent(request, false);
                });
            }

            @Override
            public void close() {
                connected = false;
                conversationHistory.clear();
                log.debug("DeepSeek 模拟连接已关闭");
            }

            @Override
            public void close(Throwable throwable) {
                connected = false;
                conversationHistory.clear();
                log.error("DeepSeek 模拟连接因错误关闭: {}", throwable.getMessage(), throwable);
            }
        };
    }
}
代码语言:java
复制
import lombok.extern.slf4j.Slf4j;

/**
 * DeepSeek 模型注册器
 * 管理 DeepSeek 模型的初始化和配置
 */
@Slf4j
public class DeepSeekModelRegistry {

    private static boolean initialized = false;
    private static DeepSeekModel deepSeekModel;

    /**
     * 初始化 DeepSeek 模型
     */
    public static synchronized void registerDeepSeekModel() {
        if (initialized) {
            log.debug("DeepSeek 模型已经初始化过了");
            return;
        }

        try {
            // 创建 DeepSeek 模型实例
            deepSeekModel = DeepSeekModel.create();

            initialized = true;
            log.info("✅ DeepSeek 模型初始化成功");

        } catch (Exception e) {
            log.error("❌ DeepSeek 模型初始化失败: {}", e.getMessage(), e);
            throw new RuntimeException("Failed to initialize DeepSeek model", e);
        }
    }

    /**
     * 检查模型是否已初始化
     */
    public static boolean isRegistered() {
        return initialized;
    }

    /**
     * 获取 DeepSeek 模型实例
     */
    public static DeepSeekModel getModel() {
        if (!initialized) {
            registerDeepSeekModel();
        }
        return deepSeekModel;
    }
}

三、Agent创建应用

代码语言:java
复制
import com.adk.models.DeepSeekModel;
import com.adk.models.DeepSeekModelRegistry;

import com.google.adk.agents.BaseAgent;
import com.google.adk.agents.LlmAgent;
import com.google.adk.events.Event;
import com.google.adk.runner.InMemoryRunner;
import com.google.adk.sessions.Session;
import com.google.adk.tools.Annotations.Schema;
import com.google.adk.tools.FunctionTool;
import com.google.genai.types.Content;
import com.google.genai.types.Part;
import io.reactivex.rxjava3.core.Flowable;
import lombok.extern.slf4j.Slf4j;

import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;

@Slf4j
public class MultiToolAgentMain {
    private static String USER_ID = "examples";
    private static String NAME = "multi_tool_agent";

    // The run your agent with Dev UI, the ROOT_AGENT should be a global public static final variable.
    public static final BaseAgent ROOT_AGENT = initAgent();

    public static BaseAgent initAgent() {
        DeepSeekModelRegistry.registerDeepSeekModel();
        DeepSeekModel deepSeekModel = DeepSeekModelRegistry.getModel();

        return LlmAgent.builder()
                .name(NAME)
                .model(deepSeekModel)
                .description("获取国家的首都或者指定城市的天气助手")
                .instruction(
                        "帮助获取指定国家的首都或者指定城市的天气")
                .tools(
//                        FunctionTool.create(MultiToolAgentMain.class, "getCapitalCity"),
                        FunctionTool.create(MultiToolAgentMain.class, "getWeather"))
                .build();
    }

    public static Map<String, Object> getCapitalCity(
            @Schema(name = "国家", description = "获取国家的首都") String country) {
        System.out.printf("%n-- Tool Call: getCapitalCity(country='%s') --%n", country);

        Map<String, String> countryCapitals = new HashMap<>();
        countryCapitals.put("united states", "Washington, D.C.");
        countryCapitals.put("canada", "Ottawa");
        countryCapitals.put("france", "Paris");
        countryCapitals.put("japan", "Tokyo");

        String result =
                countryCapitals.getOrDefault(
                        country.toLowerCase(), "对不齐, 我不能找到该 " + country + "的首都。");
        System.out.printf("-- Tool Result: '%s' --%n", result);
        return Map.of("result", result); // Tools must return a Map
    }

    public static Map<String, String> getWeather(
            @Schema(name = "城市",
                    description = "获取指定城市的天气")
            String city) {
        System.out.printf("%n-- Tool Call: getWeather(city='%s') --%n", city);
        log.info("%n-- Tool Call: getWeather(city='%s') --%n", city);
        if (city.toLowerCase().equals("new york")) {
            return Map.of(
                    "状态",
                    "success",
                    "结果",
                    "New York的天气晴朗、温度25、湿度77。");

        } else {
            return Map.of(
                    "状态", "error", "结果", "该 " + city + "城市的天气没有查询获取到。");
        }
    }

    public static void main(String[] args) throws Exception {
        // 🎯 方式1: 使用带颜色的系统输出
        printSystemInfo("🚀 启动 DeepSeek Multi-Tool Agent");
        printSystemInfo("📋 初始化会话...");

        InMemoryRunner runner = new InMemoryRunner(ROOT_AGENT);

        Session session =
                runner
                        .sessionService()
                        .createSession(NAME, USER_ID)
                        .blockingGet();

        printSystemSuccess("✅ 会话创建成功: " + session.id());
        printSystemInfo("💡 输入 'quit' 退出,输入 'help' 查看帮助");
        printSeparator();

        try (Scanner scanner = new Scanner(System.in, StandardCharsets.UTF_8)) {
            while (true) {
                // 🎯 方式2: 使用格式化的提示符
                printPrompt("You");
                String userInput = scanner.nextLine();

                if ("quit".equalsIgnoreCase(userInput)) {
                    printSystemInfo("👋 再见!");
                    break;
                }

                if ("help".equalsIgnoreCase(userInput)) {
                    printHelp();
                    continue;
                }

                // 🎯 方式3: 显示处理状态
                printSystemInfo("🤔 正在思考...");
                log.info("用户输入: {}", userInput);

                Content userMsg = Content.fromParts(Part.fromText(userInput));
                Flowable<Event> events = runner.runAsync(USER_ID, session.id(), userMsg);

                printPrompt("Agent");

                // 🎯 方式4: 处理事件时显示详细信息
                events.blockingForEach(event -> {
                    String content = event.stringifyContent();
                    if (content != null && !content.trim().isEmpty()) {
                        System.out.println(content);
                        log.debug("Agent 响应: {}", content);
                    }
                });

                printSeparator();
            }
        }
    }

    // 🎯 辅助方法:不同类型的输出

    /**
     * 打印系统信息(蓝色)
     */
    private static void printSystemInfo(String message) {
        System.out.println("\u001B[34m[SYSTEM] " + message + "\u001B[0m");
        log.info(message);
    }

    /**
     * 打印成功信息(绿色)
     */
    private static void printSystemSuccess(String message) {
        System.out.println("\u001B[32m[SUCCESS] " + message + "\u001B[0m");
        log.info(message);
    }

    /**
     * 打印错误信息(红色)
     */
    private static void printSystemError(String message) {
        System.out.println("\u001B[31m[ERROR] " + message + "\u001B[0m");
        log.error(message);
    }

    /**
     * 打印警告信息(黄色)
     */
    private static void printSystemWarning(String message) {
        System.out.println("\u001B[33m[WARNING] " + message + "\u001B[0m");
        log.warn(message);
    }

    /**
     * 打印提示符
     */
    private static void printPrompt(String role) {
        System.out.print("\n\u001B[36m" + role + " > \u001B[0m");
    }

    /**
     * 打印分隔线
     */
    private static void printSeparator() {
        System.out.println("\u001B[90m" + "─".repeat(50) + "\u001B[0m");
    }

    /**
     * 打印帮助信息
     */
    private static void printHelp() {
        System.out.println("\n\u001B[35m📖 帮助信息:\u001B[0m");
        System.out.println("  • 询问时间: '北京现在几点?'");
        System.out.println("  • 询问天气: 'New York 的天气怎么样?'");
        System.out.println("  • 退出程序: 'quit'");
        System.out.println("  • 显示帮助: 'help'");
    }
}

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、ADK版本
  • 二、自定义封装DeepSeek的Model
    • 三、Agent创建应用
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档