tirea_agentos_server/service/
messages.rs1use serde::Deserialize;
2use serde::Serialize;
3use std::sync::Arc;
4use tirea_agentos::contracts::storage::{
5 MessagePage, MessageQuery, SortOrder, ThreadReader, ThreadStoreError,
6};
7use tirea_agentos::contracts::thread::Message;
8use tirea_agentos::contracts::thread::Visibility;
9
10use super::ApiError;
11
12fn default_message_limit() -> usize {
13 50
14}
15
16#[derive(Debug, Deserialize)]
17pub struct MessageQueryParams {
18 #[serde(default)]
19 pub after: Option<i64>,
20 #[serde(default)]
21 pub before: Option<i64>,
22 #[serde(default = "default_message_limit")]
23 pub limit: usize,
24 #[serde(default)]
25 pub order: Option<String>,
26 #[serde(default)]
27 pub visibility: Option<String>,
28 #[serde(default)]
29 pub run_id: Option<String>,
30}
31
32pub fn parse_message_query(params: &MessageQueryParams) -> MessageQuery {
33 let limit = params.limit.clamp(1, 200);
34 let order = match params.order.as_deref() {
35 Some("desc") => SortOrder::Desc,
36 _ => SortOrder::Asc,
37 };
38 let visibility = match params.visibility.as_deref() {
39 Some("internal") => Some(Visibility::Internal),
40 Some("none") => None,
41 _ => Some(Visibility::All),
42 };
43 MessageQuery {
44 after: params.after,
45 before: params.before,
46 limit,
47 order,
48 visibility,
49 run_id: params.run_id.clone(),
50 }
51}
52
53pub async fn load_message_page(
54 read_store: &Arc<dyn ThreadReader>,
55 thread_id: &str,
56 params: &MessageQueryParams,
57) -> Result<MessagePage, ApiError> {
58 let query = parse_message_query(params);
59 read_store
60 .load_messages(thread_id, &query)
61 .await
62 .map_err(|e| match e {
63 ThreadStoreError::NotFound(_) => ApiError::ThreadNotFound(thread_id.to_string()),
64 other => ApiError::Internal(other.to_string()),
65 })
66}
67
68#[derive(Debug, Serialize)]
69pub struct EncodedMessagePage<M: Serialize> {
70 pub messages: Vec<M>,
71 pub has_more: bool,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 pub next_cursor: Option<i64>,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 pub prev_cursor: Option<i64>,
76}
77
78pub fn encode_message_page<M: Serialize>(
79 page: MessagePage,
80 encode: impl Fn(&Message) -> M,
81) -> EncodedMessagePage<M> {
82 EncodedMessagePage {
83 messages: page.messages.iter().map(|m| encode(&m.message)).collect(),
84 has_more: page.has_more,
85 next_cursor: page.next_cursor,
86 prev_cursor: page.prev_cursor,
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn parse_message_query_defaults_and_visibility() {
96 let params = MessageQueryParams {
97 after: None,
98 before: None,
99 limit: 999,
100 order: None,
101 visibility: None,
102 run_id: None,
103 };
104 let query = parse_message_query(¶ms);
105 assert_eq!(query.limit, 200);
106 assert!(matches!(query.order, SortOrder::Asc));
107 assert!(matches!(query.visibility, Some(Visibility::All)));
108
109 let params = MessageQueryParams {
110 after: None,
111 before: None,
112 limit: 1,
113 order: Some("desc".to_string()),
114 visibility: Some("internal".to_string()),
115 run_id: Some("r1".to_string()),
116 };
117 let query = parse_message_query(¶ms);
118 assert_eq!(query.limit, 1);
119 assert!(matches!(query.order, SortOrder::Desc));
120 assert!(matches!(query.visibility, Some(Visibility::Internal)));
121 assert_eq!(query.run_id.as_deref(), Some("r1"));
122 }
123}