From 2ccd8023aa5b96b5ad8d3ae655f2dae497417c39 Mon Sep 17 00:00:00 2001 From: frederik-uni <147479464+frederik-uni@users.noreply.github.com> Date: Sun, 17 Nov 2024 00:40:57 +0100 Subject: [PATCH] split manga_translator.py into multiple files --- manga_translator/__main__.py | 6 +- manga_translator/manga_translator.py | 891 +-------------------------- manga_translator/mode/api.py | 290 +++++++++ manga_translator/mode/local.py | 201 ++++++ manga_translator/mode/web.py | 151 +++++ manga_translator/mode/ws.py | 264 ++++++++ 6 files changed, 916 insertions(+), 887 deletions(-) create mode 100644 manga_translator/mode/api.py create mode 100644 manga_translator/mode/local.py create mode 100644 manga_translator/mode/web.py create mode 100644 manga_translator/mode/ws.py diff --git a/manga_translator/__main__.py b/manga_translator/__main__.py index c4f58003..c0e01e50 100644 --- a/manga_translator/__main__.py +++ b/manga_translator/__main__.py @@ -5,9 +5,6 @@ from .manga_translator import ( MangaTranslator, - MangaTranslatorWeb, - MangaTranslatorWS, - MangaTranslatorAPI, set_main_logger, ) from .args import parser @@ -71,14 +68,17 @@ async def dispatch(args: Namespace): await dispatch(args.host, args.port, translation_params=args_dict) elif args.mode == 'web_client': + from manga_translator.mode.web import MangaTranslatorWeb translator = MangaTranslatorWeb(args_dict) await translator.listen(args_dict) elif args.mode == 'ws': + from manga_translator.mode.ws import MangaTranslatorWS translator = MangaTranslatorWS(args_dict) await translator.listen(args_dict) elif args.mode == 'api': + from manga_translator.mode.api import MangaTranslatorAPI translator = MangaTranslatorAPI(args_dict) await translator.listen(args_dict) diff --git a/manga_translator/manga_translator.py b/manga_translator/manga_translator.py index 85dbe430..cb17850e 100644 --- a/manga_translator/manga_translator.py +++ b/manga_translator/manga_translator.py @@ -1,65 +1,44 @@ -import asyncio -import base64 -import io - import cv2 -from aiohttp.web_middlewares import middleware from omegaconf import OmegaConf import langcodes import langdetect -import requests import os import re import torch -import time import logging import numpy as np from PIL import Image -from typing import List, Tuple, Union -from aiohttp import web -from marshmallow import Schema, fields, ValidationError +from typing import Union -from manga_translator.utils.threading import Throttler -from .args import DEFAULT_ARGS, translator_chain +from .args import DEFAULT_ARGS from .utils import ( BASE_PATH, LANGUAGE_ORIENTATION_PRESETS, ModelWrapper, Context, - PriorityLock, load_image, dump_image, - replace_prefix, visualize_textblocks, - add_file_logger, - remove_file_logger, is_valuable_text, - rgb2hex, hex2rgb, - get_color_name, - natural_sort, sort_regions, ) -from .detection import DETECTORS, dispatch as dispatch_detection, prepare as prepare_detection +from .detection import dispatch as dispatch_detection, prepare as prepare_detection from .upscaling import dispatch as dispatch_upscaling, prepare as prepare_upscaling, UPSCALERS -from .ocr import OCRS, dispatch as dispatch_ocr, prepare as prepare_ocr +from .ocr import dispatch as dispatch_ocr, prepare as prepare_ocr from .textline_merge import dispatch as dispatch_textline_merge from .mask_refinement import dispatch as dispatch_mask_refinement -from .inpainting import INPAINTERS, dispatch as dispatch_inpainting, prepare as prepare_inpainting +from .inpainting import dispatch as dispatch_inpainting, prepare as prepare_inpainting from .translators import ( - TRANSLATORS, - VALID_LANGUAGES, LANGDETECT_MAP, - LanguageUnsupportedException, TranslatorChain, dispatch as dispatch_translation, prepare as prepare_translation, ) from .colorization import dispatch as dispatch_colorization, prepare as prepare_colorization from .rendering import dispatch as dispatch_rendering, dispatch_eng_render -from .save import save_result # Will be overwritten by __main__.py if module is being run directly (with python -m) logger = logging.getLogger('manga_translator') @@ -78,7 +57,7 @@ class TranslationInterrupt(Exception): pass -class MangaTranslator(): +class MangaTranslator: def __init__(self, params: dict = None): self._progress_hooks = [] @@ -117,158 +96,6 @@ def parse_init_params(self, params: dict): def using_gpu(self): return self.device.startswith('cuda') or self.device == 'mps' - async def translate_path(self, path: str, dest: str = None, params: dict[str, Union[int, str]] = None): - """ - Translates an image or folder (recursively) specified through the path. - """ - if not os.path.exists(path): - raise FileNotFoundError(path) - path = os.path.abspath(os.path.expanduser(path)) - dest = os.path.abspath(os.path.expanduser(dest)) if dest else '' - params = params or {} - - # Handle format - file_ext = params.get('format') - if params.get('save_quality', 100) < 100: - if not params.get('format'): - file_ext = 'jpg' - elif params.get('format') != 'jpg': - raise ValueError('--save-quality of lower than 100 is only supported for .jpg files') - - if os.path.isfile(path): - # Determine destination file path - if not dest: - # Use the same folder as the source - p, ext = os.path.splitext(path) - _dest = f'{p}-translated.{file_ext or ext[1:]}' - elif not os.path.basename(dest): - p, ext = os.path.splitext(os.path.basename(path)) - # If the folders differ use the original filename from the source - if os.path.dirname(path) != dest: - _dest = os.path.join(dest, f'{p}.{file_ext or ext[1:]}') - else: - _dest = os.path.join(dest, f'{p}-translated.{file_ext or ext[1:]}') - else: - p, ext = os.path.splitext(dest) - _dest = f'{p}.{file_ext or ext[1:]}' - await self.translate_file(path, _dest, params) - - elif os.path.isdir(path): - # Determine destination folder path - if path[-1] == '\\' or path[-1] == '/': - path = path[:-1] - _dest = dest or path + '-translated' - if os.path.exists(_dest) and not os.path.isdir(_dest): - raise FileExistsError(_dest) - - translated_count = 0 - for root, subdirs, files in os.walk(path): - files = natural_sort(files) - dest_root = replace_prefix(root, path, _dest) - os.makedirs(dest_root, exist_ok=True) - for f in files: - if f.lower() == '.thumb': - continue - - file_path = os.path.join(root, f) - output_dest = replace_prefix(file_path, path, _dest) - p, ext = os.path.splitext(output_dest) - output_dest = f'{p}.{file_ext or ext[1:]}' - - if await self.translate_file(file_path, output_dest, params): - translated_count += 1 - if translated_count == 0: - logger.info('No further untranslated files found. Use --overwrite to write over existing translations.') - else: - logger.info(f'Done. Translated {translated_count} image{"" if translated_count == 1 else "s"}') - - async def translate_file(self, path: str, dest: str, params: dict): - if not params.get('overwrite') and os.path.exists(dest): - logger.info( - f'Skipping as already translated: "{dest}". Use --overwrite to overwrite existing translations.') - await self._report_progress('saved', True) - return True - - logger.info(f'Translating: "{path}"') - - # Turn dict to context to make values also accessible through params. - params = params or {} - ctx = Context(**params) - self._preprocess_params(ctx) - - attempts = 0 - while ctx.attempts == -1 or attempts < ctx.attempts + 1: - if attempts > 0: - logger.info(f'Retrying translation! Attempt {attempts}' - + (f' of {ctx.attempts}' if ctx.attempts != -1 else '')) - try: - return await self._translate_file(path, dest, ctx) - - except TranslationInterrupt: - break - except Exception as e: - if isinstance(e, LanguageUnsupportedException): - await self._report_progress('error-lang', True) - else: - await self._report_progress('error', True) - if not self.ignore_errors and not (ctx.attempts == -1 or attempts < ctx.attempts): - raise - else: - logger.error(f'{e.__class__.__name__}: {e}', - exc_info=e if self.verbose else None) - attempts += 1 - return False - - async def _translate_file(self, path: str, dest: str, ctx: Context) -> bool: - if path.endswith('.txt'): - with open(path, 'r') as f: - queries = f.read().split('\n') - translated_sentences = \ - await dispatch_translation(ctx.translator, queries, ctx.use_mtpe, ctx, - 'cpu' if self._gpu_limited_memory else self.device) - p, ext = os.path.splitext(dest) - if ext != '.txt': - dest = p + '.txt' - logger.info(f'Saving "{dest}"') - with open(dest, 'w') as f: - f.write('\n'.join(translated_sentences)) - return True - - # TODO: Add .gif handler - - else: # Treat as image - try: - img = Image.open(path) - img.verify() - img = Image.open(path) - except Exception: - logger.warn(f'Failed to open image: {path}') - return False - - ctx = await self.translate(img, ctx) - result = ctx.result - - # Save result - if ctx.skip_no_text and not ctx.text_regions: - logger.debug('Not saving due to --skip-no-text') - return True - if result: - logger.info(f'Saving "{dest}"') - save_result(result, dest, ctx) - await self._report_progress('saved', True) - - if ctx.save_text or ctx.save_text_file or ctx.prep_manual: - if ctx.prep_manual: - # Save original image next to translated - p, ext = os.path.splitext(dest) - img_filename = p + '-orig' + ext - img_path = os.path.join(os.path.dirname(dest), img_filename) - img.save(img_path, quality=ctx.save_quality) - if ctx.text_regions: - self._save_text_to_file(path, ctx) - return True - return False - async def translate(self, image: Image.Image, params: Union[dict, Context] = None) -> Context: """ Translates a PIL image from a manga. Returns dict with result and intermediates of translation. @@ -771,708 +598,4 @@ async def ph(state, finished): elif state in LOG_MESSAGES_ERROR: logger.error(LOG_MESSAGES_ERROR[state]) - self.add_progress_hook(ph) - - def _save_text_to_file(self, image_path: str, ctx: Context): - cached_colors = [] - - def identify_colors(fg_rgb: List[int]): - idx = 0 - for rgb, _ in cached_colors: - # If similar color already saved - if abs(rgb[0] - fg_rgb[0]) + abs(rgb[1] - fg_rgb[1]) + abs(rgb[2] - fg_rgb[2]) < 50: - break - else: - idx += 1 - else: - cached_colors.append((fg_rgb, get_color_name(fg_rgb))) - return idx + 1, cached_colors[idx][1] - - s = f'\n[{image_path}]\n' - for i, region in enumerate(ctx.text_regions): - fore, back = region.get_font_colors() - color_id, color_name = identify_colors(fore) - - s += f'\n-- {i + 1} --\n' - s += f'color: #{color_id}: {color_name} (fg, bg: {rgb2hex(*fore)} {rgb2hex(*back)})\n' - s += f'text: {region.text}\n' - s += f'trans: {region.translation}\n' - for line in region.lines: - s += f'coords: {list(line.ravel())}\n' - s += '\n' - - text_output_file = ctx.text_output_file - if not text_output_file: - text_output_file = os.path.splitext(image_path)[0] + '_translations.txt' - - with open(text_output_file, 'a', encoding='utf-8') as f: - f.write(s) - - -class MangaTranslatorWeb(MangaTranslator): - """ - Translator client that executes tasks on behalf of the webserver in web_main.py. - """ - - def __init__(self, params: dict = None): - super().__init__(params) - self.host = params.get('host', '127.0.0.1') - if self.host == '0.0.0.0': - self.host = '127.0.0.1' - self.port = params.get('port', 5003) - self.nonce = params.get('nonce', '') - self.ignore_errors = params.get('ignore_errors', True) - self._task_id = None - self._params = None - - async def _init_connection(self): - available_translators = [] - from .translators import MissingAPIKeyException, get_translator - for key in TRANSLATORS: - try: - get_translator(key) - available_translators.append(key) - except MissingAPIKeyException: - pass - - data = { - 'nonce': self.nonce, - 'capabilities': { - 'translators': available_translators, - }, - } - requests.post(f'http://{self.host}:{self.port}/connect-internal', json=data) - - async def _send_state(self, state: str, finished: bool): - # wait for translation to be saved first (bad solution?) - finished = finished and not state == 'finished' - while True: - try: - data = { - 'task_id': self._task_id, - 'nonce': self.nonce, - 'state': state, - 'finished': finished, - } - requests.post(f'http://{self.host}:{self.port}/task-update-internal', json=data, timeout=20) - break - except Exception: - # if translation is finished server has to know - if finished: - continue - else: - break - - def _get_task(self): - try: - rjson = requests.get(f'http://{self.host}:{self.port}/task-internal?nonce={self.nonce}', - timeout=3600).json() - return rjson.get('task_id'), rjson.get('data') - except Exception: - return None, None - - async def listen(self, translation_params: dict = None): - """ - Listens for translation tasks from web server. - """ - logger.info('Waiting for translation tasks') - - await self._init_connection() - self.add_progress_hook(self._send_state) - - while True: - self._task_id, self._params = self._get_task() - if self._params and 'exit' in self._params: - break - if not (self._task_id and self._params): - await asyncio.sleep(0.1) - continue - - self.result_sub_folder = self._task_id - logger.info(f'Processing task {self._task_id}') - if translation_params is not None: - # Combine default params with params chosen by webserver - for p, default_value in translation_params.items(): - current_value = self._params.get(p) - self._params[p] = current_value if current_value is not None else default_value - if self.verbose: - # Write log file - log_file = self._result_path('log.txt') - add_file_logger(log_file) - - # final.png will be renamed if format param is set - await self.translate_path(self._result_path('input.png'), self._result_path('final.png'), - params=self._params) - print() - - if self.verbose: - remove_file_logger(log_file) - self._task_id = None - self._params = None - self.result_sub_folder = '' - - async def _run_text_translation(self, ctx: Context): - # Run machine translation as reference for manual translation (if `--translator=none` is not set) - text_regions = await super()._run_text_translation(ctx) - - if ctx.get('manual', False): - logger.info('Waiting for user input from manual translation') - requests.post(f'http://{self.host}:{self.port}/request-manual-internal', json={ - 'task_id': self._task_id, - 'nonce': self.nonce, - 'texts': [r.text for r in text_regions], - 'translations': [r.translation for r in text_regions], - }, timeout=20) - - # wait for at most 1 hour for manual translation - wait_until = time.time() + 3600 - while time.time() < wait_until: - ret = requests.post(f'http://{self.host}:{self.port}/get-manual-result-internal', json={ - 'task_id': self._task_id, - 'nonce': self.nonce - }, timeout=20).json() - if 'result' in ret: - manual_translations = ret['result'] - if isinstance(manual_translations, str): - if manual_translations == 'error': - return [] - i = 0 - for translation in manual_translations: - if not translation.strip(): - text_regions.pop(i) - i = i - 1 - else: - text_regions[i].translation = translation - text_regions[i].target_lang = ctx.translator.langs[-1] - i = i + 1 - break - elif 'cancel' in ret: - return 'cancel' - await asyncio.sleep(0.1) - return text_regions - - -class MangaTranslatorWS(MangaTranslator): - def __init__(self, params: dict = None): - super().__init__(params) - self.url = params.get('ws_url') - self.secret = params.get('ws_secret', os.getenv('WS_SECRET', '')) - self.ignore_errors = params.get('ignore_errors', True) - - self._task_id = None - self._websocket = None - - async def listen(self, translation_params: dict = None): - from threading import Thread - import io - import aioshutil - from aiofiles import os - import websockets - from .server import ws_pb2 - - self._server_loop = asyncio.new_event_loop() - self.task_lock = PriorityLock() - self.counter = 0 - - async def _send_and_yield(websocket, msg): - # send message and yield control to the event loop (to actually send the message) - await websocket.send(msg) - await asyncio.sleep(0) - - send_throttler = Throttler(0.2) - send_and_yield = send_throttler.wrap(_send_and_yield) - - async def sync_state(state, finished): - if self._websocket is None: - return - msg = ws_pb2.WebSocketMessage() - msg.status.id = self._task_id - msg.status.status = state - self._server_loop.call_soon_threadsafe( - asyncio.create_task, - send_and_yield(self._websocket, msg.SerializeToString()) - ) - - self.add_progress_hook(sync_state) - - async def translate(task_id, websocket, image, params): - async with self.task_lock((1 << 31) - params['ws_count']): - self._task_id = task_id - self._websocket = websocket - result = await self.translate(image, params) - self._task_id = None - self._websocket = None - return result - - async def server_send_status(websocket, task_id, status): - msg = ws_pb2.WebSocketMessage() - msg.status.id = task_id - msg.status.status = status - await websocket.send(msg.SerializeToString()) - await asyncio.sleep(0) - - async def server_process_inner(main_loop, logger_task, session, websocket, task) -> Tuple[bool, bool]: - logger_task.info(f'-- Processing task {task.id}') - await server_send_status(websocket, task.id, 'pending') - - if self.verbose: - await aioshutil.rmtree(f'result/{task.id}', ignore_errors=True) - await os.makedirs(f'result/{task.id}', exist_ok=True) - - params = { - 'target_lang': task.target_language, - 'skip_lang': task.skip_language, - 'detector': task.detector, - 'direction': task.direction, - 'translator': task.translator, - 'size': task.size, - 'ws_event_loop': asyncio.get_event_loop(), - 'ws_count': self.counter, - } - self.counter += 1 - - logger_task.info(f'-- Downloading image from {task.source_image}') - await server_send_status(websocket, task.id, 'downloading') - async with session.get(task.source_image) as resp: - if resp.status == 200: - source_image = await resp.read() - else: - msg = ws_pb2.WebSocketMessage() - msg.status.id = task.id - msg.status.status = 'error-download' - await websocket.send(msg.SerializeToString()) - await asyncio.sleep(0) - return False, False - - logger_task.info(f'-- Translating image') - if translation_params: - for p, default_value in translation_params.items(): - current_value = params.get(p) - params[p] = current_value if current_value is not None else default_value - - image = Image.open(io.BytesIO(source_image)) - - (ori_w, ori_h) = image.size - if max(ori_h, ori_w) > 1200: - params['upscale_ratio'] = 1 - - await server_send_status(websocket, task.id, 'preparing') - # translation_dict = await self.translate(image, params) - translation_dict = await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe( - translate(task.id, websocket, image, params), - main_loop - ) - ) - await send_throttler.flush() - - output: Image.Image = translation_dict.result - if output is not None: - await server_send_status(websocket, task.id, 'saving') - - output = output.resize((ori_w, ori_h), resample=Image.LANCZOS) - - img = io.BytesIO() - output.save(img, format='PNG') - if self.verbose: - output.save(self._result_path('ws_final.png')) - - img_bytes = img.getvalue() - logger_task.info(f'-- Uploading result to {task.translation_mask}') - await server_send_status(websocket, task.id, 'uploading') - async with session.put(task.translation_mask, data=img_bytes) as resp: - if resp.status != 200: - logger_task.error(f'-- Failed to upload result:') - logger_task.error(f'{resp.status}: {resp.reason}') - msg = ws_pb2.WebSocketMessage() - msg.status.id = task.id - msg.status.status = 'error-upload' - await websocket.send(msg.SerializeToString()) - await asyncio.sleep(0) - return False, False - - return True, output is not None - - async def server_process(main_loop, session, websocket, task) -> bool: - logger_task = logger.getChild(f'{task.id}') - try: - (success, has_translation_mask) = await server_process_inner(main_loop, logger_task, session, websocket, - task) - except Exception as e: - logger_task.error(f'-- Task failed with exception:') - logger_task.error(f'{e.__class__.__name__}: {e}', exc_info=e if self.verbose else None) - (success, has_translation_mask) = False, False - finally: - result = ws_pb2.WebSocketMessage() - result.finish_task.id = task.id - result.finish_task.success = success - result.finish_task.has_translation_mask = has_translation_mask - await websocket.send(result.SerializeToString()) - await asyncio.sleep(0) - logger_task.info(f'-- Task finished') - - async def async_server_thread(main_loop): - from aiohttp import ClientSession, ClientTimeout - timeout = ClientTimeout(total=30) - async with ClientSession(timeout=timeout) as session: - logger_conn = logger.getChild('connection') - if self.verbose: - logger_conn.setLevel(logging.DEBUG) - async for websocket in websockets.connect( - self.url, - extra_headers={ - 'x-secret': self.secret, - }, - max_size=1_000_000, - logger=logger_conn - ): - bg_tasks = set() - try: - logger.info('-- Connected to websocket server') - - async for raw in websocket: - # logger.info(f'Got message: {raw}') - msg = ws_pb2.WebSocketMessage() - msg.ParseFromString(raw) - if msg.WhichOneof('message') == 'new_task': - task = msg.new_task - bg_task = asyncio.create_task(server_process(main_loop, session, websocket, task)) - bg_tasks.add(bg_task) - bg_task.add_done_callback(bg_tasks.discard) - - except Exception as e: - logger.error(f'{e.__class__.__name__}: {e}', exc_info=e if self.verbose else None) - - finally: - logger.info('-- Disconnected from websocket server') - for bg_task in bg_tasks: - bg_task.cancel() - - def server_thread(future, main_loop, server_loop): - asyncio.set_event_loop(server_loop) - try: - server_loop.run_until_complete(async_server_thread(main_loop)) - finally: - future.set_result(None) - - future = asyncio.Future() - Thread( - target=server_thread, - args=(future, asyncio.get_running_loop(), self._server_loop), - daemon=True - ).start() - - # create a future that is never done - await future - - async def _run_text_translation(self, ctx: Context): - coroutine = super()._run_text_translation(ctx) - if ctx.translator.has_offline(): - return await coroutine - else: - task_id = self._task_id - websocket = self._websocket - await self.task_lock.release() - result = await asyncio.wrap_future( - asyncio.run_coroutine_threadsafe( - coroutine, - ctx.ws_event_loop - ) - ) - await self.task_lock.acquire((1 << 30) - ctx.ws_count) - self._task_id = task_id - self._websocket = websocket - return result - - async def _run_text_rendering(self, ctx: Context): - render_mask = (ctx.mask >= 127).astype(np.uint8)[:, :, None] - - output = await super()._run_text_rendering(ctx) - render_mask[np.sum(ctx.img_rgb != output, axis=2) > 0] = 1 - ctx.render_mask = render_mask - if self.verbose: - cv2.imwrite(self._result_path('ws_render_in.png'), cv2.cvtColor(ctx.img_rgb, cv2.COLOR_RGB2BGR)) - cv2.imwrite(self._result_path('ws_render_out.png'), cv2.cvtColor(output, cv2.COLOR_RGB2BGR)) - cv2.imwrite(self._result_path('ws_mask.png'), render_mask * 255) - - # only keep sections in mask - if self.verbose: - cv2.imwrite(self._result_path('ws_inmask.png'), cv2.cvtColor(ctx.img_rgb, cv2.COLOR_RGB2BGRA) * render_mask) - output = cv2.cvtColor(output, cv2.COLOR_RGB2RGBA) * render_mask - if self.verbose: - cv2.imwrite(self._result_path('ws_output.png'), cv2.cvtColor(output, cv2.COLOR_RGBA2BGRA) * render_mask) - - return output - - -# Experimental. May be replaced by a refactored server/web_main.py in the future. -class MangaTranslatorAPI(MangaTranslator): - def __init__(self, params: dict = None): - import nest_asyncio - nest_asyncio.apply() - super().__init__(params) - self.host = params.get('host', '127.0.0.1') - self.port = params.get('port', '5003') - self.log_web = params.get('log_web', False) - self.ignore_errors = params.get('ignore_errors', True) - self._task_id = None - self._params = None - self.params = params - self.queue = [] - - async def wait_queue(self, id: int): - while self.queue[0] != id: - await asyncio.sleep(0.05) - - def remove_from_queue(self, id: int): - self.queue.remove(id) - - def generate_id(self): - try: - x = max(self.queue) - except: - x = 0 - return x + 1 - - def middleware_factory(self): - @middleware - async def sample_middleware(request, handler): - id = self.generate_id() - self.queue.append(id) - try: - await self.wait_queue(id) - except Exception as e: - print(e) - try: - # todo make cancellable - response = await handler(request) - except: - response = web.json_response({'error': "Internal Server Error", 'status': 500}, - status=500) - # Handle cases where a user leaves the queue, request fails, or is completed - try: - self.remove_from_queue(id) - except Exception as e: - print(e) - return response - - return sample_middleware - - async def get_file(self, image, base64Images, url) -> Image: - if image is not None: - content = image.file.read() - elif base64Images is not None: - base64Images = base64Images - if base64Images.__contains__('base64,'): - base64Images = base64Images.split('base64,')[1] - content = base64.b64decode(base64Images) - elif url is not None: - from aiohttp import ClientSession - async with ClientSession() as session: - async with session.get(url) as resp: - if resp.status == 200: - content = await resp.read() - else: - return web.json_response({'status': 'error'}) - else: - raise ValidationError("donest exist") - img = Image.open(io.BytesIO(content)) - - img.verify() - img = Image.open(io.BytesIO(content)) - if img.width * img.height > 8000 ** 2: - raise ValidationError("to large") - return img - - async def listen(self, translation_params: dict = None): - self.params = translation_params - app = web.Application(client_max_size=1024 * 1024 * 50, middlewares=[self.middleware_factory()]) - - routes = web.RouteTableDef() - run_until_state = '' - - async def hook(state, finished): - if run_until_state and run_until_state == state and not finished: - raise TranslationInterrupt() - - self.add_progress_hook(hook) - - @routes.post("/get_text") - async def text_api(req): - nonlocal run_until_state - run_until_state = 'translating' - return await self.err_handling(self.run_translate, req, self.format_translate) - - @routes.post("/translate") - async def translate_api(req): - nonlocal run_until_state - run_until_state = 'after-translating' - return await self.err_handling(self.run_translate, req, self.format_translate) - - @routes.post("/inpaint_translate") - async def inpaint_translate_api(req): - nonlocal run_until_state - run_until_state = 'rendering' - return await self.err_handling(self.run_translate, req, self.format_translate) - - @routes.post("/colorize_translate") - async def colorize_translate_api(req): - nonlocal run_until_state - run_until_state = 'rendering' - return await self.err_handling(self.run_translate, req, self.format_translate, True) - - # #@routes.post("/file") - # async def file_api(req): - # #TODO: return file - # return await self.err_handling(self.file_exec, req, None) - - app.add_routes(routes) - web.run_app(app, host=self.host, port=self.port) - - async def run_translate(self, translation_params, img): - return await self.translate(img, translation_params) - - async def err_handling(self, func, req, format, ri=False): - try: - if req.content_type == 'application/json' or req.content_type == 'multipart/form-data': - if req.content_type == 'application/json': - d = await req.json() - else: - d = await req.post() - schema = self.PostSchema() - data = schema.load(d) - if 'translator_chain' in data: - data['translator_chain'] = translator_chain(data['translator_chain']) - if 'selective_translation' in data: - data['selective_translation'] = translator_chain(data['selective_translation']) - ctx = Context(**dict(self.params, **data)) - self._preprocess_params(ctx) - if data.get('image') is None and data.get('base64Images') is None and data.get('url') is None: - return web.json_response({'error': "Missing input", 'status': 422}) - fil = await self.get_file(data.get('image'), data.get('base64Images'), data.get('url')) - if 'image' in data: - del data['image'] - if 'base64Images' in data: - del data['base64Images'] - if 'url' in data: - del data['url'] - attempts = 0 - while ctx.attempts == -1 or attempts <= ctx.attempts: - if attempts > 0: - logger.info(f'Retrying translation! Attempt {attempts}' + ( - f' of {ctx.attempts}' if ctx.attempts != -1 else '')) - try: - await func(ctx, fil) - break - except TranslationInterrupt: - break - except Exception as e: - print(e) - attempts += 1 - if ctx.attempts != -1 and attempts > ctx.attempts: - return web.json_response({'error': "Internal Server Error", 'status': 500}, - status=500) - try: - return format(ctx, ri) - except Exception as e: - print(e) - return web.json_response({'error': "Failed to format", 'status': 500}, - status=500) - else: - return web.json_response({'error': "Wrong content type: " + req.content_type, 'status': 415}, - status=415) - except ValueError as e: - print(e) - return web.json_response({'error': "Wrong input type", 'status': 422}, status=422) - - except ValidationError as e: - print(e) - return web.json_response({'error': "Input invalid", 'status': 422}, status=422) - - def format_translate(self, ctx: Context, return_image: bool): - text_regions = ctx.text_regions - inpaint = ctx.img_inpainted - results = [] - if 'overlay_ext' in ctx: - overlay_ext = ctx['overlay_ext'] - else: - overlay_ext = 'jpg' - for i, blk in enumerate(text_regions): - minX, minY, maxX, maxY = blk.xyxy - if 'translations' in ctx: - trans = {key: value[i] for key, value in ctx['translations'].items()} - else: - trans = {} - trans["originalText"] = text_regions[i].text - if inpaint is not None: - overlay = inpaint[minY:maxY, minX:maxX] - - retval, buffer = cv2.imencode('.' + overlay_ext, overlay) - jpg_as_text = base64.b64encode(buffer) - background = "data:image/" + overlay_ext + ";base64," + jpg_as_text.decode("utf-8") - else: - background = None - text_region = text_regions[i] - text_region.adjust_bg_color = False - color1, color2 = text_region.get_font_colors() - - results.append({ - 'text': trans, - 'minX': int(minX), - 'minY': int(minY), - 'maxX': int(maxX), - 'maxY': int(maxY), - 'textColor': { - 'fg': color1.tolist(), - 'bg': color2.tolist() - }, - 'language': text_regions[i].source_lang, - 'background': background - }) - if return_image and ctx.img_colorized is not None: - retval, buffer = cv2.imencode('.' + overlay_ext, np.array(ctx.img_colorized)) - jpg_as_text = base64.b64encode(buffer) - img = "data:image/" + overlay_ext + ";base64," + jpg_as_text.decode("utf-8") - else: - img = None - return web.json_response({'details': results, 'img': img}) - - class PostSchema(Schema): - target_lang = fields.Str(required=False, validate=lambda a: a.upper() in VALID_LANGUAGES) - detector = fields.Str(required=False, validate=lambda a: a.lower() in DETECTORS) - ocr = fields.Str(required=False, validate=lambda a: a.lower() in OCRS) - inpainter = fields.Str(required=False, validate=lambda a: a.lower() in INPAINTERS) - upscaler = fields.Str(required=False, validate=lambda a: a.lower() in UPSCALERS) - translator = fields.Str(required=False, validate=lambda a: a.lower() in TRANSLATORS) - direction = fields.Str(required=False, validate=lambda a: a.lower() in {'auto', 'h', 'v'}) - skip_language = fields.Str(required=False) - upscale_ratio = fields.Integer(required=False) - translator_chain = fields.Str(required=False) - selective_translation = fields.Str(required=False) - attempts = fields.Integer(required=False) - detection_size = fields.Integer(required=False) - text_threshold = fields.Float(required=False) - box_threshold = fields.Float(required=False) - unclip_ratio = fields.Float(required=False) - inpainting_size = fields.Integer(required=False) - det_rotate = fields.Bool(required=False) - det_auto_rotate = fields.Bool(required=False) - det_invert = fields.Bool(required=False) - det_gamma_correct = fields.Bool(required=False) - min_text_length = fields.Integer(required=False) - colorization_size = fields.Integer(required=False) - denoise_sigma = fields.Integer(required=False) - mask_dilation_offset = fields.Integer(required=False) - ignore_bubble = fields.Integer(required=False) - gpt_config = fields.String(required=False) - filter_text = fields.String(required=False) - - # api specific - overlay_ext = fields.Str(required=False) - base64Images = fields.Raw(required=False) - image = fields.Raw(required=False) - url = fields.Raw(required=False) - - # no functionality except preventing errors when given - fingerprint = fields.Raw(required=False) - clientUuid = fields.Raw(required=False) + self.add_progress_hook(ph) \ No newline at end of file diff --git a/manga_translator/mode/api.py b/manga_translator/mode/api.py new file mode 100644 index 00000000..bfe71256 --- /dev/null +++ b/manga_translator/mode/api.py @@ -0,0 +1,290 @@ +# Experimental. May be replaced by a refactored server/web_main.py in the future. +import asyncio +import base64 +import io + +import cv2 +import numpy as np +from PIL import Image +from aiohttp import web +from aiohttp.web_middlewares import middleware +from marshmallow import fields, Schema, ValidationError + +from manga_translator import MangaTranslator, Context, UPSCALERS, TranslationInterrupt, logger +from manga_translator.args import translator_chain +from manga_translator.detection import DETECTORS +from manga_translator.inpainting import INPAINTERS +from manga_translator.ocr import OCRS +from manga_translator.translators import VALID_LANGUAGES, TRANSLATORS + + +class MangaTranslatorAPI(MangaTranslator): + def __init__(self, params: dict = None): + import nest_asyncio + nest_asyncio.apply() + super().__init__(params) + self.host = params.get('host', '127.0.0.1') + self.port = params.get('port', '5003') + self.log_web = params.get('log_web', False) + self.ignore_errors = params.get('ignore_errors', True) + self._task_id = None + self._params = None + self.params = params + self.queue = [] + + async def wait_queue(self, id: int): + while self.queue[0] != id: + await asyncio.sleep(0.05) + + def remove_from_queue(self, id: int): + self.queue.remove(id) + + def generate_id(self): + try: + x = max(self.queue) + except: + x = 0 + return x + 1 + + def middleware_factory(self): + @middleware + async def sample_middleware(request, handler): + id = self.generate_id() + self.queue.append(id) + try: + await self.wait_queue(id) + except Exception as e: + print(e) + try: + # todo make cancellable + response = await handler(request) + except: + response = web.json_response({'error': "Internal Server Error", 'status': 500}, + status=500) + # Handle cases where a user leaves the queue, request fails, or is completed + try: + self.remove_from_queue(id) + except Exception as e: + print(e) + return response + + return sample_middleware + + async def get_file(self, image, base64Images, url) -> Image: + if image is not None: + content = image.file.read() + elif base64Images is not None: + base64Images = base64Images + if base64Images.__contains__('base64,'): + base64Images = base64Images.split('base64,')[1] + content = base64.b64decode(base64Images) + elif url is not None: + from aiohttp import ClientSession + async with ClientSession() as session: + async with session.get(url) as resp: + if resp.status == 200: + content = await resp.read() + else: + return web.json_response({'status': 'error'}) + else: + raise ValidationError("donest exist") + img = Image.open(io.BytesIO(content)) + + img.verify() + img = Image.open(io.BytesIO(content)) + if img.width * img.height > 8000 ** 2: + raise ValidationError("to large") + return img + + async def listen(self, translation_params: dict = None): + self.params = translation_params + app = web.Application(client_max_size=1024 * 1024 * 50, middlewares=[self.middleware_factory()]) + + routes = web.RouteTableDef() + run_until_state = '' + + async def hook(state, finished): + if run_until_state and run_until_state == state and not finished: + raise TranslationInterrupt() + + self.add_progress_hook(hook) + + @routes.post("/get_text") + async def text_api(req): + nonlocal run_until_state + run_until_state = 'translating' + return await self.err_handling(self.run_translate, req, self.format_translate) + + @routes.post("/translate") + async def translate_api(req): + nonlocal run_until_state + run_until_state = 'after-translating' + return await self.err_handling(self.run_translate, req, self.format_translate) + + @routes.post("/inpaint_translate") + async def inpaint_translate_api(req): + nonlocal run_until_state + run_until_state = 'rendering' + return await self.err_handling(self.run_translate, req, self.format_translate) + + @routes.post("/colorize_translate") + async def colorize_translate_api(req): + nonlocal run_until_state + run_until_state = 'rendering' + return await self.err_handling(self.run_translate, req, self.format_translate, True) + + # #@routes.post("/file") + # async def file_api(req): + # #TODO: return file + # return await self.err_handling(self.file_exec, req, None) + + app.add_routes(routes) + web.run_app(app, host=self.host, port=self.port) + + async def run_translate(self, translation_params, img): + return await self.translate(img, translation_params) + + async def err_handling(self, func, req, format, ri=False): + try: + if req.content_type == 'application/json' or req.content_type == 'multipart/form-data': + if req.content_type == 'application/json': + d = await req.json() + else: + d = await req.post() + schema = self.PostSchema() + data = schema.load(d) + if 'translator_chain' in data: + data['translator_chain'] = translator_chain(data['translator_chain']) + if 'selective_translation' in data: + data['selective_translation'] = translator_chain(data['selective_translation']) + ctx = Context(**dict(self.params, **data)) + self._preprocess_params(ctx) + if data.get('image') is None and data.get('base64Images') is None and data.get('url') is None: + return web.json_response({'error': "Missing input", 'status': 422}) + fil = await self.get_file(data.get('image'), data.get('base64Images'), data.get('url')) + if 'image' in data: + del data['image'] + if 'base64Images' in data: + del data['base64Images'] + if 'url' in data: + del data['url'] + attempts = 0 + while ctx.attempts == -1 or attempts <= ctx.attempts: + if attempts > 0: + logger.info(f'Retrying translation! Attempt {attempts}' + ( + f' of {ctx.attempts}' if ctx.attempts != -1 else '')) + try: + await func(ctx, fil) + break + except TranslationInterrupt: + break + except Exception as e: + print(e) + attempts += 1 + if ctx.attempts != -1 and attempts > ctx.attempts: + return web.json_response({'error': "Internal Server Error", 'status': 500}, + status=500) + try: + return format(ctx, ri) + except Exception as e: + print(e) + return web.json_response({'error': "Failed to format", 'status': 500}, + status=500) + else: + return web.json_response({'error': "Wrong content type: " + req.content_type, 'status': 415}, + status=415) + except ValueError as e: + print(e) + return web.json_response({'error': "Wrong input type", 'status': 422}, status=422) + + except ValidationError as e: + print(e) + return web.json_response({'error': "Input invalid", 'status': 422}, status=422) + + def format_translate(self, ctx: Context, return_image: bool): + text_regions = ctx.text_regions + inpaint = ctx.img_inpainted + results = [] + if 'overlay_ext' in ctx: + overlay_ext = ctx['overlay_ext'] + else: + overlay_ext = 'jpg' + for i, blk in enumerate(text_regions): + minX, minY, maxX, maxY = blk.xyxy + if 'translations' in ctx: + trans = {key: value[i] for key, value in ctx['translations'].items()} + else: + trans = {} + trans["originalText"] = text_regions[i].text + if inpaint is not None: + overlay = inpaint[minY:maxY, minX:maxX] + + retval, buffer = cv2.imencode('.' + overlay_ext, overlay) + jpg_as_text = base64.b64encode(buffer) + background = "data:image/" + overlay_ext + ";base64," + jpg_as_text.decode("utf-8") + else: + background = None + text_region = text_regions[i] + text_region.adjust_bg_color = False + color1, color2 = text_region.get_font_colors() + + results.append({ + 'text': trans, + 'minX': int(minX), + 'minY': int(minY), + 'maxX': int(maxX), + 'maxY': int(maxY), + 'textColor': { + 'fg': color1.tolist(), + 'bg': color2.tolist() + }, + 'language': text_regions[i].source_lang, + 'background': background + }) + if return_image and ctx.img_colorized is not None: + retval, buffer = cv2.imencode('.' + overlay_ext, np.array(ctx.img_colorized)) + jpg_as_text = base64.b64encode(buffer) + img = "data:image/" + overlay_ext + ";base64," + jpg_as_text.decode("utf-8") + else: + img = None + return web.json_response({'details': results, 'img': img}) + + class PostSchema(Schema): + target_lang = fields.Str(required=False, validate=lambda a: a.upper() in VALID_LANGUAGES) + detector = fields.Str(required=False, validate=lambda a: a.lower() in DETECTORS) + ocr = fields.Str(required=False, validate=lambda a: a.lower() in OCRS) + inpainter = fields.Str(required=False, validate=lambda a: a.lower() in INPAINTERS) + upscaler = fields.Str(required=False, validate=lambda a: a.lower() in UPSCALERS) + translator = fields.Str(required=False, validate=lambda a: a.lower() in TRANSLATORS) + direction = fields.Str(required=False, validate=lambda a: a.lower() in {'auto', 'h', 'v'}) + skip_language = fields.Str(required=False) + upscale_ratio = fields.Integer(required=False) + translator_chain = fields.Str(required=False) + selective_translation = fields.Str(required=False) + attempts = fields.Integer(required=False) + detection_size = fields.Integer(required=False) + text_threshold = fields.Float(required=False) + box_threshold = fields.Float(required=False) + unclip_ratio = fields.Float(required=False) + inpainting_size = fields.Integer(required=False) + det_rotate = fields.Bool(required=False) + det_auto_rotate = fields.Bool(required=False) + det_invert = fields.Bool(required=False) + det_gamma_correct = fields.Bool(required=False) + min_text_length = fields.Integer(required=False) + colorization_size = fields.Integer(required=False) + denoise_sigma = fields.Integer(required=False) + mask_dilation_offset = fields.Integer(required=False) + ignore_bubble = fields.Integer(required=False) + gpt_config = fields.String(required=False) + filter_text = fields.String(required=False) + + # api specific + overlay_ext = fields.Str(required=False) + base64Images = fields.Raw(required=False) + image = fields.Raw(required=False) + url = fields.Raw(required=False) + + # no functionality except preventing errors when given + fingerprint = fields.Raw(required=False) + clientUuid = fields.Raw(required=False) \ No newline at end of file diff --git a/manga_translator/mode/local.py b/manga_translator/mode/local.py new file mode 100644 index 00000000..05e87b41 --- /dev/null +++ b/manga_translator/mode/local.py @@ -0,0 +1,201 @@ +import os +from typing import Union, List + +from PIL import Image + +from manga_translator import MangaTranslator, logger, Context, TranslationInterrupt +from ..save import save_result +from ..translators import ( + LanguageUnsupportedException, + dispatch as dispatch_translation, +) +from ..utils import natural_sort, replace_prefix, get_color_name, rgb2hex + + +class MangaTranslatorLocal(MangaTranslator): + async def translate_path(self, path: str, dest: str = None, params: dict[str, Union[int, str]] = None): + """ + Translates an image or folder (recursively) specified through the path. + """ + if not os.path.exists(path): + raise FileNotFoundError(path) + path = os.path.abspath(os.path.expanduser(path)) + dest = os.path.abspath(os.path.expanduser(dest)) if dest else '' + params = params or {} + + # Handle format + file_ext = params.get('format') + if params.get('save_quality', 100) < 100: + if not params.get('format'): + file_ext = 'jpg' + elif params.get('format') != 'jpg': + raise ValueError('--save-quality of lower than 100 is only supported for .jpg files') + + if os.path.isfile(path): + # Determine destination file path + if not dest: + # Use the same folder as the source + p, ext = os.path.splitext(path) + _dest = f'{p}-translated.{file_ext or ext[1:]}' + elif not os.path.basename(dest): + p, ext = os.path.splitext(os.path.basename(path)) + # If the folders differ use the original filename from the source + if os.path.dirname(path) != dest: + _dest = os.path.join(dest, f'{p}.{file_ext or ext[1:]}') + else: + _dest = os.path.join(dest, f'{p}-translated.{file_ext or ext[1:]}') + else: + p, ext = os.path.splitext(dest) + _dest = f'{p}.{file_ext or ext[1:]}' + await self.translate_file(path, _dest, params) + + elif os.path.isdir(path): + # Determine destination folder path + if path[-1] == '\\' or path[-1] == '/': + path = path[:-1] + _dest = dest or path + '-translated' + if os.path.exists(_dest) and not os.path.isdir(_dest): + raise FileExistsError(_dest) + + translated_count = 0 + for root, subdirs, files in os.walk(path): + files = natural_sort(files) + dest_root = replace_prefix(root, path, _dest) + os.makedirs(dest_root, exist_ok=True) + for f in files: + if f.lower() == '.thumb': + continue + + file_path = os.path.join(root, f) + output_dest = replace_prefix(file_path, path, _dest) + p, ext = os.path.splitext(output_dest) + output_dest = f'{p}.{file_ext or ext[1:]}' + + if await self.translate_file(file_path, output_dest, params): + translated_count += 1 + if translated_count == 0: + logger.info('No further untranslated files found. Use --overwrite to write over existing translations.') + else: + logger.info(f'Done. Translated {translated_count} image{"" if translated_count == 1 else "s"}') + + async def translate_file(self, path: str, dest: str, params: dict): + if not params.get('overwrite') and os.path.exists(dest): + logger.info( + f'Skipping as already translated: "{dest}". Use --overwrite to overwrite existing translations.') + await self._report_progress('saved', True) + return True + + logger.info(f'Translating: "{path}"') + + # Turn dict to context to make values also accessible through params. + params = params or {} + ctx = Context(**params) + self._preprocess_params(ctx) + + attempts = 0 + while ctx.attempts == -1 or attempts < ctx.attempts + 1: + if attempts > 0: + logger.info(f'Retrying translation! Attempt {attempts}' + + (f' of {ctx.attempts}' if ctx.attempts != -1 else '')) + try: + return await self._translate_file(path, dest, ctx) + + except TranslationInterrupt: + break + except Exception as e: + if isinstance(e, LanguageUnsupportedException): + await self._report_progress('error-lang', True) + else: + await self._report_progress('error', True) + if not self.ignore_errors and not (ctx.attempts == -1 or attempts < ctx.attempts): + raise + else: + logger.error(f'{e.__class__.__name__}: {e}', + exc_info=e if self.verbose else None) + attempts += 1 + return False + + async def _translate_file(self, path: str, dest: str, ctx: Context) -> bool: + if path.endswith('.txt'): + with open(path, 'r') as f: + queries = f.read().split('\n') + translated_sentences = \ + await dispatch_translation(ctx.translator, queries, ctx.use_mtpe, ctx, + 'cpu' if self._gpu_limited_memory else self.device) + p, ext = os.path.splitext(dest) + if ext != '.txt': + dest = p + '.txt' + logger.info(f'Saving "{dest}"') + with open(dest, 'w') as f: + f.write('\n'.join(translated_sentences)) + return True + + # TODO: Add .gif handler + + else: # Treat as image + try: + img = Image.open(path) + img.verify() + img = Image.open(path) + except Exception: + logger.warn(f'Failed to open image: {path}') + return False + + ctx = await self.translate(img, ctx) + result = ctx.result + + # Save result + if ctx.skip_no_text and not ctx.text_regions: + logger.debug('Not saving due to --skip-no-text') + return True + if result: + logger.info(f'Saving "{dest}"') + save_result(result, dest, ctx) + await self._report_progress('saved', True) + + if ctx.save_text or ctx.save_text_file or ctx.prep_manual: + if ctx.prep_manual: + # Save original image next to translated + p, ext = os.path.splitext(dest) + img_filename = p + '-orig' + ext + img_path = os.path.join(os.path.dirname(dest), img_filename) + img.save(img_path, quality=ctx.save_quality) + if ctx.text_regions: + self._save_text_to_file(path, ctx) + return True + return False + + def _save_text_to_file(self, image_path: str, ctx: Context): + cached_colors = [] + + def identify_colors(fg_rgb: List[int]): + idx = 0 + for rgb, _ in cached_colors: + # If similar color already saved + if abs(rgb[0] - fg_rgb[0]) + abs(rgb[1] - fg_rgb[1]) + abs(rgb[2] - fg_rgb[2]) < 50: + break + else: + idx += 1 + else: + cached_colors.append((fg_rgb, get_color_name(fg_rgb))) + return idx + 1, cached_colors[idx][1] + + s = f'\n[{image_path}]\n' + for i, region in enumerate(ctx.text_regions): + fore, back = region.get_font_colors() + color_id, color_name = identify_colors(fore) + + s += f'\n-- {i + 1} --\n' + s += f'color: #{color_id}: {color_name} (fg, bg: {rgb2hex(*fore)} {rgb2hex(*back)})\n' + s += f'text: {region.text}\n' + s += f'trans: {region.translation}\n' + for line in region.lines: + s += f'coords: {list(line.ravel())}\n' + s += '\n' + + text_output_file = ctx.text_output_file + if not text_output_file: + text_output_file = os.path.splitext(image_path)[0] + '_translations.txt' + + with open(text_output_file, 'a', encoding='utf-8') as f: + f.write(s) \ No newline at end of file diff --git a/manga_translator/mode/web.py b/manga_translator/mode/web.py new file mode 100644 index 00000000..2acc9faf --- /dev/null +++ b/manga_translator/mode/web.py @@ -0,0 +1,151 @@ +import asyncio +import time + +import requests + +from manga_translator import MangaTranslator, logger, Context +from manga_translator.translators import TRANSLATORS +from manga_translator.utils import add_file_logger, remove_file_logger + + +class MangaTranslatorWeb(MangaTranslator): + """ + Translator client that executes tasks on behalf of the webserver in web_main.py. + """ + + def __init__(self, params: dict = None): + super().__init__(params) + self.host = params.get('host', '127.0.0.1') + if self.host == '0.0.0.0': + self.host = '127.0.0.1' + self.port = params.get('port', 5003) + self.nonce = params.get('nonce', '') + self.ignore_errors = params.get('ignore_errors', True) + self._task_id = None + self._params = None + + async def _init_connection(self): + available_translators = [] + from ..translators import MissingAPIKeyException, get_translator + for key in TRANSLATORS: + try: + get_translator(key) + available_translators.append(key) + except MissingAPIKeyException: + pass + + data = { + 'nonce': self.nonce, + 'capabilities': { + 'translators': available_translators, + }, + } + requests.post(f'http://{self.host}:{self.port}/connect-internal', json=data) + + async def _send_state(self, state: str, finished: bool): + # wait for translation to be saved first (bad solution?) + finished = finished and not state == 'finished' + while True: + try: + data = { + 'task_id': self._task_id, + 'nonce': self.nonce, + 'state': state, + 'finished': finished, + } + requests.post(f'http://{self.host}:{self.port}/task-update-internal', json=data, timeout=20) + break + except Exception: + # if translation is finished server has to know + if finished: + continue + else: + break + + def _get_task(self): + try: + rjson = requests.get(f'http://{self.host}:{self.port}/task-internal?nonce={self.nonce}', + timeout=3600).json() + return rjson.get('task_id'), rjson.get('data') + except Exception: + return None, None + + async def listen(self, translation_params: dict = None): + """ + Listens for translation tasks from web server. + """ + logger.info('Waiting for translation tasks') + + await self._init_connection() + self.add_progress_hook(self._send_state) + + while True: + self._task_id, self._params = self._get_task() + if self._params and 'exit' in self._params: + break + if not (self._task_id and self._params): + await asyncio.sleep(0.1) + continue + + self.result_sub_folder = self._task_id + logger.info(f'Processing task {self._task_id}') + if translation_params is not None: + # Combine default params with params chosen by webserver + for p, default_value in translation_params.items(): + current_value = self._params.get(p) + self._params[p] = current_value if current_value is not None else default_value + if self.verbose: + # Write log file + log_file = self._result_path('log.txt') + add_file_logger(log_file) + + # final.png will be renamed if format param is set + await self.translate_path(self._result_path('input.png'), self._result_path('final.png'), + params=self._params) + print() + + if self.verbose: + remove_file_logger(log_file) + self._task_id = None + self._params = None + self.result_sub_folder = '' + + async def _run_text_translation(self, ctx: Context): + # Run machine translation as reference for manual translation (if `--translator=none` is not set) + text_regions = await super()._run_text_translation(ctx) + + if ctx.get('manual', False): + logger.info('Waiting for user input from manual translation') + requests.post(f'http://{self.host}:{self.port}/request-manual-internal', json={ + 'task_id': self._task_id, + 'nonce': self.nonce, + 'texts': [r.text for r in text_regions], + 'translations': [r.translation for r in text_regions], + }, timeout=20) + + # wait for at most 1 hour for manual translation + wait_until = time.time() + 3600 + while time.time() < wait_until: + ret = requests.post(f'http://{self.host}:{self.port}/get-manual-result-internal', json={ + 'task_id': self._task_id, + 'nonce': self.nonce + }, timeout=20).json() + if 'result' in ret: + manual_translations = ret['result'] + if isinstance(manual_translations, str): + if manual_translations == 'error': + return [] + i = 0 + for translation in manual_translations: + if not translation.strip(): + text_regions.pop(i) + i = i - 1 + else: + text_regions[i].translation = translation + text_regions[i].target_lang = ctx.translator.langs[-1] + i = i + 1 + break + elif 'cancel' in ret: + return 'cancel' + await asyncio.sleep(0.1) + return text_regions diff --git a/manga_translator/mode/ws.py b/manga_translator/mode/ws.py new file mode 100644 index 00000000..d696058f --- /dev/null +++ b/manga_translator/mode/ws.py @@ -0,0 +1,264 @@ +import asyncio +import logging +import os +from typing import Tuple + +import cv2 +import numpy as np +from PIL import Image + +from manga_translator import logger, Context, MangaTranslator +from manga_translator.utils import PriorityLock, Throttler + + +class MangaTranslatorWS(MangaTranslator): + def __init__(self, params: dict = None): + super().__init__(params) + self.url = params.get('ws_url') + self.secret = params.get('ws_secret', os.getenv('WS_SECRET', '')) + self.ignore_errors = params.get('ignore_errors', True) + + self._task_id = None + self._websocket = None + + async def listen(self, translation_params: dict = None): + from threading import Thread + import io + import aioshutil + from aiofiles import os + import websockets + from ..server import ws_pb2 + + self._server_loop = asyncio.new_event_loop() + self.task_lock = PriorityLock() + self.counter = 0 + + async def _send_and_yield(websocket, msg): + # send message and yield control to the event loop (to actually send the message) + await websocket.send(msg) + await asyncio.sleep(0) + + send_throttler = Throttler(0.2) + send_and_yield = send_throttler.wrap(_send_and_yield) + + async def sync_state(state, finished): + if self._websocket is None: + return + msg = ws_pb2.WebSocketMessage() + msg.status.id = self._task_id + msg.status.status = state + self._server_loop.call_soon_threadsafe( + asyncio.create_task, + send_and_yield(self._websocket, msg.SerializeToString()) + ) + + self.add_progress_hook(sync_state) + + async def translate(task_id, websocket, image, params): + async with self.task_lock((1 << 31) - params['ws_count']): + self._task_id = task_id + self._websocket = websocket + result = await self.translate(image, params) + self._task_id = None + self._websocket = None + return result + + async def server_send_status(websocket, task_id, status): + msg = ws_pb2.WebSocketMessage() + msg.status.id = task_id + msg.status.status = status + await websocket.send(msg.SerializeToString()) + await asyncio.sleep(0) + + async def server_process_inner(main_loop, logger_task, session, websocket, task) -> Tuple[bool, bool]: + logger_task.info(f'-- Processing task {task.id}') + await server_send_status(websocket, task.id, 'pending') + + if self.verbose: + await aioshutil.rmtree(f'result/{task.id}', ignore_errors=True) + await os.makedirs(f'result/{task.id}', exist_ok=True) + + params = { + 'target_lang': task.target_language, + 'skip_lang': task.skip_language, + 'detector': task.detector, + 'direction': task.direction, + 'translator': task.translator, + 'size': task.size, + 'ws_event_loop': asyncio.get_event_loop(), + 'ws_count': self.counter, + } + self.counter += 1 + + logger_task.info(f'-- Downloading image from {task.source_image}') + await server_send_status(websocket, task.id, 'downloading') + async with session.get(task.source_image) as resp: + if resp.status == 200: + source_image = await resp.read() + else: + msg = ws_pb2.WebSocketMessage() + msg.status.id = task.id + msg.status.status = 'error-download' + await websocket.send(msg.SerializeToString()) + await asyncio.sleep(0) + return False, False + + logger_task.info(f'-- Translating image') + if translation_params: + for p, default_value in translation_params.items(): + current_value = params.get(p) + params[p] = current_value if current_value is not None else default_value + + image = Image.open(io.BytesIO(source_image)) + + (ori_w, ori_h) = image.size + if max(ori_h, ori_w) > 1200: + params['upscale_ratio'] = 1 + + await server_send_status(websocket, task.id, 'preparing') + # translation_dict = await self.translate(image, params) + translation_dict = await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe( + translate(task.id, websocket, image, params), + main_loop + ) + ) + await send_throttler.flush() + + output: Image.Image = translation_dict.result + if output is not None: + await server_send_status(websocket, task.id, 'saving') + + output = output.resize((ori_w, ori_h), resample=Image.LANCZOS) + + img = io.BytesIO() + output.save(img, format='PNG') + if self.verbose: + output.save(self._result_path('ws_final.png')) + + img_bytes = img.getvalue() + logger_task.info(f'-- Uploading result to {task.translation_mask}') + await server_send_status(websocket, task.id, 'uploading') + async with session.put(task.translation_mask, data=img_bytes) as resp: + if resp.status != 200: + logger_task.error(f'-- Failed to upload result:') + logger_task.error(f'{resp.status}: {resp.reason}') + msg = ws_pb2.WebSocketMessage() + msg.status.id = task.id + msg.status.status = 'error-upload' + await websocket.send(msg.SerializeToString()) + await asyncio.sleep(0) + return False, False + + return True, output is not None + + async def server_process(main_loop, session, websocket, task) -> bool: + logger_task = logger.getChild(f'{task.id}') + try: + (success, has_translation_mask) = await server_process_inner(main_loop, logger_task, session, websocket, + task) + except Exception as e: + logger_task.error(f'-- Task failed with exception:') + logger_task.error(f'{e.__class__.__name__}: {e}', exc_info=e if self.verbose else None) + (success, has_translation_mask) = False, False + finally: + result = ws_pb2.WebSocketMessage() + result.finish_task.id = task.id + result.finish_task.success = success + result.finish_task.has_translation_mask = has_translation_mask + await websocket.send(result.SerializeToString()) + await asyncio.sleep(0) + logger_task.info(f'-- Task finished') + + async def async_server_thread(main_loop): + from aiohttp import ClientSession, ClientTimeout + timeout = ClientTimeout(total=30) + async with ClientSession(timeout=timeout) as session: + logger_conn = logger.getChild('connection') + if self.verbose: + logger_conn.setLevel(logging.DEBUG) + async for websocket in websockets.connect( + self.url, + extra_headers={ + 'x-secret': self.secret, + }, + max_size=1_000_000, + logger=logger_conn + ): + bg_tasks = set() + try: + logger.info('-- Connected to websocket server') + + async for raw in websocket: + # logger.info(f'Got message: {raw}') + msg = ws_pb2.WebSocketMessage() + msg.ParseFromString(raw) + if msg.WhichOneof('message') == 'new_task': + task = msg.new_task + bg_task = asyncio.create_task(server_process(main_loop, session, websocket, task)) + bg_tasks.add(bg_task) + bg_task.add_done_callback(bg_tasks.discard) + + except Exception as e: + logger.error(f'{e.__class__.__name__}: {e}', exc_info=e if self.verbose else None) + + finally: + logger.info('-- Disconnected from websocket server') + for bg_task in bg_tasks: + bg_task.cancel() + + def server_thread(future, main_loop, server_loop): + asyncio.set_event_loop(server_loop) + try: + server_loop.run_until_complete(async_server_thread(main_loop)) + finally: + future.set_result(None) + + future = asyncio.Future() + Thread( + target=server_thread, + args=(future, asyncio.get_running_loop(), self._server_loop), + daemon=True + ).start() + + # create a future that is never done + await future + + async def _run_text_translation(self, ctx: Context): + coroutine = super()._run_text_translation(ctx) + if ctx.translator.has_offline(): + return await coroutine + else: + task_id = self._task_id + websocket = self._websocket + await self.task_lock.release() + result = await asyncio.wrap_future( + asyncio.run_coroutine_threadsafe( + coroutine, + ctx.ws_event_loop + ) + ) + await self.task_lock.acquire((1 << 30) - ctx.ws_count) + self._task_id = task_id + self._websocket = websocket + return result + + async def _run_text_rendering(self, ctx: Context): + render_mask = (ctx.mask >= 127).astype(np.uint8)[:, :, None] + + output = await super()._run_text_rendering(ctx) + render_mask[np.sum(ctx.img_rgb != output, axis=2) > 0] = 1 + ctx.render_mask = render_mask + if self.verbose: + cv2.imwrite(self._result_path('ws_render_in.png'), cv2.cvtColor(ctx.img_rgb, cv2.COLOR_RGB2BGR)) + cv2.imwrite(self._result_path('ws_render_out.png'), cv2.cvtColor(output, cv2.COLOR_RGB2BGR)) + cv2.imwrite(self._result_path('ws_mask.png'), render_mask * 255) + + # only keep sections in mask + if self.verbose: + cv2.imwrite(self._result_path('ws_inmask.png'), cv2.cvtColor(ctx.img_rgb, cv2.COLOR_RGB2BGRA) * render_mask) + output = cv2.cvtColor(output, cv2.COLOR_RGB2RGBA) * render_mask + if self.verbose: + cv2.imwrite(self._result_path('ws_output.png'), cv2.cvtColor(output, cv2.COLOR_RGBA2BGRA) * render_mask) + + return output