mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 01:06:19 +00:00
✨ Add support for vision in LLM engines
This commit is contained in:
@@ -26,7 +26,8 @@ class AnthropicLLMEngine(BaseLLMEngine):
|
||||
def generate(
|
||||
self,
|
||||
system_prompt: str,
|
||||
chat_prompt: str,
|
||||
chat_prompt: str = "",
|
||||
messages: list = [],
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 1.0,
|
||||
json_mode: bool = False,
|
||||
@@ -36,9 +37,11 @@ class AnthropicLLMEngine(BaseLLMEngine):
|
||||
) -> str | dict:
|
||||
tries = 0
|
||||
while tries < 2:
|
||||
messages = [
|
||||
{"role": "user", "content": chat_prompt},
|
||||
]
|
||||
if chat_prompt:
|
||||
messages = [
|
||||
{"role": "user", "content": chat_prompt},
|
||||
*messages,
|
||||
]
|
||||
if json_mode:
|
||||
# anthropic does not officially support JSON mode, but we can bias the output towards a JSON-like format
|
||||
messages.append({"role": "assistant", "content": "{"})
|
||||
@@ -47,7 +50,7 @@ class AnthropicLLMEngine(BaseLLMEngine):
|
||||
messages=messages,
|
||||
model=self.model,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
temperature=temperature if temperature <= 1.0 else 1.0,
|
||||
system=system_prompt,
|
||||
)
|
||||
|
||||
@@ -93,3 +96,11 @@ class AnthropicLLMEngine(BaseLLMEngine):
|
||||
return gr.update(value=api_key)
|
||||
|
||||
save.click(save_api_key, inputs=[api_key_input])
|
||||
|
||||
@property
|
||||
def supports_vision(self) -> bool:
|
||||
return (
|
||||
True
|
||||
if self.model in ["claude-3-opus-20240229", "claude-3-sonnet-20240229"]
|
||||
else False
|
||||
)
|
||||
|
||||
@@ -4,11 +4,14 @@ from ..BaseEngine import BaseEngine
|
||||
|
||||
|
||||
class BaseLLMEngine(BaseEngine):
|
||||
supports_vision = False
|
||||
|
||||
@abstractmethod
|
||||
def generate(
|
||||
self,
|
||||
system_prompt: str,
|
||||
chat_prompt: str,
|
||||
chat_prompt: str = "",
|
||||
messages: list[dict] = [],
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 1.0,
|
||||
json_mode: bool = False,
|
||||
|
||||
@@ -8,8 +8,8 @@ from .BaseLLMEngine import BaseLLMEngine
|
||||
|
||||
OPENAI_POSSIBLE_MODELS = [ # Theese shall be the openai models supporting force_json
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4o"
|
||||
]
|
||||
|
||||
|
||||
@@ -29,7 +29,8 @@ class OpenaiLLMEngine(BaseLLMEngine):
|
||||
def generate(
|
||||
self,
|
||||
system_prompt: str,
|
||||
chat_prompt: str,
|
||||
chat_prompt: str = "",
|
||||
messages: list = [],
|
||||
max_tokens: int = 512,
|
||||
temperature: float = 1.0,
|
||||
json_mode: bool = False,
|
||||
@@ -40,19 +41,36 @@ class OpenaiLLMEngine(BaseLLMEngine):
|
||||
logging.info(
|
||||
f"Generating with OpenAI model {self.model} and system prompt: \n{system_prompt} and chat prompt: \n{chat_prompt[0:100]}..."
|
||||
)
|
||||
if chat_prompt:
|
||||
messages = [
|
||||
{"role": "user", "content": chat_prompt},
|
||||
*messages,
|
||||
]
|
||||
for i, message in enumerate(messages):
|
||||
if type(message["content"]) is list:
|
||||
for i, content in enumerate(message["content"]):
|
||||
if content["type"] == "image":
|
||||
message["content"][i] = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{content['source']['media_type']};base64,{content['source']['data']}",
|
||||
},
|
||||
}
|
||||
messages[i] = message
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": chat_prompt},
|
||||
*messages,
|
||||
],
|
||||
max_tokens=int(max_tokens) if max_tokens else openai._types.NOT_GIVEN,
|
||||
max_tokens=int(max_tokens) if max_tokens else openai.NOT_GIVEN,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=(
|
||||
{"type": "json_object"} if json_mode else openai._types.NOT_GIVEN
|
||||
{"type": "json_object"} if json_mode else openai.NOT_GIVEN
|
||||
),
|
||||
)
|
||||
return (
|
||||
@@ -88,3 +106,7 @@ class OpenaiLLMEngine(BaseLLMEngine):
|
||||
return gr.update(value=api_key)
|
||||
|
||||
save.click(save_api_key, inputs=[api_key_input])
|
||||
|
||||
@property
|
||||
def supports_vision(self) -> bool:
|
||||
return True if self.model in ["gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4o"] else False
|
||||
|
||||
Reference in New Issue
Block a user