Skip to content

Commit 91cb5a7

Browse files
committed
lastOf support
Signed-off-by: Nick Mitchell <[email protected]>
1 parent ba5744b commit 91cb5a7

File tree

4 files changed

+224
-78
lines changed

4 files changed

+224
-78
lines changed

pdl-live-react/src-tauri/src/pdl/ast.rs

Lines changed: 166 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use ::std::collections::HashMap;
22
use serde::{Deserialize, Serialize};
3-
use serde_json::{Number, Value};
3+
use serde_json::{to_string, Number, Value};
44

55
#[derive(Serialize, Deserialize, Debug, Clone)]
66
//why doesn't this work? #[serde(rename_all_fields(serialize = "lowercase"))]
@@ -74,6 +74,86 @@ impl CallBlock {
7474
}
7575
}
7676

77+
pub trait SequencingBlock {
78+
fn kind(&self) -> &str;
79+
fn description(&self) -> &Option<String>;
80+
fn role(&self) -> &Option<Role>;
81+
fn def(&self) -> &Option<String>;
82+
fn defs(&self) -> &Option<HashMap<String, PdlBlock>>;
83+
fn items(&self) -> &Vec<PdlBlock>;
84+
fn with_items(&self, items: Vec<PdlBlock>) -> Self;
85+
fn parser(&self) -> &Option<PdlParser>;
86+
fn to_block(&self) -> PdlBlock;
87+
fn result_for(&self, output_results: Vec<PdlResult>) -> PdlResult;
88+
fn messages_for<T: Clone>(&self, output_messages: Vec<T>) -> Vec<T>;
89+
}
90+
91+
/// Return the value of the last block if the list of blocks
92+
#[derive(Serialize, Deserialize, Debug, Clone)]
93+
pub struct LastOfBlock {
94+
/// Sequence of blocks to execute
95+
#[serde(rename = "lastOf")]
96+
pub last_of: Vec<PdlBlock>,
97+
98+
#[serde(skip_serializing_if = "Option::is_none")]
99+
pub description: Option<String>,
100+
101+
#[serde(skip_serializing_if = "Option::is_none")]
102+
pub role: Option<Role>,
103+
104+
#[serde(skip_serializing_if = "Option::is_none")]
105+
pub defs: Option<HashMap<String, PdlBlock>>,
106+
107+
#[serde(skip_serializing_if = "Option::is_none")]
108+
pub parser: Option<PdlParser>,
109+
110+
#[serde(skip_serializing_if = "Option::is_none")]
111+
pub def: Option<String>,
112+
}
113+
impl SequencingBlock for LastOfBlock {
114+
fn kind(&self) -> &str {
115+
"lastOf"
116+
}
117+
fn description(&self) -> &Option<String> {
118+
&self.description
119+
}
120+
fn role(&self) -> &Option<Role> {
121+
&self.role
122+
}
123+
fn def(&self) -> &Option<String> {
124+
return &self.def;
125+
}
126+
fn defs(&self) -> &Option<HashMap<String, PdlBlock>> {
127+
&self.defs
128+
}
129+
fn items(&self) -> &Vec<PdlBlock> {
130+
&self.last_of
131+
}
132+
fn with_items(&self, items: Vec<PdlBlock>) -> Self {
133+
let mut b = self.clone();
134+
b.last_of = items;
135+
b
136+
}
137+
fn parser(&self) -> &Option<PdlParser> {
138+
&self.parser
139+
}
140+
fn to_block(&self) -> PdlBlock {
141+
PdlBlock::LastOf(self.clone())
142+
}
143+
fn result_for(&self, output_results: Vec<PdlResult>) -> PdlResult {
144+
match output_results.last() {
145+
Some(result) => result.clone(),
146+
None => "".into(),
147+
}
148+
}
149+
fn messages_for<T: Clone>(&self, output_messages: Vec<T>) -> Vec<T> {
150+
match output_messages.last() {
151+
Some(m) => vec![m.clone()],
152+
None => vec![],
153+
}
154+
}
155+
}
156+
77157
/// Create the concatenation of the stringify version of the result of
78158
/// each block of the list of blocks.
79159
#[derive(Serialize, Deserialize, Debug, Clone)]
@@ -96,6 +176,49 @@ pub struct TextBlock {
96176
#[serde(skip_serializing_if = "Option::is_none")]
97177
pub def: Option<String>,
98178
}
179+
impl SequencingBlock for TextBlock {
180+
fn kind(&self) -> &str {
181+
"text"
182+
}
183+
fn description(&self) -> &Option<String> {
184+
&self.description
185+
}
186+
fn role(&self) -> &Option<Role> {
187+
&self.role
188+
}
189+
fn def(&self) -> &Option<String> {
190+
return &self.def;
191+
}
192+
fn defs(&self) -> &Option<HashMap<String, PdlBlock>> {
193+
&self.defs
194+
}
195+
fn items(&self) -> &Vec<PdlBlock> {
196+
&self.text
197+
}
198+
fn with_items(&self, items: Vec<PdlBlock>) -> Self {
199+
let mut b = self.clone();
200+
b.text = items;
201+
b
202+
}
203+
fn parser(&self) -> &Option<PdlParser> {
204+
&self.parser
205+
}
206+
fn to_block(&self) -> PdlBlock {
207+
PdlBlock::Text(self.clone())
208+
}
209+
fn result_for(&self, output_results: Vec<PdlResult>) -> PdlResult {
210+
PdlResult::String(
211+
output_results
212+
.into_iter()
213+
.map(|m| m.to_string())
214+
.collect::<Vec<_>>()
215+
.join("\n"),
216+
)
217+
}
218+
fn messages_for<T: Clone>(&self, output_messages: Vec<T>) -> Vec<T> {
219+
output_messages
220+
}
221+
}
99222

