Skip to content

Commit

Permalink
Make sure threading fails without ParallelExecutor fails (if settings…
Browse files Browse the repository at this point in the history
… are accessed) (#1810)

* Make sure threading fails if it doesn't use ParallelExecutor

* Retain DEFAULT_CONFIG
  • Loading branch information
okhat authored Nov 16, 2024
1 parent 3d36efb commit 1f93bff
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
21 changes: 10 additions & 11 deletions dsp/utils/settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import threading
from contextlib import contextmanager
from copy import deepcopy

from contextlib import contextmanager
from dsp.utils.utils import dotdict

DEFAULT_CONFIG = dotdict(
Expand Down Expand Up @@ -49,17 +49,16 @@ def __new__(cls):
# TODO: remove first-class support for re-ranker and potentially combine with RM to form a pipeline of sorts
# eg: RetrieveThenRerankPipeline(RetrievalModel, Reranker)
# downstream operations like dsp.retrieve would use configs from the defined pipeline.

# make a deepcopy of the default config to avoid modifying the default config
cls._instance.__append(deepcopy(DEFAULT_CONFIG))
config = copy.deepcopy(DEFAULT_CONFIG)
cls._instance.__append(config)

return cls._instance

@property
def config(self):
thread_id = threading.get_ident()
if thread_id not in self.stack_by_thread:
self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()]
# if thread_id not in self.stack_by_thread:
# self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()]
return self.stack_by_thread[thread_id][-1]

def __getattr__(self, name):
Expand All @@ -73,14 +72,14 @@ def __getattr__(self, name):

def __append(self, config):
thread_id = threading.get_ident()
if thread_id not in self.stack_by_thread:
self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()]
# if thread_id not in self.stack_by_thread:
# self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()]
self.stack_by_thread[thread_id].append(config)

def __pop(self):
thread_id = threading.get_ident()
if thread_id in self.stack_by_thread:
self.stack_by_thread[thread_id].pop()
# if thread_id in self.stack_by_thread:
self.stack_by_thread[thread_id].pop()

def configure(self, inherit_config: bool = True, **kwargs):
"""Set configuration settings.
Expand Down
8 changes: 6 additions & 2 deletions dspy/utils/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,20 @@ def _wrap_function(self, function):
def wrapped(item, parent_id=None):
thread_stacks = dspy.settings.stack_by_thread
current_thread_id = threading.get_ident()

creating_new_thread = current_thread_id not in thread_stacks

assert creating_new_thread or threading.get_ident() == dspy.settings.main_tid

if creating_new_thread:
# If we have a parent thread ID, copy its stack
# If we have a parent thread ID, copy its stack. TODO: Should the caller just pass a copy of the stack?
if parent_id and parent_id in thread_stacks:
thread_stacks[current_thread_id] = list(thread_stacks[parent_id])
else:
thread_stacks[current_thread_id] = list(dspy.settings.main_stack)

# TODO: Consider the behavior below.
# import copy; thread_stacks[current_thread_id].append(copy.deepcopy(thread_stacks[current_thread_id][-1]))

try:
return function(item)
except Exception as e:
Expand Down
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

import dspy
import copy

from dsp.utils.settings import DEFAULT_CONFIG


Expand All @@ -10,7 +11,7 @@ def clear_settings():

yield

dspy.settings.configure(**DEFAULT_CONFIG, inherit_config=False)
dspy.settings.configure(**copy.deepcopy(DEFAULT_CONFIG), inherit_config=False)


@pytest.fixture
Expand Down

0 comments on commit 1f93bff

Please sign in to comment.