From 1f93bff261048c09cb8d23fb89ffb66cf08495b2 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Sat, 16 Nov 2024 15:40:09 -0800 Subject: [PATCH] Make sure threading fails without ParallelExecutor fails (if settings are accessed) (#1810) * Make sure threading fails if it doesn't use ParallelExecutor * Retain DEFAULT_CONFIG --- dsp/utils/settings.py | 21 ++++++++++----------- dspy/utils/parallelizer.py | 8 ++++++-- tests/conftest.py | 5 +++-- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/dsp/utils/settings.py b/dsp/utils/settings.py index 621d4fb73..76943784a 100644 --- a/dsp/utils/settings.py +++ b/dsp/utils/settings.py @@ -1,7 +1,7 @@ +import copy import threading -from contextlib import contextmanager -from copy import deepcopy +from contextlib import contextmanager from dsp.utils.utils import dotdict DEFAULT_CONFIG = dotdict( @@ -49,17 +49,16 @@ def __new__(cls): # TODO: remove first-class support for re-ranker and potentially combine with RM to form a pipeline of sorts # eg: RetrieveThenRerankPipeline(RetrievalModel, Reranker) # downstream operations like dsp.retrieve would use configs from the defined pipeline. - - # make a deepcopy of the default config to avoid modifying the default config - cls._instance.__append(deepcopy(DEFAULT_CONFIG)) + config = copy.deepcopy(DEFAULT_CONFIG) + cls._instance.__append(config) return cls._instance @property def config(self): thread_id = threading.get_ident() - if thread_id not in self.stack_by_thread: - self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()] + # if thread_id not in self.stack_by_thread: + # self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()] return self.stack_by_thread[thread_id][-1] def __getattr__(self, name): @@ -73,14 +72,14 @@ def __getattr__(self, name): def __append(self, config): thread_id = threading.get_ident() - if thread_id not in self.stack_by_thread: - self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()] + # if thread_id not in self.stack_by_thread: + # self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()] self.stack_by_thread[thread_id].append(config) def __pop(self): thread_id = threading.get_ident() - if thread_id in self.stack_by_thread: - self.stack_by_thread[thread_id].pop() + # if thread_id in self.stack_by_thread: + self.stack_by_thread[thread_id].pop() def configure(self, inherit_config: bool = True, **kwargs): """Set configuration settings. diff --git a/dspy/utils/parallelizer.py b/dspy/utils/parallelizer.py index 0437b0031..27983632b 100644 --- a/dspy/utils/parallelizer.py +++ b/dspy/utils/parallelizer.py @@ -47,16 +47,20 @@ def _wrap_function(self, function): def wrapped(item, parent_id=None): thread_stacks = dspy.settings.stack_by_thread current_thread_id = threading.get_ident() - creating_new_thread = current_thread_id not in thread_stacks + assert creating_new_thread or threading.get_ident() == dspy.settings.main_tid + if creating_new_thread: - # If we have a parent thread ID, copy its stack + # If we have a parent thread ID, copy its stack. TODO: Should the caller just pass a copy of the stack? if parent_id and parent_id in thread_stacks: thread_stacks[current_thread_id] = list(thread_stacks[parent_id]) else: thread_stacks[current_thread_id] = list(dspy.settings.main_stack) + # TODO: Consider the behavior below. + # import copy; thread_stacks[current_thread_id].append(copy.deepcopy(thread_stacks[current_thread_id][-1])) + try: return function(item) except Exception as e: diff --git a/tests/conftest.py b/tests/conftest.py index 9470a79c5..ab90e8236 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import pytest - import dspy +import copy + from dsp.utils.settings import DEFAULT_CONFIG @@ -10,7 +11,7 @@ def clear_settings(): yield - dspy.settings.configure(**DEFAULT_CONFIG, inherit_config=False) + dspy.settings.configure(**copy.deepcopy(DEFAULT_CONFIG), inherit_config=False) @pytest.fixture