100223
impl TextBlock {
101224
pub fn new(text: Vec<PdlBlock>) -> Self {
@@ -366,6 +489,7 @@ pub enum PdlBlock {
366489
Message(MessageBlock),
367490
Repeat(RepeatBlock),
368491
Text(TextBlock),
492+
LastOf(LastOfBlock),
369493
Model(ModelBlock),
370494
Function(FunctionBlock),
371495
PythonCode(PythonCodeBlock),
@@ -389,3 +513,44 @@ impl From<&str> for Box<PdlBlock> {
389513
Box::new(PdlBlock::String(s.into()))
390514
}
391515
}
516+
517+
pub type Scope = HashMap<String, PdlResult>;
518+
519+
#[derive(Serialize, Deserialize, Debug, Clone)]
520+
pub struct Closure {
521+
pub scope: Scope,
522+
pub function: FunctionBlock,
523+
}
524+
525+
#[derive(Serialize, Deserialize, Debug, Clone)]
526+
#[serde(untagged)]
527+
pub enum PdlResult {
528+
Number(Number),
529+
String(String),
530+
Bool(bool),
531+
Block(PdlBlock),
532+
Closure(Closure),
533+
List(Vec<PdlResult>),
534+
Dict(HashMap<String, PdlResult>),
535+
}
536+
impl ::std::fmt::Display for PdlResult {
537+
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
538+
let s = to_string(&self).unwrap(); // TODO: .map_err(|e| e.to_string())?;
539+
write!(f, "{}", s)
540+
}
541+
}
542+
impl From<&str> for PdlResult {
543+
fn from(s: &str) -> Self {
544+
PdlResult::String(s.to_string())
545+
}
546+
}
547+
impl From<String> for PdlResult {
548+
fn from(s: String) -> Self {
549+
PdlResult::String(s)
550+
}
551+
}
552+
impl From<Number> for PdlResult {
553+
fn from(n: Number) -> Self {
554+
PdlResult::Number(n)
555+
}
556+
}

pdl-live-react/src-tauri/src/pdl/interpreter.rs

Lines changed: 35 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -21,57 +21,16 @@ use ollama_rs::{
2121
Ollama,
2222
};
2323

24-
use serde::{Deserialize, Serialize};
25-
use serde_json::{from_str, to_string, Number, Value};
24+
use serde_json::{from_str, to_string, Value};
2625
use serde_norway::{from_reader, from_str as from_yaml_str};
2726

2827
use crate::pdl::ast::{
29-
ArrayBlock, CallBlock, FunctionBlock, IfBlock, ListOrString, MessageBlock, ModelBlock,
30-
ObjectBlock, PdlBlock, PdlParser, PdlUsage, PythonCodeBlock, ReadBlock, RepeatBlock, Role,
31-
StringOrBoolean, TextBlock,
28+
ArrayBlock, CallBlock, Closure, FunctionBlock, IfBlock, ListOrString, MessageBlock, ModelBlock,
29+
ObjectBlock, PdlBlock, PdlParser, PdlResult, PdlUsage, PythonCodeBlock, ReadBlock, RepeatBlock,
30+
Role, Scope, SequencingBlock, StringOrBoolean,
3231
};
3332

34-
#[derive(Serialize, Deserialize, Debug, Clone)]
35-
pub struct Closure {
36-
pub scope: Scope,
37-
pub function: FunctionBlock,
38-
}
39-
40-
#[derive(Serialize, Deserialize, Debug, Clone)]
41-
#[serde(untagged)]
42-
pub enum PdlResult {
43-
Number(Number),
44-
String(String),
45-
Bool(bool),
46-
Block(PdlBlock),
47-
Closure(Closure),
48-
List(Vec<PdlResult>),
49-
Dict(HashMap<String, PdlResult>),
50-
}
51-
impl ::std::fmt::Display for PdlResult {
52-
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
53-
let s = to_string(&self).unwrap(); // TODO: .map_err(|e| e.to_string())?;
54-
write!(f, "{}", s)
55-
}
56-
}
57-
impl From<&str> for PdlResult {
58-
fn from(s: &str) -> Self {
59-
PdlResult::String(s.to_string())
60-
}
61-
}
62-
impl From<String> for PdlResult {
63-
fn from(s: String) -> Self {
64-
PdlResult::String(s)
65-
}
66-
}
67-
impl From<Number> for PdlResult {
68-
fn from(n: Number) -> Self {
69-
PdlResult::Number(n)
70-
}
71-
}
72-
7333
type Context = Vec<ChatMessage>;
74-
type Scope = HashMap<String, PdlResult>;
7534
type PdlError = Box<dyn Error + Send + Sync>;
7635
type Interpretation = Result<(PdlResult, Context, PdlBlock), PdlError>;
7736
type InterpretationSync = Result<(PdlResult, Context, PdlBlock), Box<dyn Error>>;
@@ -146,14 +105,15 @@ impl<'a> Interpreter<'a> {
146105
PdlBlock::PythonCode(block) => self.run_python_code(block, context).await,
147106
PdlBlock::Read(block) => self.run_read(block, context).await,
148107
PdlBlock::Repeat(block) => self.run_repeat(block, context).await,
149-
PdlBlock::Text(block) => self.run_text(block, context).await,
108+
PdlBlock::LastOf(block) => self.run_sequence(block, context).await,
109+
PdlBlock::Text(block) => self.run_sequence(block, context).await,
150110
PdlBlock::Array(block) => self.run_array(block, context).await,
151111
PdlBlock::Message(block) => self.run_message(block, context).await,
152112
_ => Err(Box::from(format!("Unsupported block {:?}", program))),
153113
}?;
154114

