🐛 fix(GenerationContext.py): fix import statements and add support for captioning engine

 feat(GenerationContext.py): add support for captioning engine in the GenerationContext class
The import statement for the `moviepy.editor` module is changed to `moviepy.editor as mp` to improve code readability. Additionally, the `gradio` module is imported as `gr` to improve code readability. The `GenerationContext` class now includes a `captioningengine` parameter and initializes a `captioningengine` attribute. The `setup_dir` method is modified to include a call to create a directory for the output files. The `get_file_path` method is modified to return the file path based on the output directory. The `process` method is modified to include additional steps for captioning. The `timed_script` attribute is added to store the result of the `ttsengine.synthesize` method. The `captioningengine` is used to generate captions and store them in the `captions` attribute. The final video is rendered using the `moviepy` library and saved as "final.mp4" in the output directory.
This commit is contained in:
2024-02-17 18:47:30 +01:00
parent eedbc99121
commit e3229518d4
12 changed files with 261 additions and 34 deletions

View File

@@ -1,6 +1,7 @@
import moviepy
import moviepy.editor as mp
import time
import os
import gradio as gr
from .. import engines
from ..utils.prompting import get_prompt
@@ -8,31 +9,64 @@ from ..utils.prompting import get_prompt
class GenerationContext:
def __init__(
self, powerfulllmengine, simplellmengine, scriptengine, ttsengine
self,
powerfulllmengine,
simplellmengine,
scriptengine,
ttsengine,
captioningengine,
) -> None:
self.powerfulllmengine: engines.LLMEngine.BaseLLMEngine = powerfulllmengine
self.powerfulllmengine: engines.LLMEngine.BaseLLMEngine = powerfulllmengine[0]
self.powerfulllmengine.ctx = self
self.simplellmengine: engines.LLMEngine.BaseLLMEngine = simplellmengine
self.simplellmengine: engines.LLMEngine.BaseLLMEngine = simplellmengine[0]
self.simplellmengine.ctx = self
self.scriptengine: engines.ScriptEngine.BaseScriptEngine = scriptengine
self.scriptengine: engines.ScriptEngine.BaseScriptEngine = scriptengine[0]
self.scriptengine.ctx = self
self.ttsengine: engines.TTSEngine.BaseTTSEngine = ttsengine
self.ttsengine: engines.TTSEngine.BaseTTSEngine = ttsengine[0]
self.ttsengine.ctx = self
self.captioningengine: engines.CaptioningEngine.BaseCaptioningEngine = (
captioningengine[0]
)
self.captioningengine.ctx = self
def setup_dir(self):
self.dir = f"output/{time.time()}"
os.makedirs(self.dir)
def get_file_path(self, name: str) -> str:
return os.path.join(self.dir, name)
def process(self):
# IMPORTANT NOTE: All methods called here are expected to be defined as abstract methods in the base classes, if not there is an issue with the engine implementation.
# ⚠️ IMPORTANT NOTE: All methods called here are expected to be defined as abstract methods in the base classes, if not there is an issue with the engine implementation.
progress = gr.Progress()
self.width, self.height = (
1080,
1920,
) # TODO: Add support for custom resolution, for now it's tiktok's resolution
self.setup_dir()
script = self.scriptengine.generate()
self.script = self.scriptengine.generate()
timed_script = self.ttsengine.synthesize(script, self.get_file_path("tts.wav"))
self.timed_script = self.ttsengine.synthesize(
self.script, self.get_file_path("tts.wav")
)
if not isinstance(self.captioningengine, engines.NoneEngine):
self.captions = self.captioningengine.get_captions()
else:
self.captions = []
# add any other processing steps here
# we render to a file called final.mp4
# using moviepy CompositeVideoClip
clip = mp.CompositeVideoClip(self.captions, size=(self.width, self.height))
audio = mp.AudioFileClip(self.get_file_path("tts.wav"))
clip = clip.set_audio(audio)
clip.write_videofile(self.get_file_path("final.mp4"), fps=60)

View File

@@ -1,8 +0,0 @@
system: |-
You will recieve from the user a textual script and its captions. Since the captions have been generated trough stt, they might contain some errors. Your task is to fix theese transcription errors and return the corrected captions, keeping the timestamped format.
Please return valid json output, with no extra characters or comments, nor any codeblocks.
chat: |-
{script}
{captions}

View File

@@ -0,0 +1,10 @@
from abc import ABC, abstractmethod
from ..BaseEngine import BaseEngine
from moviepy.editor import TextClip
class BaseCaptioningEngine(BaseEngine):
@abstractmethod
def get_captions(self) -> list[TextClip]:
...

View File

