Skip to content

Commit

Permalink
add cache
Browse files Browse the repository at this point in the history
  • Loading branch information
dh1011 committed Oct 30, 2024
1 parent 36c55ca commit e871bb4
Showing 1 changed file with 104 additions and 18 deletions.
122 changes: 104 additions & 18 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod shell;
mod model;

use std::collections::HashMap;
use std::io::{self, Write};
use std::fs;
use std::process::Command as ProcessCommand;
Expand All @@ -17,7 +18,6 @@ struct Config {
max_tokens: i32
}


fn main() -> Result<(), Box<dyn std::error::Error>> {
let matches = Command::new("llm-term")
.version("1.0")
Expand All @@ -32,6 +32,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.long("config")
.help("Run configuration setup")
.action(clap::ArgAction::SetTrue))
.arg(
Arg::new("disable-cache")
.long("disable-cache")
.help("Disable cache and always query the LLM")
.action(clap::ArgAction::SetTrue),
)
.get_matches();

let config_path = get_default_config_path().expect("Failed to get default config path");
Expand All @@ -46,32 +52,46 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {

let config = load_or_create_config(&config_path)?;

let cache_path = get_cache_path()?;
let mut cache = load_cache(&cache_path)?;

if let Some(prompt) = matches.get_one::<String>("prompt") {
match &config.model.llm_get_command(&config, prompt.as_str()) {
Ok(Some(command)) => {
println!("{}", &command.cyan().bold());
let disable_cache = matches.get_flag("disable-cache");

if !disable_cache {
if let Some(cached_command) = cache.get(prompt) {
println!("{}", "This command exists in cache".yellow());
println!("{}", cached_command.cyan().bold());
println!("{}", "Do you want to execute this command? (y/n)".yellow());

let mut user_input = String::new();
io::stdin().read_line(&mut user_input)?;

if user_input.trim().to_lowercase() == "y" {
let (shell_cmd, shell_arg) = Shell::detect().to_shell_command_and_command_arg();

match ProcessCommand::new(shell_cmd).arg(shell_arg).arg(&command).output() {
Ok(output) => {
println!("{}", "Command output:".green().bold());
io::stdout().write_all(&output.stdout)?;
io::stderr().write_all(&output.stderr)?;
}
Err(e) => eprintln!("{}", format!("Failed to execute command: {}", e).red()),
}
execute_command(cached_command)?;
} else {
println!("{}", "Command execution cancelled.".yellow());
println!("{}", "Do you want to invalidate the cache? (y/n)".yellow());
user_input.clear();
io::stdin().read_line(&mut user_input)?;

if user_input.trim().to_lowercase() == "y" {
// Invalidate cache
cache.remove(prompt);
save_cache(&cache_path, &cache)?;
// Proceed to get command from LLM
get_command_from_llm(&config, &mut cache, &cache_path, prompt)?;
} else {
println!("{}", "Command execution cancelled.".yellow());
}
}
},
Ok(None) => println!("{}", "No command could be generated.".yellow()),
Err(e) => eprintln!("{}", format!("Error: {}", e).red()),
return Ok(());
} else {
// Not in cache, proceed to get command from LLM
get_command_from_llm(&config, &mut cache, &cache_path, prompt)?;
}
} else {
// Cache is disabled, proceed to get command from LLM
get_command_from_llm(&config, &mut cache, &cache_path, prompt)?;
}
} else {
println!("{}", "Please provide a prompt or use --config to set up the configuration.".yellow());
Expand Down Expand Up @@ -129,4 +149,70 @@ fn create_config() -> Result<Config, io::Error> {
model,
max_tokens,
})
}

fn get_cache_path() -> Result<PathBuf, Box<dyn std::error::Error>> {
let exe_path = std::env::current_exe()?;
let exe_dir = exe_path.parent().ok_or("Failed to get executable directory")?;
Ok(exe_dir.join("cache.json"))
}

fn load_cache(path: &PathBuf) -> Result<HashMap<String, String>, Box<dyn std::error::Error>> {
if let Ok(content) = fs::read_to_string(path) {
Ok(serde_json::from_str(&content)?)
} else {
Ok(HashMap::new())
}
}

fn save_cache(path: &PathBuf, cache: &HashMap<String, String>) -> Result<(), Box<dyn std::error::Error>> {
let content = serde_json::to_string_pretty(&cache)?;
fs::write(path, content)?;
Ok(())
}

fn get_command_from_llm(
config: &Config,
cache: &mut HashMap<String, String>,
cache_path: &PathBuf,
prompt: &String,
) -> Result<(), Box<dyn std::error::Error>> {
match &config.model.llm_get_command(config, prompt.as_str()) {
Ok(Some(command)) => {
println!("{}", &command.cyan().bold());
println!("{}", "Do you want to execute this command? (y/n)".yellow());

let mut user_input = String::new();
io::stdin().read_line(&mut user_input)?;

if user_input.trim().to_lowercase() == "y" {
execute_command(&command)?;
} else {
println!("{}", "Command execution cancelled.".yellow());
}

// Save command to cache
cache.insert(prompt.clone(), command.clone());
save_cache(cache_path, cache)?;
},
Ok(None) => println!("{}", "No command could be generated.".yellow()),
Err(e) => eprintln!("{}", format!("Error: {}", e).red()),
}

Ok(())
}

fn execute_command(command: &str) -> Result<(), Box<dyn std::error::Error>> {
let (shell_cmd, shell_arg) = Shell::detect().to_shell_command_and_command_arg();

match ProcessCommand::new(shell_cmd).arg(shell_arg).arg(&command).output() {
Ok(output) => {
println!("{}", "Command output:".green().bold());
io::stdout().write_all(&output.stdout)?;
io::stderr().write_all(&output.stderr)?;
}
Err(e) => eprintln!("{}", format!("Failed to execute command: {}", e).red()),
}

Ok(())
}

0 comments on commit e871bb4

Please sign in to comment.