tirea_agentos_server/protocol/ai_sdk_v6/
http.rs1use axum::extract::{Path, Query, State};
2use axum::http::{header, HeaderValue, StatusCode};
3use axum::response::{IntoResponse, Response};
4use axum::routing::{get, post};
5use axum::{Json, Router};
6use bytes::Bytes;
7use std::convert::Infallible;
8use tirea_agentos::runtime::AgentOsRunError;
9use tirea_protocol_ai_sdk_v6::{
10 AiSdkEncoder, AiSdkTrigger, AiSdkV6HistoryEncoder, AiSdkV6RunRequest, UIStreamEvent,
11 AI_SDK_VERSION,
12};
13
14use super::runtime::apply_ai_sdk_extensions;
15use tokio::sync::broadcast;
16
17use crate::service::{
18 current_run_id_for_thread, encode_message_page, forward_dialog_decisions_by_thread,
19 load_message_page, start_http_dialog_run, truncate_thread_at_message, ApiError, AppState,
20 MessageQueryParams,
21};
22use crate::transport::http_run::{wire_http_sse_relay, HttpSseRelayConfig};
23use crate::transport::http_sse::{sse_body_stream, sse_response};
24
25const RUN_PATH: &str = "/agents/:agent_id/runs";
26const RESUME_STREAM_PATH: &str = "/agents/:agent_id/chats/:chat_id/stream";
27const LEGACY_RESUME_STREAM_PATH: &str = "/agents/:agent_id/runs/:chat_id/stream";
30const THREAD_MESSAGES_PATH: &str = "/threads/:id/messages";
31
32pub fn routes() -> Router<AppState> {
34 Router::new()
35 .route(RUN_PATH, post(run))
36 .route(RESUME_STREAM_PATH, get(resume_stream))
37 .route(LEGACY_RESUME_STREAM_PATH, get(resume_stream))
38 .route(THREAD_MESSAGES_PATH, get(thread_messages))
39}
40
41async fn thread_messages(
42 State(st): State<AppState>,
43 Path(id): Path<String>,
44 Query(params): Query<MessageQueryParams>,
45) -> Result<impl IntoResponse, ApiError> {
46 let page = load_message_page(&st.read_store, &id, ¶ms).await?;
47 let encoded = encode_message_page(page, AiSdkV6HistoryEncoder::encode_message);
48 Ok(Json(encoded))
49}
50
51async fn run(
52 State(st): State<AppState>,
53 Path(agent_id): Path<String>,
54 Json(req): Json<AiSdkV6RunRequest>,
55) -> Result<Response, ApiError> {
56 req.validate().map_err(ApiError::BadRequest)?;
57 if req.trigger == Some(AiSdkTrigger::RegenerateMessage) {
58 truncate_thread_at_message(&st.os, &req.thread_id, req.message_id.as_deref().unwrap())
59 .await?;
60 }
61
62 let suspension_decisions = req.suspension_decisions();
63 let maybe_forwarded = forward_dialog_decisions_by_thread(
64 &st.os,
65 &agent_id,
66 &req.thread_id,
67 req.has_user_input(),
68 None,
69 &suspension_decisions,
70 )
71 .await?;
72 if let Some(forwarded) = maybe_forwarded {
73 return Ok((
74 StatusCode::ACCEPTED,
75 Json(serde_json::json!({
76 "status": "decision_forwarded",
77 "threadId": forwarded.thread_id,
78 })),
79 )
80 .into_response());
81 }
82
83 let mut resolved = st.os.resolve(&agent_id).map_err(AgentOsRunError::from)?;
84 apply_ai_sdk_extensions(&mut resolved, &req);
85 let run_request = req.into_runtime_run_request(agent_id.clone());
86 let prepared = start_http_dialog_run(&st.os, resolved, run_request, &agent_id).await?;
87 let (fanout, _) = broadcast::channel::<Bytes>(128);
88 if !st
89 .os
90 .bind_thread_run_stream_fanout(&prepared.run_id, fanout.clone())
91 .await
92 {
93 return Err(ApiError::Internal(format!(
94 "active run handle missing for run '{}'",
95 prepared.run_id
96 )));
97 }
98 let run_id_for_cleanup = prepared.run_id.clone();
99 let os_for_cleanup = st.os.clone();
100
101 let encoder = AiSdkEncoder::new();
102 let sse_rx = wire_http_sse_relay(
103 prepared.starter,
104 encoder,
105 prepared.ingress_rx,
106 HttpSseRelayConfig {
107 thread_id: prepared.thread_id,
108 fanout: Some(fanout.clone()),
109 resumable_downstream: true,
110 protocol_label: "ai-sdk",
111 on_relay_done: move |sse_tx: tokio::sync::mpsc::Sender<Bytes>| async move {
112 let trailer = Bytes::from("data: [DONE]\n\n");
113 let _ = fanout.send(trailer.clone());
114 if sse_tx.send(trailer).await.is_err() {
115 let _ = os_for_cleanup
116 .cancel_active_run_by_id(&run_id_for_cleanup)
117 .await;
118 }
119 },
120 error_formatter: |msg| {
121 let json = serde_json::to_string(&UIStreamEvent::error(&msg)).unwrap_or_default();
122 Bytes::from(format!("data: {json}\n\n"))
123 },
124 },
125 );
126
127 Ok(ai_sdk_sse_response(sse_body_stream(sse_rx)))
128}
129
130async fn resume_stream(
131 State(st): State<AppState>,
132 Path((agent_id, chat_id)): Path<(String, String)>,
133) -> Result<Response, ApiError> {
134 let Some(run_id) =
135 current_run_id_for_thread(&st.os, &agent_id, &chat_id, st.read_store.as_ref()).await?
136 else {
137 return Ok(StatusCode::NO_CONTENT.into_response());
138 };
139 let Some(mut receiver) = st.os.subscribe_thread_run_stream(&run_id).await else {
140 return Ok(StatusCode::NO_CONTENT.into_response());
141 };
142
143 let stream = async_stream::stream! {
144 loop {
145 match receiver.recv().await {
146 Ok(chunk) => yield Ok::<Bytes, Infallible>(chunk),
147 Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
148 Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
149 }
150 }
151 };
152 Ok(ai_sdk_sse_response(stream))
153}
154
155fn ai_sdk_sse_response<S>(stream: S) -> Response
156where
157 S: futures::Stream<Item = Result<Bytes, Infallible>> + Send + 'static,
158{
159 let mut response = sse_response(stream);
160 response.headers_mut().insert(
161 header::HeaderName::from_static("x-vercel-ai-ui-message-stream"),
162 HeaderValue::from_static("v1"),
163 );
164 response.headers_mut().insert(
165 header::HeaderName::from_static("x-tirea-ai-sdk-version"),
166 HeaderValue::from_static(AI_SDK_VERSION),
167 );
168 response
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use crate::transport::runtime_endpoint::RunStarter;
175 use std::pin::Pin;
176 use tirea_agentos::contracts::{AgentEvent, RunRequest, ToolCallDecision};
177 use tirea_agentos::runtime::RunStream;
178 use tirea_contract::RunOrigin;
179 use tirea_contract::RuntimeInput;
180 use tokio::sync::mpsc;
181
182 fn test_run_request() -> RunRequest {
183 RunRequest {
184 agent_id: "test".into(),
185 thread_id: None,
186 run_id: None,
187 parent_run_id: None,
188 parent_thread_id: None,
189 resource_id: None,
190 origin: RunOrigin::default(),
191 state: None,
192 messages: vec![],
193 initial_decisions: vec![],
194 source_mailbox_entry_id: None,
195 }
196 }
197
198 fn fake_run(events: Vec<AgentEvent>) -> RunStream {
199 let (decision_tx, _decision_rx) = mpsc::unbounded_channel::<ToolCallDecision>();
200 let (event_tx, event_rx) = mpsc::channel::<AgentEvent>(16);
201
202 tokio::spawn(async move {
203 for event in events {
204 let _ = event_tx.send(event).await;
205 }
206 });
207
208 let stream: Pin<Box<dyn futures::Stream<Item = AgentEvent> + Send>> =
209 Box::pin(async_stream::stream! {
210 let mut rx = event_rx;
211 while let Some(item) = rx.recv().await {
212 yield item;
213 }
214 });
215
216 RunStream {
217 thread_id: "thread-ai-sdk".to_string(),
218 run_id: "run-ai-sdk".to_string(),
219 decision_tx,
220 events: stream,
221 }
222 }
223
224 fn ai_sdk_error_chunk(msg: &str) -> Bytes {
225 let json =
226 serde_json::to_string(&UIStreamEvent::error(msg)).expect("serialize ai-sdk error");
227 Bytes::from(format!("data: {json}\n\n"))
228 }
229
230 #[test]
231 fn ai_sdk_error_chunk_matches_ui_message_stream_schema() {
232 let chunk = ai_sdk_error_chunk("Web stream error for model 'openai::gemini-2.5-flash '");
233 let text = String::from_utf8(chunk.to_vec()).expect("utf-8 sse");
234 let payload = text.trim().strip_prefix("data: ").expect("sse payload");
235 let event: UIStreamEvent = serde_json::from_str(payload).expect("valid ai-sdk event");
236
237 match event {
238 UIStreamEvent::Error { error_text } => {
239 assert!(error_text.contains("Web stream error for model"));
240 assert!(!payload.contains("recoverable"));
241 assert!(!payload.contains("\"message\""));
242 }
243 other => panic!("expected ai-sdk error event, got {other:?}"),
244 }
245 }
246
247 #[tokio::test]
248 async fn runtime_error_event_streams_as_valid_ai_sdk_error_chunk() {
249 let starter: RunStarter = Box::new(move |_request| {
250 Box::pin(async move {
251 Ok(fake_run(vec![AgentEvent::Error {
252 message: "provider stream failed".to_string(),
253 code: Some("PROVIDER_ERROR".to_string()),
254 }]))
255 })
256 });
257 let (ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
258 ingress_tx
259 .send(RuntimeInput::Run(test_run_request()))
260 .expect("send run request");
261 drop(ingress_tx);
262
263 let mut sse_rx = wire_http_sse_relay(
264 starter,
265 AiSdkEncoder::new(),
266 ingress_rx,
267 HttpSseRelayConfig {
268 thread_id: "thread-ai-sdk".to_string(),
269 fanout: None,
270 resumable_downstream: true,
271 protocol_label: "ai-sdk",
272 on_relay_done: |_sse_tx| async move {},
273 error_formatter: |msg: String| ai_sdk_error_chunk(&msg),
274 },
275 );
276
277 let chunks: Vec<Bytes> = async {
278 let mut out = Vec::new();
279 while let Some(chunk) = sse_rx.recv().await {
280 out.push(chunk);
281 }
282 out
283 }
284 .await;
285
286 let payloads: Vec<&str> = chunks
287 .iter()
288 .filter_map(|chunk| std::str::from_utf8(chunk).ok())
289 .filter_map(|text| text.trim().strip_prefix("data: "))
290 .collect();
291
292 assert_eq!(
293 payloads.len(),
294 1,
295 "unexpected ai-sdk payloads: {payloads:?}"
296 );
297 let event: UIStreamEvent =
298 serde_json::from_str(payloads[0]).expect("valid ai-sdk runtime error event");
299 assert!(matches!(
300 event,
301 UIStreamEvent::Error { ref error_text } if error_text == "provider stream failed"
302 ));
303 }
304
305 #[tokio::test]
306 async fn ai_sdk_sse_response_sets_protocol_headers() {
307 let response = ai_sdk_sse_response(futures::stream::empty::<Result<Bytes, Infallible>>());
308 let headers = response.headers();
309
310 assert_eq!(
311 headers.get("content-type").and_then(|v| v.to_str().ok()),
312 Some("text/event-stream")
313 );
314 assert_eq!(
315 headers
316 .get("x-vercel-ai-ui-message-stream")
317 .and_then(|v| v.to_str().ok()),
318 Some("v1")
319 );
320 assert_eq!(
321 headers
322 .get("x-tirea-ai-sdk-version")
323 .and_then(|v| v.to_str().ok()),
324 Some(AI_SDK_VERSION)
325 );
326 }
327}