Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add api, distributed setup #135

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
97455fc
add pubsub
alex-dixon Aug 12, 2024
a85bd64
Merge branch 'main' into pubsub
alex-dixon Aug 12, 2024
03b8d6a
remove unused imports
alex-dixon Aug 12, 2024
b228fb0
slight cleanup
alex-dixon Aug 12, 2024
9daa336
commit latest
alex-dixon Aug 14, 2024
5ed6825
add write invocation api
alex-dixon Aug 15, 2024
6f92642
avoid union none
alex-dixon Aug 15, 2024
7f27826
optional instead of union none
alex-dixon Aug 16, 2024
94f0a69
replace store with client
alex-dixon Aug 18, 2024
2e2a159
fixes, some local + api mode qa
alex-dixon Aug 18, 2024
a78a7d6
dockerize
alex-dixon Aug 18, 2024
9064f3a
minor cleanup
alex-dixon Aug 19, 2024
e4f1525
fix write trace race
alex-dixon Aug 23, 2024
1fa16d7
Merge branch 'MadcowD:main' into pubsub
alex-dixon Aug 23, 2024
6a07b6c
build studio with latest stable node
alex-dixon Aug 23, 2024
1133e40
checkpoint toward distributed multimodal
alex-dixon Sep 7, 2024
882960f
tests passing
alex-dixon Sep 7, 2024
7c84f6d
try fix string formatting
alex-dixon Sep 7, 2024
4a64fbb
more string formatting
alex-dixon Sep 7, 2024
0be7110
mon dieu
alex-dixon Sep 7, 2024
6eac3fa
more fstrings
alex-dixon Sep 7, 2024
fd3f47e
working with studio in docker compose
alex-dixon Sep 7, 2024
71e9fd6
uses is a list
alex-dixon Sep 9, 2024
fe7b0d2
Merge branch 'main' into pubsub
alex-dixon Sep 9, 2024
3c1ed4b
no _ start of Field name
alex-dixon Sep 9, 2024
59b960f
default openai client
alex-dixon Sep 9, 2024
30658be
Merge branch 'main' into pubsub
alex-dixon Sep 18, 2024
99ec47b
poetry lock
alex-dixon Sep 18, 2024
ceeadd2
fix tests
alex-dixon Sep 18, 2024
c47a107
Merge branch 'main' into pubsub
alex-dixon Sep 21, 2024
e7c7710
refactor
alex-dixon Sep 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12.2
2 changes: 1 addition & 1 deletion examples/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_html_of_url(url: str) -> str:
get_html_of_url
]

