tirea_agentos_server/service/
messages.rs

1use serde::Deserialize;
2use serde::Serialize;
3use std::sync::Arc;
4use tirea_agentos::contracts::storage::{
5    MessagePage, MessageQuery, SortOrder, ThreadReader, ThreadStoreError,
6};
7use tirea_agentos::contracts::thread::Message;
8use tirea_agentos::contracts::thread::Visibility;
9
10use super::ApiError;
11
12fn default_message_limit() -> usize {
13    50
14}
15
16#[derive(Debug, Deserialize)]
17pub struct MessageQueryParams {
18    #[serde(default)]
19    pub after: Option<i64>,
20    #[serde(default)]
21    pub before: Option<i64>,
22    #[serde(default = "default_message_limit")]
23    pub limit: usize,
24    #[serde(default)]
25    pub order: Option<String>,
26    #[serde(default)]
27    pub visibility: Option<String>,
28    #[serde(default)]
29    pub run_id: Option<String>,
30}
31
32pub fn parse_message_query(params: &MessageQueryParams) -> MessageQuery {
33    let limit = params.limit.clamp(1, 200);
34    let order = match params.order.as_deref() {
35        Some("desc") => SortOrder::Desc,
36        _ => SortOrder::Asc,
37    };
38    let visibility = match params.visibility.as_deref() {
39        Some("internal") => Some(Visibility::Internal),
40        Some("none") => None,
41        _ => Some(Visibility::All),
42    };
43    MessageQuery {
44        after: params.after,
45        before: params.before,
46        limit,
47        order,
48        visibility,
49        run_id: params.run_id.clone(),
50    }
51}
52
53pub async fn load_message_page(
54    read_store: &Arc<dyn ThreadReader>,
55    thread_id: &str,
56    params: &MessageQueryParams,
57) -> Result<MessagePage, ApiError> {
58    let query = parse_message_query(params);
59    read_store
60        .load_messages(thread_id, &query)
61        .await
62        .map_err(|e| match e {
63            ThreadStoreError::NotFound(_) => ApiError::ThreadNotFound(thread_id.to_string()),
64            other => ApiError::Internal(other.to_string()),
65        })
66}
67
68#[derive(Debug, Serialize)]
69pub struct EncodedMessagePage<M: Serialize> {
70    pub messages: Vec<M>,
71    pub has_more: bool,
72    #[serde(skip_serializing_if = "Option::is_none")]
73    pub next_cursor: Option<i64>,
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub prev_cursor: Option<i64>,
76}
77
78pub fn encode_message_page<M: Serialize>(
79    page: MessagePage,
80    encode: impl Fn(&Message) -> M,
81) -> EncodedMessagePage<M> {
82    EncodedMessagePage {
83        messages: page.messages.iter().map(|m| encode(&m.message)).collect(),
84        has_more: page.has_more,
85        next_cursor: page.next_cursor,
86        prev_cursor: page.prev_cursor,
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[test]
95    fn parse_message_query_defaults_and_visibility() {
96        let params = MessageQueryParams {
97            after: None,
98            before: None,
99            limit: 999,
100            order: None,
101            visibility: None,
102            run_id: None,
103        };
104        let query = parse_message_query(&params);
105        assert_eq!(query.limit, 200);
106        assert!(matches!(query.order, SortOrder::Asc));
107        assert!(matches!(query.visibility, Some(Visibility::All)));
108
109        let params = MessageQueryParams {
110            after: None,
111            before: None,
112            limit: 1,
113            order: Some("desc".to_string()),
114            visibility: Some("internal".to_string()),
115            run_id: Some("r1".to_string()),
116        };
117        let query = parse_message_query(&params);
118        assert_eq!(query.limit, 1);
119        assert!(matches!(query.order, SortOrder::Desc));
120        assert!(matches!(query.visibility, Some(Visibility::Internal)));
121        assert_eq!(query.run_id.as_deref(), Some("r1"));
122    }
123}