Spring AI 1.1.0在 Tool 调用时,很难让开发者监听开始调用Tool和结束调用Tool。这篇文章就是为了解决该问题。
Spring AI 1.1.0工具调用监控:基于方法引用的最优雅强类型 Tool Callback 方案。
1. Spring AI 在 Tool 调用上的一个现实问题
Spring AI 提供了非常方便的工具调用机制:只要在 Bean 方法上加上@Tool注解,然后在对话时把这些 Bean 传给ChatClient,模型就可以自行决定何时调用这些工具。
典型的代码大概是这样:
Flux<String> flux = chatClient.prompt() .user("帮我查一下北京现在的天气,并顺便点评一下今天适合做什么运动") .tools(weatherTool) // WeatherTool 里有若干 @Tool 方法 .stream() .content();这个能力很好用,但有一个明显的缺口:
调用侧几乎感知不到工具的生命周期。
通常会想做这些事:
- 在工具调用前,告诉用户“正在调用某某工具”(例如天气服务、OCR 服务)。
- 在工具成功返回后,记录结果,或者推送一段辅助说明到前端。
- 在工具异常时,做降级、兜底,甚至区分是网络问题还是业务问题。
- 只对某几个关键工具做这些监控,而不是所有工具。
Spring AI 自带的ToolCallback更偏向底层能力,对上层业务来说缺少一个“强类型、按方法粒度订阅工具调用”的入口。
2. 核心思路:用方法引用绑定到具体工具方法
目标很简单:
调用侧用方法引用声明自己关心哪个工具方法,例如:
new ToolCallObserver<>(weatherTool::getCurrentWeather) { ... }在工具真正被调用时,能够精确知道:
- 调用的是哪一个
@Tool方法(Method对象)。 - 对应的目标对象是谁(具体的 Bean 实例)。
- 调用前、调用后、调用异常的节点,都能获得通知。
- 调用的是哪一个
最终生成一组ToolCallback[]交给 Spring AI 的ChatClient使用。
Spring AI 自身已经提供了从 Bean 到ToolCallback[]的能力(ToolCallbacks.from(...)),所以这里只需要在它外面加一层观察逻辑即可。
3. 原理:从方法引用还原 Method 和目标对象
关键在于“方法引用”这件事。
对于下面这种写法:
SerializableFunction<String, String> fn = weatherTool::getCurrentWeather;编译器会生成一个“可序列化的 lambda 类”,这个类:
- 实现了目标函数式接口(这里是
SerializableFunction)。 - 内部持有若干“捕获变量”,其中一个就是
weatherTool这个 Bean 实例。 - 实现了一个
writeReplace()方法,用于序列化时返回SerializedLambda。
通过反射调用这个writeReplace(),可以拿到java.lang.invoke.SerializedLambda对象,里面包含:
implClass:真实实现类,例如dev.w0fv1.ai_tool_callback.WeatherTool。implMethodName:方法名,例如getCurrentWeather。implMethodSignature:方法签名。capturedArgs[n]:捕获到的实例参数,例如第 0 个就是weatherTool实例。
这样就可以做到:
- 根据
implClass + implMethodName (+ implMethodSignature)还原为真实的Method对象。 - 根据
capturedArgs[0]拿到当时绑定的 Bean 实例。
接下来,只要在工具真正执行时,把“当次实际调用的Method”和“我们事先从方法引用解析出的Method”做一次对比,就能知道这次是不是要通知对应的观察者。
整个过程不依赖其他信息,也不会受到方法重命名或参数变化的影响——编译器会负责在方法引用处校验签名。
4. 配套代码
下面给出一个完整的实现组合:
SerializableFunction:为了让方法引用成为“可序列化 lambda”。ToolCallObserverRegistry:管理所有工具观察者,并生成 Spring AI 所需的ToolCallback[]。- 在上层业务中如何使用。
4.1 SerializableFunction:可序列化的函数式接口
你要监控的Tool的方法需要实现一个可序列化的函数接口,没有它我们无法从lambda中得到object和method。
比如下面的是一个传入一个参数,有返回的方法的函数式接口。
你可以提前写一个到多个传入参数的接口。
// 文件路径:src/main/java/dev/w0fv1/ai_tool_callback/SerializableFunction.java package dev.w0fv1.ai_tool_callback; import java.io.Serializable; import java.util.function.Function; /** * 可序列化的 Function,用于支持方法引用解析 Method。 */ @FunctionalInterface public interface SerializableFunction<T, R> extends Function<T, R>, Serializable { }以后有更多签名需求,可以按同样模式定义SerializableConsumer、SerializableBiFunction等。
4.2 ToolCallObserverRegistry:工具调用观察注册中心
直接复制到你的仓库里就行。
// 文件路径:src/main/java/dev/w0fv1/ai_tool_callback/ToolCallObserverRegistry.java package dev.w0fv1.ai_tool_callback; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.support.ToolCallbacks; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.metadata.ToolMetadata; import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import java.io.Serializable; import java.lang.invoke.SerializedLambda; import java.lang.reflect.Method; import java.util.*; /** * 专门管理 Tool 调用监听的注册中心: * 1. 保存所有 ToolCallObserver(强类型方法引用)。 * 2. 在工具执行前/后/异常时分发事件。 * 3. 根据方法引用推导出需要暴露给 Spring AI 的 ToolCallback[]。 */ @Slf4j public class ToolCallObserverRegistry { /** * 工具调用观察者抽象类: * - 持有一个方法引用 T methodRef(必须可序列化)。 * - 提供 before/after/error 钩子。 */ public static abstract class ToolCallObserver<T extends Serializable> implements Serializable { private final T methodRef; protected ToolCallObserver(T methodRef) { this.methodRef = methodRef; } public T methodRef() { return methodRef; } public void before(ToolDefinition toolDefinition, Object[] args) { } public void after(ToolDefinition toolDefinition, Object[] args, Object result) { } public void error(ToolDefinition toolDefinition, Object[] args, Throwable throwable) { } } private final List<ToolCallObserver<?>> observers = new ArrayList<>(); public void addObserver(ToolCallObserver<?> observer) { if (observer != null) { this.observers.add(observer); } } public List<ToolCallObserver<?>> getObservers() { return Collections.unmodifiableList(observers); } // ============ 对外暴露的通知入口(由内部包装 ToolCallback 调用) ============ public void notifyBefore(ToolDefinition toolDefinition, Object target, Method invokedMethod, Object[] args) { for (ToolCallObserver<?> observer : observers) { if (!matches(observer, target, invokedMethod)) { continue; } try { observer.before(toolDefinition, args); } catch (Exception e) { log.warn("ToolCallObserver before 执行异常, tool={}", toolDefinition.name(), e); } } } public void notifyAfter(ToolDefinition toolDefinition, Object target, Method invokedMethod, Object[] args, Object result) { for (ToolCallObserver<?> observer : observers) { if (!matches(observer, target, invokedMethod)) { continue; } try { observer.after(toolDefinition, args, result); } catch (Exception e) { log.warn("ToolCallObserver after 执行异常, tool={}", toolDefinition.name(), e); } } } public void notifyError(ToolDefinition toolDefinition, Object target, Method invokedMethod, Object[] args, Throwable throwable) { for (ToolCallObserver<?> observer : observers) { if (!matches(observer, target, invokedMethod)) { continue; } try { observer.error(toolDefinition, args, throwable); } catch (Exception e) { log.warn("ToolCallObserver error 执行异常, tool={}", toolDefinition.name(), e); } } } // ============ 构造 Spring AI 用的 ToolCallback[](只包含需要监控的 Tool) ============ /** * 基于当前注册的 ToolCallObserver(methodRef),推导出对应的工具实例, * 并构造一组可直接给 ChatClient 使用的 Spring AI ToolCallback[]。 * * 注意: * - 只包含“需要监控”的工具。 * - 不需要监控的工具,直接用 ToolCallbacks.from(otherBeans...) 单独生成再合并。 */ public ToolCallback[] buildSpringToolCallbacks() { if (observers.isEmpty()) { return new ToolCallback[0]; } // 1. 从所有 methodRef 中解析出绑定的工具实例对象 List<Object> toolBeans = new ArrayList<>(); for (ToolCallObserver<?> observer : observers) { Object fn = observer.methodRef(); if (fn == null) { continue; } Object target = resolveTargetFromLambda(fn); if (target == null) { continue; } boolean exists = false; for (Object bean : toolBeans) { if (bean == target) { exists = true; break; } } if (!exists) { toolBeans.add(target); } } if (toolBeans.isEmpty()) { return new ToolCallback[0]; } return buildObservedCallbacks(this, toolBeans.toArray()); } // ============ 内部实现:lambda 解析 & ToolCallback 包装 ============ private boolean matches(ToolCallObserver<?> observer, Object target, Method invokedMethod) { Object fn = observer.methodRef(); if (fn == null) { // 未指定 methodRef,认为匹配所有工具 return true; } Method keyMethod = resolveMethodFromLambda(fn); if (keyMethod == null) { return false; } return keyMethod.equals(invokedMethod); } /** * 从方法引用 / lambda 中解析底层 Method。 */ private static Method resolveMethodFromLambda(Object lambda) { try { Method writeReplace = lambda.getClass().getDeclaredMethod("writeReplace"); writeReplace.setAccessible(true); Object serializedForm = writeReplace.invoke(lambda); if (!(serializedForm instanceof SerializedLambda)) { return null; } SerializedLambda sl = (SerializedLambda) serializedForm; String implClassName = sl.getImplClass().replace('/', '.'); String implMethodName = sl.getImplMethodName(); Class<?> implClass = Class.forName(implClassName); Method[] methods = implClass.getDeclaredMethods(); for (Method m : methods) { if (!m.getName().equals(implMethodName)) { continue; } m.setAccessible(true); return m; } return null; } catch (Exception e) { log.warn("解析方法引用失败(获取 Method),lambda={}", lambda, e); return null; } } /** * 从方法引用 / lambda 中解析绑定实例对象。 * * 仅对绑定实例方法引用有效,例如:weatherTool::getCurrentWeather */ private static Object resolveTargetFromLambda(Object lambda) { try { Method writeReplace = lambda.getClass().getDeclaredMethod("writeReplace"); writeReplace.setAccessible(true); Object serializedForm = writeReplace.invoke(lambda); if (!(serializedForm instanceof SerializedLambda)) { return null; } SerializedLambda sl = (SerializedLambda) serializedForm; int capturedCount = sl.getCapturedArgCount(); if (capturedCount <= 0) { return null; } return sl.getCapturedArg(0); } catch (Exception e) { log.warn("解析方法引用失败(获取 target),lambda={}", lambda, e); return null; } } /** * 根据 toolBeans 构建带监控的 ToolCallback[]。 */ private static ToolCallback[] buildObservedCallbacks(ToolCallObserverRegistry registry, Object... toolBeans) { Assert.notNull(toolBeans, "toolBeans 不能为空"); if (toolBeans.length == 0) { return new ToolCallback[0]; } // 1. 基础 ToolCallback 数组(Spring AI 官方工具) ToolCallback[] baseCallbacks = ToolCallbacks.from(toolBeans); if (registry.getObservers().isEmpty()) { // 没有观察者就不包裹,直接返回 return baseCallbacks; } // 2. 建立 toolName -> (target, method) 映射 Map<String, TargetMethod> toolNameIndex = buildToolNameIndex(toolBeans); ToolCallback[] observed = new ToolCallback[baseCallbacks.length]; for (int i = 0; i < baseCallbacks.length; i++) { ToolCallback base = baseCallbacks[i]; String toolName = base.getToolDefinition().name(); TargetMethod tm = toolNameIndex.get(toolName); if (tm == null) { log.warn("未在 toolBeans 中找到名称为 {} 的工具方法,将不进行调用观察", toolName); observed[i] = base; continue; } observed[i] = new ObservingToolCallback( base, registry, tm.target(), tm.method() ); } return observed; } private static Map<String, TargetMethod> buildToolNameIndex(Object[] toolBeans) { Map<String, TargetMethod> index = new HashMap<>(); for (Object bean : toolBeans) { Class<?> userClass = ClassUtils.getUserClass(bean); Method[] methods = userClass.getMethods(); for (Method method : methods) { Tool ann = AnnotatedElementUtils.findMergedAnnotation(method, Tool.class); if (ann == null) { continue; } String name = ann.name(); if (name == null || name.isBlank()) { name = method.getName(); } if (index.containsKey(name)) { throw new IllegalStateException("检测到重复的工具名称: " + name + ",请确保所有 @Tool 的 name 唯一"); } index.put(name, new TargetMethod(bean, method)); } } return index; } private record TargetMethod(Object target, Method method) { } /** * 静态内置类:对 Spring AI 的 ToolCallback 进行包装, * 在调用前/后/异常时,把事件转发给 ToolCallObserverRegistry。 */ private static class ObservingToolCallback implements ToolCallback { private final ToolCallback delegate; private final ToolCallObserverRegistry registry; private final Object target; private final Method method; private ObservingToolCallback(ToolCallback delegate, ToolCallObserverRegistry registry, Object target, Method method) { this.delegate = delegate; this.registry = registry; this.target = target; this.method = method; } @Override public ToolDefinition getToolDefinition() { return delegate.getToolDefinition(); } @Override public ToolMetadata getToolMetadata() { return delegate.getToolMetadata(); } @Override public String call(String toolInput) { ToolDefinition def = delegate.getToolDefinition(); Object[] args = new Object[]{toolInput}; registry.notifyBefore(def, target, method, args); try { String result = delegate.call(toolInput); registry.notifyAfter(def, target, method, args, result); return result; } catch (RuntimeException e) { registry.notifyError(def, target, method, args, e); throw e; } } @Override public String call(String toolInput, ToolContext toolContext) { ToolDefinition def = delegate.getToolDefinition(); Object[] args = new Object[]{toolInput}; registry.notifyBefore(def, target, method, args); try { String result = delegate.call(toolInput, toolContext); registry.notifyAfter(def, target, method, args, result); return result; } catch (RuntimeException e) { registry.notifyError(def, target, method, args, e); throw e; } } } }4.3 使用示例:注册观察者 + 挂到 ChatClient
假设有一个闲聊服务,在这里想监控“天气查询工具”的调用情况:
// 文件路径:src/main/java/dev/w0fv1/ai_tool_callback/ConversationService.java package dev.w0fv1.ai_tool_callback; import dev.w0fv1.ai_tool_callback.ToolCallObserverRegistry; import dev.w0fv1.ai_tool_callback.WeatherTool; import dev.w0fv1.ai_tool_callback.SerializableFunction; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.tool.ToolCallback; import org.springframework.stereotype.Service; import reactor.core.publisher.Flux; import java.util.List; @Slf4j @Service public class ConversationService { private final ChatClient chatClient; private final WeatherTool weatherTool; public ConversationService(OpenAiChatModel textQuality, WeatherTool weatherTool) { this.chatClient = ChatClient.builder(textQuality).build(); this.weatherTool = weatherTool; } public Flux<String> chat(String userInput) { // 1. 创建工具调用观察注册中心,并注册一个针对 WeatherTool.getCurrentWeather 的观察者 ToolCallObserverRegistry registry = new ToolCallObserverRegistry(); registry.addObserver( new ToolCallObserverRegistry.ToolCallObserver<SerializableFunction<String, String>>(weatherTool::getCurrentWeather) { @Override public void before(org.springframework.ai.tool.definition.ToolDefinition toolDefinition, Object[] args) { log.info("开始执行工具:{}, args={}", toolDefinition.name(), args); } @Override public void after(org.springframework.ai.tool.definition.ToolDefinition toolDefinition, Object[] args, Object result) { log.info("工具执行完成:{}, result={}", toolDefinition.name(), result); } @Override public void error(org.springframework.ai.tool.definition.ToolDefinition toolDefinition, Object[] args, Throwable throwable) { log.error("工具执行异常:{}", toolDefinition.name(), throwable); } } ); // 2. 由注册中心生成“需要监控”的 ToolCallback[] ToolCallback[] observedCallbacks = registry.buildSpringToolCallbacks(); // 3. 构造请求并挂上这些回调 ChatClient.ChatClientRequestSpec spec = chatClient.prompt() .user(userInput); return spec .toolCallbacks(observedCallbacks) .stream() .content(); } }如果还有一些工具完全不想监控,可以额外用:
ToolCallback[] plain = ToolCallbacks.from(otherToolBean1, otherToolBean2); // 然后合并 observedCallbacks + plain 再传给 ChatClient5. 小结
整个方案的关键点:
- 利用“可序列化方法引用 + SerializedLambda”从调用侧的
xxx::method还原到具体的Method+ Bean 实例。 - 把“谁关心哪个工具方法”抽象成
ToolCallObserverRegistry+ToolCallObserver。 - 使用 Spring AI 官方提供的
ToolCallbacks.from(...)作为基础,再在外面包一层ObservingToolCallback做调用前/后/异常通知。 - 所有 API 都围绕方法引用展开,和业务代码天然对齐,代码改动在编译期就能暴露问题。
最终效果是:在不入侵 Spring AI 核心机制的前提下,为工具调用加上了一条强类型、可订阅、可扩展的“旁路”,上层可以非常精细地监控和控制每一次 Tool 调用的生命周期。