Files
kiro.rs/src/anthropic/types.rs
2026-03-05 21:28:41 +08:00

316 lines
8.6 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! Anthropic API 类型定义
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
// === 错误响应 ===
/// API 错误响应
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub error: ErrorDetail,
}
/// 错误详情
#[derive(Debug, Serialize)]
pub struct ErrorDetail {
#[serde(rename = "type")]
pub error_type: String,
pub message: String,
}
impl ErrorResponse {
/// 创建新的错误响应
pub fn new(error_type: impl Into<String>, message: impl Into<String>) -> Self {
Self {
error: ErrorDetail {
error_type: error_type.into(),
message: message.into(),
},
}
}
/// 创建认证错误响应
pub fn authentication_error() -> Self {
Self::new("authentication_error", "Invalid API key")
}
}
// === Models 端点类型 ===
/// 模型信息
#[derive(Debug, Serialize)]
pub struct Model {
pub id: String,
pub object: String,
pub created: i64,
pub owned_by: String,
pub display_name: String,
#[serde(rename = "type")]
pub model_type: String,
pub max_tokens: i32,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_length: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking: Option<bool>,
}
/// 模型列表响应
#[derive(Debug, Serialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<Model>,
}
// === Messages 端点类型 ===
/// 最大思考预算 tokens
const MAX_BUDGET_TOKENS: i32 = 128_000;
/// Thinking 配置
#[derive(Debug, Deserialize, Clone)]
pub struct Thinking {
#[serde(rename = "type")]
pub thinking_type: String,
#[serde(
default = "default_budget_tokens",
deserialize_with = "deserialize_budget_tokens"
)]
pub budget_tokens: i32,
}
impl Thinking {
/// 是否启用了 thinkingenabled 或 adaptive
pub fn is_enabled(&self) -> bool {
self.thinking_type == "enabled" || self.thinking_type == "adaptive"
}
}
fn default_budget_tokens() -> i32 {
20000
}
fn deserialize_budget_tokens<'de, D>(deserializer: D) -> Result<i32, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = i32::deserialize(deserializer)?;
Ok(value.min(MAX_BUDGET_TOKENS))
}
/// OutputConfig 配置
#[derive(Debug, Deserialize, Clone)]
pub struct OutputConfig {
#[serde(default = "default_effort")]
pub effort: String,
}
fn default_effort() -> String {
"high".to_string()
}
/// Claude Code 请求中的 metadata
#[derive(Debug, Clone, Deserialize)]
pub struct Metadata {
/// 用户 ID格式如: user_xxx_account__session_0b4445e1-f5be-49e1-87ce-62bbc28ad705
pub user_id: Option<String>,
}
/// Messages 请求体
#[derive(Debug, Deserialize)]
pub struct MessagesRequest {
pub model: String,
/// 为 Anthropic API 兼容保留,实际不透传给 Kiro 上游
pub max_tokens: i32,
pub messages: Vec<Message>,
#[serde(default)]
pub stream: bool,
#[serde(default, deserialize_with = "deserialize_system")]
pub system: Option<Vec<SystemMessage>>,
pub tools: Option<Vec<Tool>>,
#[allow(dead_code)]
pub tool_choice: Option<serde_json::Value>,
pub thinking: Option<Thinking>,
pub output_config: Option<OutputConfig>,
/// Claude Code 请求中的 metadata包含 session 信息
pub metadata: Option<Metadata>,
}
/// 反序列化 system 字段,支持字符串或数组格式
fn deserialize_system<'de, D>(deserializer: D) -> Result<Option<Vec<SystemMessage>>, D::Error>
where
D: serde::Deserializer<'de>,
{
// 创建一个 visitor 来处理 string 或 array
struct SystemVisitor;
impl<'de> serde::de::Visitor<'de> for SystemVisitor {
type Value = Option<Vec<SystemMessage>>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string or an array of system messages")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Some(vec![SystemMessage {
text: value.to_string(),
}]))
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut messages = Vec::new();
while let Some(msg) = seq.next_element()? {
messages.push(msg);
}
Ok(if messages.is_empty() {
None
} else {
Some(messages)
})
}
fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(None)
}
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'de>,
{
serde::de::Deserialize::deserialize(deserializer)
}
}
deserializer.deserialize_any(SystemVisitor)
}
/// 消息
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Message {
pub role: String,
/// 可以是 string 或 ContentBlock 数组
pub content: serde_json::Value,
}
/// 系统消息
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SystemMessage {
pub text: String,
}
/// 工具定义
///
/// 支持两种格式:
/// 1. 普通工具:{ name, description, input_schema }
/// 2. WebSearch 工具:{ type: "web_search_20250305", name: "web_search", max_uses: 8 }
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Tool {
/// 工具类型,如 "web_search_20250305"(可选,仅 WebSearch 工具)
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
pub tool_type: Option<String>,
/// 工具名称
#[serde(default)]
pub name: String,
/// 工具描述普通工具必需WebSearch 工具可选)
#[serde(default)]
pub description: String,
/// 输入参数 schema普通工具必需WebSearch 工具无此字段)
#[serde(default)]
pub input_schema: HashMap<String, serde_json::Value>,
/// 最大使用次数(仅 WebSearch 工具)
#[serde(skip_serializing_if = "Option::is_none")]
pub max_uses: Option<i32>,
}
impl Tool {
/// 检查是否为 WebSearch 工具
#[allow(dead_code)]
pub fn is_web_search(&self) -> bool {
self.tool_type
.as_ref()
.is_some_and(|t| t.starts_with("web_search"))
}
}
/// 内容块
#[derive(Debug, Deserialize, Serialize)]
pub struct ContentBlock {
#[serde(rename = "type")]
pub block_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_use_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub input: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub is_error: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub source: Option<ImageSource>,
}
/// 图片数据源
#[derive(Debug, Deserialize, Serialize)]
pub struct ImageSource {
#[serde(rename = "type")]
pub source_type: String,
pub media_type: String,
pub data: String,
}
// === Count Tokens 端点类型 ===
/// Token 计数请求
#[derive(Debug, Serialize, Deserialize)]
pub struct CountTokensRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deserialize_system"
)]
pub system: Option<Vec<SystemMessage>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
}
/// Token 计数响应
#[derive(Debug, Serialize, Deserialize)]
pub struct CountTokensResponse {
pub input_tokens: i32,
}
/// 根据模型名称获取上下文窗口大小
///
/// - Opus 4.6 和 Sonnet 4.6 系列: 1,000,000 tokens
/// - 其他模型: 200,000 tokens
pub fn get_context_window_size(model: &str) -> i32 {
let model_lower = model.to_lowercase();
if (model_lower.contains("opus") || model_lower.contains("sonnet"))
&& (model_lower.contains("4-6") || model_lower.contains("4.6"))
{
1_000_000
} else {
200_000
}
}