mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 09:16:19 +00:00
Some changes
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import gradio as gr
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
import orjson
|
||||
|
||||
from .BaseLLMEngine import BaseLLMEngine
|
||||
@@ -17,6 +18,10 @@ class OpenaiLLMEngine(BaseLLMEngine):
|
||||
|
||||
def __init__(self, options: list) -> None:
|
||||
self.model = options[0]
|
||||
api_key = self.retrieve_setting(identifier="openai_api_key")
|
||||
if not api_key:
|
||||
raise ValueError("OpenAI API key is not set.")
|
||||
self.client = OpenAI(api_key=api_key["api_key"])
|
||||
super().__init__()
|
||||
|
||||
def generate(
|
||||
@@ -30,7 +35,7 @@ class OpenaiLLMEngine(BaseLLMEngine):
|
||||
frequency_penalty: float = 0,
|
||||
presence_penalty: float = 0,
|
||||
) -> str | dict:
|
||||
response = openai.chat.completions.create(
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
@@ -61,3 +66,21 @@ class OpenaiLLMEngine(BaseLLMEngine):
|
||||
value=OPENAI_POSSIBLE_MODELS[0],
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_settings(cls):
|
||||
current_api_key = cls.retrieve_setting(identifier="openai_api_key")
|
||||
current_api_key = current_api_key["api_key"] if current_api_key else ""
|
||||
api_key_input = gr.Textbox(
|
||||
label="OpenAI API Key",
|
||||
type="password",
|
||||
value=current_api_key,
|
||||
)
|
||||
save = gr.Button("Save")
|
||||
|
||||
def save_api_key(api_key: str):
|
||||
cls.store_setting(identifier="openai_api_key", data={"api_key": api_key})
|
||||
gr.Info("API key saved successfully.")
|
||||
return gr.update(value=api_key)
|
||||
|
||||
save.click(save_api_key, inputs=[api_key_input])
|
||||
|
||||
Reference in New Issue
Block a user