Add support for vision in LLM engines

This commit is contained in:
2024-05-22 15:56:17 +02:00
parent 4fa347242b
commit 801a275d14
3 changed files with 47 additions and 11 deletions

View File

@@ -26,7 +26,8 @@ class AnthropicLLMEngine(BaseLLMEngine):
def generate(
self,
system_prompt: str,
chat_prompt: str,
chat_prompt: str = "",
messages: list = [],
max_tokens: int = 1024,
temperature: float = 1.0,
json_mode: bool = False,
@@ -36,9 +37,11 @@ class AnthropicLLMEngine(BaseLLMEngine):
) -> str | dict:
tries = 0
while tries < 2:
messages = [
{"role": "user", "content": chat_prompt},
]
if chat_prompt:
messages = [
{"role": "user", "content": chat_prompt},
*messages,
]
if json_mode:
# anthropic does not officially support JSON mode, but we can bias the output towards a JSON-like format
messages.append({"role": "assistant", "content": "{"})
@@ -47,7 +50,7 @@ class AnthropicLLMEngine(BaseLLMEngine):
messages=messages,
model=self.model,
top_p=top_p,
temperature=temperature,
temperature=temperature if temperature <= 1.0 else 1.0,
system=system_prompt,
)
@@ -93,3 +96,11 @@ class AnthropicLLMEngine(BaseLLMEngine):
return gr.update(value=api_key)
save.click(save_api_key, inputs=[api_key_input])
@property
def supports_vision(self) -> bool:
return (
True
if self.model in ["claude-3-opus-20240229", "claude-3-sonnet-20240229"]
else False
)

View File

@@ -4,11 +4,14 @@ from ..BaseEngine import BaseEngine
class BaseLLMEngine(BaseEngine):
supports_vision = False
@abstractmethod
def generate(
self,
system_prompt: str,
chat_prompt: str,
chat_prompt: str = "",
messages: list[dict] = [],
max_tokens: int = 512,
temperature: float = 1.0,
json_mode: bool = False,

View File

@@ -8,8 +8,8 @@ from .BaseLLMEngine import BaseLLMEngine
OPENAI_POSSIBLE_MODELS = [ # Theese shall be the openai models supporting force_json
"gpt-3.5-turbo-0125",
"gpt-4-turbo-preview",
"gpt-4-turbo",
"gpt-4o"
]
@@ -29,7 +29,8 @@ class OpenaiLLMEngine(BaseLLMEngine):
def generate(
self,
system_prompt: str,
chat_prompt: str,
chat_prompt: str = "",
messages: list = [],
max_tokens: int = 512,
temperature: float = 1.0,
json_mode: bool = False,
@@ -40,19 +41,36 @@ class OpenaiLLMEngine(BaseLLMEngine):
logging.info(
f"Generating with OpenAI model {self.model} and system prompt: \n{system_prompt} and chat prompt: \n{chat_prompt[0:100]}..."
)
if chat_prompt:
messages = [
{"role": "user", "content": chat_prompt},
*messages,
]
for i, message in enumerate(messages):
if type(message["content"]) is list:
for i, content in enumerate(message["content"]):
if content["type"] == "image":
message["content"][i] = {
"type": "image_url",
"image_url": {
"url": f"data:{content['source']['media_type']};base64,{content['source']['data']}",
},
}
messages[i] = message
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": chat_prompt},
*messages,
],
max_tokens=int(max_tokens) if max_tokens else openai._types.NOT_GIVEN,
max_tokens=int(max_tokens) if max_tokens else openai.NOT_GIVEN,
temperature=temperature,
top_p=top_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
response_format=(
{"type": "json_object"} if json_mode else openai._types.NOT_GIVEN
{"type": "json_object"} if json_mode else openai.NOT_GIVEN
),
)
return (
@@ -88,3 +106,7 @@ class OpenaiLLMEngine(BaseLLMEngine):
return gr.update(value=api_key)
save.click(save_api_key, inputs=[api_key_input])
@property
def supports_vision(self) -> bool:
return True if self.model in ["gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4o"] else False