@@ -0,0 +1,95 @@
import gradio as gr
from moviepy.editor import TextClip
from PIL import ImageFont
from . import BaseCaptioningEngine
class SimpleCaptioningEngine(BaseCaptioningEngine):
name = "SimpleCaptioningEngine"
description = "A basic captioning engine with nothing too fancy."
num_options = 5
def __init__(self, options: list[list | tuple | str | int | float | bool | None]):
self.font = options[0]
self.font_size = options[1]
self.stroke_width = options[2]
self.font_color = options[3]
self.stroke_color = options[4]
super().__init__()
def build_caption_object(self, text: str, start: float, end: float) -> TextClip:
return TextClip(
text,
fontsize=self.font_size,
color=self.font_color,
font=self.font,
method="caption",
size=(self.ctx.width /3 * 2, None),
).set_position(('center', 0.65), relative=True).set_start(start).set_duration(end - start)
def ends_with_punctuation(self, text: str) -> bool:
punctuations = (".", "?", "!", ",", ":", ";")
return text.strip().endswith(tuple(punctuations))
def get_captions(self) -> list[TextClip]:
#3 words per 1000 px, we do the math
max_words = int(self.ctx.width / 1000 * 3)
clips = []
words = (
self.ctx.timed_script.copy()
) # List of dicts with "start", "end", and "text"
current_line = ""
current_start = words[0]["start"]
current_end = words[0]["end"]
for i, word in enumerate(words):
# Use PIL to measure the text size
line_with_new_word = (
current_line + " " + word["text"] if current_line else word["text"]
)
pause = self.ends_with_punctuation(current_line.strip())
if len(line_with_new_word.split(" ")) > max_words or pause:
clips.append(self.build_caption_object(current_line.strip(), current_start, current_end))
current_line = word["text"] # Start a new line with the current word
current_start = word["start"]
current_end = word["end"]
else:
# If the line isn't too long, add the word to the current line
current_line = line_with_new_word
current_end = word["end"]
# Don't forget to add the last line if it exists
if current_line:
clips.append(
self.build_caption_object(current_line.strip(), current_start, words[-1]["end"])
)
return clips
@classmethod
def get_options(cls) -> list:
with gr.Column() as font_options:
with gr.Group():
font = gr.Dropdown(
label="Font",
choices=TextClip.list('font'),
value="Arial",
)
font_size = gr.Number(
label="Font Size",
minimum=70,
maximum=200,
step=1,
value=110,
)
font_color = gr.ColorPicker(label="Font Color", value="#ffffff")
with gr.Column() as font_stroke_options:
with gr.Group():
font_stroke_width = gr.Number(
label="Stroke Width",
minimum=0,
maximum=40,
step=1,
value=4,
)
font_stroke_color = gr.ColorPicker(label="Stroke Color", value="#000000")
return [font, font_size, font_stroke_width, font_color, font_stroke_color]

View File

@@ -0,0 +1,2 @@
from .BaseCaptioningEngine import BaseCaptioningEngine
from .SimpleCaptioningEngine import SimpleCaptioningEngine

View File

@@ -38,7 +38,7 @@ class OpenaiLLMEngine(BaseLLMEngine):
{"role": "system", "content": system_prompt},
{"role": "user", "content": chat_prompt},
],
max_tokens=max_tokens,
max_tokens=int(max_tokens) if max_tokens else openai._types.NOT_GIVEN,
temperature=temperature,
top_p=top_p,
frequency_penalty=frequency_penalty,

14
src/engines/NoneEngine.py Normal file
View File

@@ -0,0 +1,14 @@
from . import BaseEngine
class NoneEngine(BaseEngine):
num_options = 0
name = "None"
description = "No engine selected"
def __init__(self):
pass
@classmethod
def get_options(cls):
return []

View File

@@ -16,8 +16,22 @@ class Word(TypedDict):
class BaseTTSEngine(BaseEngine):
@abstractmethod
def synthesize(self, text: str, path: str) -> str:
def synthesize(self, text: str, path: str) -> list[Word]:
pass
def remove_punctuation(self, text: str) -> str:
return text.translate(str.maketrans("", "", ".,!?;:"))
def fix_captions(self, script: str, captions: list[Word]) -> list[Word]:
script = script.split(" ")
new_captions = []
for i, word in enumerate(script):
original_word = self.remove_punctuation(word.lower())
stt_word = self.remove_punctuation(word.lower())
if stt_word in original_word:
captions[i]["text"] = word
new_captions.append(captions[i])
#elif there is a word more in the stt than in the original, we
def time_with_whisper(self, path: str) -> list[Word]:
"""
@@ -46,7 +60,7 @@ class BaseTTSEngine(BaseEngine):
"""
device = "cuda" if is_available() else "cpu"
audio = wt.load_audio(path)
model = wt.load_model("tiny", device=device)
model = wt.load_model("small", device=device)
result = wt.transcribe(model=model, audio=audio)
results = [word for chunk in result["segments"] for word in chunk["words"]]

View File

