diff --git a/guidance/models/_azure_openai.py b/guidance/models/_azure_openai.py index 9a5420e51..3493d9d7f 100644 --- a/guidance/models/_azure_openai.py +++ b/guidance/models/_azure_openai.py @@ -66,14 +66,14 @@ def __init__( if api_key is None and azure_ad_token_provider is None: raise ValueError("Please provide either api_key or azure_ad_token_provider") - + parsed_url = urlparse(azure_endpoint) if azure_deployment is None: parts = pathlib.Path(parsed_url.path).parts if len(parts) > 2: azure_deployment = parts[3] - + parsed_query = parse_qs(parsed_url.query) api_version = ( version @@ -103,3 +103,6 @@ def __init__( engine_instance, echo=echo, ) + + def set_logit_bias(self, logit_bias): + self.engine.set_logit_bias(logit_bias) diff --git a/guidance/models/_openai.py b/guidance/models/_openai.py index 85293c26a..4f2c44761 100644 --- a/guidance/models/_openai.py +++ b/guidance/models/_openai.py @@ -29,6 +29,7 @@ def __init__( client_class=client_class, **kwargs, ): + self.logit_bias = None if client_class is None: raise Exception( @@ -67,6 +68,7 @@ def _generator_completion(self, prompt, temperature): top_p=1.0, # TODO: this should be controllable like temp (from the grammar) temperature=temperature, stream=True, + logit_bias=self.logit_bias, ) self.metrics.engine_input_tokens += len(self.tokenizer(prompt_decoded)) except Exception as e: @@ -139,6 +141,7 @@ def _generator_chat(self, prompt, temperature): top_p=1.0, # TODO: this should be controllable like temp (from the grammar) temperature=temperature, stream=True, + logit_bias=self.logit_bias, ) self.metrics.engine_input_tokens += input_token_count @@ -162,6 +165,9 @@ def _generator(self, prompt, temperature): # Otherwise we are in a chat context return self._generator_chat(prompt, temperature) + def set_logit_bias(self, logit_bias): + self.logit_bias = logit_bias + class OpenAI(Grammarless): def __init__( diff --git a/notebooks/gpt4_select.ipynb b/notebooks/gpt4_select.ipynb new file mode 100644 index 000000000..a172cf764 --- /dev/null +++ b/notebooks/gpt4_select.ipynb @@ -0,0 +1,142 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from dotenv import load_dotenv\n", + "load_dotenv()\n", + "\n", + "model = os.getenv(\"AZUREAI_CHAT_MODEL\")\n", + "azure_endpoint = os.getenv(\"AZUREAI_CHAT_BASE_ENDPOINT\")\n", + "azure_deployment = os.getenv(\"AZUREAI_CHAT_DEPLOYMENT\")\n", + "azure_api_version = os.getenv(\"AZUREAI_CHAT_API_VERSION\")\n", + "api_key = os.getenv(\"AZUREAI_CHAT_KEY\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from guidance import models, gen\n", + "from guidance import user, assistant\n", + "import tiktoken\n", + "\n", + "def select(azureai_model, choices):\n", + " tokenizer = tiktoken.encoding_for_model(model)\n", + " encoded_choices = tokenizer.encode_batch(choices)\n", + " encoded_choices_tokens = [token for choice in encoded_choices for token in choice]\n", + " logit_bias = {str(token): +98 for token in encoded_choices_tokens}\n", + " logit_bias[\"100257\"] = 100 # \"<|endoftext|>\"\n", + " azureai_model.set_logit_bias(logit_bias)\n", + "\n", + " max_tokens = max(len(choice) for choice in encoded_choices)\n", + " return max_tokens\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
user
Which food does a rabbit prefer?
assistant
carrots
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "azureai_model = models.AzureOpenAI(\n", + " model=model,\n", + " azure_endpoint=azure_endpoint,\n", + " azure_deployment=azure_deployment,\n", + " version=azure_api_version,\n", + " api_key=api_key,\n", + " compute_log_probs=True,\n", + ")\n", + "\n", + "choices = [\"apples\", \"potatos\", \"I don't know\", \"hay\", \"carrots\", \"lettuce\", \"grass\"]\n", + "max_tokens = select(azureai_model, choices)\n", + "\n", + "with user():\n", + " lm = azureai_model + \"Which food does a rabbit prefer?\"\n", + "\n", + "with assistant():\n", + " lm += gen(\"response\", max_tokens=max_tokens)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
user
What is the meaning of life?
assistant
42
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "azureai_model = models.AzureOpenAI(\n", + " model=model,\n", + " azure_endpoint=azure_endpoint,\n", + " azure_deployment=azure_deployment,\n", + " version=azure_api_version,\n", + " api_key=api_key,\n", + " compute_log_probs=True,\n", + ")\n", + "\n", + "choices = [\"beauty\", \"complicated\", \"love\", \"Love\", \"growth\", \"purpose\", \"42\", \"connection\", \"I don't know\"]\n", + "max_tokens = select(azureai_model, choices)\n", + "\n", + "with user():\n", + " lm = azureai_model + \"What is the meaning of life?\"\n", + "\n", + "with assistant():\n", + " lm += gen(\"response\", max_tokens=max_tokens)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "guidance", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}