tirea_agentos_server/protocol/ai_sdk_v6/
http.rs

1use axum::extract::{Path, Query, State};
2use axum::http::{header, HeaderValue, StatusCode};
3use axum::response::{IntoResponse, Response};
4use axum::routing::{get, post};
5use axum::{Json, Router};
6use bytes::Bytes;
7use std::convert::Infallible;
8use tirea_agentos::runtime::AgentOsRunError;
9use tirea_protocol_ai_sdk_v6::{
10    AiSdkEncoder, AiSdkTrigger, AiSdkV6HistoryEncoder, AiSdkV6RunRequest, UIStreamEvent,
11    AI_SDK_VERSION,
12};
13
14use super::runtime::apply_ai_sdk_extensions;
15use tokio::sync::broadcast;
16
17use crate::service::{
18    current_run_id_for_thread, encode_message_page, forward_dialog_decisions_by_thread,
19    load_message_page, start_http_dialog_run, truncate_thread_at_message, ApiError, AppState,
20    MessageQueryParams,
21};
22use crate::transport::http_run::{wire_http_sse_relay, HttpSseRelayConfig};
23use crate::transport::http_sse::{sse_body_stream, sse_response};
24
25const RUN_PATH: &str = "/agents/:agent_id/runs";
26const RESUME_STREAM_PATH: &str = "/agents/:agent_id/chats/:chat_id/stream";
27/// Legacy path kept for backward-compatibility with AI SDK clients that reconnect
28/// via `/runs/:chat_id/stream` after a network drop.
29const LEGACY_RESUME_STREAM_PATH: &str = "/agents/:agent_id/runs/:chat_id/stream";
30const THREAD_MESSAGES_PATH: &str = "/threads/:id/messages";
31
32/// Build AI SDK v6 HTTP routes.
33pub fn routes() -> Router<AppState> {
34    Router::new()
35        .route(RUN_PATH, post(run))
36        .route(RESUME_STREAM_PATH, get(resume_stream))
37        .route(LEGACY_RESUME_STREAM_PATH, get(resume_stream))
38        .route(THREAD_MESSAGES_PATH, get(thread_messages))
39}
40
41async fn thread_messages(
42    State(st): State<AppState>,
43    Path(id): Path<String>,
44    Query(params): Query<MessageQueryParams>,
45) -> Result<impl IntoResponse, ApiError> {
46    let page = load_message_page(&st.read_store, &id, &params).await?;
47    let encoded = encode_message_page(page, AiSdkV6HistoryEncoder::encode_message);
48    Ok(Json(encoded))
49}
50
51async fn run(
52    State(st): State<AppState>,
53    Path(agent_id): Path<String>,
54    Json(req): Json<AiSdkV6RunRequest>,
55) -> Result<Response, ApiError> {
56    req.validate().map_err(ApiError::BadRequest)?;
57    if req.trigger == Some(AiSdkTrigger::RegenerateMessage) {
58        truncate_thread_at_message(&st.os, &req.thread_id, req.message_id.as_deref().unwrap())
59            .await?;
60    }
61
62    let suspension_decisions = req.suspension_decisions();
63    let maybe_forwarded = forward_dialog_decisions_by_thread(
64        &st.os,
65        &agent_id,
66        &req.thread_id,
67        req.has_user_input(),
68        None,
69        &suspension_decisions,
70    )
71    .await?;
72    if let Some(forwarded) = maybe_forwarded {
73        return Ok((
74            StatusCode::ACCEPTED,
75            Json(serde_json::json!({
76                "status": "decision_forwarded",
77                "threadId": forwarded.thread_id,
78            })),
79        )
80            .into_response());
81    }
82
83    let mut resolved = st.os.resolve(&agent_id).map_err(AgentOsRunError::from)?;
84    apply_ai_sdk_extensions(&mut resolved, &req);
85    let run_request = req.into_runtime_run_request(agent_id.clone());
86    let prepared = start_http_dialog_run(&st.os, resolved, run_request, &agent_id).await?;
87    let (fanout, _) = broadcast::channel::<Bytes>(128);
88    if !st
89        .os
90        .bind_thread_run_stream_fanout(&prepared.run_id, fanout.clone())
91        .await
92    {
93        return Err(ApiError::Internal(format!(
94            "active run handle missing for run '{}'",
95            prepared.run_id
96        )));
97    }
98    let run_id_for_cleanup = prepared.run_id.clone();
99    let os_for_cleanup = st.os.clone();
100
101    let encoder = AiSdkEncoder::new();
102    let sse_rx = wire_http_sse_relay(
103        prepared.starter,
104        encoder,
105        prepared.ingress_rx,
106        HttpSseRelayConfig {
107            thread_id: prepared.thread_id,
108            fanout: Some(fanout.clone()),
109            resumable_downstream: true,
110            protocol_label: "ai-sdk",
111            on_relay_done: move |sse_tx: tokio::sync::mpsc::Sender<Bytes>| async move {
112                let trailer = Bytes::from("data: [DONE]\n\n");
113                let _ = fanout.send(trailer.clone());
114                if sse_tx.send(trailer).await.is_err() {
115                    let _ = os_for_cleanup
116                        .cancel_active_run_by_id(&run_id_for_cleanup)
117                        .await;
118                }
119            },
120            error_formatter: |msg| {
121                let json = serde_json::to_string(&UIStreamEvent::error(&msg)).unwrap_or_default();
122                Bytes::from(format!("data: {json}\n\n"))
123            },
124        },
125    );
126
127    Ok(ai_sdk_sse_response(sse_body_stream(sse_rx)))
128}
129
130async fn resume_stream(
131    State(st): State<AppState>,
132    Path((agent_id, chat_id)): Path<(String, String)>,
133) -> Result<Response, ApiError> {
134    let Some(run_id) =
135        current_run_id_for_thread(&st.os, &agent_id, &chat_id, st.read_store.as_ref()).await?
136    else {
137        return Ok(StatusCode::NO_CONTENT.into_response());
138    };
139    let Some(mut receiver) = st.os.subscribe_thread_run_stream(&run_id).await else {
140        return Ok(StatusCode::NO_CONTENT.into_response());
141    };
142
143    let stream = async_stream::stream! {
144        loop {
145            match receiver.recv().await {
146                Ok(chunk) => yield Ok::<Bytes, Infallible>(chunk),
147                Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
148                Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
149            }
150        }
151    };
152    Ok(ai_sdk_sse_response(stream))
153}
154
155fn ai_sdk_sse_response<S>(stream: S) -> Response
156where
157    S: futures::Stream<Item = Result<Bytes, Infallible>> + Send + 'static,
158{
159    let mut response = sse_response(stream);
160    response.headers_mut().insert(
161        header::HeaderName::from_static("x-vercel-ai-ui-message-stream"),
162        HeaderValue::from_static("v1"),
163    );
164    response.headers_mut().insert(
165        header::HeaderName::from_static("x-tirea-ai-sdk-version"),
166        HeaderValue::from_static(AI_SDK_VERSION),
167    );
168    response
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use crate::transport::runtime_endpoint::RunStarter;
175    use std::pin::Pin;
176    use tirea_agentos::contracts::{AgentEvent, RunRequest, ToolCallDecision};
177    use tirea_agentos::runtime::RunStream;
178    use tirea_contract::RunOrigin;
179    use tirea_contract::RuntimeInput;
180    use tokio::sync::mpsc;
181
182    fn test_run_request() -> RunRequest {
183        RunRequest {
184            agent_id: "test".into(),
185            thread_id: None,
186            run_id: None,
187            parent_run_id: None,
188            parent_thread_id: None,
189            resource_id: None,
190            origin: RunOrigin::default(),
191            state: None,
192            messages: vec![],
193            initial_decisions: vec![],
194            source_mailbox_entry_id: None,
195        }
196    }
197
198    fn fake_run(events: Vec<AgentEvent>) -> RunStream {
199        let (decision_tx, _decision_rx) = mpsc::unbounded_channel::<ToolCallDecision>();
200        let (event_tx, event_rx) = mpsc::channel::<AgentEvent>(16);
201
202        tokio::spawn(async move {
203            for event in events {
204                let _ = event_tx.send(event).await;
205            }
206        });
207
208        let stream: Pin<Box<dyn futures::Stream<Item = AgentEvent> + Send>> =
209            Box::pin(async_stream::stream! {
210                let mut rx = event_rx;
211                while let Some(item) = rx.recv().await {
212                    yield item;
213                }
214            });
215
216        RunStream {
217            thread_id: "thread-ai-sdk".to_string(),
218            run_id: "run-ai-sdk".to_string(),
219            decision_tx,
220            events: stream,
221        }
222    }
223
224    fn ai_sdk_error_chunk(msg: &str) -> Bytes {
225        let json =
226            serde_json::to_string(&UIStreamEvent::error(msg)).expect("serialize ai-sdk error");
227        Bytes::from(format!("data: {json}\n\n"))
228    }
229
230    #[test]
231    fn ai_sdk_error_chunk_matches_ui_message_stream_schema() {
232        let chunk = ai_sdk_error_chunk("Web stream error for model 'openai::gemini-2.5-flash '");
233        let text = String::from_utf8(chunk.to_vec()).expect("utf-8 sse");
234        let payload = text.trim().strip_prefix("data: ").expect("sse payload");
235        let event: UIStreamEvent = serde_json::from_str(payload).expect("valid ai-sdk event");
236
237        match event {
238            UIStreamEvent::Error { error_text } => {
239                assert!(error_text.contains("Web stream error for model"));
240                assert!(!payload.contains("recoverable"));
241                assert!(!payload.contains("\"message\""));
242            }
243            other => panic!("expected ai-sdk error event, got {other:?}"),
244        }
245    }
246
247    #[tokio::test]
248    async fn runtime_error_event_streams_as_valid_ai_sdk_error_chunk() {
249        let starter: RunStarter = Box::new(move |_request| {
250            Box::pin(async move {
251                Ok(fake_run(vec![AgentEvent::Error {
252                    message: "provider stream failed".to_string(),
253                    code: Some("PROVIDER_ERROR".to_string()),
254                }]))
255            })
256        });
257        let (ingress_tx, ingress_rx) = mpsc::unbounded_channel::<RuntimeInput>();
258        ingress_tx
259            .send(RuntimeInput::Run(test_run_request()))
260            .expect("send run request");
261        drop(ingress_tx);
262
263        let mut sse_rx = wire_http_sse_relay(
264            starter,
265            AiSdkEncoder::new(),
266            ingress_rx,
267            HttpSseRelayConfig {
268                thread_id: "thread-ai-sdk".to_string(),
269                fanout: None,
270                resumable_downstream: true,
271                protocol_label: "ai-sdk",
272                on_relay_done: |_sse_tx| async move {},
273                error_formatter: |msg: String| ai_sdk_error_chunk(&msg),
274            },
275        );
276
277        let chunks: Vec<Bytes> = async {
278            let mut out = Vec::new();
279            while let Some(chunk) = sse_rx.recv().await {
280                out.push(chunk);
281            }
282            out
283        }
284        .await;
285
286        let payloads: Vec<&str> = chunks
287            .iter()
288            .filter_map(|chunk| std::str::from_utf8(chunk).ok())
289            .filter_map(|text| text.trim().strip_prefix("data: "))
290            .collect();
291
292        assert_eq!(
293            payloads.len(),
294            1,
295            "unexpected ai-sdk payloads: {payloads:?}"
296        );
297        let event: UIStreamEvent =
298            serde_json::from_str(payloads[0]).expect("valid ai-sdk runtime error event");
299        assert!(matches!(
300            event,
301            UIStreamEvent::Error { ref error_text } if error_text == "provider stream failed"
302        ));
303    }
304
305    #[tokio::test]
306    async fn ai_sdk_sse_response_sets_protocol_headers() {
307        let response = ai_sdk_sse_response(futures::stream::empty::<Result<Bytes, Infallible>>());
308        let headers = response.headers();
309
310        assert_eq!(
311            headers.get("content-type").and_then(|v| v.to_str().ok()),
312            Some("text/event-stream")
313        );
314        assert_eq!(
315            headers
316                .get("x-vercel-ai-ui-message-stream")
317                .and_then(|v| v.to_str().ok()),
318            Some("v1")
319        );
320        assert_eq!(
321            headers
322                .get("x-tirea-ai-sdk-version")
323                .and_then(|v| v.to_str().ok()),
324            Some(AI_SDK_VERSION)
325        );
326    }
327}