@ell.l(model="gpt-4o", temperature=0.1)
@ell.lm(model="gpt-4o", temperature=0.1)
def tool_user(task: str) -> List[Any]:
return [
ell.system(
Expand Down
35 changes: 32 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ include = [
]

[tool.poetry.dependencies]
python = ">=3.9"
python = ">=3.9,<4"
fastapi = "^0.111.1"
numpy = "^2.0.1"
dill = "^0.3.8"
Expand All @@ -39,6 +39,7 @@ typing-extensions = "^4.12.2"

black = "^24.8.0"
psycopg2 = "^2.9.9"
aiomqtt = "^2.3.0"
[tool.poetry.group.dev.dependencies]
pytest = "^8.3.2"

Expand Down
10 changes: 5 additions & 5 deletions src/ell/__version__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
try:
from importlib.metadata import version
except ImportError:
from importlib_metadata import version
from importlib.metadata import version, PackageNotFoundError

__version__ = version("ell")
try:
__version__ = version("ell")
except PackageNotFoundError:
__version__ = "unknown"
4 changes: 1 addition & 3 deletions src/ell/store.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from datetime import datetime
from typing import Any, Optional, Dict, List, Set, Union
from ell.lstr import lstr
from typing import Any, Optional, Dict, List, Set
from ell.types import InvocableLM, SerializedLMP, Invocation, SerializedLStr


Expand Down
37 changes: 29 additions & 8 deletions src/ell/stores/sql.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from datetime import datetime, timedelta
import json
import asyncio
import ell.store
import os
from typing import Any, Optional, Dict, List, Set, Union
from ell.studio.pubsub import PubSub
from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLStr
from sqlalchemy import func, and_
from sqlalchemy.sql import text
from sqlmodel import Session, SQLModel, create_engine, select
import ell.store
from typing import Any, Optional, Dict, List, Set
from datetime import datetime, timedelta
import cattrs
import numpy as np
from sqlalchemy.sql import text
from ell.types import InvocationTrace, SerializedLMP, Invocation, SerializedLMPUses, SerializedLStr, utc_now
from ell.lstr import lstr
from sqlalchemy import or_, func, and_, extract, case

class SQLStore(ell.store.Store):
def __init__(self, db_uri: str):
Expand Down Expand Up @@ -237,7 +240,7 @@ def get_all_traces_leading_to(self, session: Session, invocation_id: str) -> Lis

# Convert the dictionary values back to a list
return list(unique_traces.values())

def get_invocations_aggregate(self, session: Session, lmp_filters: Dict[str, Any] = None, filters: Dict[str, Any] = None, days: int = 30) -> Dict[str, Any]:
# Calculate the start date for the graph data
start_date = datetime.utcnow() - timedelta(days=days)
Expand All @@ -255,7 +258,7 @@ def get_invocations_aggregate(self, session: Session, lmp_filters: Dict[str, Any
if filters:
base_subquery = base_subquery.filter(and_(*[getattr(Invocation, k) == v for k, v in filters.items()]))


data = session.exec(base_subquery).all()

# Calculate aggregate metrics
Expand Down Expand Up @@ -291,4 +294,22 @@ def __init__(self, storage_dir: str):
class PostgresStore(SQLStore):
def __init__(self, db_uri: str):
super().__init__(db_uri)




class SQLStorePublisher(SQLStore):
def __init__(self, db_uri: str, pubsub: PubSub):
self.pubsub = pubsub
super().__init__(db_uri)

def write_lmp(self, serialized_lmp: SerializedLMP, uses: Dict[str, Any]) -> Optional[Any]:
super().write_lmp(serialized_lmp, uses)
# todo. return result from write lmp so we can check if it was created or alredy exists
asyncio.create_task(self.pubsub.publish(f"lmp/{serialized_lmp.lmp_id}/created", serialized_lmp))
return None

def write_invocation(self, invocation: Invocation, results: List[SerializedLStr], consumes: Set[str]) -> Optional[Any]:
super().write_invocation(invocation, results, consumes)
asyncio.create_task(self.pubsub.publish(f"lmp/{invocation.lmp_id}/invoked", invocation))
return None

25 changes: 19 additions & 6 deletions src/ell/studio/__main__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,36 @@
import asyncio
import os
from fastapi import FastAPI
import uvicorn
from argparse import ArgumentParser
from ell.studio.config import Config
from ell.studio.logger import setup_logging
from ell.studio.server import create_app
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from watchfiles import awatch
import time


def main():
setup_logging()
parser = ArgumentParser(description="ELL Studio Data Server")
parser.add_argument("--storage-dir" , default=None,
help="Directory for filesystem serializer storage (default: current directory)")
parser.add_argument("--pg-connection-string", default=None,
help="PostgreSQL connection string (default: None)")
parser.add_argument("--mqtt-connection-string", default=None,
help="MQTT connection string (default: None)")
parser.add_argument("--host", default="127.0.0.1", help="Host to run the server on")
parser.add_argument("--port", type=int, default=8080, help="Port to run the server on")
parser.add_argument("--dev", action="store_true", help="Run in development mode")
args = parser.parse_args()

config = Config.create(storage_dir=args.storage_dir,
pg_connection_string=args.pg_connection_string)
config = Config(
storage_dir=args.storage_dir,
pg_connection_string=args.pg_connection_string,
mqtt_connection_string=args.mqtt_connection_string
)

app = create_app(config)

if not args.dev:
Expand All @@ -36,7 +44,7 @@ async def serve_react_app(full_path: str):

db_path = os.path.join(args.storage_dir, "ell.db")

async def db_watcher(db_path, app):
async def db_watcher(db_path: str, app: FastAPI):
last_stat = None

while True:
Expand Down Expand Up @@ -76,8 +84,13 @@ async def db_watcher(db_path, app):

config = uvicorn.Config(app=app, port=args.port, loop=loop)
server = uvicorn.Server(config)
loop.create_task(server.serve())
loop.create_task(db_watcher(db_path, app))

tasks = []
tasks.append(loop.create_task(server.serve()))

if args.storage_dir:
tasks.append(loop.create_task(db_watcher(db_path, app)))

loop.run_forever()

if __name__ == "__main__":
Expand Down
38 changes: 22 additions & 16 deletions src/ell/studio/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import lru_cache
import json
import os
from typing import Optional
from pydantic import BaseModel
Expand All @@ -10,31 +11,36 @@

# todo. maybe we default storage dir and other things in the future to a well-known location
# like ~/.ell or something
@lru_cache
@lru_cache(maxsize=1)
def ell_home() -> str:
return os.path.join(os.path.expanduser("~"), ".ell")


class Config(BaseModel):
pg_connection_string: Optional[str] = None
storage_dir: Optional[str] = None

@classmethod
def create(
cls,
storage_dir: Optional[str] = None,
pg_connection_string: Optional[str] = None,
) -> 'Config':
pg_connection_string = pg_connection_string or os.getenv("ELL_PG_CONNECTION_STRING")
storage_dir = storage_dir or os.getenv("ELL_STORAGE_DIR")

mqtt_connection_string: Optional[str] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)

def model_post_init(self, __context):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should let us do Config() instead of Config.create() and have the same functionality

# Storage
self.pg_connection_string = self.pg_connection_string or os.getenv(
"ELL_PG_CONNECTION_STRING")
self.storage_dir = self.storage_dir or os.getenv("ELL_STORAGE_DIR")
# Enforce that we use either sqlite or postgres, but not both
if pg_connection_string is not None and storage_dir is not None:
if self.pg_connection_string is not None and self.storage_dir is not None:
raise ValueError("Cannot use both sqlite and postgres")

# For now, fall back to sqlite if no PostgreSQL connection string is provided
if pg_connection_string is None and storage_dir is None:
if self.pg_connection_string is None and self.storage_dir is None:
# This intends to honor the default we had set in the CLI
storage_dir = os.getcwd()
# todo. better default?
self.storage_dir = os.getcwd()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m ending up with SQLite dbs in various places. Maybe we standardize on ell home as a default?


# Pubsub
self.mqtt_connection_string = self.mqtt_connection_string or os.getenv("ELL_MQTT_CONNECTION_STRING")

logger.info(f"Resolved config: {json.dumps(self.model_dump(), indent=2)}")

return cls(pg_connection_string=pg_connection_string, storage_dir=storage_dir)
18 changes: 0 additions & 18 deletions src/ell/studio/connection_manager.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me know if you want to restore this. Broadcast was replicated in the new “pub sub” interfaces. connection moved to the ws handler.

This file was deleted.

34 changes: 34 additions & 0 deletions src/ell/studio/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import logging
from colorama import Fore, Style, init

def setup_logging(level: int = logging.INFO):
alex-dixon marked this conversation as resolved.
Show resolved Hide resolved
# Initialize colorama for cross-platform colored output
init(autoreset=True)

# Create a custom formatter
class ColoredFormatter(logging.Formatter):
FORMATS = {
logging.DEBUG: Fore.CYAN + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL,
logging.INFO: Fore.GREEN + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL,
logging.WARNING: Fore.YELLOW + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL,
logging.ERROR: Fore.RED + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL,
logging.CRITICAL: Fore.RED + Style.BRIGHT + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL
}

def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt, datefmt="%Y-%m-%d %H:%M:%S")
return formatter.format(record)

# Create and configure the logger
logger = logging.getLogger("ell")
logger.setLevel(level)

# Create console handler and set formatter
console_handler = logging.StreamHandler()
console_handler.setFormatter(ColoredFormatter())

# Add the handler to the logger
logger.addHandler(console_handler)

return logger
Loading
Loading