fix(GenerationContext.py): fix typo in variable name powerfulllmengine to powerfulllmengine for better readability

feat(GenerationContext.py): add setup_dir method to create a directory for output files with a timestamp
feat(GenerationContext.py): call setup_dir method before generating script and synthesizing audio to ensure output directory exists
feat(prompts/fix_captions.yaml): add a new prompt file to provide instructions for fixing captions
fix(BaseTTSEngine.py): add force_duration method to adjust audio clip duration if it exceeds a specified duration
feat(CoquiTTSEngine.py): add options for forcing duration and specifying duration in the UI
feat(utils/prompting.py): add get_prompt function to load prompt files from a specified location
fix(gradio_ui.py): set equal_height=True for engine_rows to ensure consistent height for engine options
This commit is contained in:
2024-02-15 12:27:13 +01:00
parent 9f88e6d069
commit 57bcf0af8e
7 changed files with 73 additions and 9 deletions

View File

@@ -1,17 +1,31 @@
import moviepy import moviepy
import time
import os
from .. import engines from .. import engines
from ..utils.prompting import get_prompt
class GenerationContext: class GenerationContext:
def __init__(self, llmengine: engines.LLMEngine.BaseLLMEngine, scriptengine: engines.ScriptEngine.BaseScriptEngine, ttsengine: engines.TTSEngine.BaseTTSEngine) -> None: def __init__(self, powerfulllmengine: engines.LLMEngine.BaseLLMEngine, simplellmengine: engines.LLMEngine.BaseLLMEngine, scriptengine: engines.ScriptEngine.BaseScriptEngine, ttsengine: engines.TTSEngine.BaseTTSEngine) -> None:
self.llmengine = llmengine self.powerfulllmengine = powerfulllmengine
self.llmengine.ctx = self self.powerfulllmengine.ctx = self
self.simplellmengine = simplellmengine
self.simplellmengine.ctx = self
self.scriptengine = scriptengine self.scriptengine = scriptengine
self.scriptengine.ctx = self self.scriptengine.ctx = self
self.ttsengine = ttsengine self.ttsengine = ttsengine
self.ttsengine.ctx = self self.ttsengine.ctx = self
def setup_dir(self):
self.dir = f"output/{time.time()}"
os.makedirs(self.dir)
def process(self): def process(self):
timed_script = self.scriptengine.generate() self.setup_dir()
script = self.scriptengine.generate()
timed_script = self.ttsengine.synthesize(script, self.dir)

View File

@@ -0,0 +1,8 @@
system: |-
You will recieve from the user a textual script and its captions. Since the captions have been generated trough stt, they might contain some errors. Your task is to fix theese transcription errors and return the corrected captions, keeping the timestamped format.
Please return valid json output, with no extra characters or comments, nor any codeblocks.
chat: |-
{script}
{captions}

View File

@@ -1,10 +1,23 @@
import moviepy.editor as mp
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
# Assuming BaseEngine is defined elsewhere in your project
from ..BaseEngine import BaseEngine from ..BaseEngine import BaseEngine
class BaseTTSEngine(BaseEngine): class BaseTTSEngine(BaseEngine):
pass
@abstractmethod @abstractmethod
def synthesize(self, text: str, path: str) -> str: def synthesize(self, text: str, path: str) -> str:
pass pass
def force_duration(self, duration: float, path: str):
audio_clip = mp.AudioFileClip(path)
if audio_clip.duration > duration:
speed_factor = audio_clip.duration / duration
new_audio = audio_clip.fx(mp.vfx.speedx, speed_factor, final_duration=duration)
new_audio.write_audiofile(path, codec='libmp3lame')
audio_clip.close()

View File

@@ -90,13 +90,15 @@ class CoquiTTSEngine(BaseTTSEngine):
"ko", # Korean "ko", # Korean
"hi", # Hindi "hi", # Hindi
] ]
num_options = 2 num_options = 4
def __init__(self, options: list): def __init__(self, options: list):
super().__init__() super().__init__()
self.voice = options[0][0] self.voice = options[0][0]
self.language = options[1][0] self.language = options[1][0]
self.to_force_duration = options[2][0]
self.duration = options[3]
os.environ["COQUI_TOS_AGREED"] = "1" os.environ["COQUI_TOS_AGREED"] = "1"
@@ -106,11 +108,13 @@ class CoquiTTSEngine(BaseTTSEngine):
def synthesize(self, text: str, path: str) -> str: def synthesize(self, text: str, path: str) -> str:
# self.tts.tts_to_file(text=text, file_path=path, lang=self.language, speaker=self.voice) # self.tts.tts_to_file(text=text, file_path=path, lang=self.language, speaker=self.voice)
if self.to_force_duration:
self.force_duration(float(self.duration), path)
return path return path
@classmethod @classmethod
def get_options(cls) -> list: def get_options(cls) -> list:
return [ options = [
gr.Dropdown( gr.Dropdown(
label="Voice", label="Voice",
choices=cls.voices, choices=cls.voices,
@@ -124,3 +128,13 @@ class CoquiTTSEngine(BaseTTSEngine):
value=cls.languages[0], value=cls.languages[0],
), ),
] ]
duration_checkbox = gr.Checkbox(value=False)
duration = gr.Number(label="Duration", value=57, step=1, minimum=10, visible=False)
duration_switch = lambda x: gr.update(visible=x)
duration_checkbox.change(duration_switch, inputs=[duration_checkbox], outputs=[duration])
duration_checkbox_group = gr.CheckboxGroup([duration_checkbox], label="Force duration")
options.append(duration_checkbox_group)
options.append(duration)
return options

1
src/utils/__init__.py Normal file
View File

@@ -0,0 +1 @@
from . import prompting

14
src/utils/prompting.py Normal file
View File

@@ -0,0 +1,14 @@
import yaml
import os
from typing import TypedDict
class Prompt(TypedDict):
system: str
chat: str
def get_prompt(name, *, location = "src/chore/prompts") -> Prompt:
path = os.path.join(os.getcwd(), location, f"{name}.yaml")
if not os.path.exists(path):
raise FileNotFoundError(f"Prompt file {path} does not exist.")
with open(path, "r") as file:
return yaml.safe_load(file)

View File

@@ -52,7 +52,7 @@ class GenerateUI:
inputs.append(engine_dropdown) inputs.append(engine_dropdown)
engine_rows = [] engine_rows = []
for i, engine in enumerate(engines): for i, engine in enumerate(engines):
with gr.Row(visible=(i == 0)) as engine_row: with gr.Row(equal_height=True, visible=(i == 0)) as engine_row:
engine_rows.append(engine_row) engine_rows.append(engine_row)
options = engine.get_options() options = engine.get_options()
inputs.extend(options) inputs.extend(options)