Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate PALM 2 LLM into SmartGPT #44

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/auto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ pub struct ParsedResponse<T> {
raw: String
}

#[allow(dead_code)]
pub fn try_parse_yaml<T : DeserializeOwned>(llm: &LLM, tries: usize, max_tokens: Option<u16>) -> Result<ParsedResponse<T>, Box<dyn Error>> {
try_parse_base(llm, tries, max_tokens, "yml", |str| serde_yaml::from_str(str).map_err(|el| Box::new(el) as Box<dyn Error>))
}
Expand Down
5 changes: 3 additions & 2 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use colored::Colorize;
use serde::{Serialize, Deserialize};
use serde_json::Value;

use crate::{CommandContext, LLM, Plugin, create_browse, create_google, create_filesystem, create_shutdown, create_wolfram, create_chatgpt, create_news, create_wikipedia, create_none, LLMProvider, create_model_chatgpt, Agents, LLMModel, create_model_llama, AgentInfo, MemoryProvider, create_memory_local, create_memory_qdrant, MemorySystem, create_memory_redis};
use crate::{CommandContext, LLM, Plugin, create_browse, create_google, create_filesystem, create_shutdown, create_wolfram, create_chatgpt, create_news, create_wikipedia, create_none, LLMProvider, create_model_chatgpt, Agents, LLMModel, create_model_llama, AgentInfo, MemoryProvider, create_memory_local, create_memory_qdrant, MemorySystem, create_memory_redis, create_model_palm2};