155115
if match program {
156-
PdlBlock::Call(_) | PdlBlock::Model(_) | PdlBlock::Text(_) => false,
116+
PdlBlock::Call(_) | PdlBlock::Model(_) => false,
157117
_ => self.emit,
158118
} {
159119
println!("{}", pretty_print(&messages));
@@ -733,60 +693,58 @@ impl<'a> Interpreter<'a> {
733693
Ok(())
734694
}
735695

736-
/// Run a PdlBlock::Text
737-
async fn run_text(&mut self, block: &TextBlock, context: Context) -> Interpretation {
696+
/// Run a sequencing block (e.g. TextBlock, LastOfBlock)
697+
async fn run_sequence(
698+
&mut self,
699+
block: &impl SequencingBlock,
700+
context: Context,
701+
) -> Interpretation {
738702
if self.debug {
739-
eprintln!(
740-
"Text {:?}",
741-
block
742-
.description
743-
.clone()
744-
.unwrap_or("<no description>".to_string())
745-
);
703+
let description = if let Some(d) = block.description() {
704+
d
705+
} else {
706+
&"<no description>".to_string()
707+
};
708+
eprintln!("{} {description}", block.kind());
746709
}
747710

748711
let mut input_messages = context.clone();
749712
let mut output_results = vec![];
750713
let mut output_messages = vec![];
751714
let mut output_blocks = vec![];
752715

753-
self.process_defs(&block.defs).await?;
716+
self.process_defs(block.defs()).await?;
754717

755-
let mut iter = block.text.iter();
718+
let mut iter = block.items().iter();
756719
while let Some(block) = iter.next() {
757720
// run each element of the Text block
758721
let (this_result, this_messages, trace) =
759-
self.run(&block, input_messages.clone()).await?;
722+
self.run_quiet(&block, input_messages.clone()).await?;
760723
input_messages.extend(this_messages.clone());
761724
output_results.push(this_result);
725+
762726
output_messages.extend(this_messages);
763727
output_blocks.push(trace);
764728
}
765729
self.scope.pop();
766730

767-
let mut trace = block.clone();
768-
trace.text = output_blocks;
769-
770-
let result_string = output_results
771-
.into_iter()
772-
.map(|m| m.to_string())
773-
.collect::<Vec<_>>()
774-
.join("\n");
775-
776-
// FIXME, use output_results
731+
let trace = block.with_items(output_blocks);
777732
let result = self.def(
778-
&block.def,
779-
&PdlResult::String(result_string.clone()),
780-
&block.parser,
733+
trace.def(),
734+
&trace.result_for(output_results),
735+
trace.parser(),
781736
)?;
782-
737+
let result_messages = trace.messages_for::<ChatMessage>(output_messages);
783738
Ok((
784739
result,
785-
match &block.role {
786-
Some(role) => vec![ChatMessage::new(self.to_ollama_role(role), result_string)],
787-
None => output_messages,
740+
match block.role() {
741+
Some(role) => result_messages
742+
.into_iter()
743+
.map(|m| ChatMessage::new(self.to_ollama_role(role), m.content))
744+
.collect(),
745+
None => result_messages,
788746
},
789-
PdlBlock::Text(trace),
747+
trace.to_block(),
790748
))
791749
}
792750

pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,23 @@ mod tests {
138138
Ok(())
139139
}
140140

141+
#[test]
142+
fn last_of_parser_json() -> Result<(), Box<dyn Error>> {
143+
let json = "{\"key\":\"value\"}";
144+
let program = json!({
145+
"lastOf": [
146+
{ "def": "foo", "parser": "json", "text": [json] },
147+
"${ foo.key }"
148+
]
149+
});
150+
151+
let (_, messages, _) = run_json(program, false)?;
152+
assert_eq!(messages.len(), 1);
153+
assert_eq!(messages[0].role, MessageRole::User);
154+
assert_eq!(messages[0].content, "value");
155+
Ok(())
156+
}
157+
141158
#[test]
142159
fn text_call_function_no_args() -> Result<(), Box<dyn Error>> {
143160
let program = json!({
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
lastOf:
2+
- text:
3+
- '{"key": "value"}'
4+
parser: json
5+
def: foo
6+
- ${ foo.key }

0 commit comments

Comments
 (0)