Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow overriding execute/code role #331

Merged
merged 1 commit into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,17 +285,23 @@ impl Config {
}

pub fn set_execute_role(&mut self) -> Result<()> {
let role = Role::for_execute();
let role = self
.retrieve_role(Role::EXECUTE)
.unwrap_or_else(|_| Role::for_execute());
self.set_role_obj(role)
}

pub fn set_describe_role(&mut self) -> Result<()> {
let role = Role::for_describe();
pub fn set_describe_command_role(&mut self) -> Result<()> {
let role = self
.retrieve_role(Role::DESCRIBE_COMMAND)
.unwrap_or_else(|_| Role::for_describe_command());
self.set_role_obj(role)
}

pub fn set_code_role(&mut self) -> Result<()> {
let role = Role::for_code();
let role = self
.retrieve_role(Role::CODE)
.unwrap_or_else(|_| Role::for_code());
self.set_role_obj(role)
}

Expand Down
12 changes: 8 additions & 4 deletions src/config/role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ pub struct Role {
}

impl Role {
pub const EXECUTE: &'static str = "__execute__";
pub const DESCRIBE_COMMAND: &'static str = "__describe_command__";
pub const CODE: &'static str = "__code__";

pub fn for_execute() -> Self {
let os = detect_os();
let (shell, _, _) = detect_shell();
Expand All @@ -29,7 +33,7 @@ impl Role {
_ => "&&",
};
Self {
name: "__execute__".into(),
name: Self::EXECUTE.into(),
prompt: format!(
r#"Provide only {shell} commands for {os} without any description.
If there is a lack of details, provide most logical solution.
Expand All @@ -42,9 +46,9 @@ Do not provide markdown formatting such as ```"#
}
}

pub fn for_describe() -> Self {
pub fn for_describe_command() -> Self {
Self {
name: "__describe__".into(),
name: Self::DESCRIBE_COMMAND.into(),
prompt: r#"Provide a terse, single sentence description of the given shell command.
Describe each argument and option of the command.
Provide short responses in about 80 words.
Expand All @@ -56,7 +60,7 @@ APPLY MARKDOWN formatting when possible."#

pub fn for_code() -> Self {
Self {
name: "__code__".into(),
name: Self::CODE.into(),
prompt: r#"Provide only code as output without any description.
Provide only code in plain text format without Markdown formatting.
Do not include symbols such as ``` or ```python.
Expand Down
28 changes: 13 additions & 15 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod utils;

use crate::cli::Cli;
use crate::config::{Config, GlobalConfig};
use crate::utils::{extract_block, run_command};
use crate::utils::{extract_block, run_command, CODE_BLOCK_RE};

use anyhow::{bail, Result};
use clap::Parser;
Expand Down Expand Up @@ -60,19 +60,17 @@ fn main() -> Result<()> {
if cli.dry_run {
config.write().dry_run = true;
}
if cli.execute {
if let Some(name) = &cli.role {
config.write().set_role(name)?;
} else if cli.execute {
config.write().set_execute_role()?;
} else {
if let Some(name) = &cli.role {
config.write().set_role(name)?;
} else if cli.code {
config.write().set_code_role()?;
}
if let Some(session) = &cli.session {
config
.write()
.start_session(session.as_ref().map(|v| v.as_str()))?;
}
} else if cli.code {
config.write().set_code_role()?;
}
if let Some(session) = &cli.session {
config
.write()
.start_session(session.as_ref().map(|v| v.as_str()))?;
}
if let Some(model) = &cli.model {
config.write().set_model(model)?;
Expand Down Expand Up @@ -154,7 +152,7 @@ fn execute(config: &GlobalConfig, text: &str) -> Result<()> {
let client = init_client(config)?;
config.read().maybe_print_send_tokens(&input);
let mut eval_str = client.send_message(input.clone())?;
if eval_str.contains("```") {
if let Ok(true) = CODE_BLOCK_RE.is_match(&eval_str) {
eval_str = extract_block(&eval_str);
}
config.write().save_message(input, &eval_str)?;
Expand Down Expand Up @@ -192,7 +190,7 @@ fn execute(config: &GlobalConfig, text: &str) -> Result<()> {
}
"D" | "d" => {
if !describe {
config.write().set_describe_role()?;
config.write().set_describe_command_role()?;
}
let input = Input::from_str(&eval_str);
let abort = create_abort_signal();
Expand Down
8 changes: 6 additions & 2 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::env;
use std::process::Command;

lazy_static! {
static ref CODE_BLOCK_RE: Regex = Regex::new(r"(?ms)```\w*(.*?)```").unwrap();
pub static ref CODE_BLOCK_RE: Regex = Regex::new(r"(?ms)```\w*(.*)```").unwrap();
}

pub fn now() -> String {
Expand Down Expand Up @@ -165,7 +165,11 @@ pub fn extract_block(input: &str) -> String {
.map(|m| String::from(m.as_str()))
})
.collect();
output.trim().to_string()
if output.is_empty() {
input.trim().to_string()
} else {
output.trim().to_string()
}
}

#[cfg(test)]
Expand Down