From 801a275d14d798bc4a059c340b25722f8660dfa3 Mon Sep 17 00:00:00 2001 From: Paillat Date: Wed, 22 May 2024 15:56:17 +0200 Subject: [PATCH] :sparkles: Add support for vision in LLM engines --- src/engines/LLMEngine/AnthropicLLMEngine.py | 21 ++++++++++---- src/engines/LLMEngine/BaseLLMEngine.py | 5 +++- src/engines/LLMEngine/OpenaiLLMEngine.py | 32 +++++++++++++++++---- 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/src/engines/LLMEngine/AnthropicLLMEngine.py b/src/engines/LLMEngine/AnthropicLLMEngine.py index 4d94d14..4a0bbe6 100644 --- a/src/engines/LLMEngine/AnthropicLLMEngine.py +++ b/src/engines/LLMEngine/AnthropicLLMEngine.py @@ -26,7 +26,8 @@ class AnthropicLLMEngine(BaseLLMEngine): def generate( self, system_prompt: str, - chat_prompt: str, + chat_prompt: str = "", + messages: list = [], max_tokens: int = 1024, temperature: float = 1.0, json_mode: bool = False, @@ -36,9 +37,11 @@ class AnthropicLLMEngine(BaseLLMEngine): ) -> str | dict: tries = 0 while tries < 2: - messages = [ - {"role": "user", "content": chat_prompt}, - ] + if chat_prompt: + messages = [ + {"role": "user", "content": chat_prompt}, + *messages, + ] if json_mode: # anthropic does not officially support JSON mode, but we can bias the output towards a JSON-like format messages.append({"role": "assistant", "content": "{"}) @@ -47,7 +50,7 @@ class AnthropicLLMEngine(BaseLLMEngine): messages=messages, model=self.model, top_p=top_p, - temperature=temperature, + temperature=temperature if temperature <= 1.0 else 1.0, system=system_prompt, ) @@ -93,3 +96,11 @@ class AnthropicLLMEngine(BaseLLMEngine): return gr.update(value=api_key) save.click(save_api_key, inputs=[api_key_input]) + + @property + def supports_vision(self) -> bool: + return ( + True + if self.model in ["claude-3-opus-20240229", "claude-3-sonnet-20240229"] + else False + ) diff --git a/src/engines/LLMEngine/BaseLLMEngine.py b/src/engines/LLMEngine/BaseLLMEngine.py index 2816b89..e5f2f1b 100644 --- a/src/engines/LLMEngine/BaseLLMEngine.py +++ b/src/engines/LLMEngine/BaseLLMEngine.py @@ -4,11 +4,14 @@ from ..BaseEngine import BaseEngine class BaseLLMEngine(BaseEngine): + supports_vision = False + @abstractmethod def generate( self, system_prompt: str, - chat_prompt: str, + chat_prompt: str = "", + messages: list[dict] = [], max_tokens: int = 512, temperature: float = 1.0, json_mode: bool = False, diff --git a/src/engines/LLMEngine/OpenaiLLMEngine.py b/src/engines/LLMEngine/OpenaiLLMEngine.py index be8b8fb..6a198ee 100644 --- a/src/engines/LLMEngine/OpenaiLLMEngine.py +++ b/src/engines/LLMEngine/OpenaiLLMEngine.py @@ -8,8 +8,8 @@ from .BaseLLMEngine import BaseLLMEngine OPENAI_POSSIBLE_MODELS = [ # Theese shall be the openai models supporting force_json "gpt-3.5-turbo-0125", - "gpt-4-turbo-preview", "gpt-4-turbo", + "gpt-4o" ] @@ -29,7 +29,8 @@ class OpenaiLLMEngine(BaseLLMEngine): def generate( self, system_prompt: str, - chat_prompt: str, + chat_prompt: str = "", + messages: list = [], max_tokens: int = 512, temperature: float = 1.0, json_mode: bool = False, @@ -40,19 +41,36 @@ class OpenaiLLMEngine(BaseLLMEngine): logging.info( f"Generating with OpenAI model {self.model} and system prompt: \n{system_prompt} and chat prompt: \n{chat_prompt[0:100]}..." ) + if chat_prompt: + messages = [ + {"role": "user", "content": chat_prompt}, + *messages, + ] + for i, message in enumerate(messages): + if type(message["content"]) is list: + for i, content in enumerate(message["content"]): + if content["type"] == "image": + message["content"][i] = { + "type": "image_url", + "image_url": { + "url": f"data:{content['source']['media_type']};base64,{content['source']['data']}", + }, + } + messages[i] = message + response = self.client.chat.completions.create( model=self.model, messages=[ {"role": "system", "content": system_prompt}, - {"role": "user", "content": chat_prompt}, + *messages, ], - max_tokens=int(max_tokens) if max_tokens else openai._types.NOT_GIVEN, + max_tokens=int(max_tokens) if max_tokens else openai.NOT_GIVEN, temperature=temperature, top_p=top_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, response_format=( - {"type": "json_object"} if json_mode else openai._types.NOT_GIVEN + {"type": "json_object"} if json_mode else openai.NOT_GIVEN ), ) return ( @@ -88,3 +106,7 @@ class OpenaiLLMEngine(BaseLLMEngine): return gr.update(value=api_key) save.click(save_api_key, inputs=[api_key_input]) + + @property + def supports_vision(self) -> bool: + return True if self.model in ["gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4o"] else False