diff --git a/src/chore/GenerationContext.py b/src/chore/GenerationContext.py index dbba4b1..5701692 100644 --- a/src/chore/GenerationContext.py +++ b/src/chore/GenerationContext.py @@ -1,6 +1,7 @@ -import moviepy +import moviepy.editor as mp import time import os +import gradio as gr from .. import engines from ..utils.prompting import get_prompt @@ -8,31 +9,64 @@ from ..utils.prompting import get_prompt class GenerationContext: def __init__( - self, powerfulllmengine, simplellmengine, scriptengine, ttsengine + self, + powerfulllmengine, + simplellmengine, + scriptengine, + ttsengine, + captioningengine, ) -> None: - self.powerfulllmengine: engines.LLMEngine.BaseLLMEngine = powerfulllmengine + self.powerfulllmengine: engines.LLMEngine.BaseLLMEngine = powerfulllmengine[0] self.powerfulllmengine.ctx = self - self.simplellmengine: engines.LLMEngine.BaseLLMEngine = simplellmengine + self.simplellmengine: engines.LLMEngine.BaseLLMEngine = simplellmengine[0] self.simplellmengine.ctx = self - self.scriptengine: engines.ScriptEngine.BaseScriptEngine = scriptengine + self.scriptengine: engines.ScriptEngine.BaseScriptEngine = scriptengine[0] self.scriptengine.ctx = self - self.ttsengine: engines.TTSEngine.BaseTTSEngine = ttsengine + self.ttsengine: engines.TTSEngine.BaseTTSEngine = ttsengine[0] self.ttsengine.ctx = self + self.captioningengine: engines.CaptioningEngine.BaseCaptioningEngine = ( + captioningengine[0] + ) + self.captioningengine.ctx = self + def setup_dir(self): self.dir = f"output/{time.time()}" os.makedirs(self.dir) - + def get_file_path(self, name: str) -> str: return os.path.join(self.dir, name) def process(self): - # IMPORTANT NOTE: All methods called here are expected to be defined as abstract methods in the base classes, if not there is an issue with the engine implementation. + # ⚠️ IMPORTANT NOTE: All methods called here are expected to be defined as abstract methods in the base classes, if not there is an issue with the engine implementation. + + progress = gr.Progress() + self.width, self.height = ( + 1080, + 1920, + ) # TODO: Add support for custom resolution, for now it's tiktok's resolution self.setup_dir() - script = self.scriptengine.generate() + self.script = self.scriptengine.generate() - timed_script = self.ttsengine.synthesize(script, self.get_file_path("tts.wav")) + self.timed_script = self.ttsengine.synthesize( + self.script, self.get_file_path("tts.wav") + ) + + if not isinstance(self.captioningengine, engines.NoneEngine): + self.captions = self.captioningengine.get_captions() + else: + self.captions = [] + + # add any other processing steps here + + # we render to a file called final.mp4 + # using moviepy CompositeVideoClip + + clip = mp.CompositeVideoClip(self.captions, size=(self.width, self.height)) + audio = mp.AudioFileClip(self.get_file_path("tts.wav")) + clip = clip.set_audio(audio) + clip.write_videofile(self.get_file_path("final.mp4"), fps=60) diff --git a/src/chore/prompts/fix_captions.yaml b/src/chore/prompts/fix_captions.yaml deleted file mode 100644 index e9d0e56..0000000 --- a/src/chore/prompts/fix_captions.yaml +++ /dev/null @@ -1,8 +0,0 @@ -system: |- - You will recieve from the user a textual script and its captions. Since the captions have been generated trough stt, they might contain some errors. Your task is to fix theese transcription errors and return the corrected captions, keeping the timestamped format. - Please return valid json output, with no extra characters or comments, nor any codeblocks. - -chat: |- - {script} - - {captions} \ No newline at end of file diff --git a/src/engines/CaptioningEngine/BaseCaptioningEngine.py b/src/engines/CaptioningEngine/BaseCaptioningEngine.py new file mode 100644 index 0000000..b5731a8 --- /dev/null +++ b/src/engines/CaptioningEngine/BaseCaptioningEngine.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod +from ..BaseEngine import BaseEngine + +from moviepy.editor import TextClip + + +class BaseCaptioningEngine(BaseEngine): + @abstractmethod + def get_captions(self) -> list[TextClip]: + ... diff --git a/src/engines/CaptioningEngine/SimpleCaptioningEngine.py b/src/engines/CaptioningEngine/SimpleCaptioningEngine.py new file mode 100644 index 0000000..89cc748 --- /dev/null +++ b/src/engines/CaptioningEngine/SimpleCaptioningEngine.py @@ -0,0 +1,95 @@ +import gradio as gr +from moviepy.editor import TextClip +from PIL import ImageFont +from . import BaseCaptioningEngine + + +class SimpleCaptioningEngine(BaseCaptioningEngine): + name = "SimpleCaptioningEngine" + description = "A basic captioning engine with nothing too fancy." + num_options = 5 + + def __init__(self, options: list[list | tuple | str | int | float | bool | None]): + self.font = options[0] + self.font_size = options[1] + self.stroke_width = options[2] + self.font_color = options[3] + self.stroke_color = options[4] + + super().__init__() + def build_caption_object(self, text: str, start: float, end: float) -> TextClip: + return TextClip( + text, + fontsize=self.font_size, + color=self.font_color, + font=self.font, + method="caption", + size=(self.ctx.width /3 * 2, None), + ).set_position(('center', 0.65), relative=True).set_start(start).set_duration(end - start) + def ends_with_punctuation(self, text: str) -> bool: + punctuations = (".", "?", "!", ",", ":", ";") + return text.strip().endswith(tuple(punctuations)) + + def get_captions(self) -> list[TextClip]: + #3 words per 1000 px, we do the math + max_words = int(self.ctx.width / 1000 * 3) + + clips = [] + words = ( + self.ctx.timed_script.copy() + ) # List of dicts with "start", "end", and "text" + current_line = "" + current_start = words[0]["start"] + current_end = words[0]["end"] + for i, word in enumerate(words): + # Use PIL to measure the text size + line_with_new_word = ( + current_line + " " + word["text"] if current_line else word["text"] + ) + pause = self.ends_with_punctuation(current_line.strip()) + + if len(line_with_new_word.split(" ")) > max_words or pause: + clips.append(self.build_caption_object(current_line.strip(), current_start, current_end)) + current_line = word["text"] # Start a new line with the current word + current_start = word["start"] + current_end = word["end"] + else: + # If the line isn't too long, add the word to the current line + current_line = line_with_new_word + current_end = word["end"] + # Don't forget to add the last line if it exists + if current_line: + clips.append( + self.build_caption_object(current_line.strip(), current_start, words[-1]["end"]) + ) + + return clips + + @classmethod + def get_options(cls) -> list: + with gr.Column() as font_options: + with gr.Group(): + font = gr.Dropdown( + label="Font", + choices=TextClip.list('font'), + value="Arial", + ) + font_size = gr.Number( + label="Font Size", + minimum=70, + maximum=200, + step=1, + value=110, + ) + font_color = gr.ColorPicker(label="Font Color", value="#ffffff") + with gr.Column() as font_stroke_options: + with gr.Group(): + font_stroke_width = gr.Number( + label="Stroke Width", + minimum=0, + maximum=40, + step=1, + value=4, + ) + font_stroke_color = gr.ColorPicker(label="Stroke Color", value="#000000") + return [font, font_size, font_stroke_width, font_color, font_stroke_color] diff --git a/src/engines/CaptioningEngine/__init__.py b/src/engines/CaptioningEngine/__init__.py new file mode 100644 index 0000000..52f1a28 --- /dev/null +++ b/src/engines/CaptioningEngine/__init__.py @@ -0,0 +1,2 @@ +from .BaseCaptioningEngine import BaseCaptioningEngine +from .SimpleCaptioningEngine import SimpleCaptioningEngine diff --git a/src/engines/LLMEngine/OpenaiLLMEngine.py b/src/engines/LLMEngine/OpenaiLLMEngine.py index b5df62f..7e380c1 100644 --- a/src/engines/LLMEngine/OpenaiLLMEngine.py +++ b/src/engines/LLMEngine/OpenaiLLMEngine.py @@ -38,7 +38,7 @@ class OpenaiLLMEngine(BaseLLMEngine): {"role": "system", "content": system_prompt}, {"role": "user", "content": chat_prompt}, ], - max_tokens=max_tokens, + max_tokens=int(max_tokens) if max_tokens else openai._types.NOT_GIVEN, temperature=temperature, top_p=top_p, frequency_penalty=frequency_penalty, diff --git a/src/engines/NoneEngine.py b/src/engines/NoneEngine.py new file mode 100644 index 0000000..6c828b5 --- /dev/null +++ b/src/engines/NoneEngine.py @@ -0,0 +1,14 @@ +from . import BaseEngine + + +class NoneEngine(BaseEngine): + num_options = 0 + name = "None" + description = "No engine selected" + + def __init__(self): + pass + + @classmethod + def get_options(cls): + return [] diff --git a/src/engines/TTSEngine/BaseTTSEngine.py b/src/engines/TTSEngine/BaseTTSEngine.py index 302bec8..6d883f3 100644 --- a/src/engines/TTSEngine/BaseTTSEngine.py +++ b/src/engines/TTSEngine/BaseTTSEngine.py @@ -16,8 +16,22 @@ class Word(TypedDict): class BaseTTSEngine(BaseEngine): @abstractmethod - def synthesize(self, text: str, path: str) -> str: + def synthesize(self, text: str, path: str) -> list[Word]: pass + + def remove_punctuation(self, text: str) -> str: + return text.translate(str.maketrans("", "", ".,!?;:")) + + def fix_captions(self, script: str, captions: list[Word]) -> list[Word]: + script = script.split(" ") + new_captions = [] + for i, word in enumerate(script): + original_word = self.remove_punctuation(word.lower()) + stt_word = self.remove_punctuation(word.lower()) + if stt_word in original_word: + captions[i]["text"] = word + new_captions.append(captions[i]) + #elif there is a word more in the stt than in the original, we def time_with_whisper(self, path: str) -> list[Word]: """ @@ -46,7 +60,7 @@ class BaseTTSEngine(BaseEngine): """ device = "cuda" if is_available() else "cpu" audio = wt.load_audio(path) - model = wt.load_model("tiny", device=device) + model = wt.load_model("small", device=device) result = wt.transcribe(model=model, audio=audio) results = [word for chunk in result["segments"] for word in chunk["words"]] diff --git a/src/engines/TTSEngine/CoquiTTSEngine.py b/src/engines/TTSEngine/CoquiTTSEngine.py index 9635092..56a4bef 100644 --- a/src/engines/TTSEngine/CoquiTTSEngine.py +++ b/src/engines/TTSEngine/CoquiTTSEngine.py @@ -5,8 +5,9 @@ import os import torch -from .BaseTTSEngine import BaseTTSEngine +from .BaseTTSEngine import BaseTTSEngine, Word +from ...utils.prompting import get_prompt class CoquiTTSEngine(BaseTTSEngine): voices = [ @@ -122,8 +123,10 @@ class CoquiTTSEngine(BaseTTSEngine): ) if self.to_force_duration: self.force_duration(float(self.duration), path) + return self.time_with_whisper(path) + @classmethod def get_options(cls) -> list: options = [ @@ -131,7 +134,7 @@ class CoquiTTSEngine(BaseTTSEngine): label="Voice", choices=cls.voices, max_choices=1, - value=cls.voices[0], + value="Damien Black", ), gr.Dropdown( label="Language", @@ -145,6 +148,7 @@ class CoquiTTSEngine(BaseTTSEngine): label="Force duration", info="Force the duration of the generated audio to be at most the specified value", value=False, + show_label=True, ) duration = gr.Number( label="Duration [s]", value=57, step=1, minimum=10, visible=False diff --git a/src/engines/TTSEngine/prompts/fix_captions.yaml b/src/engines/TTSEngine/prompts/fix_captions.yaml new file mode 100644 index 0000000..d146bd5 --- /dev/null +++ b/src/engines/TTSEngine/prompts/fix_captions.yaml @@ -0,0 +1,32 @@ +system: |- + You will recieve from the user a textual script and its captions. Since the captions have been generated trough stt, they might contain some errors. Your task is to fix theese transcription errors and return the corrected captions, keeping the timestamped format. + Please return valid json output, with no extra characters or comments, nor any codeblocks. + The errors / corrections you should do are: + - Fix any spelling errors + - Fix any grammar errors + - If a punctuation mark is not the same as in the script, change it to match the script. However, there should still be punctioation marks. They do not count in the one word per "text" field rule. + - Turn any number or symbol that is spelled out into its numerical or symbolic representation (ex. "one" -> "1", "percent" -> "%", "dollar" -> "$", etc.) + - Add capitalization at the beginning of each SENTENCE if missing (not each "text tag, only when multile tags form a sentence !!!") but do not create or infer sentences. Only if a sentence that is already there is not capitalized, you should capitalize it. + - You are NOT allowed to change the timestamps at any cost, nor to reorganize the captions in any way. Your sole role is to fix transcription errors. Nothing else. + - You should not add new words. If a sentence feels wrong in the original script, you should not change it, but keep it as is, and if needed make the captions match the script, even if the script does not feel correct. + The response format should be a json object as follows: + { + "captions": [ + { + "start": 0, + "end": 1000.000, + "text": "This is the first caption." + }, + { + "start": 1000.000, + "end": 2000.023, + "text": "This is the second caption." + }, + etc...] + } +chat: |- + {script} + + {captions} + + Remember that each "text" field should contain ONLY ONE WORD and should be changed ONLY IF NEEDED, else just copy pasted as is with no changes, nor any changes in the timestamps! ans the "text" fiels should NEVER BE a full sentence. The transcript is made to be precise at the word level, so you should not change the words, or it will be pointless. \ No newline at end of file diff --git a/src/engines/__init__.py b/src/engines/__init__.py index a2f548e..4cfce71 100644 --- a/src/engines/__init__.py +++ b/src/engines/__init__.py @@ -1,14 +1,39 @@ +from typing import TypedDict from .BaseEngine import BaseEngine +from .NoneEngine import NoneEngine from . import TTSEngine from . import ScriptEngine from . import LLMEngine +from . import CaptioningEngine -ENGINES = { - "SimpleLLMEngine": [LLMEngine.OpenaiLLMEngine, LLMEngine.AnthropicLLMEngine], - "PowerfulLLMEngine": [LLMEngine.OpenaiLLMEngine, LLMEngine.AnthropicLLMEngine], - "TTSEngine": [TTSEngine.CoquiTTSEngine, TTSEngine.ElevenLabsTTSEngine], - "ScriptEngine": [ - ScriptEngine.ShowerThoughtsScriptEngine, - ScriptEngine.CustomScriptEngine, - ], + +class EngineDict(TypedDict): + classes: list[BaseEngine] + multiple: bool + + +ENGINES: dict[str, EngineDict] = { + "SimpleLLMEngine": { + "classes": [LLMEngine.OpenaiLLMEngine, LLMEngine.AnthropicLLMEngine], + "multiple": False, + }, + "PowerfulLLMEngine": { + "classes": [LLMEngine.OpenaiLLMEngine, LLMEngine.AnthropicLLMEngine], + "multiple": False, + }, + "ScriptEngine": { + "classes": [ + ScriptEngine.ShowerThoughtsScriptEngine, + ScriptEngine.CustomScriptEngine, + ], + "multiple": False, + }, + "TTSEngine": { + "classes": [TTSEngine.CoquiTTSEngine, TTSEngine.ElevenLabsTTSEngine], + "multiple": False, + }, + "CaptioningEngine": { + "classes": [CaptioningEngine.SimpleCaptioningEngine, NoneEngine], + "multiple": False, + }, } diff --git a/src/utils/prompting.py b/src/utils/prompting.py index 5bcd774..06c9e1a 100644 --- a/src/utils/prompting.py +++ b/src/utils/prompting.py @@ -8,10 +8,15 @@ class Prompt(TypedDict): chat: str -def get_prompt(name, *, location="src/chore/prompts") -> tuple[str, str]: - path = os.path.join(os.getcwd(), location, f"{name}.yaml") +def get_prompt(name, *, location: str="src/chore/prompts", by_file_location: str = None) -> tuple[str, str]: + if by_file_location: + path = os.path.join( + os.path.dirname(os.path.abspath(by_file_location)), "prompts", f"{name}.yaml" + ) + else: + path = os.path.join(os.getcwd(), location, f"{name}.yaml") if not os.path.exists(path): raise FileNotFoundError(f"Prompt file {path} does not exist.") with open(path, "r") as file: prompt: Prompt = yaml.safe_load(file) - return prompt["system"], prompt["chat"] \ No newline at end of file + return prompt["system"], prompt["chat"]