mod default;
pub use default::*;
Expand Down Expand Up @@ -102,7 +102,8 @@ pub fn list_plugins() -> Vec<Plugin> {
pub fn create_llm_providers() -> Vec<Box<dyn LLMProvider>> {
vec![
create_model_chatgpt(),
create_model_llama()
create_model_llama(),
create_model_palm2()
]
}

Expand Down
2 changes: 2 additions & 0 deletions src/llms/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
mod chatgpt;
mod palm2;
mod local;

pub use chatgpt::*;
pub use local::*;
pub use palm2::*;
use tokio::runtime::Runtime;

use std::{error::Error, fmt::Display};
Expand Down
70 changes: 70 additions & 0 deletions src/llms/palm2/api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use reqwest::{Client};
use std::error::Error;

use crate::{CountTokensRequest, TokenCountResponse, EmbedTextRequest, EmbeddingResponse, Embedding, MessagePrompt, GenerateMessageResponse, GenerateTextRequest, GenerateTextResponse, GCPModel, ListModelResponse};

pub struct ApiClient {
client: Client,
base_url: String,
api_key: String
}

impl ApiClient {
pub fn new(base_url: String, api_key: String) -> Self {
Self {
client: Client::new(),
base_url,
api_key
}
}

pub async fn count_message_tokens(&self, model: &str, message: CountTokensRequest) -> Result<TokenCountResponse, Box<dyn Error>> {
let url = format!("{}/v1beta2/models/{}:countMessageTokens?key={}", self.base_url, model, self.api_key);
let response = self.client.post(&url).json(&message).send().await?;

let token_count: TokenCountResponse = response.json().await?;
Ok(token_count)
}

pub async fn embed_text(&self, model: &str, message: EmbedTextRequest) -> Result<Vec<f32>, Box<dyn Error>> {
let url = format!("{}/v1beta2/models/{}:embedText?key={}", self.base_url, model, self.api_key);
let response = self.client.post(&url).json(&message).send().await?;

let embedding: EmbeddingResponse = response.json().await?;
Ok(embedding.embedding.unwrap_or(Embedding {
value: vec![]
}).value)
}

pub async fn generate_message(&self, model: &str, prompt: MessagePrompt) -> Result<GenerateMessageResponse, Box<dyn Error>> {
let url = format!("{}/v1beta2/models/{}:generateMessage?key={}", self.base_url, model, self.api_key);
let response = self.client.post(&url).json(&prompt).send().await?;

let message: GenerateMessageResponse = response.json().await?;
Ok(message)
}

pub async fn generate_text(&self, model: &str, message: GenerateTextRequest) -> Result<GenerateTextResponse, Box<dyn Error>> {
let url = format!("{}/v1beta2/models/{}:generateText?key={}", self.base_url, model, self.api_key);
let response = self.client.post(&url).json(&message).send().await?;

let text_response: GenerateTextResponse = response.json().await?;
Ok(text_response)
}

pub async fn get_model(&self, name: &str) -> Result<GCPModel, Box<dyn Error>> {
let url = format!("{}/v1beta2/models/{}?key={}", self.base_url, name, self.api_key);
let response = self.client.get(&url).send().await?;

let model: GCPModel = response.json().await?;
Ok(model)
}

pub async fn list_models(&self) -> Result<ListModelResponse, Box<dyn Error>> {
let url = format!("{}/v1beta2/models?key={}", self.base_url, self.api_key);
let response = self.client.get(&url).send().await?;

let models: ListModelResponse = response.json().await?;
Ok(models)
}
}
7 changes: 7 additions & 0 deletions src/llms/palm2/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod system;
mod api;
mod types;

pub use system::*;
pub use api::*;
pub use types::*;
139 changes: 139 additions & 0 deletions src/llms/palm2/system.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use std::{error::Error, thread::sleep, time::Duration};

use async_trait::async_trait;
use serde::{Serialize, Deserialize};
use serde_json::Value;
use tokio::runtime::Runtime;

use crate::{LLMProvider, Message, LLMModel, ApiClient, MessagePrompt, PALMMessage, GenerateTextRequest, TextPrompt, EmbedTextRequest, CountTokensRequest, GenerateTextResponse};

pub struct PALM2 {
pub model: String,
pub embedding_model: String,
pub client: ApiClient
}

#[async_trait]
impl LLMModel for PALM2 {
async fn get_response(&self, messages: &[Message], max_tokens: Option<u16>, temperature: Option<f32>) -> Result<String, Box<dyn Error>> {
let palm_messages_string = messages
.iter()
.map(|el| el.content())
.collect::<Vec<&str>>()
.join("\n");

let text_request = GenerateTextRequest {
prompt: TextPrompt {
text: palm_messages_string
},
safety_settings: vec![],
stop_sequences: vec![],
temperature: temperature.unwrap_or(1.0) as f64,
candidate_count: 1,
max_output_tokens: max_tokens.unwrap_or(1000) as i32,
top_p: 0.95,
top_k: 40,
};

let response_message: GenerateTextResponse = self.client.generate_text(&self.model, text_request).await?;

let response = response_message.candidates.unwrap_or(vec![]);

let response_text = response
.iter()
.map(|el| el.output.clone())
.collect::<Vec<String>>()
.join(" ");

Ok(response_text)
}

async fn get_base_embed(&self, text: &str) -> Result<Vec<f32>, Box<dyn Error>> {
let embedding_response = self.client.embed_text(&self.embedding_model,
EmbedTextRequest {
text: text.to_string()
}).await?;

Ok(embedding_response)
}

fn get_tokens_remaining(&self, messages: &[Message]) -> Result<usize, Box<dyn Error>> {
let all_messages: Vec<PALMMessage> = messages.iter().map(|el| PALMMessage {
author: None,
content: el.content().to_string(),
citation_metadata: None
}
).collect::<Vec<PALMMessage>>();

let count_tokens_request = CountTokensRequest {
prompt: MessagePrompt { context: "".to_string(), examples: vec![], messages: all_messages }
};

let runtime = tokio::runtime::Runtime::new()?;

let gcp_model = runtime.block_on(self.client.get_model(&self.model))?;
let token_count = runtime.block_on(self.client.count_message_tokens(&self.model, count_tokens_request))?;
let max_tokens = gcp_model.input_token_limit;

let tokens_remaining = max_tokens.checked_sub(token_count.token_count.unwrap_or(0) as i32)
.ok_or_else(|| "Token count exceeded the maximum limit.")?;

Ok(tokens_remaining as usize)
}
}

#[derive(Serialize, Deserialize)]
pub struct PALM2Config {
#[serde(rename = "api key")] pub api_key: String,
pub model: Option<String>,
#[serde(rename = "api base")] pub api_base: Option<String>,
#[serde(rename = "embedding model")] pub embedding_model: Option<String>,
}

pub struct PALM2Provider;

#[async_trait]
impl LLMProvider for PALM2Provider {
fn is_enabled(&self) -> bool {
true
}

fn get_name(&self) -> &str {
"palm2"
}

fn create(&self, value: Value) -> Result<Box<dyn LLMModel>, Box<dyn Error>> {
let rt = Runtime::new().expect("Failed to create Tokio runtime");

let config: PALM2Config = serde_json::from_value(value)?;

let client = ApiClient::new(config.api_base.unwrap_or("https://generativelanguage.googleapis.com".to_owned()), config.api_key);

let all_messages: Vec<PALMMessage> = vec![PALMMessage {
author: None,
content: "count my tokens please palm".to_string(),
citation_metadata: None
}];

// Easy way to immediately test api call on startup
let models_response = rt.block_on(async {
client.count_message_tokens("text-bison-001", CountTokensRequest {
prompt: MessagePrompt { context: "".to_string(), examples: vec![], messages: all_messages }
}).await
})?;

println!("model: {:?}", models_response);
sleep(Duration::new(10, 0));

Ok(Box::new(PALM2 {
model: config.model.unwrap_or("text-bison-001".to_string()),
embedding_model: config.embedding_model.unwrap_or("embedding-gecko-001".to_string()),
client
}))
}
}

pub fn create_model_palm2() -> Box<dyn LLMProvider> {
Box::new(PALM2Provider)
}

Loading