@@ -5,8 +5,9 @@ import os
import torch
from .BaseTTSEngine import BaseTTSEngine
from .BaseTTSEngine import BaseTTSEngine, Word
from ...utils.prompting import get_prompt
class CoquiTTSEngine(BaseTTSEngine):
voices = [
@@ -122,8 +123,10 @@ class CoquiTTSEngine(BaseTTSEngine):
)
if self.to_force_duration:
self.force_duration(float(self.duration), path)
return self.time_with_whisper(path)
@classmethod
def get_options(cls) -> list:
options = [
@@ -131,7 +134,7 @@ class CoquiTTSEngine(BaseTTSEngine):
label="Voice",
choices=cls.voices,
max_choices=1,
value=cls.voices[0],
value="Damien Black",
),
gr.Dropdown(
label="Language",
@@ -145,6 +148,7 @@ class CoquiTTSEngine(BaseTTSEngine):
label="Force duration",
info="Force the duration of the generated audio to be at most the specified value",
value=False,
show_label=True,
)
duration = gr.Number(
label="Duration [s]", value=57, step=1, minimum=10, visible=False

View File

@@ -0,0 +1,32 @@
system: |-
You will recieve from the user a textual script and its captions. Since the captions have been generated trough stt, they might contain some errors. Your task is to fix theese transcription errors and return the corrected captions, keeping the timestamped format.
Please return valid json output, with no extra characters or comments, nor any codeblocks.
The errors / corrections you should do are:
- Fix any spelling errors
- Fix any grammar errors
- If a punctuation mark is not the same as in the script, change it to match the script. However, there should still be punctioation marks. They do not count in the one word per "text" field rule.
- Turn any number or symbol that is spelled out into its numerical or symbolic representation (ex. "one" -> "1", "percent" -> "%", "dollar" -> "$", etc.)
- Add capitalization at the beginning of each SENTENCE if missing (not each "text tag, only when multile tags form a sentence !!!") but do not create or infer sentences. Only if a sentence that is already there is not capitalized, you should capitalize it.
- You are NOT allowed to change the timestamps at any cost, nor to reorganize the captions in any way. Your sole role is to fix transcription errors. Nothing else.
- You should not add new words. If a sentence feels wrong in the original script, you should not change it, but keep it as is, and if needed make the captions match the script, even if the script does not feel correct.
The response format should be a json object as follows:
{
"captions": [
{
"start": 0,
"end": 1000.000,
"text": "This is the first caption."
},
{
"start": 1000.000,
"end": 2000.023,
"text": "This is the second caption."
},
etc...]
}
chat: |-
{script}
{captions}
Remember that each "text" field should contain ONLY ONE WORD and should be changed ONLY IF NEEDED, else just copy pasted as is with no changes, nor any changes in the timestamps! ans the "text" fiels should NEVER BE a full sentence. The transcript is made to be precise at the word level, so you should not change the words, or it will be pointless.

View File

@@ -1,14 +1,39 @@
from typing import TypedDict
from .BaseEngine import BaseEngine
from .NoneEngine import NoneEngine
from . import TTSEngine
from . import ScriptEngine
from . import LLMEngine
from . import CaptioningEngine
ENGINES = {
"SimpleLLMEngine": [LLMEngine.OpenaiLLMEngine, LLMEngine.AnthropicLLMEngine],
"PowerfulLLMEngine": [LLMEngine.OpenaiLLMEngine, LLMEngine.AnthropicLLMEngine],
"TTSEngine": [TTSEngine.CoquiTTSEngine, TTSEngine.ElevenLabsTTSEngine],
"ScriptEngine": [
ScriptEngine.ShowerThoughtsScriptEngine,
ScriptEngine.CustomScriptEngine,
],
class EngineDict(TypedDict):
classes: list[BaseEngine]
multiple: bool
ENGINES: dict[str, EngineDict] = {
"SimpleLLMEngine": {
"classes": [LLMEngine.OpenaiLLMEngine, LLMEngine.AnthropicLLMEngine],
"multiple": False,
},
"PowerfulLLMEngine": {
"classes": [LLMEngine.OpenaiLLMEngine, LLMEngine.AnthropicLLMEngine],
"multiple": False,
},
"ScriptEngine": {
"classes": [
ScriptEngine.ShowerThoughtsScriptEngine,
ScriptEngine.CustomScriptEngine,
],
"multiple": False,
},
"TTSEngine": {
"classes": [TTSEngine.CoquiTTSEngine, TTSEngine.ElevenLabsTTSEngine],
"multiple": False,
},
"CaptioningEngine": {
"classes": [CaptioningEngine.SimpleCaptioningEngine, NoneEngine],
"multiple": False,
},
}

View File

@@ -8,10 +8,15 @@ class Prompt(TypedDict):
chat: str
def get_prompt(name, *, location="src/chore/prompts") -> tuple[str, str]:
path = os.path.join(os.getcwd(), location, f"{name}.yaml")
def get_prompt(name, *, location: str="src/chore/prompts", by_file_location: str = None) -> tuple[str, str]:
if by_file_location:
path = os.path.join(
os.path.dirname(os.path.abspath(by_file_location)), "prompts", f"{name}.yaml"
)
else:
path = os.path.join(os.getcwd(), location, f"{name}.yaml")
if not os.path.exists(path):
raise FileNotFoundError(f"Prompt file {path} does not exist.")
with open(path, "r") as file:
prompt: Prompt = yaml.safe_load(file)
return prompt["system"], prompt["chat"]
return prompt["system"], prompt["chat"]