Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: compress session automaticlly #333

Merged
merged 9 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ wrap_code: false # Whether wrap code block
auto_copy: false # Automatically copy the last output to the clipboard
keybindings: emacs # REPL keybindings. values: emacs, vi
prelude: '' # Set a default role or session (role:<name>, session:<name>)
compress_threshold: 1000 # Compress session if tokens exceed this value (valid when >=1000)

clients:
- type: openai
Expand Down Expand Up @@ -296,6 +297,7 @@ Usage: .file <file>... [-- text...]
> .set highlight false
> .set save false
> .set auto_copy true
> .set compress_threshold 1000
```

## Command
Expand Down
7 changes: 7 additions & 0 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ auto_copy: false # Automatically copy the last output to the cli
keybindings: emacs # REPL keybindings. (emacs, vi)
prelude: '' # Set a default role or session (role:<name>, session:<name>)

# Compress session if tokens exceed this value (valid when >=1000)
compress_threshold: 1000
# The prompt for summarizing session messages
summarize_prompt: 'Summarize the discussion briefly in 200 words or less to use as a prompt for future context.'
# The prompt for the summary of the session
summary_prompt: 'This is a summary of the chat history as a recap: '

# Custom REPL prompt, see https://github.com/sigoden/aichat/wiki/Custom-REPL-Prompt
left_prompt: '{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} '
right_prompt: '{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}'
Expand Down
2 changes: 1 addition & 1 deletion src/client/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ impl Message {
}
}

#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageRole {
System,
Expand Down
66 changes: 58 additions & 8 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ pub struct Config {
pub keybindings: Keybindings,
/// Set a default role or session (role:<name>, session:<name>)
pub prelude: String,
/// Compress session if tokens exceed this value (>=1000)
pub compress_threshold: usize,
/// The prompt for summarizing session messages
pub summarize_prompt: String,
// The prompt for the summary of the session
pub summary_prompt: String,
/// REPL left prompt
pub left_prompt: String,
/// REPL right prompt
Expand Down Expand Up @@ -104,6 +110,9 @@ impl Default for Config {
auto_copy: false,
keybindings: Default::default(),
prelude: String::new(),
compress_threshold: 2000,
summarize_prompt: "Summarize the discussion briefly in 200 words or less to use as a prompt for future context.".to_string(),
summary_prompt: "This is a summary of the chat history as a recap: ".into(),
left_prompt: "{color.green}{?session {session}{?role /}}{role}{color.cyan}{?session )}{!session >}{color.reset} ".to_string(),
right_prompt: "{color.purple}{?session {?consume_tokens {consume_tokens}({consume_percent}%)}{!consume_tokens {consume_tokens}}}{color.reset}"
.to_string(),
Expand Down Expand Up @@ -345,12 +354,18 @@ impl Config {
self.temperature
}

pub fn set_temperature(&mut self, value: Option<f64>) -> Result<()> {
pub fn set_temperature(&mut self, value: Option<f64>) {
self.temperature = value;
if let Some(session) = self.session.as_mut() {
session.set_temperature(value);
}
Ok(())
}

pub fn set_compress_threshold(&mut self, value: usize) {
self.compress_threshold = value;
if let Some(session) = self.session.as_mut() {
session.set_compress_threshold(value);
}
}

pub fn echo_messages(&self, input: &Input) -> String {
Expand Down Expand Up @@ -430,6 +445,7 @@ impl Config {
("auto_copy", self.auto_copy.to_string()),
("keybindings", self.keybindings.stringify().into()),
("prelude", prelude),
("compress_threshold", self.compress_threshold.to_string()),
("config_file", display_path(&Self::config_file()?)),
("roles_file", display_path(&Self::roles_file()?)),
("messages_file", display_path(&Self::messages_file()?)),
Expand All @@ -445,7 +461,7 @@ impl Config {

pub fn role_info(&self) -> Result<String> {
if let Some(role) = &self.role {
role.info()
role.export()
} else {
bail!("No role")
}
Expand All @@ -455,7 +471,7 @@ impl Config {
if let Some(session) = &self.session {
let render_options = self.get_render_options()?;
let mut markdown_render = MarkdownRender::init(render_options)?;
session.render(&mut markdown_render)
session.info(&mut markdown_render)
} else {
bail!("No session")
}
Expand All @@ -465,7 +481,7 @@ impl Config {
if let Some(session) = &self.session {
session.export()
} else if let Some(role) = &self.role {
role.info()
role.export()
} else {
self.sys_info()
}
Expand All @@ -486,6 +502,7 @@ impl Config {
".session" => self.list_sessions(),
".set" => vec![
"temperature ",
"compress_threshold",
"save ",
"highlight ",
"dry_run ",
Expand Down Expand Up @@ -532,7 +549,11 @@ impl Config {
let value = value.parse().with_context(|| "Invalid value")?;
Some(value)
};
self.set_temperature(value)?;
self.set_temperature(value);
}
"compress_threshold" => {
let value = value.parse().with_context(|| "Invalid value")?;
self.set_compress_threshold(value);
}
"save" => {
let value = value.parse().with_context(|| "Invalid value")?;
Expand Down Expand Up @@ -608,7 +629,7 @@ impl Config {
if let Some(mut session) = self.session.take() {
self.last_message = None;
self.temperature = self.default_temperature;
if session.should_save() {
if session.dirty {
let ans = Confirm::new("Save session?").with_default(false).prompt()?;
if !ans {
return Ok(());
Expand All @@ -634,7 +655,7 @@ impl Config {

pub fn clear_session_messages(&mut self) -> Result<()> {
if let Some(session) = self.session.as_mut() {
session.clear_messgaes();
session.clear_messages();
}
Ok(())
}
Expand All @@ -660,6 +681,35 @@ impl Config {
}
}

pub fn should_compress_session(&mut self) -> bool {
if let Some(sesion) = self.session.as_mut() {
if sesion.need_compress(self.compress_threshold) {
sesion.compressing = true;
return true;
}
}
false
}

pub fn compress_session(&mut self, summary: &str) {
if let Some(session) = self.session.as_mut() {
session.compress(format!("{}{}", self.summary_prompt, summary));
}
}

pub fn is_compressing_session(&self) -> bool {
self.session
.as_ref()
.map(|v| v.compressing)
.unwrap_or_default()
}

pub fn end_compressing_session(&mut self) {
if let Some(session) = self.session.as_mut() {
session.compressing = false;
}
}

pub fn get_render_options(&self) -> Result<RenderOptions> {
let theme = if self.highlight {
let theme_mode = if self.light_theme { "light" } else { "dark" };
Expand Down
2 changes: 1 addition & 1 deletion src/config/role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ For example if the prompt is "Hello world Python", you should return "print('Hel
}
}

pub fn info(&self) -> Result<String> {
pub fn export(&self) -> Result<String> {
let output = serde_yaml::to_string(&self)
.with_context(|| format!("Unable to show info about role {}", &self.name))?;
Ok(output.trim_end().to_string())
Expand Down
57 changes: 44 additions & 13 deletions src/config/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@ pub struct Session {
messages: Vec<Message>,
#[serde(default)]
data_urls: HashMap<String, String>,
#[serde(default)]
compressed_messages: Vec<Message>,
compress_threshold: Option<usize>,
#[serde(skip)]
pub name: String,
#[serde(skip)]
pub path: Option<String>,
#[serde(skip)]
pub dirty: bool,
#[serde(skip)]
pub compressing: bool,
#[serde(skip)]
pub role: Option<Role>,
#[serde(skip)]
pub model: Model,
Expand All @@ -41,10 +46,13 @@ impl Session {
model_id: model.id(),
temperature,
messages: vec![],
compressed_messages: vec![],
compress_threshold: None,
data_urls: Default::default(),
name: name.to_string(),
path: None,
dirty: false,
compressing: false,
role,
model,
}
Expand Down Expand Up @@ -74,6 +82,13 @@ impl Session {
self.temperature
}

pub fn need_compress(&self, current_compress_threshold: usize) -> bool {
let threshold = self
.compress_threshold
.unwrap_or(current_compress_threshold);
threshold >= 1000 && self.tokens() > threshold
}

pub fn tokens(&self) -> usize {
self.model.total_tokens(&self.messages)
}
Expand Down Expand Up @@ -106,7 +121,7 @@ impl Session {
Ok(output)
}

pub fn render(&self, render: &mut MarkdownRender) -> Result<String> {
pub fn info(&self, render: &mut MarkdownRender) -> Result<String> {
let mut items = vec![];

if let Some(path) = &self.path {
Expand All @@ -119,6 +134,10 @@ impl Session {
items.push(("temperature", temperature.to_string()));
}

if let Some(compress_threshold) = self.compress_threshold {
items.push(("compress_threshold", compress_threshold.to_string()));
}

if let Some(max_tokens) = self.model.max_tokens {
items.push(("max_tokens", max_tokens.to_string()));
}
Expand All @@ -135,7 +154,7 @@ impl Session {
for message in &self.messages {
match message.role {
MessageRole::System => {
continue;
lines.push(render.render(&message.content.render_input(resolve_url_fn)));
}
MessageRole::Assistant => {
if let MessageContent::Text(text) = &message.content {
Expand Down Expand Up @@ -181,14 +200,28 @@ impl Session {
self.temperature = value;
}

pub fn set_compress_threshold(&mut self, value: usize) {
self.compress_threshold = Some(value);
}

pub fn set_model(&mut self, model: Model) -> Result<()> {
self.model_id = model.id();
self.model = model;
Ok(())
}

pub fn compress(&mut self, prompt: String) {
self.compressed_messages.append(&mut self.messages);
self.messages.push(Message {
role: MessageRole::System,
content: MessageContent::Text(prompt),
});
self.role = None;
self.dirty = true;
}

pub fn save(&mut self, session_path: &Path) -> Result<()> {
if !self.should_save() {
if !self.dirty {
return Ok(());
}
self.path = Some(session_path.display().to_string());
Expand All @@ -208,10 +241,6 @@ impl Session {
Ok(())
}

pub fn should_save(&self) -> bool {
!self.is_empty() && self.dirty
}

pub fn guard_save(&self) -> Result<()> {
if self.path.is_none() {
bail!("Not found session '{}'", self.name)
Expand Down Expand Up @@ -258,11 +287,9 @@ impl Session {
Ok(())
}

pub fn clear_messgaes(&mut self) {
if self.messages.is_empty() {
return;
}
pub fn clear_messages(&mut self) {
self.messages.clear();
self.compressed_messages.clear();
self.data_urls.clear();
self.dirty = true;
}
Expand All @@ -275,12 +302,16 @@ impl Session {
pub fn build_emssages(&self, input: &Input) -> Vec<Message> {
let mut messages = self.messages.clone();
let mut need_add_msg = true;
if messages.is_empty() {
let len = messages.len();
if len == 0 {
if let Some(role) = self.role.as_ref() {
messages = role.build_messages(input);
need_add_msg = false;
}
};
} else if len == 1 && self.compressed_messages.len() >= 2 {
messages
.extend(self.compressed_messages[self.compressed_messages.len() - 2..].to_vec());
}
if need_add_msg {
messages.push(Message {
role: MessageRole::User,
Expand Down
20 changes: 20 additions & 0 deletions src/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ impl Repl {
if text.is_empty() && files.is_empty() {
return Ok(());
}
while self.config.read().is_compressing_session() {
std::thread::sleep(std::time::Duration::from_millis(100));
}
let input = if files.is_empty() {
Input::from_str(text)
} else {
Expand All @@ -269,6 +272,14 @@ impl Repl {
let output = render_stream(&input, client.as_ref(), &self.config, self.abort.clone())?;
self.config.write().save_message(input, &output)?;
self.config.read().maybe_copy(&output);
if self.config.write().should_compress_session() {
let config = self.config.clone();
std::thread::spawn(move || -> anyhow::Result<()> {
let _ = compress_session(&config);
config.write().end_compressing_session();
Ok(())
});
}
Ok(())
}

Expand Down Expand Up @@ -418,6 +429,15 @@ fn parse_command(line: &str) -> Option<(&str, Option<&str>)> {
}
}

fn compress_session(config: &GlobalConfig) -> Result<()> {
let input = Input::from_str(&config.read().summarize_prompt);
let mut client = init_client(config)?;
ensure_model_capabilities(client.as_mut(), input.required_capabilities())?;
let summary = client.send_message(input)?;
config.write().compress_session(&summary);
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down