diff --git a/dump-parameter-schema.py b/dump-parameter-schema.py index 185a0fb..d31b7f0 100755 --- a/dump-parameter-schema.py +++ b/dump-parameter-schema.py @@ -11,86 +11,88 @@ def _argparse_parse_args(self: argparse.ArgumentParser, args=None, namespace=None): - """Creates a json string from the argparser instance""" - json_out: Dict[str,Any] = {} - json_out["type"] = "bilingual/monolingual" #TODO "monolingual" or "bilingual" but no idea how to determine this automatically - json_out["description"] = self.description - - # non-simple type so it doesn't end up in json - SWITCH = object() - SUBSTITUTE = object() - - # We need to skip [0], as this is the prepended `--help` - param_dict: Dict[str,Dict] = {} - for argument in self._actions[1:]: - # Skip --help and --version - if isinstance(argument, (argparse._HelpAction, argparse._VersionAction)): - continue - - current_str = { - "help": argument.help, - "required": argument.required - } - - if isinstance(argument, argparse._StoreConstAction): - current_str |= { - "type": "bool", - "default": False - } - elif type(argument.type) == type: - current_str |= { - "type": cast(type, argument.type).__name__, - "default": argument.default - } - elif type(argument.type) == argparse.FileType: - current_str |= { - "type": "str", - "default": "-" - } - elif argument.default is not None: - current_str|= { - "type": type(argument.default).__name__, - "default": argument.default - } - else: - print(f"Unknown type for \"{argument.dest}\": skipped\n{argument!r}\n", file=sys.stderr) - continue - - - # If it is an `--option` type argument - if argument.option_strings: - current_str |= { - SWITCH: argument.option_strings[0], #TODO prefer long names? - SUBSTITUTE: argument.option_strings[0].replace('-','').upper() - } - # or a positional one - else: - current_str |= { - SUBSTITUTE: argument.dest.upper() - } - - if argument.choices is not None: - current_str |= { - "allowed_values": argument.choices - } - - # Add to the parameter dict - param_dict[current_str[SUBSTITUTE]] = current_str - - json_out["parameters"] = param_dict - - json_out["command"] = script_path - for _, value in param_dict.items(): - if not value["required"]: - json_out["command"] += " ${" + value[SUBSTITUTE] + ":+" + value[SWITCH] + (" $" + value[SUBSTITUTE] if value["type"] != "bool" else "") + "}" - else: - json_out["command"] += (" " + value[SWITCH] if SWITCH in value else "") + " $" + value[SUBSTITUTE] - - json.dump(json_out, sys.stdout, indent=4, skipkeys=True) - sys.exit(0) + """Creates a json string from the argparser instance""" + json_out: Dict[str, Any] = {} + json_out["type"] = ( + "bilingual/monolingual" # TODO "monolingual" or "bilingual" but no idea how to determine this automatically + ) + json_out["description"] = self.description + + # non-simple type so it doesn't end up in json + SWITCH = object() + SUBSTITUTE = object() + + # We need to skip [0], as this is the prepended `--help` + param_dict: Dict[str, Dict] = {} + for argument in self._actions[1:]: + # Skip --help and --version + if isinstance(argument, (argparse._HelpAction, argparse._VersionAction)): + continue + + current_str = {"help": argument.help, "required": argument.required} + + if isinstance(argument, argparse._StoreConstAction): + current_str |= {"type": "bool", "default": False} + elif type(argument.type) == type: + current_str |= { + "type": cast(type, argument.type).__name__, + "default": argument.default, + } + elif type(argument.type) == argparse.FileType: + current_str |= {"type": "str", "default": "-"} + elif argument.default is not None: + current_str |= { + "type": type(argument.default).__name__, + "default": argument.default, + } + else: + print( + f'Unknown type for "{argument.dest}": skipped\n{argument!r}\n', + file=sys.stderr, + ) + continue + + # If it is an `--option` type argument + if argument.option_strings: + current_str |= { + SWITCH: argument.option_strings[0], # TODO prefer long names? + SUBSTITUTE: argument.option_strings[0].replace("-", "").upper(), + } + # or a positional one + else: + current_str |= {SUBSTITUTE: argument.dest.upper()} + + if argument.choices is not None: + current_str |= {"allowed_values": argument.choices} + + # Add to the parameter dict + param_dict[current_str[SUBSTITUTE]] = current_str + + json_out["parameters"] = param_dict + + json_out["command"] = script_path + for _, value in param_dict.items(): + if not value["required"]: + json_out["command"] += ( + " ${" + + value[SUBSTITUTE] + + ":+" + + value[SWITCH] + + (" $" + value[SUBSTITUTE] if value["type"] != "bool" else "") + + "}" + ) + else: + json_out["command"] += ( + (" " + value[SWITCH] if SWITCH in value else "") + + " $" + + value[SUBSTITUTE] + ) + + json.dump(json_out, sys.stdout, indent=4, skipkeys=True) + sys.exit(0) argparse.ArgumentParser.parse_args = _argparse_parse_args with open(script_path) as fh: - exec(fh.read()) + exec(fh.read()) diff --git a/opuscleaner/_util.py b/opuscleaner/_util.py index e68a2cb..d974baa 100644 --- a/opuscleaner/_util.py +++ b/opuscleaner/_util.py @@ -6,13 +6,14 @@ T = TypeVar("T") + def none_throws(optional: Optional[T], message: str = "Unexpected `None`") -> T: if optional is None: raise AssertionError(message) return optional -def _thread_pool_worker(id:int, exc_queue:SimpleQueue, target, args, kwargs): +def _thread_pool_worker(id: int, exc_queue: SimpleQueue, target, args, kwargs): try: target(*args, **kwargs) exc_queue.put((id, None)) @@ -26,6 +27,7 @@ class ThreadPool: exception as soon as it is received. It is then up to you to stop any other threads. This pool will wait for them when it exits the context. """ + def __init__(self): self.threads = {} self.queue = SimpleQueue() @@ -41,7 +43,8 @@ def start(self, func, *args, **kwargs): "args": args, "kwargs": kwargs, }, - name=func.__name__) + name=func.__name__, + ) self.threads[thread_id].start() def join(self): @@ -54,7 +57,7 @@ def join(self): def __enter__(self): return self - + def __exit__(self, *args, **kwargs): for thread in self.threads.values(): thread.join() @@ -65,17 +68,20 @@ class Cancelled(Exception): """Error raised by CancelableQueue's `put()` or `get()` when `cancel()` was called. """ + pass -T = TypeVar('T') +T = TypeVar("T") + class CancelableQueue(Generic[T]): """SimpleQueue, but when cancel() is called it will release all blocking put() and get() calls and raise `Cancelled`. Also much worse performance than SimpleQueue so don't use for heavy workloads plz. """ - def __init__(self, capacity:Optional[int]=None): + + def __init__(self, capacity: Optional[int] = None): self.capacity = capacity self.size = 0 self.queue = deque() @@ -87,7 +93,11 @@ def put(self, item: T): when `cancel()` was called. """ with self.cv: - self.cv.wait_for(lambda: self.cancelled or self.capacity is None or self.size < self.capacity) + self.cv.wait_for( + lambda: self.cancelled + or self.capacity is None + or self.size < self.capacity + ) if self.cancelled: raise Cancelled() self.queue.append(item) @@ -106,9 +116,9 @@ def get(self) -> T: item = self.queue.popleft() self.cv.notify() return item - + def cancel(self): """Makes all calls to `get()` and `put()` raise `Cancelled()`.""" with self.cv: self.cancelled = True - self.cv.notify_all() + self.cv.notify_all() diff --git a/opuscleaner/categories.py b/opuscleaner/categories.py index 9d47e0c..f9ac225 100644 --- a/opuscleaner/categories.py +++ b/opuscleaner/categories.py @@ -8,47 +8,62 @@ class Category(BaseModel): - name: str + name: str - @validator('name') - def name_must_not_be_empty(cls, value:str) -> str: - assert len(value.strip()) > 0, 'must not be empty' - return value.strip() + @validator("name") + def name_must_not_be_empty(cls, value: str) -> str: + assert len(value.strip()) > 0, "must not be empty" + return value.strip() class CategoryMapping(BaseModel): - categories: List[Category] - mapping: Dict[str,List[str]] + categories: List[Category] + mapping: Dict[str, List[str]] - @validator('categories') - def categories_must_be_unique(cls, value:List[Category]) -> List[Category]: - assert len(set(category.name.strip() for category in value)) == len(value), 'categories must have unique names' - return value + @validator("categories") + def categories_must_be_unique(cls, value: List[Category]) -> List[Category]: + assert len(set(category.name.strip() for category in value)) == len( + value + ), "categories must have unique names" + return value - @validator('mapping') - def mapping_must_only_contain_categories(cls, value:Dict[str,List[str]], values: Dict[str,Any], **kwargs) -> Dict[str,List[str]]: - assert len(set(value.keys()) - set(category.name.strip() for category in values.get('categories', ''))) == 0, 'mapping must only contain keys that are defined in `categories`' - return value + @validator("mapping") + def mapping_must_only_contain_categories( + cls, value: Dict[str, List[str]], values: Dict[str, Any], **kwargs + ) -> Dict[str, List[str]]: + assert ( + len( + set(value.keys()) + - set( + category.name.strip() for category in values.get("categories", "") + ) + ) + == 0 + ), "mapping must only contain keys that are defined in `categories`" + return value -def read_categories(fh:TextIO) -> CategoryMapping: - return parse_obj_as(CategoryMapping, json.load(fh)) +def read_categories(fh: TextIO) -> CategoryMapping: + return parse_obj_as(CategoryMapping, json.load(fh)) -def write_categories(mapping:CategoryMapping, fh:TextIO) -> None: - json.dump(mapping.dict(), fh, indent=2) + +def write_categories(mapping: CategoryMapping, fh: TextIO) -> None: + json.dump(mapping.dict(), fh, indent=2) app = FastAPI() -@app.get('/') + +@app.get("/") def get_mapping() -> CategoryMapping: - if os.path.exists(CATEGORIES_PATH): - with open(CATEGORIES_PATH, 'r') as fh: - return read_categories(fh) - else: - return CategoryMapping(categories=DEFAULT_CATEGORIES, mapping=dict()) - -@app.put('/') -def update_categories(body:CategoryMapping) -> None: - with open(CATEGORIES_PATH, 'w') as fh: - write_categories(body, fh) + if os.path.exists(CATEGORIES_PATH): + with open(CATEGORIES_PATH, "r") as fh: + return read_categories(fh) + else: + return CategoryMapping(categories=DEFAULT_CATEGORIES, mapping=dict()) + + +@app.put("/") +def update_categories(body: CategoryMapping) -> None: + with open(CATEGORIES_PATH, "w") as fh: + write_categories(body, fh) diff --git a/opuscleaner/clean.py b/opuscleaner/clean.py index 7c22a66..e011a7d 100755 --- a/opuscleaner/clean.py +++ b/opuscleaner/clean.py @@ -4,6 +4,7 @@ links them together through pipes. Can read from stdin but by default reads the dataset from the same folder as the pipeline configuration file. """ + import argparse import json import os @@ -24,35 +25,50 @@ from opuscleaner import logging from opuscleaner.config import FILTER_PATH -from opuscleaner.filters import list_filters, set_global_filters, filter_format_command, Filter, FilterPipeline, quote, format_shell +from opuscleaner.filters import ( + list_filters, + set_global_filters, + filter_format_command, + Filter, + FilterPipeline, + quote, + format_shell, +) from opuscleaner._util import none_throws, ThreadPool, CancelableQueue, Cancelled # Queue for printing lines to stdout or stderr. None means end of input. -PrintQueue = SimpleQueue[Union[None,bytes]] +PrintQueue = SimpleQueue[Union[None, bytes]] # Control queue for communicating the return code of a child back to the parent. -ControlQueue = SimpleQueue[Tuple[int,int]] +ControlQueue = SimpleQueue[Tuple[int, int]] # Batches to be processed. tuple[batch index,batch path]. None means end of input. # Using a Queue here to limit the maximum capacity. -BatchQueue = CancelableQueue[Union[None,Tuple[int,str]]] +BatchQueue = CancelableQueue[Union[None, Tuple[int, str]]] # Batches to be merged. Same format as BatchQueue -MergeQueue = CancelableQueue[Union[None,Tuple[int,str]]] +MergeQueue = CancelableQueue[Union[None, Tuple[int, str]]] -def load_time(fh:IO[str]) -> Dict[str,float]: +def load_time(fh: IO[str]) -> Dict[str, float]: time = {} for line in fh: - match = re.match(r'^(real|user|sys)\s+(\d+\.\d+)$', line.rstrip('\r\n')) + match = re.match(r"^(real|user|sys)\s+(\d+\.\d+)$", line.rstrip("\r\n")) if match: time[match[1]] = float(match[2]) return time @logging.trace -def babysit_child(n: int, child: Popen, name: str, print_queue: PrintQueue, ctrl_queue: ControlQueue, time_read_fd:Optional[int]=None) -> None: +def babysit_child( + n: int, + child: Popen, + name: str, + print_queue: PrintQueue, + ctrl_queue: ControlQueue, + time_read_fd: Optional[int] = None, +) -> None: """Thread that looks after a child process and passes (and prefixes) all of its stderr to a queue. It will tell the parent thread about the end of the child through the ctrl_queue. @@ -60,20 +76,20 @@ def babysit_child(n: int, child: Popen, name: str, print_queue: PrintQueue, ctrl try: logging.update(n=n, pid=child.pid, args=child.args) - prefix = f'[{name}] '.encode() + prefix = f"[{name}] ".encode() for line in none_throws(child.stderr): print_queue.put(prefix + line) child.wait() - logging.event('child_exited', retval=child.returncode) + logging.event("child_exited", retval=child.returncode) # If the command was wrapped by `time`, we want to read its output as # well. It's written to a separate pipe as to not end up in the stderr # of the main command. if time_read_fd is not None: - with os.fdopen(time_read_fd, 'r') as fh: + with os.fdopen(time_read_fd, "r") as fh: logging.update(time=load_time(fh)) finally: ctrl_queue.put((n, child.returncode)) @@ -94,9 +110,10 @@ def print_lines(queue: PrintQueue, fout: IO[bytes]) -> None: fout.flush() -T = TypeVar('T') +T = TypeVar("T") + -def mark_last(iterable: Iterable[T]) -> Iterable[Tuple[bool,T]]: +def mark_last(iterable: Iterable[T]) -> Iterable[Tuple[bool, T]]: it = iter(iterable) curr_el = next(it) while True: @@ -125,24 +142,39 @@ class ProcessPool: processes in the pool and raise an exception on exit. SIGPIPE errors, and errors caused by the pool terminating the process, are ignored. """ + print_prefix: str ctrl_queue: ControlQueue print_queue: PrintQueue - environ: Dict[str,str] + environ: Dict[str, str] children: List[Child] - def __init__(self, print_queue: PrintQueue, *, env:Dict[str,str]={}, print_prefix:str=''): + def __init__( + self, + print_queue: PrintQueue, + *, + env: Dict[str, str] = {}, + print_prefix: str = "", + ): self.print_prefix = print_prefix self.ctrl_queue = SimpleQueue() self.print_queue = print_queue self.environ = dict(env) self.children = [] - def start(self, name:str, cmd: Union[str,List[str]], *, shell:bool=False, time:bool=False, **kwargs) -> Popen: + def start( + self, + name: str, + cmd: Union[str, List[str]], + *, + shell: bool = False, + time: bool = False, + **kwargs, + ) -> Popen: """Set up a process in the pool. Similar to Popen. `name` is used for identifying the process in log messages and exceptions. `time` can be set to True to wrap the process in `/usr/bin/time`. Furthermore all @@ -150,28 +182,27 @@ def start(self, name:str, cmd: Union[str,List[str]], *, shell:bool=False, time:b """ time_read_fd, time_write_fd = None, None - args = ([cmd] if isinstance(cmd, str) else cmd) - + args = [cmd] if isinstance(cmd, str) else cmd + if shell: - args = ['/bin/sh', '-c', *args] # TODO: sorry Windows, Andriod + args = ["/bin/sh", "-c", *args] # TODO: sorry Windows, Andriod # If we're measuring time, prepend `/usr/bin/time` and let it write to # a pipe we will read out later. Massive assumption: that pipe's buffer # will be sufficient for time's output. if time: time_read_fd, time_write_fd = os.pipe() - os.set_inheritable(time_write_fd, True) # TODO is this necessary? - args = ['/usr/bin/time', '-p', '-o', f'/dev/fd/{time_write_fd}', *args] - kwargs['pass_fds'] = (time_write_fd, *kwargs.get('pass_fds', tuple())) - - child = Popen(args, **{ - **kwargs, - 'env': { - **os.environ, - **self.environ, - **(kwargs.get('env') or dict()) - } - }) + os.set_inheritable(time_write_fd, True) # TODO is this necessary? + args = ["/usr/bin/time", "-p", "-o", f"/dev/fd/{time_write_fd}", *args] + kwargs["pass_fds"] = (time_write_fd, *kwargs.get("pass_fds", tuple())) + + child = Popen( + args, + **{ + **kwargs, + "env": {**os.environ, **self.environ, **(kwargs.get("env") or dict())}, + }, + ) # If we have a time pipe, make sure we release our handle of the write # side. We just keep the read side. @@ -179,12 +210,15 @@ def start(self, name:str, cmd: Union[str,List[str]], *, shell:bool=False, time:b os.close(time_write_fd) n = len(self.children) - thread = Thread(target=babysit_child, args=[n, child, name, self.print_queue, self.ctrl_queue, time_read_fd]) + thread = Thread( + target=babysit_child, + args=[n, child, name, self.print_queue, self.ctrl_queue, time_read_fd], + ) thread.start() self.children.append(Child(name, child, thread)) return child - def __enter__(self) -> 'ProcessPool': + def __enter__(self) -> "ProcessPool": return self def __exit__(self, err_type, err_inst, _) -> None: @@ -205,7 +239,7 @@ def __exit__(self, err_type, err_inst, _) -> None: child_i, retval = self.ctrl_queue.get() running_children -= 1 - logging.event('child_exit_received', n=child_i, retval=retval) + logging.event("child_exit_received", n=child_i, retval=retval) # Early exit when a process errored out. SIGPIPE is retuned by # processes that can no longer write to the next one. E.g. when @@ -231,7 +265,9 @@ def __exit__(self, err_type, err_inst, _) -> None: # If we broke out of our ctrl_queue loop we did so because there was an issue # with one of the children. Let's raise that to the parent. if not err_inst and problem_child: - raise RuntimeError(f"Child {problem_child.name} (pid {problem_child.process.pid}) exited with {problem_child.process.returncode}") + raise RuntimeError( + f"Child {problem_child.name} (pid {problem_child.process.pid}) exited with {problem_child.process.returncode}" + ) class PipelineStep(NamedTuple): @@ -245,45 +281,63 @@ class Pipeline: set up to execute in a certain environment. A Pipeline can either be dumped as a bash script, or executed on a ProcessPool. """ - def __init__(self, filters:Dict[str,Filter], languages: List[str], pipeline: FilterPipeline): + + def __init__( + self, filters: Dict[str, Filter], languages: List[str], pipeline: FilterPipeline + ): self.steps: List[PipelineStep] = [] # Make sure the path to the python binary (and the installed utils) # is in the PATH variable. If you load a virtualenv this happens by - # default, but if you call it with the virtualenv's python binary + # default, but if you call it with the virtualenv's python binary # directly it wont. pyenv_bin_path = os.path.dirname(sys.executable) - os_env_bin_paths = os.environ.get('PATH', '').split(os.pathsep) - self.env: Optional[Dict[str,str]] = { - **os.environ, - 'PATH': os.pathsep.join([pyenv_bin_path] + os_env_bin_paths) - } if pyenv_bin_path not in os_env_bin_paths else None + os_env_bin_paths = os.environ.get("PATH", "").split(os.pathsep) + self.env: Optional[Dict[str, str]] = ( + {**os.environ, "PATH": os.pathsep.join([pyenv_bin_path] + os_env_bin_paths)} + if pyenv_bin_path not in os_env_bin_paths + else None + ) # Assert we have all filters we need - assert set(step.filter for step in pipeline.filters) - set(filters.keys()) == set() + assert ( + set(step.filter for step in pipeline.filters) - set(filters.keys()) == set() + ) # Make sure the path to the python binary (and the installed utils) # is in the PATH variable. If you load a virtualenv this happens by - # default, but if you call it with the virtualenv's python binary + # default, but if you call it with the virtualenv's python binary # directly it wont. pyenv_bin_path = os.path.dirname(sys.executable) - os_env_bin_paths = os.environ.get('PATH', '').split(os.pathsep) - self.env: Optional[Dict[str,str]] = { - **os.environ, - 'PATH': os.pathsep.join([pyenv_bin_path] + os_env_bin_paths) - } if pyenv_bin_path not in os_env_bin_paths else None + os_env_bin_paths = os.environ.get("PATH", "").split(os.pathsep) + self.env: Optional[Dict[str, str]] = ( + {**os.environ, "PATH": os.pathsep.join([pyenv_bin_path] + os_env_bin_paths)} + if pyenv_bin_path not in os_env_bin_paths + else None + ) for step in pipeline.filters: filter_def = filters[step.filter] command_str = filter_format_command(filter_def, step, languages) - self.steps.append(PipelineStep(step.filter, command_str, filter_def.basedir)) - - def run(self, pool:ProcessPool, stdin:IO[bytes], stdout:IO[bytes], *, tee:bool=False, basename:str="", time:bool=False) -> None: + self.steps.append( + PipelineStep(step.filter, command_str, filter_def.basedir) + ) + + def run( + self, + pool: ProcessPool, + stdin: IO[bytes], + stdout: IO[bytes], + *, + tee: bool = False, + basename: str = "", + time: bool = False, + ) -> None: """Set up all the processes on `pool`, processing `stdin` to `stdout`. Note that this function will return as soon as the processes have been set up. You will have to use the ProcessPool to wait for them to finish. Optionally you can `tee` the output of each filter step to a separate - file for debugging (with the name "{basename}.step-{i}.tsv". You can + file for debugging (with the name "{basename}.step-{i}.tsv". You can use `time` two wrap every filter step command in `/usr/bin/time` and the baby sitter will measure how much processing time the filter process used.""" @@ -292,18 +346,21 @@ def run(self, pool:ProcessPool, stdin:IO[bytes], stdout:IO[bytes], *, tee:bool=F return for i, (is_last_step, step) in enumerate(mark_last(self.steps)): - child = pool.start(f'{pool.print_prefix}{i}:{step.name}', step.command, + child = pool.start( + f"{pool.print_prefix}{i}:{step.name}", + step.command, stdin=stdin, stdout=stdout if is_last_step and not tee else PIPE, stderr=PIPE, cwd=step.basedir, env=self.env, shell=True, - time=time) + time=time, + ) # Close our reference to the previous child, now taken over by the next child stdin.close() - + # Set stdin for next step (unless there is none, then child.stdout is None) if not is_last_step and not tee: stdin = none_throws(child.stdout) @@ -311,28 +368,31 @@ def run(self, pool:ProcessPool, stdin:IO[bytes], stdout:IO[bytes], *, tee:bool=F # If we are tee-ing for debug, shunt the output to a separate file # TODO: uncompressed at the moment. Might be trouble. if tee: - tee_child = pool.start(f'{pool.print_prefix}{i}:tee', - ['tee', f'{basename}.step-{i}.tsv'], + tee_child = pool.start( + f"{pool.print_prefix}{i}:tee", + ["tee", f"{basename}.step-{i}.tsv"], stdin=stdin, stdout=stdout if is_last_step else PIPE, - stderr=PIPE) + stderr=PIPE, + ) stdin.close() stdin = none_throws(tee_child.stdout) - def dump(self, out:IO[str]) -> None: + def dump(self, out: IO[str]) -> None: """Write this pipeline as a bash script.""" if self.env: for key, val in self.env: - out.write(f'export {key}={quote(format_shell(val))}\n') + out.write(f"export {key}={quote(format_shell(val))}\n") for is_last_step, step in mark_last(self.steps): - out.write(f'(cd {quote(format_shell(step.basedir))} && ({step.command}))') - out.write('\n' if is_last_step else ' |\n') + out.write(f"(cd {quote(format_shell(step.basedir))} && ({step.command}))") + out.write("\n" if is_last_step else " |\n") - -def split_input(parallel: int, batch_queue: BatchQueue, batch_size:int, stdin:IO[bytes]) -> None: +def split_input( + parallel: int, batch_queue: BatchQueue, batch_size: int, stdin: IO[bytes] +) -> None: """Reads data from `stdin` and splits it into chunks of `batch_size` lines. These chunks are stored in temporary files, whose filenames are put onto `batch_queue`. @@ -343,20 +403,20 @@ def split_input(parallel: int, batch_queue: BatchQueue, batch_size:int, stdin:IO while more: fh = NamedTemporaryFile(delete=False) - + lines = 0 while lines < batch_size: line = stdin.readline() - if line == b'': + if line == b"": more = False break fh.write(line) lines += 1 - + fh.close() - try: + try: if lines > 0: batch_queue.put((batch_index, fh.name)) else: @@ -377,7 +437,14 @@ def split_input(parallel: int, batch_queue: BatchQueue, batch_size:int, stdin:IO @logging.trace -def run_pipeline(print_queue:PrintQueue, batch_queue:BatchQueue, merge_queue:MergeQueue, pipeline:Pipeline, *, time:bool=False) -> None: +def run_pipeline( + print_queue: PrintQueue, + batch_queue: BatchQueue, + merge_queue: MergeQueue, + pipeline: Pipeline, + *, + time: bool = False, +) -> None: """Receives an input filename from `batch_queue`, and once that has been processed with `pipeline`, it will post the output filename to `merge_queue`. stderr from any of the filter processes will be forwarded to `print_queue`. @@ -407,9 +474,13 @@ def run_pipeline(print_queue:PrintQueue, batch_queue:BatchQueue, merge_queue:Mer try: # Open chunk file and process pool and run the pipeline with it. # The pool's __exit__() will make us wait till the pipeline is done. - with logging.span('run_pipeline_batch', batch_index=batch_index), \ - open(filename, 'rb') as stdin, \ - ProcessPool(print_queue, env={'TMPDIR': tmpdir}, print_prefix=f'{batch_index}/') as pool: + with logging.span( + "run_pipeline_batch", batch_index=batch_index + ), open(filename, "rb") as stdin, ProcessPool( + print_queue, + env={"TMPDIR": tmpdir}, + print_prefix=f"{batch_index}/", + ) as pool: pipeline.run(pool, stdin, stdout, time=time) stdout.close() @@ -422,16 +493,18 @@ def run_pipeline(print_queue:PrintQueue, batch_queue:BatchQueue, merge_queue:Mer raise except Exception as exc: # Add a bit more info, and re-raise - raise RuntimeError(f'Error while processing batch {batch_index}') from exc + raise RuntimeError( + f"Error while processing batch {batch_index}" + ) from exc finally: # Delete the input file from disk. os.unlink(filename) - + # Tell the merger that they should not be expecting more input from you. merge_queue.put(None) -def merge_output(parallel:int, merge_queue:MergeQueue, stdout:IO[bytes]) -> None: +def merge_output(parallel: int, merge_queue: MergeQueue, stdout: IO[bytes]) -> None: """Takes batch filenames and numbers from `merge_queue` and will concatenate files in the order of the batches. If batches arrive out of order, it will wait for the next in order batch to arrive before continuing to concatenate. @@ -446,10 +519,12 @@ def merge_output(parallel:int, merge_queue:MergeQueue, stdout:IO[bytes]) -> None batch_index, filename = next_batch_index, pending_batches[next_batch_index] try: - with logging.span('merge_output_batch', batch_index=batch_index), open(filename, 'rb') as fh: + with logging.span("merge_output_batch", batch_index=batch_index), open( + filename, "rb" + ) as fh: copyfileobj(fh, stdout) except Exception as exc: - raise RuntimeError(f'Error while merging batch {batch_index}') from exc + raise RuntimeError(f"Error while merging batch {batch_index}") from exc finally: os.unlink(filename) @@ -471,11 +546,22 @@ def merge_output(parallel:int, merge_queue:MergeQueue, stdout:IO[bytes]) -> None break if len(pending_batches) and next_batch_index <= max(pending_batches.keys()): - raise RuntimeError(f'Not all batches got merged: {next_batch_index=} <= {max(pending_batches.keys())=}') + raise RuntimeError( + f"Not all batches got merged: {next_batch_index=} <= {max(pending_batches.keys())=}" + ) @logging.trace -def run_parallel(pipeline:Pipeline, stdin:IO[bytes], stdout:IO[bytes], *, parallel:int, batch_size:int, print_queue: PrintQueue, time:bool=False) -> None: +def run_parallel( + pipeline: Pipeline, + stdin: IO[bytes], + stdout: IO[bytes], + *, + parallel: int, + batch_size: int, + print_queue: PrintQueue, + time: bool = False, +) -> None: """Run `parallel` copies of the processing pipeline in parallel, each working on a batch of `batch_size` lines at a time. Batches will be cut from `stdin` and printed to `stdout`, in order. stderr from the filter @@ -493,14 +579,16 @@ def run_parallel(pipeline:Pipeline, stdin:IO[bytes], stdout:IO[bytes], *, parall # Read `batch_queue` for batch filenames, and process them. Put output files # on `merge_queue`. for _ in range(parallel): - pool.start(run_pipeline, print_queue, batch_queue, merge_queue, pipeline, time=time) + pool.start( + run_pipeline, print_queue, batch_queue, merge_queue, pipeline, time=time + ) # Read from `merge_queue` and combine files in order. pool.start(merge_output, parallel, merge_queue, stdout) try: pool.join() - except BaseException: # Note: also catches KeyboardInterrupt + except BaseException: # Note: also catches KeyboardInterrupt batch_queue.cancel() merge_queue.cancel() raise @@ -508,41 +596,103 @@ def run_parallel(pipeline:Pipeline, stdin:IO[bytes], stdout:IO[bytes], *, parall def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument('--filters', '-f', type=str, default=FILTER_PATH, help='Path to directory with filter specifications') - parser.add_argument('--input', '-i', type=argparse.FileType('rb'), help='Input tsv. If unspecified input files are read from filter json; use - to read from stdin') - parser.add_argument('--output', '-o', type=argparse.FileType('wb'), default=sys.stdout.buffer, help='Output tsv (defaults to stdout)') - parser.add_argument('--basedir', '-b', type=str, help='Directory to look for data files when --input is not used (defaults to same as input pipeline file)') - parser.add_argument('--tee', action='store_true', help='Write output after each step to a separate file') - parser.add_argument('--parallel', type=int, default=1, help='Run N parallel copies of the pipeline processing batches') - parser.add_argument('--batch-size', type=int, default=1_000_000, help='Batch size in lines that each parallel copy processes (only if --parallel > 1)') - parser.add_argument('--first', type=int, default=0, help='Limit reading input to the N first lines') - parser.add_argument('--dump', action='store_true', help='Print shell script instead') - parser.add_argument('--trace', type=argparse.FileType('a'), nargs='?', const='/dev/stderr', help='Write tracing JSON to file (defaults to stderr)') - parser.add_argument('--time', action='store_true', help='Measure real/user/sys times for each filter step') - parser.add_argument('pipeline', metavar='PIPELINE', type=argparse.FileType('r'), help='Pipeline steps specification file, e.g. *.filters.json') - parser.add_argument('languages', metavar='LANG', type=str, nargs='*', help='Language codes of the columns in the input TSV. Only used when --input is set') + parser.add_argument( + "--filters", + "-f", + type=str, + default=FILTER_PATH, + help="Path to directory with filter specifications", + ) + parser.add_argument( + "--input", + "-i", + type=argparse.FileType("rb"), + help="Input tsv. If unspecified input files are read from filter json; use - to read from stdin", + ) + parser.add_argument( + "--output", + "-o", + type=argparse.FileType("wb"), + default=sys.stdout.buffer, + help="Output tsv (defaults to stdout)", + ) + parser.add_argument( + "--basedir", + "-b", + type=str, + help="Directory to look for data files when --input is not used (defaults to same as input pipeline file)", + ) + parser.add_argument( + "--tee", + action="store_true", + help="Write output after each step to a separate file", + ) + parser.add_argument( + "--parallel", + type=int, + default=1, + help="Run N parallel copies of the pipeline processing batches", + ) + parser.add_argument( + "--batch-size", + type=int, + default=1_000_000, + help="Batch size in lines that each parallel copy processes (only if --parallel > 1)", + ) + parser.add_argument( + "--first", type=int, default=0, help="Limit reading input to the N first lines" + ) + parser.add_argument( + "--dump", action="store_true", help="Print shell script instead" + ) + parser.add_argument( + "--trace", + type=argparse.FileType("a"), + nargs="?", + const="/dev/stderr", + help="Write tracing JSON to file (defaults to stderr)", + ) + parser.add_argument( + "--time", + action="store_true", + help="Measure real/user/sys times for each filter step", + ) + parser.add_argument( + "pipeline", + metavar="PIPELINE", + type=argparse.FileType("r"), + help="Pipeline steps specification file, e.g. *.filters.json", + ) + parser.add_argument( + "languages", + metavar="LANG", + type=str, + nargs="*", + help="Language codes of the columns in the input TSV. Only used when --input is set", + ) args = parser.parse_args() - with logging.Context(file=args.trace), logging.span('main'): + with logging.Context(file=args.trace), logging.span("main"): # default search path for the data files is next to the configuration file # which is the default save location for empty-train. if not args.basedir: args.basedir = os.path.dirname(args.pipeline.name) or os.getcwd() if args.input is not None and not args.languages: - parser.error('When --input is specified, each column\'s LANG has to be specified as well') + parser.error( + "When --input is specified, each column's LANG has to be specified as well" + ) if args.tee and args.parallel > 1: - parser.error('Using --tee is not supported when using --parallel') + parser.error("Using --tee is not supported when using --parallel") if args.time and not args.trace: - parser.error('You need to use --trace to see the output of --time') + parser.error("You need to use --trace to see the output of --time") # load all filter definitions (we need to, to get their name) filters = { - definition.name: definition - for definition in list_filters(args.filters) + definition.name: definition for definition in list_filters(args.filters) } # set_global_filters() provides the filters to the validators in FilterPipeline @@ -550,10 +700,18 @@ def main() -> None: pipeline_config = parse_obj_as(FilterPipeline, json.load(args.pipeline)) # Order of columns. Matches datasets.py:list_datasets(path) - languages: List[str] = args.languages if args.input else [filename.rsplit('.', 2)[1] for filename in pipeline_config.files] + languages: List[str] = ( + args.languages + if args.input + else [filename.rsplit(".", 2)[1] for filename in pipeline_config.files] + ) # Directory plus basename to write debug (`--tee`) files to - basename: str = 'stdin' if args.input else os.path.commonprefix(pipeline_config.files).rstrip('.') + basename: str = ( + "stdin" + if args.input + else os.path.commonprefix(pipeline_config.files).rstrip(".") + ) pipeline = Pipeline(filters, languages, pipeline_config) @@ -561,7 +719,7 @@ def main() -> None: stdin: IO[bytes] # Output of this program - stdout:IO[bytes] = args.output + stdout: IO[bytes] = args.output # If we're just dumping the pipeline, do so to the specified output if args.dump: @@ -570,7 +728,7 @@ def main() -> None: # Queue filled by the babysitters with the stderr of the children, consumed # by `print_lines()` to prevent racing on stderr. - print_queue = SimpleQueue() # type: SimpleQueue[Optional[bytes]] + print_queue = SimpleQueue() # type: SimpleQueue[Optional[bytes]] # First start the print thread so that we get immediate feedback from the # children even if all of them haven't started yet. @@ -586,22 +744,26 @@ def main() -> None: else: # Open `gzunip` for each language file gunzips = [ - pool.start(f'gunzip {filename}', - ['gzip', '-cd', filename], + pool.start( + f"gunzip {filename}", + ["gzip", "-cd", filename], stdout=PIPE, stderr=PIPE, - cwd=args.basedir) + cwd=args.basedir, + ) for filename in pipeline_config.files ] fds = [none_throws(gunzip.stdout).fileno() for gunzip in gunzips] # .. and a `paste` to combine them into columns - paste = pool.start('paste', - ['paste'] + [f'/dev/fd/{fd}' for fd in fds], + paste = pool.start( + "paste", + ["paste"] + [f"/dev/fd/{fd}" for fd in fds], stdout=PIPE, stderr=PIPE, - pass_fds=fds) + pass_fds=fds, + ) # Now that `paste` has inherited all the children, close our connection to them for gunzip in gunzips: @@ -611,19 +773,36 @@ def main() -> None: # If we only want the first N lines processed, use `head` to chop those off. if args.first > 0: - head = pool.start('head', - ['head', '-n', str(args.first)], + head = pool.start( + "head", + ["head", "-n", str(args.first)], stdin=stdin, stdout=PIPE, - stderr=PIPE) + stderr=PIPE, + ) - stdin.close() # now taken over by `head`. + stdin.close() # now taken over by `head`. stdin = none_throws(head.stdout) if args.parallel > 1: - run_parallel(pipeline, stdin, stdout, print_queue=print_queue, parallel=args.parallel, batch_size=args.batch_size, time=args.time) + run_parallel( + pipeline, + stdin, + stdout, + print_queue=print_queue, + parallel=args.parallel, + batch_size=args.batch_size, + time=args.time, + ) else: - pipeline.run(pool, stdin, stdout, tee=args.tee, basename=basename, time=args.time) + pipeline.run( + pool, + stdin, + stdout, + tee=args.tee, + basename=basename, + time=args.time, + ) except: # If we didn't cleanly exit all processes, we err as well traceback.print_exc(file=sys.stderr) @@ -634,5 +813,5 @@ def main() -> None: print_thread.join() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/opuscleaner/col.py b/opuscleaner/col.py index a86ec10..3fafeb5 100755 --- a/opuscleaner/col.py +++ b/opuscleaner/col.py @@ -6,119 +6,142 @@ from typing import BinaryIO, Optional, TypeVar, List -queue = SimpleQueue() # type: SimpleQueue[None|list[bytes]] +queue = SimpleQueue() # type: SimpleQueue[None|list[bytes]] T = TypeVar("T") + def none_throws(optional: Optional[T], message: str = "Unexpected `None`") -> T: - if optional is None: - raise AssertionError(message) - return optional + if optional is None: + raise AssertionError(message) + return optional -def parse_columns(text:str) -> List[int]: - return sorted(int(col) for col in text.split(',')) +def parse_columns(text: str) -> List[int]: + return sorted(int(col) for col in text.split(",")) class RaisingThread(Thread): - """Thread that will raise any uncaught exceptions in the thread in the - parent once it joins again.""" - - exception: Optional[Exception] - - def run(self): - self.exception = None - try: - super().run() - except Exception as exc: - self.exception = exc - - def join(self, timeout:Optional[float]=None): - super().join(timeout=timeout) - if self.exception is not None: - raise self.exception - - -def split(columns:List[int], queue:'SimpleQueue[None|list[bytes]]', fin:BinaryIO, fout:BinaryIO): - try: - field_count = None - passthru_columns = [] - for line in fin: - fields = line.rstrip(b'\r\n').split(b'\t') - if field_count is None: - field_count = len(fields) - passthru_columns = [n for n in range(field_count) if n not in columns] - elif field_count != len(fields): - raise RuntimeError(f'line contains a different number of fields: {len(fields)} vs {field_count}') - queue.put([fields[column] for column in passthru_columns]) - for column in columns: - fout.write(fields[column] + b'\n') - except BrokenPipeError: - pass - finally: - try: - fout.close() # might fail if BrokenPipeError - except: - pass - queue.put(None) # End indicator - fin.close() - - -def merge(columns:List[int], queue:'SimpleQueue[None|list[bytes]]', fin:BinaryIO, fout:BinaryIO): - try: - while True: - passthru_fields = queue.get() - if passthru_fields is None: - if fin.readline() != b'': - raise RuntimeError('subprocess produced more lines of output than it was given') - break - - passthru_it = iter(passthru_fields) - for column in range(len(passthru_fields) + len(columns)): - if column in columns: - field = fin.readline() - if field == b'': - raise RuntimeError('subprocess produced fewer lines than it was given') - field = field.rstrip(b'\r\n') - else: - field = next(passthru_it) - - if column > 0: - fout.write(b'\t') - fout.write(field) - fout.write(b'\n') - except BrokenPipeError: - pass - finally: - fout.close() - fin.close() + """Thread that will raise any uncaught exceptions in the thread in the + parent once it joins again.""" + + exception: Optional[Exception] + + def run(self): + self.exception = None + try: + super().run() + except Exception as exc: + self.exception = exc + + def join(self, timeout: Optional[float] = None): + super().join(timeout=timeout) + if self.exception is not None: + raise self.exception + + +def split( + columns: List[int], + queue: "SimpleQueue[None|list[bytes]]", + fin: BinaryIO, + fout: BinaryIO, +): + try: + field_count = None + passthru_columns = [] + for line in fin: + fields = line.rstrip(b"\r\n").split(b"\t") + if field_count is None: + field_count = len(fields) + passthru_columns = [n for n in range(field_count) if n not in columns] + elif field_count != len(fields): + raise RuntimeError( + f"line contains a different number of fields: {len(fields)} vs {field_count}" + ) + queue.put([fields[column] for column in passthru_columns]) + for column in columns: + fout.write(fields[column] + b"\n") + except BrokenPipeError: + pass + finally: + try: + fout.close() # might fail if BrokenPipeError + except: + pass + queue.put(None) # End indicator + fin.close() + + +def merge( + columns: List[int], + queue: "SimpleQueue[None|list[bytes]]", + fin: BinaryIO, + fout: BinaryIO, +): + try: + while True: + passthru_fields = queue.get() + if passthru_fields is None: + if fin.readline() != b"": + raise RuntimeError( + "subprocess produced more lines of output than it was given" + ) + break + + passthru_it = iter(passthru_fields) + for column in range(len(passthru_fields) + len(columns)): + if column in columns: + field = fin.readline() + if field == b"": + raise RuntimeError( + "subprocess produced fewer lines than it was given" + ) + field = field.rstrip(b"\r\n") + else: + field = next(passthru_it) + + if column > 0: + fout.write(b"\t") + fout.write(field) + fout.write(b"\n") + except BrokenPipeError: + pass + finally: + fout.close() + fin.close() def main(): - retval = 0 + retval = 0 + + try: + columns = parse_columns(sys.argv[1]) - try: - columns = parse_columns(sys.argv[1]) + child = Popen(sys.argv[2:], stdin=PIPE, stdout=PIPE) - child = Popen(sys.argv[2:], stdin=PIPE, stdout=PIPE) + feeder = RaisingThread( + target=split, + args=[columns, queue, sys.stdin.buffer, none_throws(child).stdin], + ) + feeder.start() - feeder = RaisingThread(target=split, args=[columns, queue, sys.stdin.buffer, none_throws(child).stdin]) - feeder.start() + consumer = RaisingThread( + target=merge, + args=[columns, queue, none_throws(child).stdout, sys.stdout.buffer], + ) + consumer.start() - consumer = RaisingThread(target=merge, args=[columns, queue, none_throws(child).stdout, sys.stdout.buffer]) - consumer.start() + retval = child.wait() - retval = child.wait() - - if retval != 0: - raise RuntimeError(f'subprocess exited with status code {retval}') + if retval != 0: + raise RuntimeError(f"subprocess exited with status code {retval}") - feeder.join() - consumer.join() - except Exception as e: - print(f'Error: {e}', file=sys.stderr) - sys.exit(retval or 1) + feeder.join() + consumer.join() + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(retval or 1) -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/opuscleaner/config.py b/opuscleaner/config.py index 10a88d6..a4d4137 100644 --- a/opuscleaner/config.py +++ b/opuscleaner/config.py @@ -2,38 +2,33 @@ import sys # Path to data files. Expects to find files named `$DATASET.$LANG.gz`. -DATA_PATH = os.getenv('DATA_PATH', 'data/train-parts/*.*.gz') +DATA_PATH = os.getenv("DATA_PATH", "data/train-parts/*.*.gz") -# Path to the file that defines the categories, and which dataset belongs to +# Path to the file that defines the categories, and which dataset belongs to # which. -CATEGORIES_PATH = os.path.join(os.path.dirname(DATA_PATH), 'categories.json') +CATEGORIES_PATH = os.path.join(os.path.dirname(DATA_PATH), "categories.json") -DEFAULT_CATEGORIES = [ - {'name': 'clean'}, - {'name': 'medium'}, - {'name': 'dirty'} -] +DEFAULT_CATEGORIES = [{"name": "clean"}, {"name": "medium"}, {"name": "dirty"}] # TODO: Derive this from DATA_PATH. The `train-parts` is a mtdata compatibility # thing. I'm now used to also have a data/clean directory there, so keeping it. -DOWNLOAD_PATH = 'data/train-parts' +DOWNLOAD_PATH = "data/train-parts" # glob expression that looks for the filter files. Unfortunately you can't use # commas and {} in this expression. -FILTER_PATH = os.pathsep.join([ - 'filters/**/*.json', - os.path.join(os.path.dirname(__file__), 'filters/**/*.json') -]) +FILTER_PATH = os.pathsep.join( + ["filters/**/*.json", os.path.join(os.path.dirname(__file__), "filters/**/*.json")] +) # col.py is used to apply a monolingual filter to a bilingual dataset. Needs # to be absolute since filters can run from different cwds. -COL_PY = [sys.executable, os.path.join(os.path.dirname(__file__), 'col.py')] +COL_PY = [sys.executable, os.path.join(os.path.dirname(__file__), "col.py")] # Program used to derive a sample from the dataset. Will be called like # `./sample.py -n $SAMPLE_SIZE ...file-per-lang.gz`. Absolute because filters # can specify their own `cwd`. -SAMPLE_PY = [sys.executable, os.path.join(os.path.dirname(__file__), 'sample.py')] +SAMPLE_PY = [sys.executable, os.path.join(os.path.dirname(__file__), "sample.py")] # Size of each of the three sections (head, random sample of middle, tail) of # the dataset sample that we operate on. -SAMPLE_SIZE = int(os.getenv('SAMPLE_SIZE', '1000')) +SAMPLE_SIZE = int(os.getenv("SAMPLE_SIZE", "1000")) diff --git a/opuscleaner/datasets.py b/opuscleaner/datasets.py index b639b02..ebc25e7 100644 --- a/opuscleaner/datasets.py +++ b/opuscleaner/datasets.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 """Lists datasets given a directory. It works by scanning the directory and looking for gz files.""" + from glob import glob from itertools import groupby from pathlib import Path as Path @@ -8,10 +9,10 @@ from opuscleaner.config import DATA_PATH -def list_datasets(path:str) -> Dict[str,List[Tuple[str,Path]]]: +def list_datasets(path: str) -> Dict[str, List[Tuple[str, Path]]]: """Lists datasets given a directory. Scans the directories and returns a dictionary of the datasets encoutered. Dictionary looks like {dataset_name : { lang: path}}""" - root = Path(path.split('*')[0]) + root = Path(path.split("*")[0]) entries = (Path(entry) for entry in glob(path, recursive=True)) @@ -19,22 +20,20 @@ def list_datasets(path:str) -> Dict[str,List[Tuple[str,Path]]]: entry for entry in entries if entry.is_file() - and entry.name.endswith('.gz') - and not entry.name.startswith('.') + and entry.name.endswith(".gz") + and not entry.name.startswith(".") ] datasets = [ (name, list(files)) for name, files in groupby( sorted(files, key=lambda entry: str(entry)), - key=lambda entry: str(entry.relative_to(root)).rsplit('.', 2)[0]) + key=lambda entry: str(entry.relative_to(root)).rsplit(".", 2)[0], + ) ] return { - name: [ - (entry.name.rsplit('.', 2)[1], entry) - for entry in files - ] + name: [(entry.name.rsplit(".", 2)[1], entry) for entry in files] for name, files in datasets } @@ -42,11 +41,12 @@ def list_datasets(path:str) -> Dict[str,List[Tuple[str,Path]]]: def main() -> None: import sys import pprint + if len(sys.argv) == 1: pprint.pprint(list_datasets(DATA_PATH)) else: pprint.pprint(list_datasets(sys.argv[1])) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/opuscleaner/download.py b/opuscleaner/download.py index 522586b..9c3b909 100644 --- a/opuscleaner/download.py +++ b/opuscleaner/download.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 """Various mtdata dataset downloading utilities""" + import argparse import os import json @@ -33,9 +34,9 @@ class EntryRef(BaseModel): class Entry(EntryRef): corpus: str version: str - langs: Tuple[str,str] - pairs: Optional[int] # Number of sentence pairs - size: Optional[int] # Size on disk in bytes (rounded to lowest 1024) + langs: Tuple[str, str] + pairs: Optional[int] # Number of sentence pairs + size: Optional[int] # Size on disk in bytes (rounded to lowest 1024) @property def basename(self) -> str: @@ -51,48 +52,50 @@ class RemoteEntry(Entry): class DownloadState(Enum): - PENDING = 'pending' - CANCELLED = 'cancelled' - DOWNLOADING = 'downloading' - DOWNLOADED = 'downloaded' - FAILED = 'failed' + PENDING = "pending" + CANCELLED = "cancelled" + DOWNLOADING = "downloading" + DOWNLOADED = "downloaded" + FAILED = "failed" -def get_dataset(entry:RemoteEntry, path:str) -> None: - if entry.url.endswith('.zip'): +def get_dataset(entry: RemoteEntry, path: str) -> None: + if entry.url.endswith(".zip"): get_bilingual_dataset(entry, path) - elif entry.url.endswith('.txt.gz'): + elif entry.url.endswith(".txt.gz"): get_monolingual_dataset(entry, path) else: - raise RuntimeError(f'Unknown dataset file type: {entry.url}') + raise RuntimeError(f"Unknown dataset file type: {entry.url}") + +def get_monolingual_dataset(entry: RemoteEntry, path: str) -> None: + lang = next(lang for lang in entry.langs if lang != "") + assert entry.url.endswith(f"{lang}.txt.gz") -def get_monolingual_dataset(entry:RemoteEntry, path:str) -> None: - lang = next(lang for lang in entry.langs if lang != '') - assert entry.url.endswith(f'{lang}.txt.gz') - # Make sure our path exists os.makedirs(path, exist_ok=True) with TemporaryDirectory(dir=path) as temp_dir: temp_path = os.path.join(temp_dir, os.path.basename(entry.url)) - dest_path = os.path.join(path, f'{entry.basename}.{lang}.gz') + dest_path = os.path.join(path, f"{entry.basename}.{lang}.gz") # Download dataset to temporary file - with urlopen(entry.url) as fh, open(temp_path, 'wb') as fout: + with urlopen(entry.url) as fh, open(temp_path, "wb") as fout: copyfileobj(fh, fout) # move to permanent position os.rename(temp_path, dest_path) -def _extract(path:str, name:str, dest:str) -> str: - with ZipFile(path) as archive, archive.open(name) as fin, gzip.open(dest, 'wb') as fout: - copyfileobj(fin, fout, length=2**24) # 16MB blocks +def _extract(path: str, name: str, dest: str) -> str: + with ZipFile(path) as archive, archive.open(name) as fin, gzip.open( + dest, "wb" + ) as fout: + copyfileobj(fin, fout, length=2**24) # 16MB blocks return dest -def get_bilingual_dataset(entry:RemoteEntry, path:str) -> None: +def get_bilingual_dataset(entry: RemoteEntry, path: str) -> None: # List of extensions of the expected files, e.g. `.en-mt.mt` and `.en-mt.en`. suffixes = [f'.{"-".join(entry.langs)}.{lang}' for lang in entry.langs] @@ -113,19 +116,23 @@ def get_bilingual_dataset(entry:RemoteEntry, path:str) -> None: with ZipFile(temp_archive) as archive: for info in archive.filelist: - if info.is_dir() or not any(info.filename.endswith(suffix) for suffix in suffixes): + if info.is_dir() or not any( + info.filename.endswith(suffix) for suffix in suffixes + ): continue # `info.filename` is something like "beepboop.en-nl.en", `lang` will be "en". - _, lang = info.filename.rsplit('.', maxsplit=1) + _, lang = info.filename.rsplit(".", maxsplit=1) - filename = f'{entry.basename}.{lang}.gz' + filename = f"{entry.basename}.{lang}.gz" temp_dest = os.path.join(temp_extracted, filename) data_dest = os.path.join(path, filename) # Extract the file from the zip archive into the temporary directory, and # compress it while we're at it. - future = pool.submit(_extract, temp_archive.name, info.filename, temp_dest) + future = pool.submit( + _extract, temp_archive.name, info.filename, temp_dest + ) futures.append((future, data_dest)) @@ -142,10 +149,10 @@ class EntryDownload: entry: RemoteEntry _child: Optional[Process] - def __init__(self, entry:RemoteEntry): + def __init__(self, entry: RemoteEntry): self.entry = entry self._child = None - + def start(self) -> None: self._child = Process(target=get_dataset, args=(self.entry, DOWNLOAD_PATH)) self._child.start() @@ -173,26 +180,28 @@ def state(self) -> DownloadState: return DownloadState.CANCELLED -DownloadQueue = SimpleQueue#[Optional[EntryDownload]] +DownloadQueue = SimpleQueue # [Optional[EntryDownload]] class Downloader: - def __init__(self, workers:int): + def __init__(self, workers: int): self.queue: DownloadQueue = SimpleQueue() self.threads: List[Thread] = [] for _ in range(workers): - thread = Thread(target=self.__class__.worker_thread, args=[self.queue], daemon=True) + thread = Thread( + target=self.__class__.worker_thread, args=[self.queue], daemon=True + ) thread.start() self.threads.append(thread) - def download(self, entry:RemoteEntry) -> EntryDownload: + def download(self, entry: RemoteEntry) -> EntryDownload: download = EntryDownload(entry=entry) self.queue.put(download) return download @staticmethod - def worker_thread(queue:DownloadQueue) -> None: + def worker_thread(queue: DownloadQueue) -> None: while True: entry = queue.get() if not entry: @@ -208,39 +217,32 @@ class EntryDownloadView(BaseModel): class OpusAPI: endpoint: str - _datasets: Dict[int,Entry] = {} + _datasets: Dict[int, Entry] = {} - def __init__(self, endpoint:str): + def __init__(self, endpoint: str): self.endpoint = endpoint self._datasets = {} def languages(self, lang1: Optional[str] = None) -> List[str]: - query = {'languages': 'True'} + query = {"languages": "True"} if lang1 is not None: - query['source'] = lang1 + query["source"] = lang1 - with urlopen(f'{self.endpoint}?{urlencode(query)}') as fh: - return [str(lang) for lang in json.load(fh).get('languages', [])] + with urlopen(f"{self.endpoint}?{urlencode(query)}") as fh: + return [str(lang) for lang in json.load(fh).get("languages", [])] - def get_dataset(self, id:int) -> Entry: + def get_dataset(self, id: int) -> Entry: return self._datasets[id] - def find_datasets(self, lang1:str, lang2:Optional[str]=None) -> List[Entry]: + def find_datasets(self, lang1: str, lang2: Optional[str] = None) -> List[Entry]: if lang2 is None: - query = { - 'source': lang1, - 'preprocessing': 'mono' - } + query = {"source": lang1, "preprocessing": "mono"} else: - query = { - 'source': lang1, - 'target': lang2, - 'preprocessing': 'moses' - } + query = {"source": lang1, "target": lang2, "preprocessing": "moses"} - with urlopen(f'{self.endpoint}?{urlencode(query)}') as fh: - datasets = [cast_entry(entry) for entry in json.load(fh).get('corpora', [])] + with urlopen(f"{self.endpoint}?{urlencode(query)}") as fh: + datasets = [cast_entry(entry) for entry in json.load(fh).get("corpora", [])] # FIXME dirty hack to keep a local copy to be able to do id based lookup # Related: https://github.com/Helsinki-NLP/OPUS-API/issues/3 @@ -252,29 +254,37 @@ def find_datasets(self, lang1:str, lang2:Optional[str]=None) -> List[Entry]: app = FastAPI() -api = OpusAPI('https://opus.nlpl.eu/opusapi/') +api = OpusAPI("https://opus.nlpl.eu/opusapi/") -downloads: Dict[int,EntryDownload] = {} +downloads: Dict[int, EntryDownload] = {} downloader = Downloader(2) datasets_by_id: Dict[int, Entry] = {} -def cast_entry(data:Dict[str,Any]) -> Entry: + +def cast_entry(data: Dict[str, Any]) -> Entry: entry = Entry( - id=int(data['id']), - corpus=str(data['corpus']), - version=str(data['version']), - pairs=int(data['alignment_pairs']) if data.get('alignment_pairs') != '' else None, - size=int(data['size']) * 1024, # FIXME file size but do we care? - langs=(data['source'], data['target']), # FIXME these are messy OPUS-API lang codes :( + id=int(data["id"]), + corpus=str(data["corpus"]), + version=str(data["version"]), + pairs=int(data["alignment_pairs"]) + if data.get("alignment_pairs") != "" + else None, + size=int(data["size"]) * 1024, # FIXME file size but do we care? + langs=( + data["source"], + data["target"], + ), # FIXME these are messy OPUS-API lang codes :( ) paths = set( filename for data_root in [os.path.dirname(DATA_PATH), DOWNLOAD_PATH] for lang in entry.langs - for filename in iglob(os.path.join(data_root, f'{entry.basename}.{lang}.gz'), recursive=True) + for filename in iglob( + os.path.join(data_root, f"{entry.basename}.{lang}.gz"), recursive=True + ) ) # Print search paths @@ -286,48 +296,42 @@ def cast_entry(data:Dict[str,Any]) -> Entry: # )) if paths: - return LocalEntry( - **entry.__dict__, - paths=paths) + return LocalEntry(**entry.__dict__, paths=paths) else: - return RemoteEntry( - **entry.__dict__, - url=str(data['url'])) + return RemoteEntry(**entry.__dict__, url=str(data["url"])) @app.get("/languages/") @app.get("/languages/{lang1}") -def list_languages(lang1:Optional[str] = None) -> List[str]: +def list_languages(lang1: Optional[str] = None) -> List[str]: return sorted(api.languages(lang1)) @app.get("/by-language/{langs}") -def list_datasets(langs:str) -> Iterable[Entry]: - return api.find_datasets(*langs.split('-')) +def list_datasets(langs: str) -> Iterable[Entry]: + return api.find_datasets(*langs.split("-")) -@app.get('/downloads/') +@app.get("/downloads/") def list_downloads() -> Iterable[EntryDownloadView]: return ( - EntryDownloadView( - entry = download.entry, - state = download.state - ) + EntryDownloadView(entry=download.entry, state=download.state) for download in downloads.values() ) -@app.post('/downloads/') +@app.post("/downloads/") def batch_add_downloads(datasets: List[EntryRef]) -> Iterable[EntryDownloadView]: """Batch download requests!""" - needles = set(dataset.id + needles = set( + dataset.id for dataset in datasets if dataset.id not in downloads - or downloads[dataset.id].state in {DownloadState.CANCELLED, DownloadState.FAILED}) + or downloads[dataset.id].state + in {DownloadState.CANCELLED, DownloadState.FAILED} + ) - entries = [ - api.get_dataset(id) for id in needles - ] + entries = [api.get_dataset(id) for id in needles] for entry in entries: assert isinstance(entry, RemoteEntry) @@ -336,81 +340,91 @@ def batch_add_downloads(datasets: List[EntryRef]) -> Iterable[EntryDownloadView] return list_downloads() -@app.delete('/downloads/{dataset_id}') -def cancel_download(dataset_id:int) -> EntryDownloadView: +@app.delete("/downloads/{dataset_id}") +def cancel_download(dataset_id: int) -> EntryDownloadView: """Cancel a download. Removes it from the queue, does not kill the process if download is already happening. """ if dataset_id not in downloads: - raise HTTPException(status_code=404, detail='Download not found') + raise HTTPException(status_code=404, detail="Download not found") download = downloads[dataset_id] download.cancel() - return EntryDownloadView( - entry = download.entry, - state = download.state - ) - + return EntryDownloadView(entry=download.entry, state=download.state) + + LOG = logging.getLogger("download") - + + def main(): - logging.basicConfig(format='%(asctime)s %(levelname)s: %(name)s: %(message)s', \ - datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG) + logging.basicConfig( + format="%(asctime)s %(levelname)s: %(name)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.DEBUG, + ) parser = argparse.ArgumentParser() - parser.add_argument("-d", "--directory", help="Directory to search for categories.json files. Defaults to current directory.") - parser.add_argument("-t", "--threads", type=int, help="Threads for downloading", default=4) + parser.add_argument( + "-d", + "--directory", + help="Directory to search for categories.json files. Defaults to current directory.", + ) + parser.add_argument( + "-t", "--threads", type=int, help="Threads for downloading", default=4 + ) args = parser.parse_args() - + root_dir = args.directory if root_dir == None: root_dir = Path.cwd() else: root_dir = Path(root_dir) - + LOG.info(f"Searching for categories.json in {root_dir}") - cat_files = [Path(dirpath,f) for dirpath,_,files in os.walk(root_dir) for f in files if Path(f).name == "categories.json"] + cat_files = [ + Path(dirpath, f) + for dirpath, _, files in os.walk(root_dir) + for f in files + if Path(f).name == "categories.json" + ] LOG.info(f"Found {len(cat_files)} categories.json files") - - entry_cache = {} # caches basename -> entry + + entry_cache = {} # caches basename -> entry downloader = Downloader(workers=args.threads) for cat_file in cat_files: target_dir = cat_file.parent LOG.debug(f"Processing corpora in {cat_file}") cat_data = json.load(open(cat_file)) - for cat_list in cat_data['mapping'].values(): + for cat_list in cat_data["mapping"].values(): for corpus_id in cat_list: entry = entry_cache.get(corpus_id) if entry == None: # cache miss, download the list for this language from opus - l1,l2 = corpus_id.split(".")[-1].split("-") - entries_by_langs = api.find_datasets(l1,l2) + l1, l2 = corpus_id.split(".")[-1].split("-") + entries_by_langs = api.find_datasets(l1, l2) for entry in entries_by_langs: entry_cache[entry.basename] = entry entry = entry_cache.get(corpus_id) if entry == None: - raise RuntimeError(f"Unable to find corpus with basename: {corpus_id}") - #TODO: + raise RuntimeError( + f"Unable to find corpus with basename: {corpus_id}" + ) + # TODO: # - Use downloader for multithreaded downloading (but need to set target dir) # - Do not download if file exists - + if hasattr(entry, "paths"): for source_path in entry.paths: LOG.debug(f"Copying from {source_path}") - shutil.copy(source_path,target_dir) + shutil.copy(source_path, target_dir) else: LOG.debug(f"Downloading corpus {corpus_id}") get_bilingual_dataset(entry, target_dir) - #downloader.download(entry) # Currently downloads to DOWNLOAD_PATH + # downloader.download(entry) # Currently downloads to DOWNLOAD_PATH # # This does not work, because workers do not exit - - #for thread in downloader.threads: - # thread.join() - #import time - #time.sleep(10) - #print(downloader.queue.get()) - - - - + # for thread in downloader.threads: + # thread.join() + # import time + # time.sleep(10) + # print(downloader.queue.get()) diff --git a/opuscleaner/filters.py b/opuscleaner/filters.py index 374ef3a..0da0aff 100644 --- a/opuscleaner/filters.py +++ b/opuscleaner/filters.py @@ -68,7 +68,7 @@ def export(self, value: Any) -> Any: return super().export(value) def default_factory(self) -> Any: - return '' + return "" class FilterParameterList(FilterParameterBase): @@ -77,10 +77,7 @@ class FilterParameterList(FilterParameterBase): default: Optional[List[Any]] def export(self, value: Any) -> Any: - return [ - self.parameter.export(item) - for item in value - ] + return [self.parameter.export(item) for item in value] def default_factory(self) -> Any: return [] @@ -95,15 +92,11 @@ class FilterParameterTuple(FilterParameterBase): def export(self, value: Any) -> Any: return tuple( - parameter.export(val) - for parameter, val in zip(self.parameters, value) + parameter.export(val) for parameter, val in zip(self.parameters, value) ) def default_factory(self) -> Any: - return [ - parameter.default_factory() - for parameter in self.parameters - ] + return [parameter.default_factory() for parameter in self.parameters] FilterParameter = Union[ @@ -112,7 +105,7 @@ def default_factory(self) -> Any: FilterParameterBool, FilterParameterStr, FilterParameterList, - FilterParameterTuple + FilterParameterTuple, ] FilterParameterList.update_forward_refs() @@ -121,40 +114,44 @@ def default_factory(self) -> Any: class Filter(BaseModel): type: FilterType - name: str # comes from filename by default + name: str # comes from filename by default description: Optional[str] command: str basedir: str - parameters: Dict[str,FilterParameter] + parameters: Dict[str, FilterParameter] - @validator('parameters') - def check_keys(cls, parameters: Dict[str,Any]) -> Dict[str,Any]: + @validator("parameters") + def check_keys(cls, parameters: Dict[str, Any]) -> Dict[str, Any]: for var_name in parameters.keys(): if not re.match(r"^[a-zA-Z_][a-zA-Z_0-9]*$", var_name): - raise ValueError(f"Parameter name is not a valid bash variable: {var_name}") + raise ValueError( + f"Parameter name is not a valid bash variable: {var_name}" + ) return parameters -_FILTERS: Dict[str,Filter] = {} +_FILTERS: Dict[str, Filter] = {} class FilterStep(BaseModel): filter: str - parameters: Dict[str,Any] + parameters: Dict[str, Any] language: Optional[str] - @validator('filter') - def check_filter(cls, filter_name:str) -> str: + @validator("filter") + def check_filter(cls, filter_name: str) -> str: global _FILTERS if _FILTERS and filter_name not in _FILTERS: - raise ValueError(f'Unknown filter: `{filter_name}`') + raise ValueError(f"Unknown filter: `{filter_name}`") return filter_name - @validator('parameters') - def check_parameters(cls, parameters:Dict[str,Any], values:Dict[str,Any], **kwargs) -> Dict[str,Any]: + @validator("parameters") + def check_parameters( + cls, parameters: Dict[str, Any], values: Dict[str, Any], **kwargs + ) -> Dict[str, Any]: global _FILTERS - if _FILTERS and 'filter' in values: - required = set(_FILTERS[values['filter']].parameters.keys()) + if _FILTERS and "filter" in values: + required = set(_FILTERS[values["filter"]].parameters.keys()) provided = set(parameters.keys()) missing_keys = required - provided @@ -162,27 +159,41 @@ def check_parameters(cls, parameters:Dict[str,Any], values:Dict[str,Any], **kwar warn(f"Missing filter parameters: {' '.join(missing_keys)}") # Just add their default values in that case. parameters |= { - key: parameter.default if hasattr(parameter, 'default') and parameter.default is not None else parameter.default_factory() - for key, parameter in _FILTERS[values['filter']].parameters.items() + key: parameter.default + if hasattr(parameter, "default") and parameter.default is not None + else parameter.default_factory() + for key, parameter in _FILTERS[values["filter"]].parameters.items() if key in missing_keys } - + superfluous_keys = provided - required if superfluous_keys: - warn(f"Provided parameters not supported by the filter: {' '.join(superfluous_keys)}") + warn( + f"Provided parameters not supported by the filter: {' '.join(superfluous_keys)}" + ) # Not doing anything though, might be that we have just loaded an # old version of the filter definition and we don't want to lose # any of these keys. return parameters - @validator('language', always=True) - def check_language_is_provided(cls, language:str, values:Dict[str,Any], **kwargs) -> str: - if _FILTERS and 'filter' in values: - if _FILTERS[values['filter']].type == FilterType.BILINGUAL and language is not None: - raise ValueError('Cannot `language` attribute for a bilingual filter') - elif _FILTERS[values['filter']].type == FilterType.MONOLINGUAL and language is None: - raise ValueError('`language` attribute required for a monolingual filter') + @validator("language", always=True) + def check_language_is_provided( + cls, language: str, values: Dict[str, Any], **kwargs + ) -> str: + if _FILTERS and "filter" in values: + if ( + _FILTERS[values["filter"]].type == FilterType.BILINGUAL + and language is not None + ): + raise ValueError("Cannot `language` attribute for a bilingual filter") + elif ( + _FILTERS[values["filter"]].type == FilterType.MONOLINGUAL + and language is None + ): + raise ValueError( + "`language` attribute required for a monolingual filter" + ) return language @@ -192,37 +203,37 @@ class FilterPipeline(BaseModel): filters: List[FilterStep] -def list_filters(paths:str) -> Iterable[Filter]: +def list_filters(paths: str) -> Iterable[Filter]: for path in paths.split(os.pathsep): for filename in glob(path, recursive=True): try: with open(filename) as fh: defaults = { "name": os.path.splitext(os.path.basename(filename))[0], - "basedir": os.path.dirname(filename) + "basedir": os.path.dirname(filename), } yield parse_obj_as(Filter, {**defaults, **json.load(fh)}) except Exception as e: warn(f"Could not parse {filename}: {e}") -def set_global_filters(filters:Iterable[Filter]) -> None: +def set_global_filters(filters: Iterable[Filter]) -> None: global _FILTERS _FILTERS = {filter.name: filter for filter in filters} -def get_global_filters() -> Dict[str,Filter]: +def get_global_filters() -> Dict[str, Filter]: global _FILTERS return _FILTERS -def get_global_filter(name:str) -> Filter: +def get_global_filter(name: str) -> Filter: return get_global_filters()[name] def format_shell(val: Any) -> str: if isinstance(val, bool): - return '1' if val else '' + return "1" if val else "" elif isinstance(val, tuple): raise NotImplementedError() elif isinstance(val, list): @@ -231,11 +242,20 @@ def format_shell(val: Any) -> str: return str(val) -def filter_format_command(filter_definition:Filter, filter_step:FilterStep, langs:List[str], *, path_to_col:List[str]=COL_PY) -> str: +def filter_format_command( + filter_definition: Filter, + filter_step: FilterStep, + langs: List[str], + *, + path_to_col: List[str] = COL_PY, +) -> str: if filter_definition.type == FilterType.BILINGUAL: command = filter_definition.command elif filter_definition.type == FilterType.MONOLINGUAL: - columns = [langs.index(language) for language in none_throws(filter_step.language).split(',')] + columns = [ + langs.index(language) + for language in none_throws(filter_step.language).split(",") + ] command = f'{" ".join(map(quote, path_to_col))} {",".join(map(str, columns))} {filter_definition.command}' else: raise NotImplementedError() @@ -245,10 +265,12 @@ def filter_format_command(filter_definition:Filter, filter_step:FilterStep, lang name: props.export(filter_step.parameters[name]) for name, props in filter_definition.parameters.items() } - if 'PARAMETERS_AS_YAML' in command: - command = f'PARAMETERS_AS_YAML={quote(yaml.safe_dump(params))}; {command}' + if "PARAMETERS_AS_YAML" in command: + command = f"PARAMETERS_AS_YAML={quote(yaml.safe_dump(params))}; {command}" else: - vars_setter = '; '.join(f"{k}={quote(format_shell(v))}" for k, v in params.items()) - command = f'{vars_setter}; {command}' + vars_setter = "; ".join( + f"{k}={quote(format_shell(v))}" for k, v in params.items() + ) + command = f"{vars_setter}; {command}" return command diff --git a/opuscleaner/filters/alpha_ratio.py b/opuscleaner/filters/alpha_ratio.py index f140537..d14a913 100755 --- a/opuscleaner/filters/alpha_ratio.py +++ b/opuscleaner/filters/alpha_ratio.py @@ -5,77 +5,124 @@ import re from clean_common import CHARS + def parse_user_args(): """Parse the arguments necessary for this filter""" - parser = argparse.ArgumentParser(description="Filters the lines based on the ratio between alphabetic characters in a line from the language and others") - parser.add_argument("--ratio-words-src", default=0.6, type=float, help='Ratio between words and non words (eg numbers, foreign words) in a src sentence.') - parser.add_argument("--ratio-words-trg", default=0.6, type=float, help='Ratio between words and non words (eg numbers, foreign words) in a trg sentence.') - parser.add_argument("--ratio-alpha-src", default=0.4, type=float, help='Ratio between characters from the src language compared to all characters (eg numbers, emoji, punctuation, etc...)') - parser.add_argument("--ratio-alpha-trg", default=0.4, type=float, help='Ratio between characters from the trg language compared to all characters (eg numbers, emoji, punctuation, etc...)') - parser.add_argument("--src-lang", default="en", type=str, choices=list(CHARS.keys())) + parser = argparse.ArgumentParser( + description="Filters the lines based on the ratio between alphabetic characters in a line from the language and others" + ) + parser.add_argument( + "--ratio-words-src", + default=0.6, + type=float, + help="Ratio between words and non words (eg numbers, foreign words) in a src sentence.", + ) + parser.add_argument( + "--ratio-words-trg", + default=0.6, + type=float, + help="Ratio between words and non words (eg numbers, foreign words) in a trg sentence.", + ) + parser.add_argument( + "--ratio-alpha-src", + default=0.4, + type=float, + help="Ratio between characters from the src language compared to all characters (eg numbers, emoji, punctuation, etc...)", + ) + parser.add_argument( + "--ratio-alpha-trg", + default=0.4, + type=float, + help="Ratio between characters from the trg language compared to all characters (eg numbers, emoji, punctuation, etc...)", + ) + parser.add_argument( + "--src-lang", default="en", type=str, choices=list(CHARS.keys()) + ) parser.add_argument("--trg-lang", type=str, choices=list(CHARS.keys())) - parser.add_argument("--debug", action='store_true') + parser.add_argument("--debug", action="store_true") return parser.parse_args() -def clean_parallel(src_lang: str, ratio_words_src: float, ratio_alpha_src: float,\ -trg_lang: Optional[str], ratio_words_trg: float, ratio_alpha_trg: float,\ - debug: bool = True) -> None: + +def clean_parallel( + src_lang: str, + ratio_words_src: float, + ratio_alpha_src: float, + trg_lang: Optional[str], + ratio_words_trg: float, + ratio_alpha_trg: float, + debug: bool = True, +) -> None: """Cleans the parallel (or monolingual) dataset based on the number of characters""" for line in stdin: - fields = line.rstrip('\r\n').split('\t') + fields = line.rstrip("\r\n").split("\t") if len(fields) == 1: src = fields[-1].strip() trg = None - else: # Assumes that the multiline filter already run + else: # Assumes that the multiline filter already run src = fields[-2].strip() trg = fields[-1].strip() if src_lang in CHARS: src_toks = src.split() src_len = len(src_toks) - if src_len==0: + if src_len == 0: if debug: - stderr.write(f'EMPTY_SRC\t{src}\t{trg}\n') + stderr.write(f"EMPTY_SRC\t{src}\t{trg}\n") continue num_words = sum( - [1 if re.match(CHARS[src_lang], t, re.IGNORECASE) else 0 for t in src_toks]) + [ + 1 if re.match(CHARS[src_lang], t, re.IGNORECASE) else 0 + for t in src_toks + ] + ) if num_words / float(src_len) < ratio_words_src: if debug: - stderr.write(f'RATIO_WORDS_SRC\t{src}\t{trg}\n') + stderr.write(f"RATIO_WORDS_SRC\t{src}\t{trg}\n") continue char_alpha = len(re.findall(CHARS[src_lang], src, re.IGNORECASE)) - if char_alpha / float(len(src.replace(' ', ''))) < ratio_alpha_src: + if char_alpha / float(len(src.replace(" ", ""))) < ratio_alpha_src: if debug: - stderr.write(f'RATIO_ALPHA_SRC\t{src}\t{trg}\n') + stderr.write(f"RATIO_ALPHA_SRC\t{src}\t{trg}\n") continue if trg is not None and trg_lang in CHARS: trg_toks = trg.split() trg_len = len(trg_toks) - if trg_len==0: + if trg_len == 0: if debug: - stderr.write(f'EMPTY_TRG\t{src}\t{trg}\n') + stderr.write(f"EMPTY_TRG\t{src}\t{trg}\n") continue num_words = sum( - [1 if re.match(CHARS[trg_lang], t, re.IGNORECASE) else 0 for t in trg_toks]) + [ + 1 if re.match(CHARS[trg_lang], t, re.IGNORECASE) else 0 + for t in trg_toks + ] + ) if num_words / float(trg_len) < ratio_words_trg: if debug: - stderr.write(f'RATIO_WORDS_TRG\t{src}\t{trg}\n') + stderr.write(f"RATIO_WORDS_TRG\t{src}\t{trg}\n") continue char_alpha = len(re.findall(CHARS[trg_lang], trg, re.IGNORECASE)) - if char_alpha / float(len(trg.replace(' ', ''))) < ratio_alpha_trg: + if char_alpha / float(len(trg.replace(" ", ""))) < ratio_alpha_trg: if debug: - stderr.write(f'RATIO_ALPHA_TRG\t{src}\t{trg}\n') + stderr.write(f"RATIO_ALPHA_TRG\t{src}\t{trg}\n") continue # If none of our filters have failed, we're good to go stdout.write(line) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_user_args() - clean_parallel(src_lang=args.src_lang, ratio_words_src=args.ratio_words_src, ratio_alpha_src=args.ratio_alpha_src,\ - trg_lang=args.trg_lang, ratio_words_trg=args.ratio_words_trg, ratio_alpha_trg=args.ratio_alpha_trg, debug=args.debug) + clean_parallel( + src_lang=args.src_lang, + ratio_words_src=args.ratio_words_src, + ratio_alpha_src=args.ratio_alpha_src, + trg_lang=args.trg_lang, + ratio_words_trg=args.ratio_words_trg, + ratio_alpha_trg=args.ratio_alpha_trg, + debug=args.debug, + ) diff --git a/opuscleaner/filters/bifixer_dedupe.py b/opuscleaner/filters/bifixer_dedupe.py index f6a9082..5892f1e 100755 --- a/opuscleaner/filters/bifixer_dedupe.py +++ b/opuscleaner/filters/bifixer_dedupe.py @@ -4,17 +4,18 @@ class SentencePair(NamedTuple): - src: str - trg: str - rank: float + src: str + trg: str + rank: float -best: Dict[str,SentencePair] = {} + +best: Dict[str, SentencePair] = {} for line in sys.stdin: - src, trg, checksum, rank = line.rstrip('\r\n').split('\t') + src, trg, checksum, rank = line.rstrip("\r\n").split("\t") - if checksum not in best or best[checksum].rank < float(rank): - best[checksum] = SentencePair(src, trg, float(rank)) + if checksum not in best or best[checksum].rank < float(rank): + best[checksum] = SentencePair(src, trg, float(rank)) for pair in best.values(): - print(f"{pair.src}\t{pair.trg}", file=sys.stdout) + print(f"{pair.src}\t{pair.trg}", file=sys.stdout) diff --git a/opuscleaner/filters/clean_common.py b/opuscleaner/filters/clean_common.py index d83a01e..afbdf47 100755 --- a/opuscleaner/filters/clean_common.py +++ b/opuscleaner/filters/clean_common.py @@ -2,43 +2,43 @@ """Common filtering code to be used by various submodules""" CHARS = { - 'ar': r'[\u0600-\u06FF]', # This is not entirely right, as it also includes farsi symbols and whatnot - 'bg': r'[АаБбВвГгДддЕеЖжЗзИиЙйКкkasЛлМмНнОоПпРрСсТтУуФфХхЦцЧчШшЩщЪъЬьЮюЯя]', - 'bn': r'[\u0980-\u09FF]', # bangla - 'ca': r'[a-zÀàÈèÉéÍíÒòÓóÚúÇç]', - 'cs': r'[a-zÁáČčĎďÉéěÍíŇňÓóŘřŠšŤťÚúůÝýŽž]', - 'da': r'[a-zÆæØøÅå]', - 'de': r'[a-zÄäÖöÜüß]', - 'en': r'[a-z]', - 'el': r'[a-zΑαΒβΓγΔδΕεΖζΗηΘθΙιΚκΛλΜμΝνΞξΟοΠπΡρΣσςΤτΥυΦφΧχΨψΩω]', - 'es': r'[a-zÁáÉéÍíÓóÚúñÑ]', - 'et': r'[a-zÕõÄäÖöÜü]', - 'eu': r'[a-zñÑ]', - 'fi': r'[a-zÅåÄäÖö]', - 'fr': r'[a-zÂâÁáÀàâÇçÉéÈèÊêÓóÒòÔôŒœÜüÛûŸÿ]', - 'ga': r'[abcdefghilmnoprstuáéíóúÁÉÍÓÚ]', - 'gl': r'[a-zÁáÉéÍíÓóÚúÑñ]', - 'hi': r'[\u0900-\u097F]', # devanagari - 'hr': r'[abcčČćĆdđĐefghijklmnoprsšŠtuvzžŽ]', - 'hu': r'[a-zÁáÉéÍíÓóÖöŐőŰű]', - 'hy': r'[\u0530-\u058F]', - 'is': r'[abdefghijklmnoprstuvxyÁáðÐÉéÍíÓóÚúÝýÞþÆæÖö]', - 'it': r'[a-zàÀèÈéÉìÌíÍîÎòÒóÓùÙúÚ]', - 'ko': r'[\uac00-\ud7af]|[\u1100-\u11ff]|[\u3130-\u318f]|[\ua960-\ua97f]|[\ud7b0-\ud7ff]', - 'lt': r'[aąbcČčdeĘęĖėfghiĮįyjklmnoprsŠštuŲųŪūvzŽž]', - 'lv': r'[aĀābcČčdeĒēfgĢģhiĪījkĶķlĻļmnŅņoprsŠštuŪūvzŽž]', - 'mt': r'[abĊċdefĠġghĦħiiejklmnopqrstuvwxŻżz]', - 'nb': r'[a-zÂâÁáÀàâÉéÈèÊêÓóÒòÔôÜüÆæØøÅå]', - 'nl': r'[a-zÂâÁáÀàâÉéÈèÊêÓóÒòÔôÚú]', - 'no': r'[a-zÂâÁáÀàâÉéÈèÊêÓóÒòÔôÜüÆæØøÅå]', - 'nn': r'[a-zÂâÁáÀàâÉéÈèÊêÓóÒòÔôÜüÆæØøÅå]', - 'pl': r'[a-zĄąĆćĘꣳŃńÓ󌜏źŻż]', - 'pt': r'[a-zÂâÁáÀàÃãÇçÉéÈèÊêÍíÌìÓóÒòÔôÕõÚúÙù]', - 'ro': r'[a-zĂăÂâÎîȘșȚț]', - 'ru': r'[а-я]', - 'sk': r'[a-záäÁÄčČďĎžéÉíÍĺĹľĽňŇóÓôÔŕŔšŠťŤúÚýÝžŽ]', - 'sl': r'[abcčČdđĐefghijklmnoprsšŠtuvzžŽ]', - 'sv': r'[a-zÅåÄäÖö]', - 'uk': r'[А-ЩЬЮЯҐЄІЇа-щьюяґєії\'`’ʼ]', - 'zh': r'[\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff]', + "ar": r"[\u0600-\u06FF]", # This is not entirely right, as it also includes farsi symbols and whatnot + "bg": r"[АаБбВвГгДддЕеЖжЗзИиЙйКкkasЛлМмНнОоПпРрСсТтУуФфХхЦцЧчШшЩщЪъЬьЮюЯя]", + "bn": r"[\u0980-\u09FF]", # bangla + "ca": r"[a-zÀàÈèÉéÍíÒòÓóÚúÇç]", + "cs": r"[a-zÁáČčĎďÉéěÍíŇňÓóŘřŠšŤťÚúůÝýŽž]", + "da": r"[a-zÆæØøÅå]", + "de": r"[a-zÄäÖöÜüß]", + "en": r"[a-z]", + "el": r"[a-zΑαΒβΓγΔδΕεΖζΗηΘθΙιΚκΛλΜμΝνΞξΟοΠπΡρΣσςΤτΥυΦφΧχΨψΩω]", + "es": r"[a-zÁáÉéÍíÓóÚúñÑ]", + "et": r"[a-zÕõÄäÖöÜü]", + "eu": r"[a-zñÑ]", + "fi": r"[a-zÅåÄäÖö]", + "fr": r"[a-zÂâÁáÀàâÇçÉéÈèÊêÓóÒòÔôŒœÜüÛûŸÿ]", + "ga": r"[abcdefghilmnoprstuáéíóúÁÉÍÓÚ]", + "gl": r"[a-zÁáÉéÍíÓóÚúÑñ]", + "hi": r"[\u0900-\u097F]", # devanagari + "hr": r"[abcčČćĆdđĐefghijklmnoprsšŠtuvzžŽ]", + "hu": r"[a-zÁáÉéÍíÓóÖöŐőŰű]", + "hy": r"[\u0530-\u058F]", + "is": r"[abdefghijklmnoprstuvxyÁáðÐÉéÍíÓóÚúÝýÞþÆæÖö]", + "it": r"[a-zàÀèÈéÉìÌíÍîÎòÒóÓùÙúÚ]", + "ko": r"[\uac00-\ud7af]|[\u1100-\u11ff]|[\u3130-\u318f]|[\ua960-\ua97f]|[\ud7b0-\ud7ff]", + "lt": r"[aąbcČčdeĘęĖėfghiĮįyjklmnoprsŠštuŲųŪūvzŽž]", + "lv": r"[aĀābcČčdeĒēfgĢģhiĪījkĶķlĻļmnŅņoprsŠštuŪūvzŽž]", + "mt": r"[abĊċdefĠġghĦħiiejklmnopqrstuvwxŻżz]", + "nb": r"[a-zÂâÁáÀàâÉéÈèÊêÓóÒòÔôÜüÆæØøÅå]", + "nl": r"[a-zÂâÁáÀàâÉéÈèÊêÓóÒòÔôÚú]", + "no": r"[a-zÂâÁáÀàâÉéÈèÊêÓóÒòÔôÜüÆæØøÅå]", + "nn": r"[a-zÂâÁáÀàâÉéÈèÊêÓóÒòÔôÜüÆæØøÅå]", + "pl": r"[a-zĄąĆćĘꣳŃńÓ󌜏źŻż]", + "pt": r"[a-zÂâÁáÀàÃãÇçÉéÈèÊêÍíÌìÓóÒòÔôÕõÚúÙù]", + "ro": r"[a-zĂăÂâÎîȘșȚț]", + "ru": r"[а-я]", + "sk": r"[a-záäÁÄčČďĎžéÉíÍĺĹľĽňŇóÓôÔŕŔšŠťŤúÚýÝžŽ]", + "sl": r"[abcčČdđĐefghijklmnoprsšŠtuvzžŽ]", + "sv": r"[a-zÅåÄäÖö]", + "uk": r"[А-ЩЬЮЯҐЄІЇа-щьюяґєії\'`’ʼ]", + "zh": r"[\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff]", } diff --git a/opuscleaner/filters/deescape_tsv.py b/opuscleaner/filters/deescape_tsv.py index f76694e..3e969c8 100755 --- a/opuscleaner/filters/deescape_tsv.py +++ b/opuscleaner/filters/deescape_tsv.py @@ -4,9 +4,9 @@ QUOTECHR = ord('"') for line in sys.stdin.buffer: - fields = line.rstrip(b"\r\n").split(b"\t") - for i, field in enumerate(fields): - if len(field) > 0 and field[0] == QUOTECHR and field[-1] == QUOTECHR: - fields[i] = field[1:-1].replace(b'""', b'"') - sys.stdout.buffer.write(b"\t".join(fields)) - sys.stdout.buffer.write(b"\n") + fields = line.rstrip(b"\r\n").split(b"\t") + for i, field in enumerate(fields): + if len(field) > 0 and field[0] == QUOTECHR and field[-1] == QUOTECHR: + fields[i] = field[1:-1].replace(b'""', b'"') + sys.stdout.buffer.write(b"\t".join(fields)) + sys.stdout.buffer.write(b"\n") diff --git a/opuscleaner/filters/fasttext_filter.py b/opuscleaner/filters/fasttext_filter.py index 32a7c69..d67ceb7 100755 --- a/opuscleaner/filters/fasttext_filter.py +++ b/opuscleaner/filters/fasttext_filter.py @@ -30,7 +30,9 @@ def download_model(model_type: str): handle.write(response.content) -def verify_lang(model: fasttext.FastText._FastText, texts: List[str], desired_lang: str, debug: bool) -> List[bool]: +def verify_lang( + model: fasttext.FastText._FastText, texts: List[str], desired_lang: str, debug: bool +) -> List[bool]: # Langs is a list of list - for each row we get a list of identified languages, sorted by their probability. # Future work - using `model.predict(texts, k=10)` get the 10 most probable languages # and do some clever filtering based on the distribution. @@ -42,10 +44,21 @@ def verify_lang(model: fasttext.FastText._FastText, texts: List[str], desired_la def main(): - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("--source-lang", type=str, help="Code of the desired source language.") - parser.add_argument("--target-lang", type=str, help="Code of the desired target language.") - parser.add_argument("--batch-size", type=int, default=16, help="Size of the batch to send the data to fasttext.") + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--source-lang", type=str, help="Code of the desired source language." + ) + parser.add_argument( + "--target-lang", type=str, help="Code of the desired target language." + ) + parser.add_argument( + "--batch-size", + type=int, + default=16, + help="Size of the batch to send the data to fasttext.", + ) parser.add_argument("--debug", action="store_true") parser.add_argument( "--model-type", diff --git a/opuscleaner/filters/fix_elitr_eca.py b/opuscleaner/filters/fix_elitr_eca.py index 84fabb3..ba42e3c 100755 --- a/opuscleaner/filters/fix_elitr_eca.py +++ b/opuscleaner/filters/fix_elitr_eca.py @@ -5,44 +5,45 @@ mapping = [ - ("\u0083", "É"), # DCLARATION - ("\u0088", "à"), # permettre lŐUnion - ("\u0089", "â"), # probants obtenus grce aux - ("\u008d", "ç"), # Recettes perues - ("\u008e", "é"), # dclaration - ("\u008f", "è"), # ci-aprs - ("\u0090", "ê"), # doivent tre exclus - ("\u0091", "ë"), # lŐexercice ont d être - ("\u0094", "î"), # L'Agence reconnat - ("\u0099", "ô"), # xLes contrles - ("\u009d", "ù"), # dans la mesure o il - ("\u009e", "û"), # l'exercice ont d être reportés - ("\u0092", "í"), # Vtor Manuel da SILVA CALDEIRA - ("Ő", "'"), # lŐexercice - ("ă", "'"), # ăAutoriteitÓ - ("Ó", "'"), - ("􏳕", "ë"), # financi􏳕le - ("¬ ", ""), + ("\u0083", "É"), # DCLARATION + ("\u0088", "à"), # permettre lŐUnion + ("\u0089", "â"), # probants obtenus grce aux + ("\u008d", "ç"), # Recettes perues + ("\u008e", "é"), # dclaration + ("\u008f", "è"), # ci-aprs + ("\u0090", "ê"), # doivent tre exclus + ("\u0091", "ë"), # lŐexercice ont d être + ("\u0094", "î"), # L'Agence reconnat + ("\u0099", "ô"), # xLes contrles + ("\u009d", "ù"), # dans la mesure o il + ("\u009e", "û"), # l'exercice ont d être reportés + ("\u0092", "í"), # Vtor Manuel da SILVA CALDEIRA + ("Ő", "'"), # lŐexercice + ("ă", "'"), # ăAutoriteitÓ + ("Ó", "'"), + ("􏳕", "ë"), # financi􏳕le + ("¬ ", ""), ] class Translator: - def __init__(self, mapping): - self.mapping = {entry[0]: entry[1] for entry in mapping} - self.pattern = re.compile('(' + '|'.join(self.mapping.keys()) + ')') - self.callback = lambda match: self.mapping[match[0]] + def __init__(self, mapping): + self.mapping = {entry[0]: entry[1] for entry in mapping} + self.pattern = re.compile("(" + "|".join(self.mapping.keys()) + ")") + self.callback = lambda match: self.mapping[match[0]] - def __call__(self, input): - return re.sub(self.pattern, self.callback, input) + def __call__(self, input): + return re.sub(self.pattern, self.callback, input) def parse_user_args(): - parser = argparse.ArgumentParser(description="Fixes select encoding issues on the French side of the ELITR ECA dataset.") + parser = argparse.ArgumentParser( + description="Fixes select encoding issues on the French side of the ELITR ECA dataset." + ) return parser.parse_args() if __name__ == "__main__": - args = parse_user_args() - for line in sys.stdin: - sys.stdout.write(Translator(mapping)(line)) - + args = parse_user_args() + for line in sys.stdin: + sys.stdout.write(Translator(mapping)(line)) diff --git a/opuscleaner/filters/fix_quotes.py b/opuscleaner/filters/fix_quotes.py index 3cb3c24..d5d4e7a 100755 --- a/opuscleaner/filters/fix_quotes.py +++ b/opuscleaner/filters/fix_quotes.py @@ -2,13 +2,15 @@ import sys import re -def fix(text:str)->str: - return re.sub(r'^[\'‘"“„](.+?)["”;]*$', r'\1', text) + +def fix(text: str) -> str: + return re.sub(r'^[\'‘"“„](.+?)["”;]*$', r"\1", text) + for line in sys.stdin: - fields = line.rstrip("\r\n").split("\t") + fields = line.rstrip("\r\n").split("\t") - fields = [fix(field).strip() for field in fields] + fields = [fix(field).strip() for field in fields] - if all(len(field) > 0 for field in fields): - print("\t".join(fields)) + if all(len(field) > 0 for field in fields): + print("\t".join(fields)) diff --git a/opuscleaner/filters/fix_sent_final_punct.py b/opuscleaner/filters/fix_sent_final_punct.py index 6c3abc9..4a0f211 100755 --- a/opuscleaner/filters/fix_sent_final_punct.py +++ b/opuscleaner/filters/fix_sent_final_punct.py @@ -1,40 +1,98 @@ #!/usr/bin/env python3 import sys -my_punct = {'!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~', '»', '«', '“', '”'} +my_punct = { + "!", + '"', + "#", + "$", + "%", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + ".", + "/", + ":", + ";", + "<", + "=", + ">", + "?", + "@", + "[", + "\\", + "]", + "^", + "_", + "`", + "{", + "|", + "}", + "~", + "»", + "«", + "“", + "”", +} for line in sys.stdin: src, trg = line.rstrip("\r\n").split("\t") if len(src) == 0 or len(trg) == 0: - print(src + '\t' + trg) + print(src + "\t" + trg) continue # Sometimes we have a space between the final letter and the punctuation, which is wrong except if using french quotes - if len(src) >= 2 and src[-1] in my_punct and src[-2] == " " and src[-1] != '»' and src[-1] != '«': + if ( + len(src) >= 2 + and src[-1] in my_punct + and src[-2] == " " + and src[-1] != "»" + and src[-1] != "«" + ): src = src[:-2] + src[-1] - if len(trg) >= 2 and trg[-1] in my_punct and trg[-2] == " " and trg[-1] != '»' and trg[-1] != '«': + if ( + len(trg) >= 2 + and trg[-1] in my_punct + and trg[-2] == " " + and trg[-1] != "»" + and trg[-1] != "«" + ): trg = trg[:-2] + trg[-1] # Sometimes two punctuation marks are swapped... - if len(src) >=2 and len(trg) >= 2 and src[-2] == trg[-1] and src[-1] == trg[-2]: - trg = trg[:-2] + src[-2] + src[-1] + if len(src) >= 2 and len(trg) >= 2 and src[-2] == trg[-1] and src[-1] == trg[-2]: + trg = trg[:-2] + src[-2] + src[-1] # Sometimes they are swapped with space around eg SPACE». -> .SPACE» - if len(src) >=3 and src[-1] in my_punct and src[-2] == '»' and src[-3] == ' ': - src = src[:-3] + src[-1] + ' ' + src[-2] - if len(trg) >=3 and trg[-1] in my_punct and trg[-2] == '»' and trg[-3] == ' ': - trg = trg[:-3] + trg[-1] + ' ' + trg[-2] - + if len(src) >= 3 and src[-1] in my_punct and src[-2] == "»" and src[-3] == " ": + src = src[:-3] + src[-1] + " " + src[-2] + if len(trg) >= 3 and trg[-1] in my_punct and trg[-2] == "»" and trg[-3] == " ": + trg = trg[:-3] + trg[-1] + " " + trg[-2] # check for the french quotes special case - if (src[-1] == '»' or src[-1] == '«') and trg[-1] not in my_punct: + if (src[-1] == "»" or src[-1] == "«") and trg[-1] not in my_punct: trg = trg + '"' - elif (trg[-1] == '»' or trg[-1] == '«') and src[-1] not in my_punct: + elif (trg[-1] == "»" or trg[-1] == "«") and src[-1] not in my_punct: src = src + '"' elif src[-1] in my_punct and trg[-1] not in my_punct: trg = trg + src[-1] elif trg[-1] in my_punct and src[-1] not in my_punct: src = src + trg[-1] # Final case. Fix mismatched punctuation on the src and trg. EXCEPT in cases like french quotes. And in cases where we have emdash at the front, as it means spech - elif trg[-1] in my_punct and src[-1] in my_punct and src[-1] != trg[-1] and src[-1] != '»' \ -and src[-1] != '«' and trg[-1] != '»' and trg[-1] != '«' and src[0] != '–' and trg[0] != '–' and src[0] != '—' and trg[0] != '—': + elif ( + trg[-1] in my_punct + and src[-1] in my_punct + and src[-1] != trg[-1] + and src[-1] != "»" + and src[-1] != "«" + and trg[-1] != "»" + and trg[-1] != "«" + and src[0] != "–" + and trg[0] != "–" + and src[0] != "—" + and trg[0] != "—" + ): trg = trg[:-1] + src[-1] - print(src + '\t' + trg) + print(src + "\t" + trg) diff --git a/opuscleaner/filters/fix_un_chinese.py b/opuscleaner/filters/fix_un_chinese.py index 0395dba..89d32f1 100755 --- a/opuscleaner/filters/fix_un_chinese.py +++ b/opuscleaner/filters/fix_un_chinese.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 """Legacy fix_un.py from Barry. Need to fix it up a bit.""" + import re import sys @@ -9,8 +10,8 @@ for line in sys.stdin: line = line.strip() - if line[-1] == ',': + if line[-1] == ",": line = line[:-1] + "\u3002" - line = line.replace(",", "\uFF0C") + line = line.replace(",", "\uff0c") line = re_final_comma.sub("\u3002", line) print(line) diff --git a/opuscleaner/filters/fix_wiki.py b/opuscleaner/filters/fix_wiki.py index 2ce6668..466f4b2 100755 --- a/opuscleaner/filters/fix_wiki.py +++ b/opuscleaner/filters/fix_wiki.py @@ -7,91 +7,105 @@ class MatchType(Enum): - EXACT = 'exact' - COUNT = 'count' + EXACT = "exact" + COUNT = "count" -Pattern = Tuple[str,str, MatchType] +Pattern = Tuple[str, str, MatchType] # Footnote pattern needs to have the exact same matches on both sides. They're # just things like `[3]` at the end of a word. -FOOTNOTE_PATTERN: Pattern = r'\[[0-9]+\]', r'', MatchType.EXACT +FOOTNOTE_PATTERN: Pattern = r"\[[0-9]+\]", r"", MatchType.EXACT # URL match needs to be an exact match on both sides. If not, we teach the MT # system to translate urls for us and that's not good. -URL_PATTERN: Pattern = r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)', r'', MatchType.EXACT +URL_PATTERN: Pattern = ( + r"https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)", + r"", + MatchType.EXACT, +) # If wiki links, e.g. `[[pagename|Link Label]]` or ``[[pagename]]` don't appear # in equal amount on both sides, replace them with just their label. Not # comparing the links themselves since they may link to translated page names. -WIKILINKS_PATTERN: Pattern = r'\[\[(?:.+?\|)?(.+?)\]\]', r'\1', MatchType.COUNT +WIKILINKS_PATTERN: Pattern = r"\[\[(?:.+?\|)?(.+?)\]\]", r"\1", MatchType.COUNT # If header does not appear on both sides, remove it entirely. We only compare # counts because the headers themselves are probably translated. -HEADINGS_PATTERN: Pattern = r'(==+)(.+?)\1', r'', MatchType.COUNT +HEADINGS_PATTERN: Pattern = r"(==+)(.+?)\1", r"", MatchType.COUNT -CODE_PATTERN = r'\.mw-parser-output' # Very specific for OPUS-wikimedia +CODE_PATTERN = r"\.mw-parser-output" # Very specific for OPUS-wikimedia -def find_matches(pattern:Pattern, text:str) -> Set[str]: - return set(match[0] for match in re.finditer(pattern[0], text)) +def find_matches(pattern: Pattern, text: str) -> Set[str]: + return set(match[0] for match in re.finditer(pattern[0], text)) -def filter_matches(pattern:Pattern, text:str) -> str: - return re.sub(pattern[0], pattern[1], text) +def filter_matches(pattern: Pattern, text: str) -> str: + return re.sub(pattern[0], pattern[1], text) -def is_mismatch(pattern:Pattern, fields: List[str]) -> bool: - matches = [find_matches(pattern, field) for field in fields[:2]] - if pattern[2] == MatchType.EXACT: - return len(matches[0] & matches[1]) < len(matches[0] ^ matches[1]) - elif pattern[2] == MatchType.COUNT: - return len(matches[0]) != len(matches[1]) - else: - raise NotImplementedError() +def is_mismatch(pattern: Pattern, fields: List[str]) -> bool: + matches = [find_matches(pattern, field) for field in fields[:2]] + if pattern[2] == MatchType.EXACT: + return len(matches[0] & matches[1]) < len(matches[0] ^ matches[1]) + elif pattern[2] == MatchType.COUNT: + return len(matches[0]) != len(matches[1]) + else: + raise NotImplementedError() def is_code(field: str) -> bool: - return re.search(CODE_PATTERN, field) is not None + return re.search(CODE_PATTERN, field) is not None if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Remove common wiki patterns from sentence pair if they don't match on both sides") - parser.add_argument("--always", action="store_true", help="Always remove patterns") - parser.add_argument("--footnotes", action="store_true", help="Remove footnotes, e.g. [1], [2]") - parser.add_argument("--urls", action="store_true", help="Remove url`s") - parser.add_argument("--wikilinks", action="store_true", help="Remove [[wikilinks]]") - parser.add_argument("--code", action="store_true", help="Remove lines that contain code") - parser.add_argument("--headings", action="store_true", help="Remove ==headings==") - parser.add_argument("--remove-empty-lines", action="store_true", help="Remove sentence pairs when one side is empty after filtering") - args = parser.parse_args() - - patterns: List[Pattern] = [] - - if args.footnotes: - patterns.append(FOOTNOTE_PATTERN) - - if args.urls: - patterns.append(URL_PATTERN) - - if args.wikilinks: - patterns.append(WIKILINKS_PATTERN) - - if args.headings: - patterns.append(HEADINGS_PATTERN) - - for n, line in enumerate(sys.stdin, start=1): - fields = line.rstrip("\r\n").split("\t") - - if args.code and any(is_code(field) for field in fields): - continue - - for pattern in patterns: - if args.always or is_mismatch(pattern, fields[:2]): - fields[:2] = [filter_matches(pattern, field) for field in fields[:2]] - - # Make sure we didn't add padding after all that replacing - fields = [field.strip() for field in fields] - - if args.remove_empty_lines and all(len(field) > 0 for field in fields): - print("\t".join(fields)) + parser = argparse.ArgumentParser( + description="Remove common wiki patterns from sentence pair if they don't match on both sides" + ) + parser.add_argument("--always", action="store_true", help="Always remove patterns") + parser.add_argument( + "--footnotes", action="store_true", help="Remove footnotes, e.g. [1], [2]" + ) + parser.add_argument("--urls", action="store_true", help="Remove url`s") + parser.add_argument("--wikilinks", action="store_true", help="Remove [[wikilinks]]") + parser.add_argument( + "--code", action="store_true", help="Remove lines that contain code" + ) + parser.add_argument("--headings", action="store_true", help="Remove ==headings==") + parser.add_argument( + "--remove-empty-lines", + action="store_true", + help="Remove sentence pairs when one side is empty after filtering", + ) + args = parser.parse_args() + + patterns: List[Pattern] = [] + + if args.footnotes: + patterns.append(FOOTNOTE_PATTERN) + + if args.urls: + patterns.append(URL_PATTERN) + + if args.wikilinks: + patterns.append(WIKILINKS_PATTERN) + + if args.headings: + patterns.append(HEADINGS_PATTERN) + + for n, line in enumerate(sys.stdin, start=1): + fields = line.rstrip("\r\n").split("\t") + + if args.code and any(is_code(field) for field in fields): + continue + + for pattern in patterns: + if args.always or is_mismatch(pattern, fields[:2]): + fields[:2] = [filter_matches(pattern, field) for field in fields[:2]] + + # Make sure we didn't add padding after all that replacing + fields = [field.strip() for field in fields] + + if args.remove_empty_lines and all(len(field) > 0 for field in fields): + print("\t".join(fields)) diff --git a/opuscleaner/filters/langid.py b/opuscleaner/filters/langid.py index ad2122b..36dbf7f 100755 --- a/opuscleaner/filters/langid.py +++ b/opuscleaner/filters/langid.py @@ -5,7 +5,7 @@ import pycld2 -# Similar languages, taken from +# Similar languages, taken from # https://github.com/mbanon/fastspell/blob/main/fastspell/config/similar.yaml SIMILAR = { "ca": {"es", "ca"}, @@ -19,10 +19,10 @@ "me": {"bs", "hr", "me", "sr"}, "mk": {"bg", "mk"}, "nb": {"nn", "da", "nb"}, - "nl": {"nl", "af"}, # Maybe also Frisian (fy) and French (fr) because of - # short sentences are often misidentified as one of - # those (and honestly cld2 has probably been trained - # with a lot of Dutch in their Frisian corpora.) + "nl": {"nl", "af"}, # Maybe also Frisian (fy) and French (fr) because of + # short sentences are often misidentified as one of + # those (and honestly cld2 has probably been trained + # with a lot of Dutch in their Frisian corpora.) "nn": {"nb", "da", "nn"}, "sk": {"cs", "sk"}, "sr": {"bs", "hr", "me", "sr"}, @@ -30,6 +30,7 @@ LANG_UNKNOWN = "un" + def parse_user_args(): """Parse the arguments necessary for this filter""" parser = argparse.ArgumentParser(description="Langid") @@ -61,8 +62,11 @@ def detect_language_parallel(args: argparse.Namespace, fin: BinaryIO, fout: Bina continue if args.debug: - print(f"Line {n} rejected. Detected '{detected_lang}', expected '{lang}': {field.decode()}", file=stderr) - + print( + f"Line {n} rejected. Detected '{detected_lang}', expected '{lang}': {field.decode()}", + file=stderr, + ) + # Break because no need to look at the other columns. Also will # stop the else clause from being executed! break diff --git a/opuscleaner/filters/laser_similarity.py b/opuscleaner/filters/laser_similarity.py index c798159..04719e8 100755 --- a/opuscleaner/filters/laser_similarity.py +++ b/opuscleaner/filters/laser_similarity.py @@ -12,7 +12,9 @@ from io import TextIOBase -def _compute_similarity(laser: Laser, batch: List[Tuple[str, str]], src_lang: str, tgt_lang: str) -> List[float]: +def _compute_similarity( + laser: Laser, batch: List[Tuple[str, str]], src_lang: str, tgt_lang: str +) -> List[float]: assert len(batch) > 0 embeddings_src = laser.embed_sentences([line[0] for line in batch], lang=src_lang) embeddings_tgt = laser.embed_sentences([line[1] for line in batch], lang=tgt_lang) @@ -23,20 +25,30 @@ def _cosine_sim(emb1: np.ndarray, emb2: np.ndarray) -> np.ndarray: return np.sum(emb1 * emb2, axis=-1) / (norm(emb1, axis=-1) * norm(emb2, axis=-1)) -def interpolate(sample: Iterable[Tuple[int, float]], target:float) -> int: - poly = Polynomial.fit([duration for size, duration in sample], [size for size, duration in sample], 1) +def interpolate(sample: Iterable[Tuple[int, float]], target: float) -> int: + poly = Polynomial.fit( + [duration for size, duration in sample], [size for size, duration in sample], 1 + ) return int(poly(target)), poly class NullIO(TextIOBase): """TextIO that does nothing, as if writing to /dev/null.""" - def write(self, data:str) -> int: + + def write(self, data: str) -> int: return len(data) -T = TypeVar('T') +T = TypeVar("T") + -def chunked(iterable: Iterable[T], *, chunk_size:Optional[int]=None, chunk_time:Optional[float]=None, verbose:Optional[TextIO]=NullIO()) -> Iterable[List[T]]: +def chunked( + iterable: Iterable[T], + *, + chunk_size: Optional[int] = None, + chunk_time: Optional[float] = None, + verbose: Optional[TextIO] = NullIO(), +) -> Iterable[List[T]]: """Self-tuning batching iterator""" it = iter(iterable) @@ -76,23 +88,42 @@ def chunked(iterable: Iterable[T], *, chunk_size:Optional[int]=None, chunk_time: except StopIteration: # No, we've run all the samples. Use previous measurements limit, poly = interpolate(measurements, chunk_time) - print(f'Fitted {poly}', file=verbose) + print(f"Fitted {poly}", file=verbose) print(f"Setting chunk size to {limit}", file=verbose) def main(): - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, - description="Filter a parallel dataset using LASER.") + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="Filter a parallel dataset using LASER.", + ) parser.add_argument("--verbose", action="store_true", help="Print tuning info") parser.add_argument("--batch-size", type=int, help="LASER batch size") - parser.add_argument("--batch-latency", type=float, default=10.0, help="Tune batch size to process a batch every N seconds (defaults to 10s, ignored if --batch-size is given)") - parser.add_argument("--src-lang", type=str, required=True, help="Two-letter source language code (ISO 639-1)") - parser.add_argument("--tgt-lang", type=str, required=True, help="Two-letter target language code (ISO 639-1)") - + parser.add_argument( + "--batch-latency", + type=float, + default=10.0, + help="Tune batch size to process a batch every N seconds (defaults to 10s, ignored if --batch-size is given)", + ) + parser.add_argument( + "--src-lang", + type=str, + required=True, + help="Two-letter source language code (ISO 639-1)", + ) + parser.add_argument( + "--tgt-lang", + type=str, + required=True, + help="Two-letter target language code (ISO 639-1)", + ) + group = parser.add_mutually_exclusive_group() group.add_argument("--threshold", type=float, help="Minimum accepted LASER score.") - group.add_argument("--scores", action="store_true", help="Print scores instead of lines") + group.add_argument( + "--scores", action="store_true", help="Print scores instead of lines" + ) args = parser.parse_args() @@ -101,9 +132,19 @@ def main(): laser = Laser() - for batch in chunked(sys.stdin, chunk_size=args.batch_size, chunk_time=args.batch_latency, verbose=sys.stderr if args.verbose else NullIO()): + for batch in chunked( + sys.stdin, + chunk_size=args.batch_size, + chunk_time=args.batch_latency, + verbose=sys.stderr if args.verbose else NullIO(), + ): # TODO error checking of column count? - scores = _compute_similarity(laser, [tuple(line.rstrip("\r\n").split("\t")[:2]) for line in batch], args.src_lang, args.tgt_lang) + scores = _compute_similarity( + laser, + [tuple(line.rstrip("\r\n").split("\t")[:2]) for line in batch], + args.src_lang, + args.tgt_lang, + ) if args.scores: for score in scores: @@ -115,4 +156,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/opuscleaner/filters/max_length.py b/opuscleaner/filters/max_length.py index 38a7308..d8f77e1 100755 --- a/opuscleaner/filters/max_length.py +++ b/opuscleaner/filters/max_length.py @@ -5,26 +5,28 @@ def parse_user_args(): """Parse the arguments necessary for this filter""" - parser = argparse.ArgumentParser(description="Filters a parallel or mono dataset based on line lengths") + parser = argparse.ArgumentParser( + description="Filters a parallel or mono dataset based on line lengths" + ) parser.add_argument("--max-length", default=150, type=float) parser.add_argument("--min-length", default=1, type=float) - parser.add_argument("--debug", action='store_true') + parser.add_argument("--debug", action="store_true") return parser.parse_args() -def clean_parallel(max_length: float, min_length: float, debug: bool=True) -> None: +def clean_parallel(max_length: float, min_length: float, debug: bool = True) -> None: """Cleans the parallel or mono dataset based on line lengths""" for line in stdin: - fields = line.strip('\r\n').split('\t') + fields = line.strip("\r\n").split("\t") if len(fields) == 1: src = fields[-1].strip() trg = None - else: # Assumes that the multiline filter already run + else: # Assumes that the multiline filter already run src = fields[-2].strip() trg = fields[-1].strip() srctok = src.split() - srcpass: bool = (len(srctok) <= max_length and len(srctok) >= min_length) + srcpass: bool = len(srctok) <= max_length and len(srctok) >= min_length trgpass: bool @@ -33,15 +35,15 @@ def clean_parallel(max_length: float, min_length: float, debug: bool=True) -> No trgpass = True else: trgtok = trg.split() - trgpass = (len(trgtok) <= max_length and len(trgtok) >= min_length) + trgpass = len(trgtok) <= max_length and len(trgtok) >= min_length # write if srcpass and trgpass: stdout.write(line) elif debug: - stderr.write(f'LENGTH\t{src}\t{trg}\n') + stderr.write(f"LENGTH\t{src}\t{trg}\n") -if __name__ == '__main__': +if __name__ == "__main__": args = parse_user_args() clean_parallel(args.max_length, args.min_length, args.debug) diff --git a/opuscleaner/filters/max_word_length.py b/opuscleaner/filters/max_word_length.py index c4c0844..44034a7 100755 --- a/opuscleaner/filters/max_word_length.py +++ b/opuscleaner/filters/max_word_length.py @@ -6,7 +6,9 @@ def parse_user_args(): """Parse the arguments necessary for this filter""" - parser = argparse.ArgumentParser(description="Filters a parallel dataset based on max word length") + parser = argparse.ArgumentParser( + description="Filters a parallel dataset based on max word length" + ) parser.add_argument("--max-word-length", default=150, type=int) return parser.parse_args() @@ -14,16 +16,18 @@ def parse_user_args(): def clean_parallel(max_word_length: float, fin: TextIO, fout: TextIO) -> None: """Cleans the parallel or mono dataset based on line lengths""" for line in fin: - fields = line.rstrip('\r\n').split('\t') + fields = line.rstrip("\r\n").split("\t") - if any(len(token) > max_word_length + if any( + len(token) > max_word_length for field in fields - for token in field.split(' ')): + for token in field.split(" ") + ): continue fout.write(line) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_user_args() clean_parallel(args.max_word_length, sys.stdin, sys.stdout) diff --git a/opuscleaner/filters/normalize_whitespace.py b/opuscleaner/filters/normalize_whitespace.py index fc2098d..d03fa5a 100755 --- a/opuscleaner/filters/normalize_whitespace.py +++ b/opuscleaner/filters/normalize_whitespace.py @@ -5,6 +5,7 @@ of whitespaces into a single space """ + import argparse import sys @@ -29,7 +30,9 @@ def clean(collapse): if __name__ == "__main__": parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - "--collapse", action="store_true", - help="Collapse whitespace groups into single spaces") + "--collapse", + action="store_true", + help="Collapse whitespace groups into single spaces", + ) args = parser.parse_args() clean(args.collapse) diff --git a/opuscleaner/filters/num_mismatch.py b/opuscleaner/filters/num_mismatch.py index 5ff26cc..8f5d85d 100755 --- a/opuscleaner/filters/num_mismatch.py +++ b/opuscleaner/filters/num_mismatch.py @@ -5,7 +5,8 @@ from typing import Match, TextIO -NUM_EXPR = re.compile(r""" +NUM_EXPR = re.compile( + r""" (?: (?P(?:(?<=\s)|^)[-+]) # start with a sign if it is not attached to a word, i.e. not - in tid-30 |\b # or start with a word boundary, i.e. not just30 @@ -15,42 +16,53 @@ (?:[\.,:]\d+)* # allow commas or dots, i.e. 300,000.0 but not 3,,,5 ) \b # disallow 30beep, but also 30th and 1st #TODO -""", re.X) +""", + re.X, +) -def normalize(numstr:Match) -> str: - return (numstr['sign'] or '') + re.sub(r'[^\d]+', '*', numstr['value']) # ignore the decimal and digit separators +def normalize(numstr: Match) -> str: + return (numstr["sign"] or "") + re.sub( + r"[^\d]+", "*", numstr["value"] + ) # ignore the decimal and digit separators -def filter_numerical_mismatch(fin: TextIO, fout: TextIO, ratio: float, *, debug: bool = False): - for line in fin: - cols = line.rstrip('\r').split('\t') +def filter_numerical_mismatch( + fin: TextIO, fout: TextIO, ratio: float, *, debug: bool = False +): + for line in fin: + cols = line.rstrip("\r").split("\t") - assert len(cols) >= 2 + assert len(cols) >= 2 - nums_left, nums_right = (set(map(normalize, re.finditer(NUM_EXPR, col))) for col in cols[:2]) + nums_left, nums_right = ( + set(map(normalize, re.finditer(NUM_EXPR, col))) for col in cols[:2] + ) - # Only bother calculating the ratio if there were any numbers to begin with - if nums_left or nums_right: - overlap = nums_left & nums_right - difference = nums_left ^ nums_right + # Only bother calculating the ratio if there were any numbers to begin with + if nums_left or nums_right: + overlap = nums_left & nums_right + difference = nums_left ^ nums_right - # Big > 1.0 number if lots of overlap, small < 1.0 number if lots of differences - line_ratio = (len(overlap) + 1) / (len(difference) + 1) + # Big > 1.0 number if lots of overlap, small < 1.0 number if lots of differences + line_ratio = (len(overlap) + 1) / (len(difference) + 1) - if debug: - print(f"{len(overlap)} / {len(difference)} : {overlap!r} | {difference!r}", file=sys.stderr) + if debug: + print( + f"{len(overlap)} / {len(difference)} : {overlap!r} | {difference!r}", + file=sys.stderr, + ) - if line_ratio < ratio: - continue + if line_ratio < ratio: + continue - fout.write(line) + fout.write(line) -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--ratio', type=float, default=1.0) - parser.add_argument('--debug', action='store_true') - args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ratio", type=float, default=1.0) + parser.add_argument("--debug", action="store_true") + args = parser.parse_args() - filter_numerical_mismatch(sys.stdin, sys.stdout, args.ratio, debug=args.debug) + filter_numerical_mismatch(sys.stdin, sys.stdout, args.ratio, debug=args.debug) diff --git a/opuscleaner/filters/opusfilter/opusfilter-ersatz.py b/opuscleaner/filters/opusfilter/opusfilter-ersatz.py index a915cef..8299b6a 100755 --- a/opuscleaner/filters/opusfilter/opusfilter-ersatz.py +++ b/opuscleaner/filters/opusfilter/opusfilter-ersatz.py @@ -13,51 +13,53 @@ import yaml parser = argparse.ArgumentParser() -parser.add_argument('--quiet', '-q', action='store_true') -parser.add_argument('filter', type=str) -parser.add_argument('config', type=str) +parser.add_argument("--quiet", "-q", action="store_true") +parser.add_argument("filter", type=str) +parser.add_argument("config", type=str) args = parser.parse_args() if args.quiet: - # Filter out warnings from opusfilter about missing env variables for like - # word alignment. - logging.getLogger().setLevel(logging.ERROR) + # Filter out warnings from opusfilter about missing env variables for like + # word alignment. + logging.getLogger().setLevel(logging.ERROR) - # Filter warnings (especially MarkupResemblesLocatorWarning) from - # BeautifulSoup (the HtmlTagFilter) - warnings.filterwarnings('ignore', module='bs4') + # Filter warnings (especially MarkupResemblesLocatorWarning) from + # BeautifulSoup (the HtmlTagFilter) + warnings.filterwarnings("ignore", module="bs4") -module_path, class_name = args.filter.rsplit('.', maxsplit=1) +module_path, class_name = args.filter.rsplit(".", maxsplit=1) config = yaml.safe_load(args.config) if not isinstance(config, dict): - if config: - raise ValueError('config has to be a mapping') + if config: + raise ValueError("config has to be a mapping") - config = dict() + config = dict() mod = importlib.import_module(module_path) filter_cls = getattr(mod, class_name) filter_obj = filter_cls(**config) if isinstance(filter_obj, opusfilter.FilterABC): - def apply_filter(lines): - # Duplicate the iterator into two, one goes into the scorer, one for output - # because scorer could be eating them in chunks. - lines1, lines2 = itertools.tee(lines) - pairs = (line[0:2] for line in lines1) - for line, score in zip(lines2, filter_obj.score(pairs)): - if filter_obj.accept(score): - yield line + + def apply_filter(lines): + # Duplicate the iterator into two, one goes into the scorer, one for output + # because scorer could be eating them in chunks. + lines1, lines2 = itertools.tee(lines) + pairs = (line[0:2] for line in lines1) + for line, score in zip(lines2, filter_obj.score(pairs)): + if filter_obj.accept(score): + yield line elif isinstance(filter_obj, opusfilter.PreprocessorABC): - def apply_filter(pairs): - return filter_obj.process(pairs) + + def apply_filter(pairs): + return filter_obj.process(pairs) else: - raise ValueError('filter class does not implement FilterABC or PreprocessorABC') + raise ValueError("filter class does not implement FilterABC or PreprocessorABC") -lines = (line.rstrip('\r\n').split('\t') for line in sys.stdin) +lines = (line.rstrip("\r\n").split("\t") for line in sys.stdin) for line in apply_filter(lines): - print("\t".join(line)) + print("\t".join(line)) diff --git a/opuscleaner/filters/remove_empty_lines.py b/opuscleaner/filters/remove_empty_lines.py index 1c04240..2201023 100755 --- a/opuscleaner/filters/remove_empty_lines.py +++ b/opuscleaner/filters/remove_empty_lines.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import sys + def main(): for line in sys.stdin: fields = line.strip("\r\n").split("\t") @@ -12,5 +13,6 @@ def main(): if ok: sys.stdout.write(line) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/opuscleaner/filters/remove_frequent_patterns.py b/opuscleaner/filters/remove_frequent_patterns.py index 644ad95..e5a0a73 100755 --- a/opuscleaner/filters/remove_frequent_patterns.py +++ b/opuscleaner/filters/remove_frequent_patterns.py @@ -37,17 +37,31 @@ def load_patterns(file_path: str) -> List[Pattern]: for line in lines: parts = line.split("\t") if len(parts) == 2: - patterns.append(Pattern(group_match=re.compile(parts[0]), replacement=parts[1])) + patterns.append( + Pattern(group_match=re.compile(parts[0]), replacement=parts[1]) + ) elif len(parts) == 3: - patterns.append(Pattern(pattern_on_both_cols=re.compile(parts[0]), group_match=re.compile(parts[1]), replacement=parts[2])) + patterns.append( + Pattern( + pattern_on_both_cols=re.compile(parts[0]), + group_match=re.compile(parts[1]), + replacement=parts[2], + ) + ) else: - raise ValueError(f"Patterns have to have 2 or 3 columns, but got {len(parts)}") + raise ValueError( + f"Patterns have to have 2 or 3 columns, but got {len(parts)}" + ) return patterns def main(): - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("--pattern-file", type=str, help="Path to the file with patterns.") + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--pattern-file", type=str, help="Path to the file with patterns." + ) parser.add_argument("--debug", action="store_true") args = parser.parse_args() patterns = load_patterns(args.pattern_file) @@ -57,7 +71,10 @@ def main(): source, target = line.split("\t", 1) for pattern in patterns: # Either pattern_on_both_cols is not set or it matches the whole line - if pattern.pattern_on_both_cols is None or pattern.pattern_on_both_cols.match(line): + if ( + pattern.pattern_on_both_cols is None + or pattern.pattern_on_both_cols.match(line) + ): source = pattern.group_match.sub(pattern.replacement, source) target = pattern.group_match.sub(pattern.replacement, target) sys.stdout.write(f"{source}\t{target}\n") diff --git a/opuscleaner/filters/segment_chinese.py b/opuscleaner/filters/segment_chinese.py index 05a220d..b1d0cd6 100755 --- a/opuscleaner/filters/segment_chinese.py +++ b/opuscleaner/filters/segment_chinese.py @@ -2,7 +2,7 @@ from sys import stdin import spacy_pkuseg as pkuseg -seg = pkuseg.pkuseg() #load the default model +seg = pkuseg.pkuseg() # load the default model for line in stdin: text = seg.cut(line.strip()) print(" ".join(text)) diff --git a/opuscleaner/filters/segment_japanese.py b/opuscleaner/filters/segment_japanese.py index 52f7b54..c8d1ef9 100755 --- a/opuscleaner/filters/segment_japanese.py +++ b/opuscleaner/filters/segment_japanese.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 -'''Segments Japanese text using the fugashi tokenizer. +"""Segments Japanese text using the fugashi tokenizer. Note the specifics of Japanese tokenization, where verbs are always separate to stem and conjugation part, as well topic or subject particles are split from the nouns. This means that the Japanese sentences would likely be quite a bit longer than the -English ones.''' +English ones.""" + import fugashi from sys import stdin diff --git a/opuscleaner/filters/split_sentences.py b/opuscleaner/filters/split_sentences.py index a5bab9b..7bb47ee 100755 --- a/opuscleaner/filters/split_sentences.py +++ b/opuscleaner/filters/split_sentences.py @@ -5,33 +5,37 @@ from sentence_splitter import SentenceSplitter -def split_sentences_in_bitext(fin: TextIO, fout: TextIO, languages: List[str], keep_unbalanced: bool = False): - splitters = [SentenceSplitter(language=lang) for lang in languages] - - for line in fin: - cols = line.rstrip('\r\n').split('\t') +def split_sentences_in_bitext( + fin: TextIO, fout: TextIO, languages: List[str], keep_unbalanced: bool = False +): + splitters = [SentenceSplitter(language=lang) for lang in languages] - assert len(cols) == len(splitters) + for line in fin: + cols = line.rstrip("\r\n").split("\t") - splitted = [splitter.split(col) for splitter, col in zip(splitters, cols)] + assert len(cols) == len(splitters) - if any(len(col) != len(splitted[0]) for col in splitted[1:]): - if keep_unbalanced: - # Revert back to the input line - splitted = [[col] for col in cols] - else: - # Skip line - continue + splitted = [splitter.split(col) for splitter, col in zip(splitters, cols)] - for cols in zip(*splitted): - fout.write('\t'.join(cols) + '\n') + if any(len(col) != len(splitted[0]) for col in splitted[1:]): + if keep_unbalanced: + # Revert back to the input line + splitted = [[col] for col in cols] + else: + # Skip line + continue + for cols in zip(*splitted): + fout.write("\t".join(cols) + "\n") -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--keep-unbalanced', action='store_true') - parser.add_argument('languages', type=str, nargs='+') - args = parser.parse_args() +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--keep-unbalanced", action="store_true") + parser.add_argument("languages", type=str, nargs="+") - split_sentences_in_bitext(sys.stdin, sys.stdout, args.languages, args.keep_unbalanced) + args = parser.parse_args() + + split_sentences_in_bitext( + sys.stdin, sys.stdout, args.languages, args.keep_unbalanced + ) diff --git a/opuscleaner/filters/src_trg_ratio.py b/opuscleaner/filters/src_trg_ratio.py index d538b3c..2f831f5 100755 --- a/opuscleaner/filters/src_trg_ratio.py +++ b/opuscleaner/filters/src_trg_ratio.py @@ -7,11 +7,13 @@ def parse_user_args(): """Parse the arguments necessary for this filter""" - parser = argparse.ArgumentParser(description="Filters the lines based on the ratio between num_src_tokens and num_trg_tokens") + parser = argparse.ArgumentParser( + description="Filters the lines based on the ratio between num_src_tokens and num_trg_tokens" + ) parser.add_argument("--ratio-length", default=0.6, type=float) parser.add_argument("--filter-identical", action="store_true") parser.add_argument("--log", action="store_true") - parser.add_argument("--debug", action='store_true') + parser.add_argument("--debug", action="store_true") return parser.parse_args() @@ -39,12 +41,18 @@ def compare_lin(src: List[str], trg: List[str], ratio: float) -> bool: Comparator = Callable[[List[str], List[str], float], bool] -def clean_parallel(ratio: float, filter_identical: bool, *, debug: bool=False, compare: Comparator=compare_lin) -> None: +def clean_parallel( + ratio: float, + filter_identical: bool, + *, + debug: bool = False, + compare: Comparator = compare_lin, +) -> None: """Cleans the parallel dataset based on the ratio of source to target tokens and vice versa""" for line in stdin: - fields = line.rstrip('\r\n').split('\t') + fields = line.rstrip("\r\n").split("\t") if len(fields) != 2: - stderr.write(f'SINGLE/MULTIPLE_LINES\t{line}') + stderr.write(f"SINGLE/MULTIPLE_LINES\t{line}") continue src = fields[0].strip() @@ -53,7 +61,7 @@ def clean_parallel(ratio: float, filter_identical: bool, *, debug: bool=False, c # Remove identical lines if filter_identical and src.lower() == trg.lower(): if debug: - stderr.write(f'IDENTICAL\t{src}\t{trg}\n') + stderr.write(f"IDENTICAL\t{src}\t{trg}\n") continue src_toks = src.split() @@ -61,13 +69,16 @@ def clean_parallel(ratio: float, filter_identical: bool, *, debug: bool=False, c if not compare(src_toks, trg_toks, ratio): if debug: - stderr.write(f'RATIO_LENGTH: {src}\t{trg}\n') + stderr.write(f"RATIO_LENGTH: {src}\t{trg}\n") else: stdout.write(line) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_user_args() - clean_parallel(args.ratio_length, args.filter_identical, + clean_parallel( + args.ratio_length, + args.filter_identical, debug=args.debug, - compare=compare_log if args.log else compare_lin) + compare=compare_log if args.log else compare_lin, + ) diff --git a/opuscleaner/filters/strip_suffix.py b/opuscleaner/filters/strip_suffix.py index 257e7a1..5bf0e02 100755 --- a/opuscleaner/filters/strip_suffix.py +++ b/opuscleaner/filters/strip_suffix.py @@ -7,84 +7,89 @@ def common_suffix(buffer: Iterable[str]) -> str: - iters = [iter(line[::-1]) for line in buffer] - assert len(iters) > 1 - suffix = takewhile(identical, zip(*iters)) - return "".join(t[0] for t in suffix)[::-1] + iters = [iter(line[::-1]) for line in buffer] + assert len(iters) > 1 + suffix = takewhile(identical, zip(*iters)) + return "".join(t[0] for t in suffix)[::-1] -T = TypeVar('T') +T = TypeVar("T") + def identical(elements: Iterable[T]) -> bool: - it = iter(elements) - first = next(it) - return all(first == el for el in it) - - -def strip_suffix(lines: Iterable[str], *, minlen: int = 2, minocc: int = 5, counter: Counter=None) -> Iterable[str]: - buffer = deque() - - suffix = "" - - for line in lines: - if suffix and line.endswith(suffix): - assert not buffer, "buffer should been empty" - if counter is not None: - counter[suffix] += 1 - yield line[:-1 * len(suffix)] - - elif suffix: # and not line ends with suffix - assert not buffer, "buffer should been empty" - suffix = "" - buffer.append(line) - - else: # suffix is None - # Make space in the buffer - if len(buffer) == minocc: - yield buffer.popleft() - - buffer.append(line) - - # If our buffer is too small to identify a suffix, don't bother - if len(buffer) < minocc: - continue - - # Try to identify a new common suffix - suffix = common_suffix(buffer) - - # If the suffix is too short, it might as well be nothing - if len(suffix) < minlen: - suffix = "" - - # if found, empty buffer, stripping that suffix - if suffix: - if counter is not None: - counter[suffix] += len(buffer) - while buffer: - line = buffer.popleft() - yield line[:-1 * len(suffix)] - - # Empty buffer - yield from buffer - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--minlen", type=int, default=4) - parser.add_argument("--minocc", type=int, default=5) - parser.add_argument("--count", action="store_true") - args = parser.parse_args() - - lines = (line.rstrip("\r\n") for line in sys.stdin) - - if args.count: - counter = Counter() - else: - counter = None - - for line in strip_suffix(lines, minlen=args.minlen, minocc=args.minocc, counter=counter): - print(line, file=sys.stdout) - - if counter: - for suffix, count in counter.most_common(): - print(f"{suffix}\t{count}", file=sys.stderr) + it = iter(elements) + first = next(it) + return all(first == el for el in it) + + +def strip_suffix( + lines: Iterable[str], *, minlen: int = 2, minocc: int = 5, counter: Counter = None +) -> Iterable[str]: + buffer = deque() + + suffix = "" + + for line in lines: + if suffix and line.endswith(suffix): + assert not buffer, "buffer should been empty" + if counter is not None: + counter[suffix] += 1 + yield line[: -1 * len(suffix)] + + elif suffix: # and not line ends with suffix + assert not buffer, "buffer should been empty" + suffix = "" + buffer.append(line) + + else: # suffix is None + # Make space in the buffer + if len(buffer) == minocc: + yield buffer.popleft() + + buffer.append(line) + + # If our buffer is too small to identify a suffix, don't bother + if len(buffer) < minocc: + continue + + # Try to identify a new common suffix + suffix = common_suffix(buffer) + + # If the suffix is too short, it might as well be nothing + if len(suffix) < minlen: + suffix = "" + + # if found, empty buffer, stripping that suffix + if suffix: + if counter is not None: + counter[suffix] += len(buffer) + while buffer: + line = buffer.popleft() + yield line[: -1 * len(suffix)] + + # Empty buffer + yield from buffer + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--minlen", type=int, default=4) + parser.add_argument("--minocc", type=int, default=5) + parser.add_argument("--count", action="store_true") + args = parser.parse_args() + + lines = (line.rstrip("\r\n") for line in sys.stdin) + + if args.count: + counter = Counter() + else: + counter = None + + for line in strip_suffix( + lines, minlen=args.minlen, minocc=args.minocc, counter=counter + ): + print(line, file=sys.stdout) + + if counter: + for suffix, count in counter.most_common(): + print(f"{suffix}\t{count}", file=sys.stderr) diff --git a/opuscleaner/filters/test_num_mismatch.py b/opuscleaner/filters/test_num_mismatch.py index 676c13b..3cbe004 100644 --- a/opuscleaner/filters/test_num_mismatch.py +++ b/opuscleaner/filters/test_num_mismatch.py @@ -3,54 +3,59 @@ from num_mismatch import filter_numerical_mismatch + class TestNumMismatch(unittest.TestCase): - def _test(self, line:str, ratio:float, **kwargs) -> bool: - fin = io.StringIO(line) - fout = io.StringIO() - filter_numerical_mismatch(fin, fout, ratio, **kwargs) - return fout.getvalue() == line - - def assertAccept(self, line:str, ratio:float, **kwargs): - """Test that this line is accepted""" - self.assertTrue(self._test(line, ratio, **kwargs)) - - def assertReject(self, line:str, ratio:float, **kwargs): - """Test that this line is rejected""" - self.assertFalse(self._test(line, ratio, **kwargs)) - - def test_match(self): - """Exact matches should be accepted.""" - self.assertAccept('There are 74 cows\t74 cows have we', 1.0) - - def test_accepted_comma_mismatch(self): - """Differences in the decimal separator should be accepted.""" - self.assertAccept('There are 7.4 cows\t7,4 cows have we', 1.0) - - def test_mismatch(self): - """Differences in the number should be rejected.""" - self.assertReject('There are 73 cows\t74 cows have we', 1.0) - - def test_ratio(self): - """Lowering the ratio threshold will accept mismatches""" - line = 'There are 73 cows in 6 fields\tWe have 6 fields with 74 cows' # 2 / 3 = 0.667 - self.assertAccept(line, 0.5) - self.assertReject(line, 1.0) - - def test_prefix_zero(self): - """Numbers like 06 and 6 should be the same. See #89""" - self.assertAccept('These are the same 007 numbers\tThe number is 7', 1.0) - - def test_sign_match(self): - """Signs matter.""" - self.assertAccept('The current temp is -7.5 degrees\tI lost -7.5 points', 1.0) - self.assertReject('The current temp is -7.5 degrees\tI walked 7.5 miles', 1.0) - self.assertReject('The difference is +7.5 degrees\tI walked 7.5 miles', 1.0) # questionable? - self.assertReject('The current temp is -7.5 degrees\tI changed the value by +7.5', 1.0) - - def test_word_boundary(self): - self.assertAccept('I am a 30something\tThat just40 should be ignored', 1.0) - - def test_word_boundary_dash(self): - self.assertAccept('-30 is the number\tThe number -30', 1.0) - self.assertAccept('The-number-30\tThe number 30', 1.0) - self.assertReject('Beep-30\tThe number is -30', 1.0) + def _test(self, line: str, ratio: float, **kwargs) -> bool: + fin = io.StringIO(line) + fout = io.StringIO() + filter_numerical_mismatch(fin, fout, ratio, **kwargs) + return fout.getvalue() == line + + def assertAccept(self, line: str, ratio: float, **kwargs): + """Test that this line is accepted""" + self.assertTrue(self._test(line, ratio, **kwargs)) + + def assertReject(self, line: str, ratio: float, **kwargs): + """Test that this line is rejected""" + self.assertFalse(self._test(line, ratio, **kwargs)) + + def test_match(self): + """Exact matches should be accepted.""" + self.assertAccept("There are 74 cows\t74 cows have we", 1.0) + + def test_accepted_comma_mismatch(self): + """Differences in the decimal separator should be accepted.""" + self.assertAccept("There are 7.4 cows\t7,4 cows have we", 1.0) + + def test_mismatch(self): + """Differences in the number should be rejected.""" + self.assertReject("There are 73 cows\t74 cows have we", 1.0) + + def test_ratio(self): + """Lowering the ratio threshold will accept mismatches""" + line = "There are 73 cows in 6 fields\tWe have 6 fields with 74 cows" # 2 / 3 = 0.667 + self.assertAccept(line, 0.5) + self.assertReject(line, 1.0) + + def test_prefix_zero(self): + """Numbers like 06 and 6 should be the same. See #89""" + self.assertAccept("These are the same 007 numbers\tThe number is 7", 1.0) + + def test_sign_match(self): + """Signs matter.""" + self.assertAccept("The current temp is -7.5 degrees\tI lost -7.5 points", 1.0) + self.assertReject("The current temp is -7.5 degrees\tI walked 7.5 miles", 1.0) + self.assertReject( + "The difference is +7.5 degrees\tI walked 7.5 miles", 1.0 + ) # questionable? + self.assertReject( + "The current temp is -7.5 degrees\tI changed the value by +7.5", 1.0 + ) + + def test_word_boundary(self): + self.assertAccept("I am a 30something\tThat just40 should be ignored", 1.0) + + def test_word_boundary_dash(self): + self.assertAccept("-30 is the number\tThe number -30", 1.0) + self.assertAccept("The-number-30\tThe number 30", 1.0) + self.assertReject("Beep-30\tThe number is -30", 1.0) diff --git a/opuscleaner/logging.py b/opuscleaner/logging.py index 5e9ac97..43fbc06 100644 --- a/opuscleaner/logging.py +++ b/opuscleaner/logging.py @@ -10,152 +10,164 @@ def iter_queue(queue): - while True: - task = queue.get() - if task is None: - break - yield task + while True: + task = queue.get() + if task is None: + break + yield task class Span: - """Log the duration (enter and exit) of a set of instructions""" - - def __init__(self, logger:'Logger', name:str, extra:dict=dict()): - self.logger = logger - self.name = name - self.extra = extra - self.span = None - - def __enter__(self) -> 'Span': - self.span = self.logger.push(self.name, **self.extra, type='span', start=time.monotonic_ns()) - return self - - def __exit__(self, typ, value, traceback): - assert self.span is not None - span = self.logger.pop() - assert self.span == span - self.logger.update(self.span, end=time.monotonic_ns(), error=repr(value) if value is not None else None) - - def event(self, name, **kwargs) -> UUID: - return self.logger.event(name, type='event', parent=self.span, **kwargs) + """Log the duration (enter and exit) of a set of instructions""" + + def __init__(self, logger: "Logger", name: str, extra: dict = dict()): + self.logger = logger + self.name = name + self.extra = extra + self.span = None + + def __enter__(self) -> "Span": + self.span = self.logger.push( + self.name, **self.extra, type="span", start=time.monotonic_ns() + ) + return self + + def __exit__(self, typ, value, traceback): + assert self.span is not None + span = self.logger.pop() + assert self.span == span + self.logger.update( + self.span, + end=time.monotonic_ns(), + error=repr(value) if value is not None else None, + ) + + def event(self, name, **kwargs) -> UUID: + return self.logger.event(name, type="event", parent=self.span, **kwargs) class Handler(Protocol): - """Handler receives log records to handle. For example, by writing them as - JSON to a file, or putting them on a queue to be processed in another thread. - """ - def emit(self, record:dict) -> None: - pass + """Handler receives log records to handle. For example, by writing them as + JSON to a file, or putting them on a queue to be processed in another thread. + """ + + def emit(self, record: dict) -> None: + pass class FallbackJSONEncoder(JSONEncoder): - """JSONEncoder that just calls `str(obj)` in case it has no built-in - conversion for the type.""" - def default(self, obj): - return str(obj) + """JSONEncoder that just calls `str(obj)` in case it has no built-in + conversion for the type.""" + + def default(self, obj): + return str(obj) class NullHandler(Handler): - """Handler that does nothing, like printing to /dev/null.""" - def emit(self, record:dict) -> None: - pass + """Handler that does nothing, like printing to /dev/null.""" + + def emit(self, record: dict) -> None: + pass class StreamHandler(Handler): - """Writes log records as JSON lines to a file.""" - def __init__(self, stream:IO[str]): - self.stream = stream - self.encoder = FallbackJSONEncoder() + """Writes log records as JSON lines to a file.""" + + def __init__(self, stream: IO[str]): + self.stream = stream + self.encoder = FallbackJSONEncoder() - def emit(self, record:dict) -> None: - for chunk in self.encoder.iterencode(record): - self.stream.write(chunk) - self.stream.write('\n') + def emit(self, record: dict) -> None: + for chunk in self.encoder.iterencode(record): + self.stream.write(chunk) + self.stream.write("\n") def _queue_to_handler(queue: SimpleQueue, handler: Handler): - for record in iter_queue(queue): - handler.emit(record) + for record in iter_queue(queue): + handler.emit(record) class ThreadEmitter(Handler): - """Handler that puts log records onto a queue""" - def __init__(self, queue): - self.queue = queue + """Handler that puts log records onto a queue""" + + def __init__(self, queue): + self.queue = queue - def emit(self, record:dict): - self.queue.put(record) + def emit(self, record: dict): + self.queue.put(record) class ThreadReceiver: - """Context manager that will run a thread in the background to capture log - records emitted by ThreadEmitter handlers, and forward them to a single - handler. Make these emitters through `make_hander()`. - Note that the handler is run in another thread, so if your handler is writing - to say stderr and you're also writing to stderr on the main thread, you will - need to do some coordination through locking. - """ - def __init__(self, handler:Handler): - self.handler = handler - self.queue = SimpleQueue() - - def __enter__(self): - self.thread = Thread(target=_queue_to_handler, args=[self.queue, self.handler]) - self.thread.start() - return self - - def __exit__(self, typ, value, traceback): - self.queue.put(None) - self.thread.join() - - def make_handler(self) -> Handler: - return ThreadEmitter(self.queue) + """Context manager that will run a thread in the background to capture log + records emitted by ThreadEmitter handlers, and forward them to a single + handler. Make these emitters through `make_hander()`. + Note that the handler is run in another thread, so if your handler is writing + to say stderr and you're also writing to stderr on the main thread, you will + need to do some coordination through locking. + """ + + def __init__(self, handler: Handler): + self.handler = handler + self.queue = SimpleQueue() + + def __enter__(self): + self.thread = Thread(target=_queue_to_handler, args=[self.queue, self.handler]) + self.thread.start() + return self + + def __exit__(self, typ, value, traceback): + self.queue.put(None) + self.thread.join() + + def make_handler(self) -> Handler: + return ThreadEmitter(self.queue) class Logger: - """Logger object that tracks the stack when using spans.""" - handler: Handler - serial: Iterator[int] - stack: deque[UUID] - - def __init__(self, handler:Handler): - self.handler = handler - self.serial = count() - self.stack = deque() - - def span(self, name:str, **kwargs) -> Span: - """Start recording a span. Use as context `with logger.span('name'):`""" - return Span(self, name, kwargs) - - def event(self, name:str, **kwargs) -> UUID: - """Record a singular event. If inside the context of a span, the event will - be associated with that span.""" - event_id = uuid1(get_ident(), next(self.serial)) - self.handler.emit({ - 'id': event_id, - 'parent': self.stack[-1] if len(self.stack) > 0 else None, - 'name': name, - **kwargs, - }) - return event_id - - def update(self, event_id:UUID, **kwargs) -> None: - """Update a particular event, e.g. to add more context.""" - self.handler.emit({ - 'id': event_id, - **kwargs - }) - - def push(self, name:str, **kwargs) -> UUID: - """Emit an event and put it onto the stack. Same a `span().__enter__()`, it - is better to use `with span():` in most cases.""" - event_id = self.event(name, **kwargs) - self.stack.append(event_id) - return event_id - - def pop(self) -> UUID: - """Pops event of the stack.""" - return self.stack.pop() + """Logger object that tracks the stack when using spans.""" + + handler: Handler + serial: Iterator[int] + stack: deque[UUID] + + def __init__(self, handler: Handler): + self.handler = handler + self.serial = count() + self.stack = deque() + + def span(self, name: str, **kwargs) -> Span: + """Start recording a span. Use as context `with logger.span('name'):`""" + return Span(self, name, kwargs) + + def event(self, name: str, **kwargs) -> UUID: + """Record a singular event. If inside the context of a span, the event will + be associated with that span.""" + event_id = uuid1(get_ident(), next(self.serial)) + self.handler.emit( + { + "id": event_id, + "parent": self.stack[-1] if len(self.stack) > 0 else None, + "name": name, + **kwargs, + } + ) + return event_id + + def update(self, event_id: UUID, **kwargs) -> None: + """Update a particular event, e.g. to add more context.""" + self.handler.emit({"id": event_id, **kwargs}) + + def push(self, name: str, **kwargs) -> UUID: + """Emit an event and put it onto the stack. Same a `span().__enter__()`, it + is better to use `with span():` in most cases.""" + event_id = self.event(name, **kwargs) + self.stack.append(event_id) + return event_id + + def pop(self) -> UUID: + """Pops event of the stack.""" + return self.stack.pop() _main_thread_id = main_thread().ident @@ -164,70 +176,70 @@ def pop(self) -> UUID: class Context: - """Logging context. This deals with having multiple loggers on multiple - threads all combine into the same event stream. - Generally you'd have something like `with logger.Context() as ctx: main()` - in your app. You can access the current context's logger through - `logger.get_logger()` as well as `ctx.get_logger()`. - """ - def __init__(self, *, file:Optional[IO[str]]=None): - if file: - self.handler = StreamHandler(file) - else: - self.handler = NullHandler() + """Logging context. This deals with having multiple loggers on multiple + threads all combine into the same event stream. + Generally you'd have something like `with logger.Context() as ctx: main()` + in your app. You can access the current context's logger through + `logger.get_logger()` as well as `ctx.get_logger()`. + """ - self.receiver = ThreadReceiver(self.handler) + def __init__(self, *, file: Optional[IO[str]] = None): + if file: + self.handler = StreamHandler(file) + else: + self.handler = NullHandler() - self.loggers = { - _main_thread_id: Logger(self.handler) - } + self.receiver = ThreadReceiver(self.handler) + self.loggers = {_main_thread_id: Logger(self.handler)} - def get_logger(self): - thread_id = get_ident() + def get_logger(self): + thread_id = get_ident() - if thread_id not in self.loggers: - self.loggers[thread_id] = Logger(self.receiver.make_handler()) - self.loggers[thread_id].stack = deque(self.loggers[_main_thread_id].stack) # TODO what about threads starting threads? + if thread_id not in self.loggers: + self.loggers[thread_id] = Logger(self.receiver.make_handler()) + self.loggers[thread_id].stack = deque( + self.loggers[_main_thread_id].stack + ) # TODO what about threads starting threads? - return self.loggers[thread_id] + return self.loggers[thread_id] - def __enter__(self): - global _context - assert _context is None - self.receiver.__enter__() - _context = self - return self + def __enter__(self): + global _context + assert _context is None + self.receiver.__enter__() + _context = self + return self - def __exit__(self, typ, value, traceback): - global _context - assert _context is self - self.receiver.__exit__(typ, value, traceback) - _context = None + def __exit__(self, typ, value, traceback): + global _context + assert _context is self + self.receiver.__exit__(typ, value, traceback) + _context = None def get_logger() -> Logger: - """Shortcut for Context().get_logger()""" - if _context is None: - raise RuntimeError('called get_logger() outside logging context') - return _context.get_logger() + """Shortcut for Context().get_logger()""" + if _context is None: + raise RuntimeError("called get_logger() outside logging context") + return _context.get_logger() -def event(name:str, **kwargs) -> UUID: - """Shortcut for get_logger().event()""" - return get_logger().event(name, **kwargs) +def event(name: str, **kwargs) -> UUID: + """Shortcut for get_logger().event()""" + return get_logger().event(name, **kwargs) def update(**kwargs): - """Shortcut for get_logger().update(current_span, ...)""" - logger = get_logger() - event_id = logger.stack[-1] - logger.update(event_id, **kwargs) + """Shortcut for get_logger().update(current_span, ...)""" + logger = get_logger() + event_id = logger.stack[-1] + logger.update(event_id, **kwargs) -def span(name:str, **kwargs) -> Span: - """Shortcut for get_logger().span()""" - return get_logger().span(name, **kwargs) +def span(name: str, **kwargs) -> Span: + """Shortcut for get_logger().span()""" + return get_logger().span(name, **kwargs) # TODO: once Python3.11: @@ -235,40 +247,45 @@ def span(name:str, **kwargs) -> Span: # R = TypeVar('R') # def trace(fn:Callable[P,R]) -> Callable[P,R] -T = TypeVar('T', bound=Callable) +T = TypeVar("T", bound=Callable) + -def trace(fn:T) -> T: - """Decorator for wrapping each call to this function with - ``` - with get_logger().span(__name__): +def trace(fn: T) -> T: + """Decorator for wrapping each call to this function with + ``` + with get_logger().span(__name__): fn() - ``` - """ - @wraps(fn) - def wrapper(*args, **kwargs): - with get_logger().span(fn.__name__): - return fn(*args, **kwargs) - return wrapper # type:ignore - - -T = TypeVar('T') - -def trace_context(cls:Type[T]) -> Type[T]: - """Similar to `@trace`, but for a class with __enter__ and __exit__.""" - class Wrapper(cls): - __span: Span - - def __enter__(self): - self.__span = get_logger().span(cls.__name__).__enter__() - return super().__enter__() - - def __exit__(self, typ, value, traceback): - # add an __exit__ event to make it possible to measure how long the - # wrapped __exit__ actually takes. - self.__span.event('__exit__') - try: - super().__exit__(typ, value, traceback) - finally: - self.__span.__exit__(typ, value, traceback) - - return Wrapper + ``` + """ + + @wraps(fn) + def wrapper(*args, **kwargs): + with get_logger().span(fn.__name__): + return fn(*args, **kwargs) + + return wrapper # type:ignore + + +T = TypeVar("T") + + +def trace_context(cls: Type[T]) -> Type[T]: + """Similar to `@trace`, but for a class with __enter__ and __exit__.""" + + class Wrapper(cls): + __span: Span + + def __enter__(self): + self.__span = get_logger().span(cls.__name__).__enter__() + return super().__enter__() + + def __exit__(self, typ, value, traceback): + # add an __exit__ event to make it possible to measure how long the + # wrapped __exit__ actually takes. + self.__span.event("__exit__") + try: + super().__exit__(typ, value, traceback) + finally: + self.__span.__exit__(typ, value, traceback) + + return Wrapper diff --git a/opuscleaner/opusfilter_compat.py b/opuscleaner/opusfilter_compat.py index 4bba112..447dfb3 100644 --- a/opuscleaner/opusfilter_compat.py +++ b/opuscleaner/opusfilter_compat.py @@ -12,63 +12,71 @@ def encode_env(type_name: str, value: Any) -> str: - if type_name == 'bool': - return '1' if value else '' + if type_name == "bool": + return "1" if value else "" else: return str(value) -def load_filter_definition(filter_name:str) -> Dict: - with open('filters/{filter_name}.json') as fh: - return json.load(fh) +def load_filter_definition(filter_name: str) -> Dict: + with open("filters/{filter_name}.json") as fh: + return json.load(fh) -def generate_filter_command(filter_definition:Dict, parameters:Dict) -> str: - filter_definition = filters[step['filter']] +def generate_filter_command(filter_definition: Dict, parameters: Dict) -> str: + filter_definition = filters[step["filter"]] # List of k=v shell variable definitions filter_params = [ - '{}={}'.format(name, quote(encode_env(props['type'], parameters.get(name, props.get('default', None))))) - for name, props in filter_definition['parameters'].items() + "{}={}".format( + name, + quote( + encode_env( + props["type"], parameters.get(name, props.get("default", None)) + ) + ), + ) + for name, props in filter_definition["parameters"].items() ] # Command, prefixed by variable definitions so they get expanded # correctly in the command bit. - return '; '.join(filter_params + [filter_definition['command']]) + return "; ".join(filter_params + [filter_definition["command"]]) -def patch_environ() -> Optional[Dict[str,str]]: +def patch_environ() -> Optional[Dict[str, str]]: # Make sure the path to the python binary (and the installed utils) # is in the PATH variable. If you load a virtualenv this happens by - # default, but if you call it with the virtualenv's python binary + # default, but if you call it with the virtualenv's python binary # directly it wont. pyenv_bin_path = os.path.dirname(sys.executable) - os_env_bin_paths = os.environ.get('PATH', '').split(os.pathsep) - return { - **os.environ, - 'PATH': os.pathsep.join([pyenv_bin_path] + os_env_bin_paths) - } if pyenv_bin_path not in os_env_bin_paths else None + os_env_bin_paths = os.environ.get("PATH", "").split(os.pathsep) + return ( + {**os.environ, "PATH": os.pathsep.join([pyenv_bin_path] + os_env_bin_paths)} + if pyenv_bin_path not in os_env_bin_paths + else None + ) -def feed_child_worker(input_queue:SimpleQueue, stdin): +def feed_child_worker(input_queue: SimpleQueue, stdin): while True: line = input_queue.pop() if line is None: break - stdin.write(line.encode() + b'\n') + stdin.write(line.encode() + b"\n") stdin.close() -def read_child_worker(stdout, output_queue:SimpleQueue): +def read_child_worker(stdout, output_queue: SimpleQueue): for line in stdout: - output_queue.put(line.rstrip(b'\r\n').decode()) + output_queue.put(line.rstrip(b"\r\n").decode()) output_queue.put(None) class OpusCleanerPreprocessor(PreprocessorABC): - def __init__(self, filter:str, parameters:Dict[str,Any], column:int, **kwargs): + def __init__(self, filter: str, parameters: Dict[str, Any], column: int, **kwargs): filter_definition = load_filter_definition(filter) - if filter_definition['type'] != 'monolingual': + if filter_definition["type"] != "monolingual": raise ConfigurationError() self.command = generate_filter_command(filter_definition, parameters) @@ -83,7 +91,9 @@ def process(self, pairs): # Remainder of the columns python -> python column_queue = deque() - child = Popen(self.command, cwd=basedir, stdin=PIPE, stdout=PIPE, env=patch_environ()) + child = Popen( + self.command, cwd=basedir, stdin=PIPE, stdout=PIPE, env=patch_environ() + ) feeder = Thread(target=feed_child_worker, args=[input_queue, child.stdin]) feeder.start() @@ -92,12 +102,12 @@ def process(self, pairs): reader.start() def split(pair): - column_queue.append(pair[:self.column] + pair[self.column+1:]) + column_queue.append(pair[: self.column] + pair[self.column + 1 :]) return pair[self.column] def merge(val): rest = column_queue.popleft() - return rest[:self.column] + [val] + rest[self.column:] + return rest[: self.column] + [val] + rest[self.column :] for pair in pairs: # Push input pair @@ -109,7 +119,7 @@ def merge(val): yield merge(output_queue.get_nowait()) except Empty: break - + # Signal worker to stop input_queue.put(None) feeder.join() @@ -126,7 +136,9 @@ def merge(val): reader.join() if retval != 0: - raise Exception(f'Child process {command} exited with non-zero exit code: {retval}') + raise Exception( + f"Child process {command} exited with non-zero exit code: {retval}" + ) assert len(column_queue) == 0 @@ -134,9 +146,9 @@ def merge(val): class OpusCleanerFilter(FilterABC): """One Big Hack (Tm)""" - def __init__(self, filter:str, parameters:Dict[str,Any], **kwargs): + def __init__(self, filter: str, parameters: Dict[str, Any], **kwargs): filter_definition = load_filter_definition(filter) - if filter_definition['type'] != 'bilingual': + if filter_definition["type"] != "bilingual": raise ConfigurationError() self.command = generate_filter_command(filter_definition, parameters) @@ -148,10 +160,12 @@ def accept(self, score): def score(self, pairs): input_queue = SimpleQueue() output_queue = SimpleQueue() - + input_log = deque() - - child = Popen(self.command, cwd=basedir, stdin=PIPE, stdout=PIPE, env=patch_environ()) + + child = Popen( + self.command, cwd=basedir, stdin=PIPE, stdout=PIPE, env=patch_environ() + ) feeder = Thread(target=feed_child_worker, args=[input_queue, child.stdin]) feeder.start() @@ -161,7 +175,7 @@ def score(self, pairs): def record(pair): """Record the hash of the line so we know whether it makes it through the filter""" - line = '\t'.join(pair) + line = "\t".join(pair) input_log.append(xxh32(line).digest()) return line @@ -182,7 +196,7 @@ def catch_up(line): yield from catch_up(output_queue.get_nowait()) except Empty: break - + # Signal worker to stop input_queue.put(None) feeder.join() @@ -199,6 +213,8 @@ def catch_up(line): reader.join() if retval != 0: - raise Exception(f'Child process {command} exited with non-zero exit code: {retval}') + raise Exception( + f"Child process {command} exited with non-zero exit code: {retval}" + ) assert len(column_queue) == 0 diff --git a/opuscleaner/sample.py b/opuscleaner/sample.py index 6edab32..f2f4207 100755 --- a/opuscleaner/sample.py +++ b/opuscleaner/sample.py @@ -11,146 +11,178 @@ from typing import TypeVar, Iterable, List, Tuple -T = TypeVar('T') +T = TypeVar("T") -def reservoir_sample(k:int, it:Iterable[T], *, rand:random.Random = random._inst, sort:bool=False) -> List[T]: - """Take k samples from iterable by reading from start to end. If sort is - True, it will return the selected samples in the order they appeared in. - """ - sample: List[Tuple[int,T]] = [] +def reservoir_sample( + k: int, it: Iterable[T], *, rand: random.Random = random._inst, sort: bool = False +) -> List[T]: + """Take k samples from iterable by reading from start to end. If sort is + True, it will return the selected samples in the order they appeared in. + """ + sample: List[Tuple[int, T]] = [] - numbered_it = enumerate(it) + numbered_it = enumerate(it) - i = 0 + i = 0 - for i, (_, line) in zip(range(k), numbered_it): - sample.append((i, line)) + for i, (_, line) in zip(range(k), numbered_it): + sample.append((i, line)) - w = exp(log(rand.random())/k) + w = exp(log(rand.random()) / k) - try: - while True: - next_i = i + floor(log(rand.random()) / log(1 - w)) + 1 + try: + while True: + next_i = i + floor(log(rand.random()) / log(1 - w)) + 1 - # Skip forward - while i < next_i: - i, line = next(numbered_it) + # Skip forward + while i < next_i: + i, line = next(numbered_it) - sample[rand.randrange(k)] = (i, line) # type:ignore - w = w * exp(log(rand.random()) / k) - except StopIteration: - pass + sample[rand.randrange(k)] = (i, line) # type:ignore + w = w * exp(log(rand.random()) / k) + except StopIteration: + pass - if sort: - return [line for _, line in sorted(sample)] - else: - return [line for _, line in sample] + if sort: + return [line for _, line in sorted(sample)] + else: + return [line for _, line in sample] class Tailer(Iterable[T]): - """Functions as an iterator that returns all but the last K lines. Those lines - you can read from `tail`.""" + """Functions as an iterator that returns all but the last K lines. Those lines + you can read from `tail`.""" - def __init__(self, k:int, it:Iterable[T]): - self.sample: List[T] = [] # ring buffer of (at maximum) length k - self.k = k - self.i = 0 - self.it = iter(it) + def __init__(self, k: int, it: Iterable[T]): + self.sample: List[T] = [] # ring buffer of (at maximum) length k + self.k = k + self.i = 0 + self.it = iter(it) - def __iter__(self) -> Iterator[T]: - try: - while self.i < self.k: - self.sample.append(next(self.it)) - self.i += 1 - except StopIteration: - # Oh less than k samples in iterable? :( - return + def __iter__(self) -> Iterator[T]: + try: + while self.i < self.k: + self.sample.append(next(self.it)) + self.i += 1 + except StopIteration: + # Oh less than k samples in iterable? :( + return - for line in self.it: - yield self.sample[self.i % len(self.sample)] - self.sample[self.i % len(self.sample)] = line - self.i += 1 + for line in self.it: + yield self.sample[self.i % len(self.sample)] + self.sample[self.i % len(self.sample)] = line + self.i += 1 - @property - def tail(self) -> List[T]: - # In the scenario where we read less than our tail of data, we just return - # the entire buffer in one go. - if len(self.sample) < self.k: - return self.sample + @property + def tail(self) -> List[T]: + # In the scenario where we read less than our tail of data, we just return + # the entire buffer in one go. + if len(self.sample) < self.k: + return self.sample - return self.sample[(self.i % len(self.sample)):] + self.sample[0:(self.i % len(self.sample))] + return ( + self.sample[(self.i % len(self.sample)) :] + + self.sample[0 : (self.i % len(self.sample))] + ) -def sample(k:int, iterable:Iterable[T], sort:bool=False) -> Iterable[Iterable[T]]: - """Take `k` items from the start, the end and the middle from `iterable`. If - `sort` is True, the items in the middle will be in the order they appeared - in.""" - it = iter(iterable) +def sample(k: int, iterable: Iterable[T], sort: bool = False) -> Iterable[Iterable[T]]: + """Take `k` items from the start, the end and the middle from `iterable`. If + `sort` is True, the items in the middle will be in the order they appeared + in.""" + it = iter(iterable) - yield (val for _, val in zip(range(k), it)) + yield (val for _, val in zip(range(k), it)) - tailer = Tailer(k, it) + tailer = Tailer(k, it) - yield reservoir_sample(k, tailer, sort=sort) + yield reservoir_sample(k, tailer, sort=sort) - yield tailer.tail + yield tailer.tail @contextmanager -def gunzip(path:str) -> Iterator[IO[bytes]]: - """Like gzip.open(), but using external gzip process which for some reason - is a lot faster on macOS.""" - with subprocess.Popen(['gzip', '-cd', path], stdout=subprocess.PIPE) as proc: - assert proc.stdout is not None - yield proc.stdout - - # Context is done with proc.stdout, so we close it. It might be that it - # isn't completely read yet, and thus proc.wait() would block otherwise. - proc.stdout.close() - if proc.wait() != 0: - raise RuntimeError(f'gzip returned error code {proc.returncode} while decompressing {path}') - - -def magic_open_or_stdin(ctx:ExitStack, path:str) -> IO[bytes]: - # TODO ideally we would look at the magic bytes, but that would entail - # consuming the input file partially and then I can't pass the complete - # file onto gzip afterwards - if path.endswith('.gz'): - return ctx.enter_context(gunzip(path)) - elif path == '-': - return sys.stdin.buffer - else: - return ctx.enter_context(open(path, 'rb')) +def gunzip(path: str) -> Iterator[IO[bytes]]: + """Like gzip.open(), but using external gzip process which for some reason + is a lot faster on macOS.""" + with subprocess.Popen(["gzip", "-cd", path], stdout=subprocess.PIPE) as proc: + assert proc.stdout is not None + yield proc.stdout + + # Context is done with proc.stdout, so we close it. It might be that it + # isn't completely read yet, and thus proc.wait() would block otherwise. + proc.stdout.close() + if proc.wait() != 0: + raise RuntimeError( + f"gzip returned error code {proc.returncode} while decompressing {path}" + ) + + +def magic_open_or_stdin(ctx: ExitStack, path: str) -> IO[bytes]: + # TODO ideally we would look at the magic bytes, but that would entail + # consuming the input file partially and then I can't pass the complete + # file onto gzip afterwards + if path.endswith(".gz"): + return ctx.enter_context(gunzip(path)) + elif path == "-": + return sys.stdin.buffer + else: + return ctx.enter_context(open(path, "rb")) def main() -> None: - parser = argparse.ArgumentParser(description="Take a file's head, tail and a random sample from the rest.") - parser.add_argument('-n', dest='lines', type=int, default=10, help="number of lines for each section of the sample") - parser.add_argument('-d', dest='delimiter', type=str, default="\\t", help="column delimiter. Defaults to \\t.") - parser.add_argument('-N', '--line-numbers', action='store_true', help="print line numbers") - parser.add_argument('files', metavar='file', type=str, nargs='*', default=['-'], help="files to sample. Multiple files for multiple columns. Use '-' for stdin. If none, reads from stdin.") - args = parser.parse_args() - - with ExitStack() as ctx: - columns:List[Iterator[bytes]] = [magic_open_or_stdin(ctx, file) for file in args.files] - - if args.line_numbers: - columns.insert(0, (str(i).encode() for i in count())) - - pairs = zip(*columns) - - delimiter = args.delimiter.replace("\\t", "\t").replace("\\n", "\n").encode() - - for section in sample(args.lines, pairs, sort=True): - for pair in section: - for col, entry in enumerate(pair): - if col > 0: - sys.stdout.buffer.write(delimiter) - sys.stdout.buffer.write(entry.rstrip(b"\r\n")) - sys.stdout.buffer.write(b"\n") - sys.stdout.buffer.flush() - - -if __name__ == '__main__': - main() + parser = argparse.ArgumentParser( + description="Take a file's head, tail and a random sample from the rest." + ) + parser.add_argument( + "-n", + dest="lines", + type=int, + default=10, + help="number of lines for each section of the sample", + ) + parser.add_argument( + "-d", + dest="delimiter", + type=str, + default="\\t", + help="column delimiter. Defaults to \\t.", + ) + parser.add_argument( + "-N", "--line-numbers", action="store_true", help="print line numbers" + ) + parser.add_argument( + "files", + metavar="file", + type=str, + nargs="*", + default=["-"], + help="files to sample. Multiple files for multiple columns. Use '-' for stdin. If none, reads from stdin.", + ) + args = parser.parse_args() + + with ExitStack() as ctx: + columns: List[Iterator[bytes]] = [ + magic_open_or_stdin(ctx, file) for file in args.files + ] + + if args.line_numbers: + columns.insert(0, (str(i).encode() for i in count())) + + pairs = zip(*columns) + + delimiter = args.delimiter.replace("\\t", "\t").replace("\\n", "\n").encode() + + for section in sample(args.lines, pairs, sort=True): + for pair in section: + for col, entry in enumerate(pair): + if col > 0: + sys.stdout.buffer.write(delimiter) + sys.stdout.buffer.write(entry.rstrip(b"\r\n")) + sys.stdout.buffer.write(b"\n") + sys.stdout.buffer.flush() + + +if __name__ == "__main__": + main() diff --git a/opuscleaner/server.py b/opuscleaner/server.py index afcdae0..286d63b 100644 --- a/opuscleaner/server.py +++ b/opuscleaner/server.py @@ -24,20 +24,33 @@ from opuscleaner.config import DATA_PATH, FILTER_PATH, SAMPLE_PY, SAMPLE_SIZE from opuscleaner.datasets import list_datasets, Path from opuscleaner.download import app as download_app -from opuscleaner.filters import filter_format_command, get_global_filter, get_global_filters, set_global_filters, list_filters, FilterType, FilterStep, FilterPipeline +from opuscleaner.filters import ( + filter_format_command, + get_global_filter, + get_global_filters, + set_global_filters, + list_filters, + FilterType, + FilterStep, + FilterPipeline, +) import mimetypes -mimetypes.add_type('application/javascript', '.js') +mimetypes.add_type("application/javascript", ".js") -FRONTEND_PATH = next(iter(path - for path in [ - os.path.join(os.path.dirname(__file__), 'frontend'), - os.path.join(os.path.dirname(__file__), '../frontend/dist'), - ] - if os.path.exists(path) -)) + +FRONTEND_PATH = next( + iter( + path + for path in [ + os.path.join(os.path.dirname(__file__), "frontend"), + os.path.join(os.path.dirname(__file__), "../frontend/dist"), + ] + if os.path.exists(path) + ) +) class Column(BaseModel): @@ -53,17 +66,18 @@ class Dataset(BaseModel): class FilterPipelinePatch(BaseModel): """A list of changes to a filter pipeline (used when updating filters)""" + filters: List[FilterStep] -def dataset_path(name:str, template:str) -> str: +def dataset_path(name: str, template: str) -> str: # TODO: fix this hack to get the file path from the name this is silly we # should just use get_dataset(name).path or something - root = DATA_PATH.split('*')[0] + root = DATA_PATH.split("*")[0] # If the dataset name is a subdirectory, do some hacky shit to get to a # .sample.gz file in said subdirectory. - parts = name.rsplit('/', maxsplit=2) + parts = name.rsplit("/", maxsplit=2) if len(parts) == 2: root = os.path.join(root, parts[0]) filename = parts[1] @@ -73,38 +87,40 @@ def dataset_path(name:str, template:str) -> str: return os.path.join(root, template.format(filename)) -def sample_path(name:str, langs:Iterable[str]) -> str: - languages = '.'.join(sorted(langs)) - return dataset_path(name, f'.sample.{{}}.{languages}') +def sample_path(name: str, langs: Iterable[str]) -> str: + languages = ".".join(sorted(langs)) + return dataset_path(name, f".sample.{{}}.{languages}") -def filter_configuration_path(name:str) -> str: - return dataset_path(name, '{}.filters.json') +def filter_configuration_path(name: str) -> str: + return dataset_path(name, "{}.filters.json") -async def compute_sample(name:str, columns:List[Tuple[str,Path]]) -> None: +async def compute_sample(name: str, columns: List[Tuple[str, Path]]) -> None: langs = [lang for lang, _ in columns] with TemporaryFile() as tempfile: proc = await asyncio.subprocess.create_subprocess_exec( *SAMPLE_PY, - '-n', str(SAMPLE_SIZE), + "-n", + str(SAMPLE_SIZE), *[str(file.resolve()) for _, file in columns], stdout=tempfile, - stderr=asyncio.subprocess.PIPE) + stderr=asyncio.subprocess.PIPE, + ) _, stderr = await proc.communicate() if proc.returncode != 0: - raise Exception(f'sample.py returned {proc.returncode}: {stderr.decode()}') + raise Exception(f"sample.py returned {proc.returncode}: {stderr.decode()}") tempfile.seek(0) - with open(sample_path(name, langs), 'wb') as fdest: + with open(sample_path(name, langs), "wb") as fdest: copyfileobj(tempfile, fdest) class FilterOutput(NamedTuple): - langs: List[str] # order of columns + langs: List[str] # order of columns returncode: int stdout: bytes stderr: bytes @@ -112,36 +128,40 @@ class FilterOutput(NamedTuple): class ParsedFilterOutput(BaseModel): """JSON serializable version of FilterOutput that has stdout parsed into - an array of dicts, with a field per language. + an array of dicts, with a field per language. """ + returncode: int - stdout: List[Dict[str,str]] + stdout: List[Dict[str, str]] stderr: str - def __init__(self, output:FilterOutput): + def __init__(self, output: FilterOutput): lines = [] - for lineno, line in enumerate(output.stdout.rstrip(b'\r\n').split(b'\n'), start=1): + for lineno, line in enumerate( + output.stdout.rstrip(b"\r\n").split(b"\n"), start=1 + ): values = [] - for colno, field in enumerate(line.rstrip(b'\r').split(b'\t'), start=1): + for colno, field in enumerate(line.rstrip(b"\r").split(b"\t"), start=1): try: values.append(field.decode()) except UnicodeDecodeError as e: - values.append(f'[Error: Cannot decode line {lineno} column {colno}: {e!s}]') - lines.append(dict(zip_longest(output.langs, values, fillvalue=''))) + values.append( + f"[Error: Cannot decode line {lineno} column {colno}: {e!s}]" + ) + lines.append(dict(zip_longest(output.langs, values, fillvalue=""))) super().__init__( - returncode=output.returncode, - stdout=lines, - stderr=output.stderr.decode()) + returncode=output.returncode, stdout=lines, stderr=output.stderr.decode() + ) class SampleCacheEntry(NamedTuple): checksum: bytes - future: asyncio.Task#[FilterOutput] + future: asyncio.Task # [FilterOutput] -sample_cache: Dict[str,List[SampleCacheEntry]] = {} +sample_cache: Dict[str, List[SampleCacheEntry]] = {} def cache_hash(obj: Any, seed: bytes = bytes()) -> bytes: @@ -150,19 +170,23 @@ def cache_hash(obj: Any, seed: bytes = bytes()) -> bytes: return impl.digest() -async def get_dataset_sample(name:str, columns:List[Tuple[str,Path]]) -> FilterOutput: +async def get_dataset_sample( + name: str, columns: List[Tuple[str, Path]] +) -> FilterOutput: langs = [lang for lang, _ in columns] if not os.path.exists(sample_path(name, langs)): await compute_sample(name, columns) - with open(sample_path(name, langs), 'rb') as fh: + with open(sample_path(name, langs), "rb") as fh: stdout = fh.read() return FilterOutput([lang for lang, _ in columns], 0, stdout, bytes()) -async def exec_filter_step(filter_step: FilterStep, langs: List[str], input: bytes) -> Tuple[bytes,bytes]: +async def exec_filter_step( + filter_step: FilterStep, langs: List[str], input: bytes +) -> Tuple[bytes, bytes]: filter_definition = get_global_filter(filter_step.filter) command = filter_format_command(filter_definition, filter_step, langs) @@ -172,18 +196,21 @@ async def exec_filter_step(filter_step: FilterStep, langs: List[str], input: byt # default, but if you call it with the virtualenv's python binary # directly it wont. pyenv_bin_path = os.path.dirname(sys.executable) - os_env_bin_paths = os.environ.get('PATH', '').split(os.pathsep) - filter_env = { - **os.environ, - 'PATH': os.pathsep.join([pyenv_bin_path] + os_env_bin_paths) - } if pyenv_bin_path not in os_env_bin_paths else None + os_env_bin_paths = os.environ.get("PATH", "").split(os.pathsep) + filter_env = ( + {**os.environ, "PATH": os.pathsep.join([pyenv_bin_path] + os_env_bin_paths)} + if pyenv_bin_path not in os_env_bin_paths + else None + ) - p_filter = await asyncio.create_subprocess_shell(command, + p_filter = await asyncio.create_subprocess_shell( + command, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=filter_definition.basedir, - env=filter_env) + env=filter_env, + ) # Check exit codes, testing most obvious problems first. stdout, stderr = await p_filter.communicate(input=input) @@ -191,21 +218,24 @@ async def exec_filter_step(filter_step: FilterStep, langs: List[str], input: byt return FilterOutput(langs, p_filter.returncode, stdout, stderr) -def cancel_cached_tasks(name:str, offset:int) -> None: - assert offset > 0 # offset == 0 is sample.py itself, and should never be cancelled through this method. +def cancel_cached_tasks(name: str, offset: int) -> None: + assert ( + offset > 0 + ) # offset == 0 is sample.py itself, and should never be cancelled through this method. for entry in sample_cache[name][offset:]: entry.future.cancel() del sample_cache[name][offset:] -async def get_sample(name:str, filters:List[FilterStep]) -> AsyncIterator[FilterOutput]: - columns: List[Tuple[str,Path]] = list_datasets(DATA_PATH)[name] +async def get_sample( + name: str, filters: List[FilterStep] +) -> AsyncIterator[FilterOutput]: + columns: List[Tuple[str, Path]] = list_datasets(DATA_PATH)[name] langs = [lang for lang, _ in columns] - checksum = cache_hash([ - (name, str(path), path.stat().st_mtime) - for name, path in columns - ]) + checksum = cache_hash( + [(name, str(path), path.stat().st_mtime) for name, path in columns] + ) # If we don't have a sample stored, generate one. Doing it in bytes because # it might save us parsing utf-8 (also assumptions! It it utf-8?) @@ -217,7 +247,7 @@ async def get_sample(name:str, filters:List[FilterStep]) -> AsyncIterator[Filter sample_cache[name] = [ SampleCacheEntry( checksum=checksum, - future=asyncio.create_task(get_dataset_sample(name, columns)) + future=asyncio.create_task(get_dataset_sample(name, columns)), ) ] @@ -237,18 +267,24 @@ async def get_sample(name:str, filters:List[FilterStep]) -> AsyncIterator[Filter # - stdin input (via checksum of previous step) checksum = cache_hash( jsonable_encoder(filter_step), - cache_hash(jsonable_encoder(filter_definition), - sample_cache[name][i-1].checksum)) + cache_hash( + jsonable_encoder(filter_definition), sample_cache[name][i - 1].checksum + ), + ) # If we do not have a cache entry for this point if len(sample_cache[name]) <= i or sample_cache[name][i].checksum != checksum: # Invalidate all the cache after this step cancel_cached_tasks(name, i) - sample_cache[name].append(SampleCacheEntry( - checksum=checksum, - future=asyncio.create_task(exec_filter_step(filter_step, langs, sample.stdout)) - )) + sample_cache[name].append( + SampleCacheEntry( + checksum=checksum, + future=asyncio.create_task( + exec_filter_step(filter_step, langs, sample.stdout) + ), + ) + ) assert len(sample_cache[name]) == i + 1 @@ -259,19 +295,19 @@ async def get_sample(name:str, filters:List[FilterStep]) -> AsyncIterator[Filter # Return the (partially) filtered sample yield sample - # if there are additional steps left in the cache, remove them if len(sample_cache[name]) > len(filters) + 1: cancel_cached_tasks(name, len(filters) + 1) -def stream_jsonl(iterable:AsyncIterator[Any]) -> StreamingResponse: +def stream_jsonl(iterable: AsyncIterator[Any]) -> StreamingResponse: return StreamingResponse( ( - json.dumps(jsonable_encoder(line), separators=(',', ':')).encode() + b"\n" + json.dumps(jsonable_encoder(line), separators=(",", ":")).encode() + b"\n" async for line in iterable ), - media_type='application/json') + media_type="application/json", + ) @asynccontextmanager @@ -289,56 +325,64 @@ async def lifespan(app: FastAPI): app.add_middleware(GZipMiddleware, minimum_size=512) -@app.get('/api/datasets/') + +@app.get("/api/datasets/") def api_list_datasets() -> List[Dataset]: return [ - Dataset(name=name, columns=[ - Column(lang=lang, path=file.name, size=file.stat().st_size) - for lang, file in columns - ]) + Dataset( + name=name, + columns=[ + Column(lang=lang, path=file.name, size=file.stat().st_size) + for lang, file in columns + ], + ) for name, columns in list_datasets(DATA_PATH).items() ] -@app.get('/api/datasets/{name:path}/') -def api_get_dataset(name:str) -> Dataset: +@app.get("/api/datasets/{name:path}/") +def api_get_dataset(name: str) -> Dataset: columns = list_datasets(DATA_PATH).get(name) if not columns: - raise HTTPException(status_code=404, detail='Dataset not found') + raise HTTPException(status_code=404, detail="Dataset not found") - return Dataset(name=name, columns=[ - Column(lang=lang, path=file.name, size=file.stat().st_size) - for lang, file in columns - ]) + return Dataset( + name=name, + columns=[ + Column(lang=lang, path=file.name, size=file.stat().st_size) + for lang, file in columns + ], + ) -@app.get('/api/datasets/{name:path}/sample') -def api_get_sample(name:str) -> Response: - return stream_jsonl(ParsedFilterOutput(output) async for output in get_sample(name, [])) +@app.get("/api/datasets/{name:path}/sample") +def api_get_sample(name: str) -> Response: + return stream_jsonl( + ParsedFilterOutput(output) async for output in get_sample(name, []) + ) -@app.post('/api/datasets/{name:path}/sample') -def api_get_filtered_sample(name:str, filters:List[FilterStep]) -> Response: - return stream_jsonl(ParsedFilterOutput(output) async for output in get_sample(name, filters)) +@app.post("/api/datasets/{name:path}/sample") +def api_get_filtered_sample(name: str, filters: List[FilterStep]) -> Response: + return stream_jsonl( + ParsedFilterOutput(output) async for output in get_sample(name, filters) + ) -def make_pipeline(name:str, filters:List[FilterStep] = []) -> FilterPipeline: +def make_pipeline(name: str, filters: List[FilterStep] = []) -> FilterPipeline: columns = list_datasets(DATA_PATH)[name] return FilterPipeline( - version=1, - files=[file.name for _, file in columns], - filters=filters + version=1, files=[file.name for _, file in columns], filters=filters ) -@app.get('/api/datasets/{name:path}/configuration.json') -def api_get_dataset_filters(name:str) -> FilterPipeline: - +@app.get("/api/datasets/{name:path}/configuration.json") +def api_get_dataset_filters(name: str) -> FilterPipeline: if not os.path.exists(filter_configuration_path(name)): return make_pipeline(name) - with open(filter_configuration_path(name), 'r') as fh: + with open(filter_configuration_path(name), "r") as fh: data = json.load(fh) try: return parse_obj_as(FilterPipeline, data) @@ -351,119 +395,140 @@ def api_get_dataset_filters(name:str) -> FilterPipeline: return make_pipeline(name) - -@app.patch('/api/datasets/{name:path}/configuration.json') -def api_update_dataset_filters(name:str, patch:FilterPipelinePatch): +@app.patch("/api/datasets/{name:path}/configuration.json") +def api_update_dataset_filters(name: str, patch: FilterPipelinePatch): pipeline = make_pipeline(name, patch.filters) - with open(filter_configuration_path(name), 'w') as fh: + with open(filter_configuration_path(name), "w") as fh: return json.dump(pipeline.dict(), fh, indent=2) -@app.get('/api/datasets/{name:path}/configuration-for-opusfilter.yaml') -def api_get_dataset_filters_as_openfilter(name:str) -> Response: +@app.get("/api/datasets/{name:path}/configuration-for-opusfilter.yaml") +def api_get_dataset_filters_as_openfilter(name: str) -> Response: if not os.path.exists(filter_configuration_path(name)): - raise HTTPException(status_code=404, detail='Dataset not found') + raise HTTPException(status_code=404, detail="Dataset not found") - with open(filter_configuration_path(name), 'r') as fh: + with open(filter_configuration_path(name), "r") as fh: data = json.load(fh) pipeline = parse_obj_as(FilterPipeline, data) - opusfilter_config: Dict[str,Any] = { - 'steps': [] - } + opusfilter_config: Dict[str, Any] = {"steps": []} input_files = pipeline.files preprocess_steps = [] - filter_steps: List[Dict[str,Any]] = [] + filter_steps: List[Dict[str, Any]] = [] for step in pipeline.filters: - if (match := re.search(r'\bopusfilter\.preprocessors\.(\w+)\b', get_global_filter(step.filter).command)): - preprocess_steps.append({ - str(match.group(1)): step.parameters - }) - elif (match := re.search(r'\bopusfilter\.filters\.(\w+)\b', get_global_filter(step.filter).command)): - filter_steps.append({ - str(match.group(1)): step.parameters - }) + if match := re.search( + r"\bopusfilter\.preprocessors\.(\w+)\b", + get_global_filter(step.filter).command, + ): + preprocess_steps.append({str(match.group(1)): step.parameters}) + elif match := re.search( + r"\bopusfilter\.filters\.(\w+)\b", get_global_filter(step.filter).command + ): + filter_steps.append({str(match.group(1)): step.parameters}) elif get_global_filter(step.filter).type == FilterType.BILINGUAL: - filter_steps.append({ - 'OpusCleanerFilter': { - 'filter': step.filter, - 'parameters': step.parameters - }, - 'module': 'opuscleaner.opusfilter_compat' - }) + filter_steps.append( + { + "OpusCleanerFilter": { + "filter": step.filter, + "parameters": step.parameters, + }, + "module": "opuscleaner.opusfilter_compat", + } + ) elif get_global_filter(step.filter).type == FilterType.MONOLINGUAL: - filter_steps.append({ - 'OpusCleanerFilter': { - 'filter': step.filter, - 'parameters': step.parameters - }, - 'module': 'opuscleaner.opusfilter_compat' - }) + filter_steps.append( + { + "OpusCleanerFilter": { + "filter": step.filter, + "parameters": step.parameters, + }, + "module": "opuscleaner.opusfilter_compat", + } + ) else: - raise ValueError(f'Cannot convert "{step.filter}" to opusfilter configuration') + raise ValueError( + f'Cannot convert "{step.filter}" to opusfilter configuration' + ) if preprocess_steps: output_files = [ - os.path.join(os.path.dirname(file), 'preprocessed.' + os.path.basename(file)) + os.path.join( + os.path.dirname(file), "preprocessed." + os.path.basename(file) + ) for file in pipeline.files ] - opusfilter_config['steps'].append({ - 'type': 'preprocess', - 'parameters': { - 'inputs': input_files, - 'outputs': output_files, - 'preprocessors': preprocess_steps + opusfilter_config["steps"].append( + { + "type": "preprocess", + "parameters": { + "inputs": input_files, + "outputs": output_files, + "preprocessors": preprocess_steps, + }, } - }) + ) input_files = output_files if filter_steps: output_files = [ - os.path.join(os.path.dirname(file), 'filtered.' + os.path.basename(file)) + os.path.join(os.path.dirname(file), "filtered." + os.path.basename(file)) for file in pipeline.files ] - opusfilter_config['steps'].append({ - 'type': 'filter', - 'parameters': { - 'inputs': input_files, - 'outputs': output_files, - 'filters': filter_steps + opusfilter_config["steps"].append( + { + "type": "filter", + "parameters": { + "inputs": input_files, + "outputs": output_files, + "filters": filter_steps, + }, } - }) + ) input_files = output_files - return Response(yaml.safe_dump(opusfilter_config, sort_keys=False), media_type='application/yaml') + return Response( + yaml.safe_dump(opusfilter_config, sort_keys=False), + media_type="application/yaml", + ) -@app.get('/api/filters/') +@app.get("/api/filters/") def api_get_filters(): set_global_filters(list_filters(FILTER_PATH)) return get_global_filters() -@app.get('/') +@app.get("/") def redirect_to_interface(): - return RedirectResponse('/frontend/index.html') + return RedirectResponse("/frontend/index.html") + +app.mount("/frontend/", StaticFiles(directory=FRONTEND_PATH, html=True), name="static") -app.mount('/frontend/', StaticFiles(directory=FRONTEND_PATH, html=True), name='static') +app.mount("/api/download/", download_app) -app.mount('/api/download/', download_app) +app.mount("/api/categories/", categories_app) -app.mount('/api/categories/', categories_app) def main_serve(args): import uvicorn - uvicorn.run('opuscleaner.server:app', host=args.host, port=args.port, reload=args.reload, log_level='info') + + uvicorn.run( + "opuscleaner.server:app", + host=args.host, + port=args.port, + reload=args.reload, + log_level="info", + ) async def sample_all_datasets(args): @@ -475,9 +540,16 @@ async def sample_all_datasets(args): print(f"Sampling {name}...", file=sys.stderr) tasks.append([name, columns]) - for task, result in zip(tasks, await asyncio.gather(*[compute_sample(*task) for task in tasks], return_exceptions=True)): + for task, result in zip( + tasks, + await asyncio.gather( + *[compute_sample(*task) for task in tasks], return_exceptions=True + ), + ): if isinstance(result, Exception): - print(f"Could not compute sample for {task[0]}: {result!s}", file=sys.stderr) + print( + f"Could not compute sample for {task[0]}: {result!s}", file=sys.stderr + ) def main_sample(args): @@ -485,32 +557,51 @@ def main_sample(args): def main_list_commands(args): - print("Error: No command specified.\n\n" - "Available commands:\n" - " serve run webserver\n" - " sample sample all datasets\n" - "", file=sys.stderr) + print( + "Error: No command specified.\n\n" + "Available commands:\n" + " serve run webserver\n" + " sample sample all datasets\n" + "", + file=sys.stderr, + ) sys.exit(1) def main(argv=sys.argv): import argparse - parser = argparse.ArgumentParser(description='Fill up those seats on your empty train.') + parser = argparse.ArgumentParser( + description="Fill up those seats on your empty train." + ) parser.set_defaults(func=main_list_commands) subparsers = parser.add_subparsers() - parser_serve = subparsers.add_parser('serve') - parser_serve.add_argument('--host', type=str, default='127.0.0.1', help='Bind socket to this host. (default: 127.0.0.1)') - parser_serve.add_argument('-p', '--port', type=int, default=8000, help='Bind socket to this port. (default: 8000)') - parser_serve.add_argument('--reload', action='store_true', help='Enable auto-reload.') + parser_serve = subparsers.add_parser("serve") + parser_serve.add_argument( + "--host", + type=str, + default="127.0.0.1", + help="Bind socket to this host. (default: 127.0.0.1)", + ) + parser_serve.add_argument( + "-p", + "--port", + type=int, + default=8000, + help="Bind socket to this port. (default: 8000)", + ) + parser_serve.add_argument( + "--reload", action="store_true", help="Enable auto-reload." + ) parser_serve.set_defaults(func=main_serve) - parser_sample = subparsers.add_parser('sample') + parser_sample = subparsers.add_parser("sample") parser_sample.set_defaults(func=main_sample) args = parser.parse_args() args.func(args) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/opuscleaner/threshold.py b/opuscleaner/threshold.py index 4834bef..7b34494 100755 --- a/opuscleaner/threshold.py +++ b/opuscleaner/threshold.py @@ -3,7 +3,7 @@ It passes every line of input onto the scorer program (unless --cache is specified and the line is already in the cache) that generates a score. If ---threshold is specified (optionally with an operator specified, default is +--threshold is specified (optionally with an operator specified, default is greater or equal, c.f. --ge) then the line is printed if the threshold is met. If no threshold is specified, the score is added as the first column to the output. @@ -12,6 +12,7 @@ program, you also change the path to the cache. Otherwise you'll get scores from the scorer that was run with different arguments. """ + import sys import os import signal @@ -29,241 +30,271 @@ class Entry: - """Cache entry. Only an object with single property so we can easily update - it by reference, really. - """ - __slots__ = ['score'] - score: Optional[float] + """Cache entry. Only an object with single property so we can easily update + it by reference, really. + """ + + __slots__ = ["score"] + score: Optional[float] - def __init__(self, score: Optional[float] = None): - self.score = score + def __init__(self, score: Optional[float] = None): + self.score = score class Cache: - """Just a subset of dict[] really!""" - def __init__(self): - self.entries = {} + """Just a subset of dict[] really!""" + + def __init__(self): + self.entries = {} + + def __contains__(self, key: bytes) -> bool: + return key in self.entries - def __contains__(self, key: bytes) -> bool: - return key in self.entries - - def __getitem__(self, key: bytes) -> Entry: - return self.entries[key] + def __getitem__(self, key: bytes) -> Entry: + return self.entries[key] - def __setitem__(self, key: bytes, value: Entry): - self.entries[key] = value + def __setitem__(self, key: bytes, value: Entry): + self.entries[key] = value - def __enter__(self): - return self + def __enter__(self): + return self - def __exit__(self, *args): - return + def __exit__(self, *args): + return class PersistentEntry: - """Mimics Entry(), but with a setter that updates the persistent cache.""" - __slots__ = ['cache', 'key', '_score'] - cache: "PersistentCache" - key: bytes + """Mimics Entry(), but with a setter that updates the persistent cache.""" - def __init__(self, cache, key: bytes, score: Optional[float] = None): - self.cache = cache - self.key = key - self._score = score + __slots__ = ["cache", "key", "_score"] + cache: "PersistentCache" + key: bytes - @property - def score(self) -> Optional[float]: - return self._score + def __init__(self, cache, key: bytes, score: Optional[float] = None): + self.cache = cache + self.key = key + self._score = score - @score.setter - def score(self, score: float): - self._score = score - self.cache._write(self.key, score) + @property + def score(self) -> Optional[float]: + return self._score + + @score.setter + def score(self, score: float): + self._score = score + self.cache._write(self.key, score) class PersistentCache(Cache): - """Similar to Cache, but will also look at the database file that's on disk - when queried for known entries. - """ - __slots__ = ['entries', 'db', '_backing'] + """Similar to Cache, but will also look at the database file that's on disk + when queried for known entries. + """ + + __slots__ = ["entries", "db", "_backing"] - def __init__(self, path: str): - self.db = dbm.open(path, 'c') # was 'cfu' create, fast, unlocked (TODO unlocked?!) but that only works if the gnu backend is used - self.entries: Dict[bytes,PersistentEntry] = {} + def __init__(self, path: str): + self.db = dbm.open( + path, "c" + ) # was 'cfu' create, fast, unlocked (TODO unlocked?!) but that only works if the gnu backend is used + self.entries: Dict[bytes, PersistentEntry] = {} - def __enter__(self): - self._backing = self.db.__enter__() - return self + def __enter__(self): + self._backing = self.db.__enter__() + return self - def __exit__(self, *args): - self._backing.__exit__(*args) + def __exit__(self, *args): + self._backing.__exit__(*args) - def __contains__(self, key: bytes): - return key in self.entries or key in self._backing + def __contains__(self, key: bytes): + return key in self.entries or key in self._backing - def __getitem__(self, key: bytes) -> PersistentEntry: - if key not in self.entries: - score = self._decode(self._backing[key]) - self.entries[key] = PersistentEntry(self, key, score) - return self.entries[key] + def __getitem__(self, key: bytes) -> PersistentEntry: + if key not in self.entries: + score = self._decode(self._backing[key]) + self.entries[key] = PersistentEntry(self, key, score) + return self.entries[key] - def __setitem__(self, key: bytes, value: Entry): - self.entries[key] = PersistentEntry(self, key, value.score) + def __setitem__(self, key: bytes, value: Entry): + self.entries[key] = PersistentEntry(self, key, value.score) - def _write(self, key: bytes, value: float): - self._backing[key] = self._encode(value) + def _write(self, key: bytes, value: float): + self._backing[key] = self._encode(value) - def _encode(self, value: float) -> bytes: - return struct.pack(' bytes: + return struct.pack(" float: - return struct.unpack(' float: + return struct.unpack(" T: - """Runtime cast of `Optional[T]` into `T`. Will raise an AssertionError if - the argument was indeed `None`. - """ - if optional is None: - raise ValueError(message) - return optional + """Runtime cast of `Optional[T]` into `T`. Will raise an AssertionError if + the argument was indeed `None`. + """ + if optional is None: + raise ValueError(message) + return optional def exit_on_throw(fn): - """Wraps thread main function so that an exception thrown in the thread - will terminate the entire process. - """ - @wraps(fn) - def wrapper(*args, **kwargs): - try: - return fn(*args, **kwargs) - except: - print_exc(file=sys.stderr) - os.kill(os.getpid(), signal.SIGKILL) - return wrapper + """Wraps thread main function so that an exception thrown in the thread + will terminate the entire process. + """ + + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except: + print_exc(file=sys.stderr) + os.kill(os.getpid(), signal.SIGKILL) + + return wrapper def feed_child(queue, fin, fchild, cache): - """Thread that reads each line from stream `fin`, and will put its score - `Entry` onto queue `queue`. If a line is a duplicate, it will use the entry - from the previous occurrence. If a line is new, it will also feed it to - child process `fchild` so a score can be calculated. Because the order of - the queue and the order of feeding fchild are the same, the - `threshold_scores` thread will know how to link them back together. - """ - derive_key = lambda val: xxh32(val).digest() - - try: - for line in fin: - key = derive_key(line) - # if key not in cache, we've never seen the sentence - if key not in cache: - fchild.write(line) - cache[key] = Entry() - - queue.put((line, cache[key])) - fchild.close() - except BrokenPipeError: - pass - finally: - queue.put(None) # End indicator - fin.close() + """Thread that reads each line from stream `fin`, and will put its score + `Entry` onto queue `queue`. If a line is a duplicate, it will use the entry + from the previous occurrence. If a line is new, it will also feed it to + child process `fchild` so a score can be calculated. Because the order of + the queue and the order of feeding fchild are the same, the + `threshold_scores` thread will know how to link them back together. + """ + derive_key = lambda val: xxh32(val).digest() + + try: + for line in fin: + key = derive_key(line) + # if key not in cache, we've never seen the sentence + if key not in cache: + fchild.write(line) + cache[key] = Entry() + + queue.put((line, cache[key])) + fchild.close() + except BrokenPipeError: + pass + finally: + queue.put(None) # End indicator + fin.close() def threshold_scores(queue, fchild, fout, threshold, operator): - """Thread that reads the queue and, depending on the threshold, will write - the line to output. It will also read any missing scores from the child - `fchild`. Because this is the only thread reading & writing to Entry objects - no locks are necessary. - """ - try: - while True: - item = queue.get() - - # Poison - if item is None: - break - - # If no score yet, get it from the child - if item[1].score is None: - item[1].score = float(fchild.readline()) - - # If no threshold is specified, print everything and prefix it with the score - if threshold is None: - fout.write(str(item[1].score).encode() + b'\t' + item[0]) - - # Otherwise only print the actual line if threshold is met - elif operator(item[1].score, threshold): - fout.write(item[0]) - - # TODO: test somehow that child has stopped producing? Reading from `fchild` - # should at this point return EOF since its stdin is already closed. - fout.close() - except BrokenPipeError: - pass - finally: - fchild.close() + """Thread that reads the queue and, depending on the threshold, will write + the line to output. It will also read any missing scores from the child + `fchild`. Because this is the only thread reading & writing to Entry objects + no locks are necessary. + """ + try: + while True: + item = queue.get() + + # Poison + if item is None: + break + + # If no score yet, get it from the child + if item[1].score is None: + item[1].score = float(fchild.readline()) + + # If no threshold is specified, print everything and prefix it with the score + if threshold is None: + fout.write(str(item[1].score).encode() + b"\t" + item[0]) + + # Otherwise only print the actual line if threshold is met + elif operator(item[1].score, threshold): + fout.write(item[0]) + + # TODO: test somehow that child has stopped producing? Reading from `fchild` + # should at this point return EOF since its stdin is already closed. + fout.close() + except BrokenPipeError: + pass + finally: + fchild.close() def open_cache(path: Optional[str]) -> Cache: - """Instantiates a cache type based on the path (or None) given.""" - if path: - return PersistentCache(path) - else: - return Cache() + """Instantiates a cache type based on the path (or None) given.""" + if path: + return PersistentCache(path) + else: + return Cache() def main(): - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument('threshold', type=float, help='Threshold (b) to compare score to.') - parser.add_argument('scorer', type=str, nargs='+', help='Scorer program (a) and arguments.') - parser.add_argument('--cache', '-c', type=str, help='Path to cache database.') - - ops = parser.add_mutually_exclusive_group() - ops.set_defaults(operator=operator.ge) # default to --ge - for name in ['lt', 'le', 'eq', 'ne', 'ge', 'gt']: - ops.add_argument(f'--{name}', dest='operator', action='store_const', const=getattr(operator, name), help=getattr(operator, name).__doc__) - - args, scorer_args = parser.parse_known_args() - - # TODO: Make this Popen call only necessary if there was any need for it, - # i.e. not all sentences could be scored by just the cache. I'm tempted to - # add yet another wrapper program that only starts the process once input - # is readable from stdin and then just re-attaches stdin to the child? Bit - # like how inetd works. Or should this be a task for the downstream scorer - # i.e. only load the model once input is received? - child = Popen(args.scorer + scorer_args, stdin=PIPE, stdout=PIPE) - - queue = SimpleQueue() # type: SimpleQueue[tuple[bytes,Entry]] - - try: - with open_cache(args.cache) as cache: - # Reads stdin, writes it to queue, and possibly to child for scoring. - feeder = Thread(target=exit_on_throw(feed_child), args=[queue, sys.stdin.buffer, child.stdin, cache]) - feeder.start() - - # Reads queue, writes to stdout, reading scores from child if necessary. - consumer = Thread(target=exit_on_throw(threshold_scores), args=[queue, child.stdout, sys.stdout.buffer, args.threshold, args.operator]) - consumer.start() - - # Feeder will be done at this point - feeder.join() - - # Consumer will be done once it read the last None from the queue. - consumer.join() - - # Feeder will close child.stdin when all input is processed, which should - # cause child to terminate. - except: - none_throws(child.stdin).close() - finally: - sys.stderr.close() - retval = child.wait() - sys.exit(retval) - - -if __name__ == '__main__': - main() + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "threshold", type=float, help="Threshold (b) to compare score to." + ) + parser.add_argument( + "scorer", type=str, nargs="+", help="Scorer program (a) and arguments." + ) + parser.add_argument("--cache", "-c", type=str, help="Path to cache database.") + + ops = parser.add_mutually_exclusive_group() + ops.set_defaults(operator=operator.ge) # default to --ge + for name in ["lt", "le", "eq", "ne", "ge", "gt"]: + ops.add_argument( + f"--{name}", + dest="operator", + action="store_const", + const=getattr(operator, name), + help=getattr(operator, name).__doc__, + ) + + args, scorer_args = parser.parse_known_args() + + # TODO: Make this Popen call only necessary if there was any need for it, + # i.e. not all sentences could be scored by just the cache. I'm tempted to + # add yet another wrapper program that only starts the process once input + # is readable from stdin and then just re-attaches stdin to the child? Bit + # like how inetd works. Or should this be a task for the downstream scorer + # i.e. only load the model once input is received? + child = Popen(args.scorer + scorer_args, stdin=PIPE, stdout=PIPE) + + queue = SimpleQueue() # type: SimpleQueue[tuple[bytes,Entry]] + + try: + with open_cache(args.cache) as cache: + # Reads stdin, writes it to queue, and possibly to child for scoring. + feeder = Thread( + target=exit_on_throw(feed_child), + args=[queue, sys.stdin.buffer, child.stdin, cache], + ) + feeder.start() + + # Reads queue, writes to stdout, reading scores from child if necessary. + consumer = Thread( + target=exit_on_throw(threshold_scores), + args=[ + queue, + child.stdout, + sys.stdout.buffer, + args.threshold, + args.operator, + ], + ) + consumer.start() + + # Feeder will be done at this point + feeder.join() + + # Consumer will be done once it read the last None from the queue. + consumer.join() + + # Feeder will close child.stdin when all input is processed, which should + # cause child to terminate. + except: + none_throws(child.stdin).close() + finally: + sys.stderr.close() + retval = child.wait() + sys.exit(retval) + + +if __name__ == "__main__": + main() diff --git a/placeholders/placeholders.py b/placeholders/placeholders.py index 26123a3..d54fbf6 100755 --- a/placeholders/placeholders.py +++ b/placeholders/placeholders.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 """An encoder/decoder for placeholders in python3, using spm vocabulary""" + import sys from typing import List, Dict, Tuple from dataclasses import dataclass @@ -11,33 +12,65 @@ parser = argparse.ArgumentParser() -parser.add_argument('-c', '--config', type=str, help='Path to yaml configuration file, required for encoding') -parser.add_argument('-m', '--mappings_file', type=str, default="mappings.yml", help='Path to the mappings, one yaml entry per line.') -parser.add_argument('-s', '--seed', type=int, default=None, help='Seed for random number generator.') -parser.add_argument('-n', '--no-mapping', action="store_true", help='Do not dump a mapping file for decoding. Useful for training') -parser.add_argument('-t', '--strict', action="store_true", help="Only generate a placeholder if there's equal number on the source and target side of each (assumes TSV input).") +parser.add_argument( + "-c", + "--config", + type=str, + help="Path to yaml configuration file, required for encoding", +) +parser.add_argument( + "-m", + "--mappings_file", + type=str, + default="mappings.yml", + help="Path to the mappings, one yaml entry per line.", +) +parser.add_argument( + "-s", "--seed", type=int, default=None, help="Seed for random number generator." +) +parser.add_argument( + "-n", + "--no-mapping", + action="store_true", + help="Do not dump a mapping file for decoding. Useful for training", +) +parser.add_argument( + "-t", + "--strict", + action="store_true", + help="Only generate a placeholder if there's equal number on the source and target side of each (assumes TSV input).", +) mutex_group_1 = parser.add_mutually_exclusive_group(required=True) -mutex_group_1.add_argument('--decode', action='store_true') -mutex_group_1.add_argument('--encode', action='store_true') -mutex_group_1.add_argument('--dump_placeholders', action='store_true', help='Check to print placeholders out') +mutex_group_1.add_argument("--decode", action="store_true") +mutex_group_1.add_argument("--encode", action="store_true") +mutex_group_1.add_argument( + "--dump_placeholders", action="store_true", help="Check to print placeholders out" +) + @dataclass class Rule: """Just a wrapper for regex rules""" + pattern: str + @dataclass class Configuration: """Object holding the yaml config""" + def __init__(self, config_file, dump_placeholders: bool): - with open(config_file, 'r') as config_handle: + with open(config_file, "r") as config_handle: my_config = yaml.safe_load(config_handle) # Parse - self.rules = [Rule(regex) for regex in my_config['regexes']] - self.placeholder_symbol = my_config.get('placeholder-symbol', '@') - self.num_placeholders = my_config.get('num-placeholders', 20) - self.placeholders = [self.placeholder_symbol[:-1] + str(i) + self.placeholder_symbol[-1] for i in range(self.num_placeholders)] + self.rules = [Rule(regex) for regex in my_config["regexes"]] + self.placeholder_symbol = my_config.get("placeholder-symbol", "@") + self.num_placeholders = my_config.get("num-placeholders", 20) + self.placeholders = [ + self.placeholder_symbol[:-1] + str(i) + self.placeholder_symbol[-1] + for i in range(self.num_placeholders) + ] # Add a rule that escapes patterns that look like a placeholder already # TODO: this will match placeholders we can't reach because `num_placeholders` might be smaller @@ -45,45 +78,63 @@ def __init__(self, config_file, dump_placeholders: bool): # available placeholders. This will only ever happen when the placeholder symbol itself is in the # text we are trying to encode. We don't expect this to happen otherwise, but we have a testcase # for it in our example input. - self.rules.append(Rule(pattern=re.escape(self.placeholder_symbol) + r'\d+')) + self.rules.append(Rule(pattern=re.escape(self.placeholder_symbol) + r"\d+")) # During encoding assert that we have vocab - if not dump_placeholders and 'vocab' in my_config: - vocab = my_config['vocab'] + if not dump_placeholders and "vocab" in my_config: + vocab = my_config["vocab"] self.sp = SentencePieceProcessor(vocab) # Ensure that the placeholder symbol doesn't contain unk anywhere (including in the numbers) for placeholder in self.placeholders: - for token_proto in self.sp.encode(placeholder, out_type='immutable_proto').pieces: + for token_proto in self.sp.encode( + placeholder, out_type="immutable_proto" + ).pieces: if token_proto.id == self.sp.unk_id(): - sys.stderr.write("The unk token is contained within the placeholder: " + str(token_proto.surface) + - " which will cause all sorts of trouble. Please choose a different one.\n") + sys.stderr.write( + "The unk token is contained within the placeholder: " + + str(token_proto.surface) + + " which will cause all sorts of trouble. Please choose a different one.\n" + ) sys.exit(1) else: self.sp = None class Encoder: - '''Encodes spm strings''' - def __init__(self, placeholders: List[str], spm_vocab: SentencePieceProcessor, rules: List[Rule], strict: bool, *, random: Random = Random()): + """Encodes spm strings""" + + def __init__( + self, + placeholders: List[str], + spm_vocab: SentencePieceProcessor, + rules: List[Rule], + strict: bool, + *, + random: Random = Random(), + ): self.placeholders = placeholders self.sp = spm_vocab self.rules = rules - self.unk_id = self.sp.unk_id() + self.unk_id = self.sp.unk_id() self.random = random - self.strict = strict # Use strict mode, only making replacements when the same amount of tokens are on the source and the target side + self.strict = strict # Use strict mode, only making replacements when the same amount of tokens are on the source and the target side # Compile rules into one mega-pattern - self.rule_pattern = re.compile('|'.join('(?:{})'.format(rule.pattern) for rule in self.rules)) + self.rule_pattern = re.compile( + "|".join("(?:{})".format(rule.pattern) for rule in self.rules) + ) def make_placeholders(self, inputline: str) -> Tuple[str, Dict[str, str]]: """Replaces strings that match the regex patterns from the config file and words that cause the appearance of """ - my_placeholders = list(self.placeholders) # For each line start with the full set of placeholders + my_placeholders = list( + self.placeholders + ) # For each line start with the full set of placeholders self.random.shuffle(my_placeholders) - replacements: Dict[str,str] = {} + replacements: Dict[str, str] = {} def generate_random_placeholder() -> str: """Generates random number in range defined by `num_placeholders` argparse argument @@ -102,24 +153,26 @@ def replace_one(token) -> str: return token # Remove line ending - inputline = inputline.rstrip('\r\n') + inputline = inputline.rstrip("\r\n") # use regex rules - inputline = re.sub(self.rule_pattern, lambda match: replace_one(match.group()), inputline) + inputline = re.sub( + self.rule_pattern, lambda match: replace_one(match.group()), inputline + ) # check for - input_proto = self.sp.encode(inputline, out_type='immutable_proto') + input_proto = self.sp.encode(inputline, out_type="immutable_proto") inputline = "" for token_proto in input_proto.pieces: token = token_proto.surface if token_proto.id == self.unk_id: token = replace_one(token_proto.surface) inputline += token - inputline += '\n' + inputline += "\n" # Check if strict rules apply if self.strict: - src, trg = inputline.split('\t') + src, trg = inputline.split("\t") for mytoken, myreplacement in replacements.items(): if src.count(myreplacement) != trg.count(myreplacement): # We have a mismatch placeholder on source and target @@ -128,20 +181,29 @@ def replace_one(token) -> str: return (inputline, dict((v, k) for k, v in replacements.items())) -def encode(my_placeholders: List[str], my_sp: SentencePieceProcessor, my_rules: List[Rule], strict: bool, *, random: Random, no_mapping: bool) -> None: - '''Encodes everything form stdin, dumping it to stdout and dumping a file with - all replacements''' +def encode( + my_placeholders: List[str], + my_sp: SentencePieceProcessor, + my_rules: List[Rule], + strict: bool, + *, + random: Random, + no_mapping: bool, +) -> None: + """Encodes everything form stdin, dumping it to stdout and dumping a file with + all replacements""" encoder = Encoder(my_placeholders, my_sp, my_rules, strict, random=random) - if no_mapping: # Do not produce any mappings as we are going to just use it during training + if ( + no_mapping + ): # Do not produce any mappings as we are going to just use it during training for line in sys.stdin: encoded_line, _ = encoder.make_placeholders(line) - sys.stdout.write(encoded_line) # Write the encoded line to stdout + sys.stdout.write(encoded_line) # Write the encoded line to stdout else: - with open(args.mappings_file, 'w') as yamlout: + with open(args.mappings_file, "w") as yamlout: for counter, line in enumerate(sys.stdin): - encoded_line, mappings = encoder.make_placeholders(line) - sys.stdout.write(encoded_line) # Write the encoded line to stdout + sys.stdout.write(encoded_line) # Write the encoded line to stdout # Keep track of which sentence has what replacement mappings via a yaml config sent_mapping = {counter: mappings} @@ -149,10 +211,9 @@ def encode(my_placeholders: List[str], my_sp: SentencePieceProcessor, my_rules: yamlout.flush() - def decode() -> None: """Decodes a string from stdin, given a mappings file and spits it to stdout""" - with open(args.mappings_file, 'r') as mappings: + with open(args.mappings_file, "r") as mappings: placeholder_lines = yaml.safe_load(mappings) for counter, line in enumerate(sys.stdin): try: @@ -161,12 +222,15 @@ def decode() -> None: line = line.replace(placeholder, my_placeholders[placeholder]) sys.stdout.write(line) except KeyError as e: - sys.stderr.write(f'Input line {counter + 1} contains a placeholder {e.args[0]} but there is no mapping for it.') + sys.stderr.write( + f"Input line {counter + 1} contains a placeholder {e.args[0]} but there is no mapping for it." + ) sys.exit(1) except IndexError: sys.stderr.write("The mappings file contains less lines than the input.") sys.exit(1) + if __name__ == "__main__": args = parser.parse_args() @@ -179,7 +243,13 @@ def decode() -> None: print(" ".join(config.placeholders)) sys.exit(0) elif args.encode: - encode(config.placeholders, config.sp, config.rules, args.strict, random=random, no_mapping=args.no_mapping) + encode( + config.placeholders, + config.sp, + config.rules, + args.strict, + random=random, + no_mapping=args.no_mapping, + ) else: decode() - diff --git a/test/test_clean.py b/test/test_clean.py index 5781f79..42af9e4 100644 --- a/test/test_clean.py +++ b/test/test_clean.py @@ -11,177 +11,175 @@ from tempfile import TemporaryFile, NamedTemporaryFile -TEST_CWD = Path(os.path.join(os.path.dirname(__file__), 'deeper')) +TEST_CWD = Path(os.path.join(os.path.dirname(__file__), "deeper")) -FILES = [ - "bible-uedin-v1.de-en.de.gz", - "bible-uedin-v1.de-en.en.gz" -] +FILES = ["bible-uedin-v1.de-en.de.gz", "bible-uedin-v1.de-en.en.gz"] SCENARIOS = { - 'single': [], - 'parallel': ['--parallel', '2', '--batch-size', '32000'], # parallel + "single": [], + "parallel": ["--parallel", "2", "--batch-size", "32000"], # parallel } -def parse_record(line:str) -> dict: - return json.loads(line) +def parse_record(line: str) -> dict: + return json.loads(line) -def accumulate_records(records_it:Iterable[dict]) -> Iterable[dict]: - records = defaultdict(dict) - for record in records_it: - records[record['id']].update(record) - return records.values() -class TestClean(unittest.TestCase): - def _run(self, args:List[str]): - proc = subprocess.Popen( - args=[sys.executable, '-m', 'opuscleaner.clean'] + args, - cwd=TEST_CWD, # so it can find filters - env={ - 'PYTHONPATH': os.path.join(os.path.dirname(__file__), '..') # so it can find opuscleaner code - }, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - - out, err = proc.communicate() - proc.wait() - return out, err, proc.returncode - - def test_simple(self): - """Test that clean runs""" - config = { - "version": 1, - "files": FILES, - "filters": [ - { - "filter": "deescape_tsv", - "parameters": {}, - "language": None - } - ] - } - with NamedTemporaryFile(mode='w', dir=TEST_CWD / 'data/train-parts') as fh: - json.dump(config, fh) - fh.flush() - for mode, args in SCENARIOS.items(): - with self.subTest(mode=mode): - out, _, retval = self._run([*args, fh.name]) - self.assertEqual(out.count(b'\n'), 62195) - self.assertEqual(retval, 0) - - def test_filter_fail(self): - """Test that clean returns a non-zero exit code if a filter fails""" - config = { - "version": 1, - "files": FILES, - "filters": [ - { - "filter": "fail", - "parameters": { - "EXITCODE": "42" - }, - "language": "de" - } - ] - } - with NamedTemporaryFile(mode='w', dir=TEST_CWD / 'data/train-parts') as fh: - json.dump(config, fh) - fh.flush() - - for mode, args in SCENARIOS.items(): - with self.subTest(mode=mode): - out, err, retval = self._run([*args, fh.name]) - self.assertEqual(out.count(b'\n'), 0) - self.assertIn(b'subprocess exited with status code 42', err) - self.assertNotEqual(retval, 0) - - def test_stdin(self): - """Test that clean runs""" - config = { - "version": 1, - "files": FILES, - "filters": [ - { - "filter": "deescape_tsv", - "parameters": {}, - "language": None - } - ] - } - with NamedTemporaryFile(mode='w', dir=TEST_CWD / 'data/train-parts') as fconf: - json.dump(config, fconf) - fconf.flush() - with TemporaryFile('w+b') as fdata: - # Concatenate the dataset together as if it was made with `paste <(gzip -cd a) <(gzip -cd b)` - with ExitStack() as ctx: - fhs = [ - ctx.enter_context(gzip.open(TEST_CWD / 'data/train-parts' / filename)) - for filename in FILES - ] - fdata.writelines( - b"\t".join(col.rstrip(b"\r\n") for col in line) + b"\n" - for line in zip(*fhs) - ) - fdata.flush() - - for mode, args in SCENARIOS.items(): - with self.subTest(mode=mode): - # Reset dataset input - fdata.seek(0) - - # Run cleaner with `--input -` and pass the data through stdin - proc_clean = subprocess.Popen( - args=[sys.executable, '-m', 'opuscleaner.clean', *args, '--input', '-', fconf.name, 'de', 'en'], - cwd=TEST_CWD, - env={ - 'PYTHONPATH': os.path.join(os.path.dirname(__file__), '..') # so it can find opuscleaner code - }, - stdin=fdata, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - - out, err = proc_clean.communicate() - retval = proc_clean.wait() - - # Help with debugging - if retval != 0: - print(err, file=sys.stderr) - - self.assertEqual(out.count(b'\n'), 62195) - self.assertEqual(retval, 0) - - def test_trace_time(self): - """Test that running with --trace and --time will give us time records""" - config = { - "version": 1, - "files": FILES, - "filters": [ - { - "filter": "deescape_tsv", - "parameters": {}, - "language": None - } - ] - } - with NamedTemporaryFile(mode='w', dir=TEST_CWD / 'data/train-parts') as fh: - json.dump(config, fh) - fh.flush() - for mode, args in SCENARIOS.items(): - with self.subTest(mode=mode), NamedTemporaryFile(mode='r+') as trace_fh: - out, err, retval = self._run(['--trace', trace_fh.name, '--time', *args, fh.name]) - self.assertEqual(err, b'') - self.assertEqual(out.count(b'\n'), 62195) - self.assertEqual(retval, 0) - - time_records = [ - record - for record in accumulate_records(parse_record(line) for line in trace_fh) - if 'time' in record - ] - - self.assertGreater(len(time_records), 0) - for record in time_records: - self.assertEqual(set(record['time'].keys()), {'user', 'real', 'sys'}) - self.assertGreater(record['time']['real'], 0.0) +def accumulate_records(records_it: Iterable[dict]) -> Iterable[dict]: + records = defaultdict(dict) + for record in records_it: + records[record["id"]].update(record) + return records.values() + +class TestClean(unittest.TestCase): + def _run(self, args: List[str]): + proc = subprocess.Popen( + args=[sys.executable, "-m", "opuscleaner.clean"] + args, + cwd=TEST_CWD, # so it can find filters + env={ + "PYTHONPATH": os.path.join( + os.path.dirname(__file__), ".." + ) # so it can find opuscleaner code + }, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + out, err = proc.communicate() + proc.wait() + return out, err, proc.returncode + + def test_simple(self): + """Test that clean runs""" + config = { + "version": 1, + "files": FILES, + "filters": [{"filter": "deescape_tsv", "parameters": {}, "language": None}], + } + with NamedTemporaryFile(mode="w", dir=TEST_CWD / "data/train-parts") as fh: + json.dump(config, fh) + fh.flush() + for mode, args in SCENARIOS.items(): + with self.subTest(mode=mode): + out, _, retval = self._run([*args, fh.name]) + self.assertEqual(out.count(b"\n"), 62195) + self.assertEqual(retval, 0) + + def test_filter_fail(self): + """Test that clean returns a non-zero exit code if a filter fails""" + config = { + "version": 1, + "files": FILES, + "filters": [ + {"filter": "fail", "parameters": {"EXITCODE": "42"}, "language": "de"} + ], + } + with NamedTemporaryFile(mode="w", dir=TEST_CWD / "data/train-parts") as fh: + json.dump(config, fh) + fh.flush() + + for mode, args in SCENARIOS.items(): + with self.subTest(mode=mode): + out, err, retval = self._run([*args, fh.name]) + self.assertEqual(out.count(b"\n"), 0) + self.assertIn(b"subprocess exited with status code 42", err) + self.assertNotEqual(retval, 0) + + def test_stdin(self): + """Test that clean runs""" + config = { + "version": 1, + "files": FILES, + "filters": [{"filter": "deescape_tsv", "parameters": {}, "language": None}], + } + with NamedTemporaryFile(mode="w", dir=TEST_CWD / "data/train-parts") as fconf: + json.dump(config, fconf) + fconf.flush() + with TemporaryFile("w+b") as fdata: + # Concatenate the dataset together as if it was made with `paste <(gzip -cd a) <(gzip -cd b)` + with ExitStack() as ctx: + fhs = [ + ctx.enter_context( + gzip.open(TEST_CWD / "data/train-parts" / filename) + ) + for filename in FILES + ] + fdata.writelines( + b"\t".join(col.rstrip(b"\r\n") for col in line) + b"\n" + for line in zip(*fhs) + ) + fdata.flush() + + for mode, args in SCENARIOS.items(): + with self.subTest(mode=mode): + # Reset dataset input + fdata.seek(0) + + # Run cleaner with `--input -` and pass the data through stdin + proc_clean = subprocess.Popen( + args=[ + sys.executable, + "-m", + "opuscleaner.clean", + *args, + "--input", + "-", + fconf.name, + "de", + "en", + ], + cwd=TEST_CWD, + env={ + "PYTHONPATH": os.path.join( + os.path.dirname(__file__), ".." + ) # so it can find opuscleaner code + }, + stdin=fdata, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + out, err = proc_clean.communicate() + retval = proc_clean.wait() + + # Help with debugging + if retval != 0: + print(err, file=sys.stderr) + + self.assertEqual(out.count(b"\n"), 62195) + self.assertEqual(retval, 0) + + def test_trace_time(self): + """Test that running with --trace and --time will give us time records""" + config = { + "version": 1, + "files": FILES, + "filters": [{"filter": "deescape_tsv", "parameters": {}, "language": None}], + } + with NamedTemporaryFile(mode="w", dir=TEST_CWD / "data/train-parts") as fh: + json.dump(config, fh) + fh.flush() + for mode, args in SCENARIOS.items(): + with self.subTest(mode=mode), NamedTemporaryFile(mode="r+") as trace_fh: + out, err, retval = self._run( + ["--trace", trace_fh.name, "--time", *args, fh.name] + ) + self.assertEqual(err, b"") + self.assertEqual(out.count(b"\n"), 62195) + self.assertEqual(retval, 0) + + time_records = [ + record + for record in accumulate_records( + parse_record(line) for line in trace_fh + ) + if "time" in record + ] + + self.assertGreater(len(time_records), 0) + for record in time_records: + self.assertEqual( + set(record["time"].keys()), {"user", "real", "sys"} + ) + self.assertGreater(record["time"]["real"], 0.0) diff --git a/test/test_col.py b/test/test_col.py index 3e73f3c..0f33d9a 100644 --- a/test/test_col.py +++ b/test/test_col.py @@ -7,159 +7,176 @@ from opuscleaner.config import COL_PY -TEST_INPUT = "".join([ - "Hello\tHallo\n", - "Goodbye\tBye\n", - "Beep\t\n", - "\t\n", - "beep\tboop\n", - "\tboop\n", -]) +TEST_INPUT = "".join( + [ + "Hello\tHallo\n", + "Goodbye\tBye\n", + "Beep\t\n", + "\t\n", + "beep\tboop\n", + "\tboop\n", + ] +) -TEST_INPUT_SANE = "".join([ - "Hello\tHallo\n", - "Goodbye\tTot ziens\n", - "Monitor\tComputerscherm\n", - "Outside world\tBuitenwereld\n", -]) +TEST_INPUT_SANE = "".join( + [ + "Hello\tHallo\n", + "Goodbye\tTot ziens\n", + "Monitor\tComputerscherm\n", + "Outside world\tBuitenwereld\n", + ] +) -TEST_INPUT_COL_MISSING = "".join([ - *TEST_INPUT, - "single-col\n", - *TEST_INPUT_SANE -]) +TEST_INPUT_COL_MISSING = "".join([*TEST_INPUT, "single-col\n", *TEST_INPUT_SANE]) -TEST_INPUT_COL_OVERFLOW = "".join([ - *TEST_INPUT, - "triple-col\ttriple-col\ttriple-col\n", - *TEST_INPUT_SANE -]) +TEST_INPUT_COL_OVERFLOW = "".join( + [*TEST_INPUT, "triple-col\ttriple-col\ttriple-col\n", *TEST_INPUT_SANE] +) class TestCol(unittest.TestCase): - def _run(self, args:List[str], input:str) -> Tuple[str,str,int]: - proc = subprocess.Popen(COL_PY + args, - text=True, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - out, err = proc.communicate(input) - proc.stdin.close() - proc.wait() - return out, err, proc.returncode - - def test_reproduce_sane(self): - """Sane input should not be a problem.""" - reproduce = dedent(""" + def _run(self, args: List[str], input: str) -> Tuple[str, str, int]: + proc = subprocess.Popen( + COL_PY + args, + text=True, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, err = proc.communicate(input) + proc.stdin.close() + proc.wait() + return out, err, proc.returncode + + def test_reproduce_sane(self): + """Sane input should not be a problem.""" + reproduce = dedent(""" import sys for line in sys.stdin: sys.stdout.write(line) """) - out, err, retval = self._run(['0', sys.executable, '-u', '-c', reproduce], TEST_INPUT_SANE) - self.assertEqual(out, TEST_INPUT_SANE) - self.assertEqual(err, ''), - self.assertEqual(retval, 0) + out, err, retval = self._run( + ["0", sys.executable, "-u", "-c", reproduce], TEST_INPUT_SANE + ) + self.assertEqual(out, TEST_INPUT_SANE) + (self.assertEqual(err, ""),) + self.assertEqual(retval, 0) - def test_reproduce_streaming(self): - """Test that subprocess that reads one line, writes one line, works""" - reproduce = dedent(""" + def test_reproduce_streaming(self): + """Test that subprocess that reads one line, writes one line, works""" + reproduce = dedent(""" import sys for line in sys.stdin: sys.stdout.write(line) """) - out, err, retval = self._run(['0', sys.executable, '-u', '-c', reproduce], TEST_INPUT) - self.assertEqual(out, TEST_INPUT) - self.assertEqual(err, ''), - self.assertEqual(retval, 0) + out, err, retval = self._run( + ["0", sys.executable, "-u", "-c", reproduce], TEST_INPUT + ) + self.assertEqual(out, TEST_INPUT) + (self.assertEqual(err, ""),) + self.assertEqual(retval, 0) - def test_reproduce_buffering(self): - """Test that a subprocess that reads the entire input to memory before generating output works.""" - reproduce = dedent(""" + def test_reproduce_buffering(self): + """Test that a subprocess that reads the entire input to memory before generating output works.""" + reproduce = dedent(""" import sys sys.stdout.write(sys.stdin.read()) """) - for colset in ('0', '1', '0,1'): - with self.subTest(colset=colset): - out, err, retval = self._run([colset, sys.executable, '-c', reproduce], TEST_INPUT) - self.assertEqual(out, TEST_INPUT) - self.assertEqual(err, ''), - self.assertEqual(retval, 0) - - def test_overproduce(self): - """Test that an overproducing program is caught""" - overproduce = dedent(""" + for colset in ("0", "1", "0,1"): + with self.subTest(colset=colset): + out, err, retval = self._run( + [colset, sys.executable, "-c", reproduce], TEST_INPUT + ) + self.assertEqual(out, TEST_INPUT) + (self.assertEqual(err, ""),) + self.assertEqual(retval, 0) + + def test_overproduce(self): + """Test that an overproducing program is caught""" + overproduce = dedent(""" import sys for line in sys.stdin: sys.stdout.write(line) sys.stdout.write(line) """) - out, err, retval = self._run(['0', sys.executable, '-c', overproduce], TEST_INPUT) - self.assertIn('subprocess produced more lines of output than it was given', err) - self.assertNotEqual(retval, 0) + out, err, retval = self._run( + ["0", sys.executable, "-c", overproduce], TEST_INPUT + ) + self.assertIn("subprocess produced more lines of output than it was given", err) + self.assertNotEqual(retval, 0) - def test_underproduce(self): - """Test that an underproducing program is caught""" - underproduce = dedent(""" + def test_underproduce(self): + """Test that an underproducing program is caught""" + underproduce = dedent(""" import sys for n, line in enumerate(sys.stdin): if n % 2 == 0: sys.stdout.write(line) """) - out, err, retval = self._run(['0', sys.executable, '-c', underproduce], TEST_INPUT) - self.assertIn('subprocess produced fewer lines than it was given', err) - self.assertNotEqual(retval, 0) + out, err, retval = self._run( + ["0", sys.executable, "-c", underproduce], TEST_INPUT + ) + self.assertIn("subprocess produced fewer lines than it was given", err) + self.assertNotEqual(retval, 0) - def test_error_incorrect_subprocess(self): - """Test that an unclean exit from a subprocess is caught.""" - underproduce = dedent(""" + def test_error_incorrect_subprocess(self): + """Test that an unclean exit from a subprocess is caught.""" + underproduce = dedent(""" import sys sys.exit(42) """) - out, err, retval = self._run(['0', sys.executable, '-c', underproduce], TEST_INPUT) - self.assertEqual(retval, 42) + out, err, retval = self._run( + ["0", sys.executable, "-c", underproduce], TEST_INPUT + ) + self.assertEqual(retval, 42) - def test_error_correct_subprocess(self): - """Test that an unclean exit from a subprocess is caught even if the output looks sane.""" - underproduce = dedent(""" + def test_error_correct_subprocess(self): + """Test that an unclean exit from a subprocess is caught even if the output looks sane.""" + underproduce = dedent(""" import sys for line in sys.stdin: sys.stdout.write(line) sys.exit(42) """) - out, err, retval = self._run(['0', sys.executable, '-c', underproduce], TEST_INPUT) - self.assertEqual(retval, 42) - self.assertIn('subprocess exited with status code 42', err) + out, err, retval = self._run( + ["0", sys.executable, "-c", underproduce], TEST_INPUT + ) + self.assertEqual(retval, 42) + self.assertIn("subprocess exited with status code 42", err) - def test_error_col_missing(self): - """A missing column in the input should raise an error""" - reproduce = dedent(""" + def test_error_col_missing(self): + """A missing column in the input should raise an error""" + reproduce = dedent(""" import sys for line in sys.stdin: sys.stdout.write(line) """) - out, err, retval = self._run(['1', sys.executable, '-u', '-c', reproduce], TEST_INPUT_COL_MISSING) - self.assertEqual(retval, 1) - self.assertIn('line contains a different number of fields', err) - - def test_error_col_overflow(self): - """A line with too many columns should raise an error""" - reproduce = dedent(""" + out, err, retval = self._run( + ["1", sys.executable, "-u", "-c", reproduce], TEST_INPUT_COL_MISSING + ) + self.assertEqual(retval, 1) + self.assertIn("line contains a different number of fields", err) + + def test_error_col_overflow(self): + """A line with too many columns should raise an error""" + reproduce = dedent(""" import sys for line in sys.stdin: sys.stdout.write(line) """) - out, err, retval = self._run(['1', sys.executable, '-u', '-c', reproduce], TEST_INPUT_COL_OVERFLOW) - self.assertEqual(retval, 1) - self.assertIn('line contains a different number of fields', err) - + out, err, retval = self._run( + ["1", sys.executable, "-u", "-c", reproduce], TEST_INPUT_COL_OVERFLOW + ) + self.assertEqual(retval, 1) + self.assertIn("line contains a different number of fields", err) diff --git a/utils/dedup/hash-seg.py b/utils/dedup/hash-seg.py index 255b325..4482c43 100755 --- a/utils/dedup/hash-seg.py +++ b/utils/dedup/hash-seg.py @@ -6,18 +6,19 @@ import sys parser = ArgumentParser() -parser.add_argument('-a', '--aggressive', action='store_true', default=False) +parser.add_argument("-a", "--aggressive", action="store_true", default=False) args = parser.parse_args() # Translate table to remove non alphabetic characters -tbl = [chr(i) for i in range(sys.maxunicode) if not cat(chr(i)).startswith('L')] -remove_non_alpha = str.maketrans('', '', ''.join(tbl)) +tbl = [chr(i) for i in range(sys.maxunicode) if not cat(chr(i)).startswith("L")] +remove_non_alpha = str.maketrans("", "", "".join(tbl)) + def main(): shashes, thashes = set(), set() for line in sys.stdin: - sline = line.rstrip('\n') - parts = sline.split('\t') + sline = line.rstrip("\n") + parts = sline.split("\t") src = parts[0] trg = parts[1] diff --git a/utils/dedup/superdedup.py b/utils/dedup/superdedup.py index 333eb30..91bf7f7 100755 --- a/utils/dedup/superdedup.py +++ b/utils/dedup/superdedup.py @@ -3,17 +3,18 @@ import os import pickle + def main(): shashes, thashes = set(), set() # Try to old existing hashes - if os.path.isfile('shashes.pickle'): - with open('shashes.pickle', 'rb') as f: + if os.path.isfile("shashes.pickle"): + with open("shashes.pickle", "rb") as f: shashes = pickle.load(f) - if os.path.isfile('thashes.pickle'): - with open('thashes.pickle', 'rb') as f: + if os.path.isfile("thashes.pickle"): + with open("thashes.pickle", "rb") as f: thashes = pickle.load(f) for line in sys.stdin: - parts = line.rstrip("\n").split('\t') + parts = line.rstrip("\n").split("\t") src_hash = parts[2] trg_hash = parts[3] @@ -23,10 +24,11 @@ def main(): shashes.add(src_hash) thashes.add(trg_hash) # Write a list of seen hashes - with open('shashes.pickle','wb') as f: - pickle.dump(shashes, f) - with open('thashes.pickle','wb') as f: - pickle.dump(thashes, f) + with open("shashes.pickle", "wb") as f: + pickle.dump(shashes, f) + with open("thashes.pickle", "wb") as f: + pickle.dump(thashes, f) + if __name__ == "__main__": main()