From 280df1dd67cba0262d10775553ed8903aa4c136f Mon Sep 17 00:00:00 2001 From: Paillat Date: Tue, 20 Feb 2024 14:47:54 +0100 Subject: [PATCH] Formatting --- .../AssetsEngine/AssetsEngineSelector.py | 16 ++++--- src/engines/AssetsEngine/BaseAssetsEngine.py | 1 - src/engines/AssetsEngine/__init__.py | 2 +- .../SimpleCaptioningEngine.py | 45 ++++++++++++------- src/engines/SettingsEngine/SettingsEngine.py | 9 ++-- src/engines/SettingsEngine/__init__.py | 2 +- src/engines/TTSEngine/BaseTTSEngine.py | 4 +- src/utils/prompting.py | 10 +++-- 8 files changed, 58 insertions(+), 31 deletions(-) diff --git a/src/engines/AssetsEngine/AssetsEngineSelector.py b/src/engines/AssetsEngine/AssetsEngineSelector.py index b01cd40..632caac 100644 --- a/src/engines/AssetsEngine/AssetsEngineSelector.py +++ b/src/engines/AssetsEngine/AssetsEngineSelector.py @@ -2,6 +2,8 @@ import json from ...utils.prompting import get_prompt from ...chore import GenerationContext + + class AssetsEngineSelector: def __init__(self): self.ctx: GenerationContext @@ -9,12 +11,16 @@ class AssetsEngineSelector: def get_assets(self): system_prompt, chat_prompt = get_prompt("assets", by_file_location=__file__) engines_descriptors = "" - + for engine in self.ctx.assetsengine: - engines_descriptors += f"name: '{engine.name}'\n{json.dumps(engine.specification)}\n" - + engines_descriptors += ( + f"name: '{engine.name}'\n{json.dumps(engine.specification)}\n" + ) + system_prompt = system_prompt.replace("{engines}", engines_descriptors) - chat_prompt = chat_prompt.replace("{caption}", json.dumps(self.ctx.timed_script)) + chat_prompt = chat_prompt.replace( + "{caption}", json.dumps(self.ctx.timed_script) + ) assets = self.ctx.powerfulllmengine.generate( system_prompt=system_prompt, @@ -27,4 +33,4 @@ class AssetsEngineSelector: assets_opts = [asset for asset in assets if asset["engine"] == engine.name] assets_opts = [asset["args"] for asset in assets_opts] clips.extend(engine.get_assets(assets_opts)) - return clips \ No newline at end of file + return clips diff --git a/src/engines/AssetsEngine/BaseAssetsEngine.py b/src/engines/AssetsEngine/BaseAssetsEngine.py index 4b7f95b..9079c76 100644 --- a/src/engines/AssetsEngine/BaseAssetsEngine.py +++ b/src/engines/AssetsEngine/BaseAssetsEngine.py @@ -4,7 +4,6 @@ from typing import TypedDict from moviepy.editor import ImageClip, VideoFileClip - class BaseAssetsEngine(BaseEngine): """ The base class for all assets engines. diff --git a/src/engines/AssetsEngine/__init__.py b/src/engines/AssetsEngine/__init__.py index c714cff..b626d98 100644 --- a/src/engines/AssetsEngine/__init__.py +++ b/src/engines/AssetsEngine/__init__.py @@ -1,3 +1,3 @@ from .BaseAssetsEngine import BaseAssetsEngine from .DallEAssetsEngine import DallEAssetsEngine -from .AssetsEngineSelector import AssetsEngineSelector \ No newline at end of file +from .AssetsEngineSelector import AssetsEngineSelector diff --git a/src/engines/CaptioningEngine/SimpleCaptioningEngine.py b/src/engines/CaptioningEngine/SimpleCaptioningEngine.py index bae0fcc..60d7958 100644 --- a/src/engines/CaptioningEngine/SimpleCaptioningEngine.py +++ b/src/engines/CaptioningEngine/SimpleCaptioningEngine.py @@ -17,23 +17,30 @@ class SimpleCaptioningEngine(BaseCaptioningEngine): self.stroke_color = options[4] super().__init__() + def build_caption_object(self, text: str, start: float, end: float) -> TextClip: - return TextClip( - text, - fontsize=self.font_size, - color=self.font_color, - font=self.font, - stroke_color=self.stroke_color, - stroke_width=self.stroke_width, - method="caption", - size=(self.ctx.width /3 * 2, None), - ).set_position(('center', 0.65), relative=True).set_start(start).set_duration(end - start) + return ( + TextClip( + text, + fontsize=self.font_size, + color=self.font_color, + font=self.font, + stroke_color=self.stroke_color, + stroke_width=self.stroke_width, + method="caption", + size=(self.ctx.width / 3 * 2, None), + ) + .set_position(("center", 0.65), relative=True) + .set_start(start) + .set_duration(end - start) + ) + def ends_with_punctuation(self, text: str) -> bool: punctuations = (".", "?", "!", ",", ":", ";") return text.strip().endswith(tuple(punctuations)) def get_captions(self) -> list[TextClip]: - #3 words per 1000 px, we do the math + # 3 words per 1000 px, we do the math max_words = int(self.ctx.width / 1000 * 3) clips = [] @@ -51,7 +58,11 @@ class SimpleCaptioningEngine(BaseCaptioningEngine): pause = self.ends_with_punctuation(current_line.strip()) if len(line_with_new_word.split(" ")) > max_words or pause: - clips.append(self.build_caption_object(current_line.strip(), current_start, current_end)) + clips.append( + self.build_caption_object( + current_line.strip(), current_start, current_end + ) + ) current_line = word["text"] # Start a new line with the current word current_start = word["start"] current_end = word["end"] @@ -62,7 +73,9 @@ class SimpleCaptioningEngine(BaseCaptioningEngine): # Don't forget to add the last line if it exists if current_line: clips.append( - self.build_caption_object(current_line.strip(), current_start, words[-1]["end"]) + self.build_caption_object( + current_line.strip(), current_start, words[-1]["end"] + ) ) return clips @@ -73,7 +86,7 @@ class SimpleCaptioningEngine(BaseCaptioningEngine): with gr.Group(): font = gr.Dropdown( label="Font", - choices=TextClip.list('font'), + choices=TextClip.list("font"), value="Arial", ) font_size = gr.Number( @@ -93,5 +106,7 @@ class SimpleCaptioningEngine(BaseCaptioningEngine): step=1, value=6, ) - font_stroke_color = gr.ColorPicker(label="Stroke Color", value="#000000") + font_stroke_color = gr.ColorPicker( + label="Stroke Color", value="#000000" + ) return [font, font_size, font_stroke_width, font_color, font_stroke_color] diff --git a/src/engines/SettingsEngine/SettingsEngine.py b/src/engines/SettingsEngine/SettingsEngine.py index 69f5688..024e498 100644 --- a/src/engines/SettingsEngine/SettingsEngine.py +++ b/src/engines/SettingsEngine/SettingsEngine.py @@ -16,10 +16,13 @@ class SettingsEngine(BaseEngine): def load(self): self.ctx.width = self.width self.ctx.height = self.height + @classmethod def get_options(cls): - #minimum is 720p, maximum is 4k, default is portrait hd + # minimum is 720p, maximum is 4k, default is portrait hd width = gr.Number(value=1080, minimum=720, maximum=3840, label="Width", step=1) - height = gr.Number(value=1920, minimum=720, maximum=3840, label="Height", step=1) + height = gr.Number( + value=1920, minimum=720, maximum=3840, label="Height", step=1 + ) - return [width, height] \ No newline at end of file + return [width, height] diff --git a/src/engines/SettingsEngine/__init__.py b/src/engines/SettingsEngine/__init__.py index e87071d..e9f74ba 100644 --- a/src/engines/SettingsEngine/__init__.py +++ b/src/engines/SettingsEngine/__init__.py @@ -1 +1 @@ -from .SettingsEngine import SettingsEngine \ No newline at end of file +from .SettingsEngine import SettingsEngine diff --git a/src/engines/TTSEngine/BaseTTSEngine.py b/src/engines/TTSEngine/BaseTTSEngine.py index 6d883f3..3e7aa1e 100644 --- a/src/engines/TTSEngine/BaseTTSEngine.py +++ b/src/engines/TTSEngine/BaseTTSEngine.py @@ -18,7 +18,7 @@ class BaseTTSEngine(BaseEngine): @abstractmethod def synthesize(self, text: str, path: str) -> list[Word]: pass - + def remove_punctuation(self, text: str) -> str: return text.translate(str.maketrans("", "", ".,!?;:")) @@ -31,7 +31,7 @@ class BaseTTSEngine(BaseEngine): if stt_word in original_word: captions[i]["text"] = word new_captions.append(captions[i]) - #elif there is a word more in the stt than in the original, we + # elif there is a word more in the stt than in the original, we def time_with_whisper(self, path: str) -> list[Word]: """ diff --git a/src/utils/prompting.py b/src/utils/prompting.py index 06c9e1a..72c9257 100644 --- a/src/utils/prompting.py +++ b/src/utils/prompting.py @@ -8,11 +8,15 @@ class Prompt(TypedDict): chat: str -def get_prompt(name, *, location: str="src/chore/prompts", by_file_location: str = None) -> tuple[str, str]: +def get_prompt( + name, *, location: str = "src/chore/prompts", by_file_location: str = None +) -> tuple[str, str]: if by_file_location: path = os.path.join( - os.path.dirname(os.path.abspath(by_file_location)), "prompts", f"{name}.yaml" - ) + os.path.dirname(os.path.abspath(by_file_location)), + "prompts", + f"{name}.yaml", + ) else: path = os.path.join(os.getcwd(), location, f"{name}.yaml") if not os.path.exists(path):