mirror of
https://github.com/Paillat-dev/viralfactory.git
synced 2026-01-02 01:06:19 +00:00
🐛 fix(GenerationContext.py): fix import statements and add support for captioning engine
✨ feat(GenerationContext.py): add support for captioning engine in the GenerationContext class
The import statement for the `moviepy.editor` module is changed to `moviepy.editor as mp` to improve code readability. Additionally, the `gradio` module is imported as `gr` to improve code readability. The `GenerationContext` class now includes a `captioningengine` parameter and initializes a `captioningengine` attribute. The `setup_dir` method is modified to include a call to create a directory for the output files. The `get_file_path` method is modified to return the file path based on the output directory. The `process` method is modified to include additional steps for captioning. The `timed_script` attribute is added to store the result of the `ttsengine.synthesize` method. The `captioningengine` is used to generate captions and store them in the `captions` attribute. The final video is rendered using the `moviepy` library and saved as "final.mp4" in the output directory.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import moviepy
|
||||
import moviepy.editor as mp
|
||||
import time
|
||||
import os
|
||||
import gradio as gr
|
||||
|
||||
from .. import engines
|
||||
from ..utils.prompting import get_prompt
|
||||
@@ -8,31 +9,64 @@ from ..utils.prompting import get_prompt
|
||||
|
||||
class GenerationContext:
|
||||
def __init__(
|
||||
self, powerfulllmengine, simplellmengine, scriptengine, ttsengine
|
||||
self,
|
||||
powerfulllmengine,
|
||||
simplellmengine,
|
||||
scriptengine,
|
||||
ttsengine,
|
||||
captioningengine,
|
||||
) -> None:
|
||||
self.powerfulllmengine: engines.LLMEngine.BaseLLMEngine = powerfulllmengine
|
||||
self.powerfulllmengine: engines.LLMEngine.BaseLLMEngine = powerfulllmengine[0]
|
||||
self.powerfulllmengine.ctx = self
|
||||
|
||||
self.simplellmengine: engines.LLMEngine.BaseLLMEngine = simplellmengine
|
||||
self.simplellmengine: engines.LLMEngine.BaseLLMEngine = simplellmengine[0]
|
||||
self.simplellmengine.ctx = self
|
||||
|
||||
self.scriptengine: engines.ScriptEngine.BaseScriptEngine = scriptengine
|
||||
self.scriptengine: engines.ScriptEngine.BaseScriptEngine = scriptengine[0]
|
||||
self.scriptengine.ctx = self
|
||||
|
||||
self.ttsengine: engines.TTSEngine.BaseTTSEngine = ttsengine
|
||||
self.ttsengine: engines.TTSEngine.BaseTTSEngine = ttsengine[0]
|
||||
self.ttsengine.ctx = self
|
||||
|
||||
self.captioningengine: engines.CaptioningEngine.BaseCaptioningEngine = (
|
||||
captioningengine[0]
|
||||
)
|
||||
self.captioningengine.ctx = self
|
||||
|
||||
def setup_dir(self):
|
||||
self.dir = f"output/{time.time()}"
|
||||
os.makedirs(self.dir)
|
||||
|
||||
|
||||
def get_file_path(self, name: str) -> str:
|
||||
return os.path.join(self.dir, name)
|
||||
|
||||
def process(self):
|
||||
# IMPORTANT NOTE: All methods called here are expected to be defined as abstract methods in the base classes, if not there is an issue with the engine implementation.
|
||||
# ⚠️ IMPORTANT NOTE: All methods called here are expected to be defined as abstract methods in the base classes, if not there is an issue with the engine implementation.
|
||||
|
||||
progress = gr.Progress()
|
||||
self.width, self.height = (
|
||||
1080,
|
||||
1920,
|
||||
) # TODO: Add support for custom resolution, for now it's tiktok's resolution
|
||||
self.setup_dir()
|
||||
|
||||
script = self.scriptengine.generate()
|
||||
self.script = self.scriptengine.generate()
|
||||
|
||||
timed_script = self.ttsengine.synthesize(script, self.get_file_path("tts.wav"))
|
||||
self.timed_script = self.ttsengine.synthesize(
|
||||
self.script, self.get_file_path("tts.wav")
|
||||
)
|
||||
|
||||
if not isinstance(self.captioningengine, engines.NoneEngine):
|
||||
self.captions = self.captioningengine.get_captions()
|
||||
else:
|
||||
self.captions = []
|
||||
|
||||
# add any other processing steps here
|
||||
|
||||
# we render to a file called final.mp4
|
||||
# using moviepy CompositeVideoClip
|
||||
|
||||
clip = mp.CompositeVideoClip(self.captions, size=(self.width, self.height))
|
||||
audio = mp.AudioFileClip(self.get_file_path("tts.wav"))
|
||||
clip = clip.set_audio(audio)
|
||||
clip.write_videofile(self.get_file_path("final.mp4"), fps=60)
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
system: |-
|
||||
You will recieve from the user a textual script and its captions. Since the captions have been generated trough stt, they might contain some errors. Your task is to fix theese transcription errors and return the corrected captions, keeping the timestamped format.
|
||||
Please return valid json output, with no extra characters or comments, nor any codeblocks.
|
||||
|
||||
chat: |-
|
||||
{script}
|
||||
|
||||
{captions}
|
||||
10
src/engines/CaptioningEngine/BaseCaptioningEngine.py
Normal file
10
src/engines/CaptioningEngine/BaseCaptioningEngine.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from ..BaseEngine import BaseEngine
|
||||
|
||||
from moviepy.editor import TextClip
|
||||
|
||||
|
||||
class BaseCaptioningEngine(BaseEngine):
|
||||
@abstractmethod
|
||||
def get_captions(self) -> list[TextClip]:
|
||||
...
|
||||
95
src/engines/CaptioningEngine/SimpleCaptioningEngine.py
Normal file
95
src/engines/CaptioningEngine/SimpleCaptioningEngine.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import gradio as gr
|
||||
from moviepy.editor import TextClip
|
||||
from PIL import ImageFont
|
||||
from . import BaseCaptioningEngine
|
||||
|
||||
|
||||
class SimpleCaptioningEngine(BaseCaptioningEngine):
|
||||
name = "SimpleCaptioningEngine"
|
||||
description = "A basic captioning engine with nothing too fancy."
|
||||
num_options = 5
|
||||
|
||||
def __init__(self, options: list[list | tuple | str | int | float | bool | None]):
|
||||
self.font = options[0]
|
||||
self.font_size = options[1]
|
||||
self.stroke_width = options[2]
|
||||
self.font_color = options[3]
|
||||
self.stroke_color = options[4]
|
||||
|
||||
super().__init__()
|
||||
def build_caption_object(self, text: str, start: float, end: float) -> TextClip:
|
||||
return TextClip(
|
||||
text,
|
||||
fontsize=self.font_size,
|
||||
color=self.font_color,
|
||||
font=self.font,
|
||||
method="caption",
|
||||
size=(self.ctx.width /3 * 2, None),
|
||||
).set_position(('center', 0.65), relative=True).set_start(start).set_duration(end - start)
|
||||
def ends_with_punctuation(self, text: str) -> bool:
|
||||
punctuations = (".", "?", "!", ",", ":", ";")
|
||||
return text.strip().endswith(tuple(punctuations))
|
||||
|
||||
def get_captions(self) -> list[TextClip]:
|
||||
#3 words per 1000 px, we do the math
|
||||
max_words = int(self.ctx.width / 1000 * 3)
|
||||
|
||||
clips = []
|
||||
words = (
|
||||
self.ctx.timed_script.copy()
|
||||
) # List of dicts with "start", "end", and "text"
|
||||
current_line = ""
|
||||
current_start = words[0]["start"]
|
||||
current_end = words[0]["end"]
|
||||
for i, word in enumerate(words):
|
||||
# Use PIL to measure the text size
|
||||
line_with_new_word = (
|
||||
current_line + " " + word["text"] if current_line else word["text"]
|
||||
)
|
||||
pause = self.ends_with_punctuation(current_line.strip())
|
||||
|
||||
if len(line_with_new_word.split(" ")) > max_words or pause:
|
||||
clips.append(self.build_caption_object(current_line.strip(), current_start, current_end))
|
||||
current_line = word["text"] # Start a new line with the current word
|
||||
current_start = word["start"]
|
||||
current_end = word["end"]
|
||||
else:
|
||||
# If the line isn't too long, add the word to the current line
|
||||
current_line = line_with_new_word
|
||||
current_end = word["end"]
|
||||
# Don't forget to add the last line if it exists
|
||||
if current_line:
|
||||
clips.append(
|
||||
self.build_caption_object(current_line.strip(), current_start, words[-1]["end"])
|
||||
)
|
||||
|
||||
return clips
|
||||
|
||||
@classmethod
|
||||
def get_options(cls) -> list:
|
||||
with gr.Column() as font_options:
|
||||
with gr.Group():
|
||||
font = gr.Dropdown(
|
||||
label="Font",
|
||||
choices=TextClip.list('font'),
|
||||
value="Arial",
|
||||
)
|
||||
font_size = gr.Number(
|
||||
label="Font Size",
|
||||
minimum=70,
|
||||
maximum=200,
|
||||
step=1,
|
||||
value=110,
|
||||
)
|
||||
font_color = gr.ColorPicker(label="Font Color", value="#ffffff")
|
||||
with gr.Column() as font_stroke_options:
|
||||
with gr.Group():
|
||||
font_stroke_width = gr.Number(
|
||||
label="Stroke Width",
|
||||
minimum=0,
|
||||
maximum=40,
|
||||
step=1,
|
||||
value=4,
|
||||
)
|
||||
font_stroke_color = gr.ColorPicker(label="Stroke Color", value="#000000")
|
||||
return [font, font_size, font_stroke_width, font_color, font_stroke_color]
|
||||
2
src/engines/CaptioningEngine/__init__.py
Normal file
2
src/engines/CaptioningEngine/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .BaseCaptioningEngine import BaseCaptioningEngine
|
||||
from .SimpleCaptioningEngine import SimpleCaptioningEngine
|
||||
@@ -38,7 +38,7 @@ class OpenaiLLMEngine(BaseLLMEngine):
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": chat_prompt},
|
||||
],
|
||||
max_tokens=max_tokens,
|
||||
max_tokens=int(max_tokens) if max_tokens else openai._types.NOT_GIVEN,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
frequency_penalty=frequency_penalty,
|
||||
|
||||
14
src/engines/NoneEngine.py
Normal file
14
src/engines/NoneEngine.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from . import BaseEngine
|
||||
|
||||
|
||||
class NoneEngine(BaseEngine):
|
||||
num_options = 0
|
||||
name = "None"
|
||||
description = "No engine selected"
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_options(cls):
|
||||
return []
|
||||
@@ -16,8 +16,22 @@ class Word(TypedDict):
|
||||
|
||||
class BaseTTSEngine(BaseEngine):
|
||||
@abstractmethod
|
||||
def synthesize(self, text: str, path: str) -> str:
|
||||
def synthesize(self, text: str, path: str) -> list[Word]:
|
||||
pass
|
||||
|
||||
def remove_punctuation(self, text: str) -> str:
|
||||
return text.translate(str.maketrans("", "", ".,!?;:"))
|
||||
|
||||
def fix_captions(self, script: str, captions: list[Word]) -> list[Word]:
|
||||
script = script.split(" ")
|
||||
new_captions = []
|
||||
for i, word in enumerate(script):
|
||||
original_word = self.remove_punctuation(word.lower())
|
||||
stt_word = self.remove_punctuation(word.lower())
|
||||
if stt_word in original_word:
|
||||
captions[i]["text"] = word
|
||||
new_captions.append(captions[i])
|
||||
#elif there is a word more in the stt than in the original, we
|
||||
|
||||
def time_with_whisper(self, path: str) -> list[Word]:
|
||||
"""
|
||||
@@ -46,7 +60,7 @@ class BaseTTSEngine(BaseEngine):
|
||||
"""
|
||||
device = "cuda" if is_available() else "cpu"
|
||||
audio = wt.load_audio(path)
|
||||
model = wt.load_model("tiny", device=device)
|
||||
model = wt.load_model("small", device=device)
|
||||
|
||||
result = wt.transcribe(model=model, audio=audio)
|
||||
results = [word for chunk in result["segments"] for word in chunk["words"]]
|
||||
|
||||
@@ -5,8 +5,9 @@ import os
|
||||
|
||||
import torch
|
||||
|
||||
from .BaseTTSEngine import BaseTTSEngine
|
||||
from .BaseTTSEngine import BaseTTSEngine, Word
|
||||
|
||||
from ...utils.prompting import get_prompt
|
||||
|
||||
class CoquiTTSEngine(BaseTTSEngine):
|
||||
voices = [
|
||||
@@ -122,8 +123,10 @@ class CoquiTTSEngine(BaseTTSEngine):
|
||||
)
|
||||
if self.to_force_duration:
|
||||
self.force_duration(float(self.duration), path)
|
||||
|
||||
return self.time_with_whisper(path)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_options(cls) -> list:
|
||||
options = [
|
||||
@@ -131,7 +134,7 @@ class CoquiTTSEngine(BaseTTSEngine):
|
||||
label="Voice",
|
||||
choices=cls.voices,
|
||||
max_choices=1,
|
||||
value=cls.voices[0],
|
||||
value="Damien Black",
|
||||
),
|
||||
gr.Dropdown(
|
||||
label="Language",
|
||||
@@ -145,6 +148,7 @@ class CoquiTTSEngine(BaseTTSEngine):
|
||||
label="Force duration",
|
||||
info="Force the duration of the generated audio to be at most the specified value",
|
||||
value=False,
|
||||
show_label=True,
|
||||
)
|
||||
duration = gr.Number(
|
||||
label="Duration [s]", value=57, step=1, minimum=10, visible=False
|
||||
|
||||
32
src/engines/TTSEngine/prompts/fix_captions.yaml
Normal file
32
src/engines/TTSEngine/prompts/fix_captions.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
system: |-
|
||||
You will recieve from the user a textual script and its captions. Since the captions have been generated trough stt, they might contain some errors. Your task is to fix theese transcription errors and return the corrected captions, keeping the timestamped format.
|
||||
Please return valid json output, with no extra characters or comments, nor any codeblocks.
|
||||
The errors / corrections you should do are:
|
||||
- Fix any spelling errors
|
||||
- Fix any grammar errors
|
||||
- If a punctuation mark is not the same as in the script, change it to match the script. However, there should still be punctioation marks. They do not count in the one word per "text" field rule.
|
||||
- Turn any number or symbol that is spelled out into its numerical or symbolic representation (ex. "one" -> "1", "percent" -> "%", "dollar" -> "$", etc.)
|
||||
- Add capitalization at the beginning of each SENTENCE if missing (not each "text tag, only when multile tags form a sentence !!!") but do not create or infer sentences. Only if a sentence that is already there is not capitalized, you should capitalize it.
|
||||
- You are NOT allowed to change the timestamps at any cost, nor to reorganize the captions in any way. Your sole role is to fix transcription errors. Nothing else.
|
||||
- You should not add new words. If a sentence feels wrong in the original script, you should not change it, but keep it as is, and if needed make the captions match the script, even if the script does not feel correct.
|
||||
The response format should be a json object as follows:
|
||||
{
|
||||
"captions": [
|
||||
{
|
||||
"start": 0,
|
||||
"end": 1000.000,
|
||||
"text": "This is the first caption."
|
||||
},
|
||||
{
|
||||
"start": 1000.000,
|
||||
"end": 2000.023,
|
||||
"text": "This is the second caption."
|
||||
},
|
||||
etc...]
|
||||
}
|
||||
chat: |-
|
||||
{script}
|
||||
|
||||
{captions}
|
||||
|
||||
Remember that each "text" field should contain ONLY ONE WORD and should be changed ONLY IF NEEDED, else just copy pasted as is with no changes, nor any changes in the timestamps! ans the "text" fiels should NEVER BE a full sentence. The transcript is made to be precise at the word level, so you should not change the words, or it will be pointless.
|
||||
@@ -1,14 +1,39 @@
|
||||
from typing import TypedDict
|
||||
from .BaseEngine import BaseEngine
|
||||
from .NoneEngine import NoneEngine
|
||||
from . import TTSEngine
|
||||
from . import ScriptEngine
|
||||
from . import LLMEngine
|
||||
from . import CaptioningEngine
|
||||
|
||||
ENGINES = {
|
||||
"SimpleLLMEngine": [LLMEngine.OpenaiLLMEngine, LLMEngine.AnthropicLLMEngine],
|
||||
"PowerfulLLMEngine": [LLMEngine.OpenaiLLMEngine, LLMEngine.AnthropicLLMEngine],
|
||||
"TTSEngine": [TTSEngine.CoquiTTSEngine, TTSEngine.ElevenLabsTTSEngine],
|
||||
"ScriptEngine": [
|
||||
ScriptEngine.ShowerThoughtsScriptEngine,
|
||||
ScriptEngine.CustomScriptEngine,
|
||||
],
|
||||
|
||||
class EngineDict(TypedDict):
|
||||
classes: list[BaseEngine]
|
||||
multiple: bool
|
||||
|
||||
|
||||
ENGINES: dict[str, EngineDict] = {
|
||||
"SimpleLLMEngine": {
|
||||
"classes": [LLMEngine.OpenaiLLMEngine, LLMEngine.AnthropicLLMEngine],
|
||||
"multiple": False,
|
||||
},
|
||||
"PowerfulLLMEngine": {
|
||||
"classes": [LLMEngine.OpenaiLLMEngine, LLMEngine.AnthropicLLMEngine],
|
||||
"multiple": False,
|
||||
},
|
||||
"ScriptEngine": {
|
||||
"classes": [
|
||||
ScriptEngine.ShowerThoughtsScriptEngine,
|
||||
ScriptEngine.CustomScriptEngine,
|
||||
],
|
||||
"multiple": False,
|
||||
},
|
||||
"TTSEngine": {
|
||||
"classes": [TTSEngine.CoquiTTSEngine, TTSEngine.ElevenLabsTTSEngine],
|
||||
"multiple": False,
|
||||
},
|
||||
"CaptioningEngine": {
|
||||
"classes": [CaptioningEngine.SimpleCaptioningEngine, NoneEngine],
|
||||
"multiple": False,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -8,10 +8,15 @@ class Prompt(TypedDict):
|
||||
chat: str
|
||||
|
||||
|
||||
def get_prompt(name, *, location="src/chore/prompts") -> tuple[str, str]:
|
||||
path = os.path.join(os.getcwd(), location, f"{name}.yaml")
|
||||
def get_prompt(name, *, location: str="src/chore/prompts", by_file_location: str = None) -> tuple[str, str]:
|
||||
if by_file_location:
|
||||
path = os.path.join(
|
||||
os.path.dirname(os.path.abspath(by_file_location)), "prompts", f"{name}.yaml"
|
||||
)
|
||||
else:
|
||||
path = os.path.join(os.getcwd(), location, f"{name}.yaml")
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f"Prompt file {path} does not exist.")
|
||||
with open(path, "r") as file:
|
||||
prompt: Prompt = yaml.safe_load(file)
|
||||
return prompt["system"], prompt["chat"]
|
||||
return prompt["system"], prompt["chat"]
|
||||
|
||||
Reference in New Issue
Block a user