Skip to content

Commit

Permalink
feat: added azure model name matching to find_model
Browse files Browse the repository at this point in the history
  • Loading branch information
NP4567-dev committed Jul 30, 2024
1 parent 347661e commit 6953c34
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 12 deletions.
2 changes: 1 addition & 1 deletion docs/providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
| Hugging Face Hub | `huggingface-hub` | [Guide for Hugging Face Hub :octicons-link-16:](tutorial/providers/huggingface_hub.md) |
| LiteLLM | `litellm` | [Guide for LiteLLM :octicons-link-16:](tutorial/providers/litellm.md) |
| Mistral AI | `mistralai` | [Guide for Mistral AI :octicons-link-16:](tutorial/providers/mistralai.md) |
| OpenAI | `openai` | [Guide for OpenAI :octicons-link-16:](tutorial/providers/openai.md) |
| OpenAI | `openai` | [Guide for OpenAI (including Azure) :octicons-link-16:](tutorial/providers/openai.md) |


## Chat Completions
Expand Down
28 changes: 28 additions & 0 deletions docs/tutorial/providers/openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,34 @@ Integrating EcoLogits with your applications does not alter the standard outputs

asyncio.run(main())
```
### Azure OpenAI example

Under the hood it is the same function that is called by the Azure OpenAI client. Hence the impacts attribute will automatically be added to the response object.

=== "Azure"
```python
from ecologits import EcoLogits
from openai import AzureOpenAI

# Initialize EcoLogits
EcoLogits.init()

client = AzureOpenAI(
azure_endpoint= "http://myazureendpoint",
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("OPENAI_API_VERSION"),
)

response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": "Tell me a funny joke!"}
]
)

# Get estimated environmental impacts of the inference
print(response.impacts)
```

### Streaming example

Expand Down
15 changes: 4 additions & 11 deletions ecologits/model_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@ class Model:


class ModelRepository:

def __init__(self, models: list[Model]) -> None:
self.__models = models

def find_model(self, provider: str, model_name: str) -> Optional[Model]:
for model in self.__models:
# To handle specific LiteLLM calling (e.g., mistral/mistral-small)
if model.provider == provider and model.name in model_name:
if model.provider == provider and (model.name in model_name or model.name.replace(".", "") == model_name):
return model
return None

Expand All @@ -51,28 +50,22 @@ def find_provider(self, model_name: str) -> Optional[str]:
@classmethod
def from_csv(cls, filepath: Optional[str] = None) -> "ModelRepository":
if filepath is None:
filepath = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "data", "models.csv"
)
filepath = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "models.csv")
models = []
with open(filepath) as fd:
csv = DictReader(fd)
for row in csv:
total_parameters = None
total_parameters_range = None
if ";" in row["total_parameters"]:
total_parameters_range = [
float(p) for p in row["total_parameters"].split(";")
]
total_parameters_range = [float(p) for p in row["total_parameters"].split(";")]
elif row["total_parameters"] != "":
total_parameters = float(row["total_parameters"])

active_parameters = None
active_parameters_range = None
if ";" in row["active_parameters"]:
active_parameters_range = [
float(p) for p in row["active_parameters"].split(";")
]
active_parameters_range = [float(p) for p in row["active_parameters"].split(";")]
elif row["active_parameters"] != "":
active_parameters = float(row["active_parameters"])

Expand Down
4 changes: 4 additions & 0 deletions tests/test_model_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def test_create_model_repository_from_scratch():
])
assert models.find_model(provider="provider-test", model_name="model-test")

def test_find_azure_openai_model():
models = ModelRepository.from_csv()
assert models.find_model(provider="openai", model_name="gpt-35-turbo").name =="gpt-3.5-turbo"


def test_find_unknown_provider():
models = ModelRepository.from_csv()
Expand Down

0 comments on commit 6953c34

Please sign in to comment.