diff --git a/guidance/models/_azure_openai.py b/guidance/models/_azure_openai.py index 9a5420e51..3493d9d7f 100644 --- a/guidance/models/_azure_openai.py +++ b/guidance/models/_azure_openai.py @@ -66,14 +66,14 @@ def __init__( if api_key is None and azure_ad_token_provider is None: raise ValueError("Please provide either api_key or azure_ad_token_provider") - + parsed_url = urlparse(azure_endpoint) if azure_deployment is None: parts = pathlib.Path(parsed_url.path).parts if len(parts) > 2: azure_deployment = parts[3] - + parsed_query = parse_qs(parsed_url.query) api_version = ( version @@ -103,3 +103,6 @@ def __init__( engine_instance, echo=echo, ) + + def set_logit_bias(self, logit_bias): + self.engine.set_logit_bias(logit_bias) diff --git a/guidance/models/_openai.py b/guidance/models/_openai.py index 85293c26a..4f2c44761 100644 --- a/guidance/models/_openai.py +++ b/guidance/models/_openai.py @@ -29,6 +29,7 @@ def __init__( client_class=client_class, **kwargs, ): + self.logit_bias = None if client_class is None: raise Exception( @@ -67,6 +68,7 @@ def _generator_completion(self, prompt, temperature): top_p=1.0, # TODO: this should be controllable like temp (from the grammar) temperature=temperature, stream=True, + logit_bias=self.logit_bias, ) self.metrics.engine_input_tokens += len(self.tokenizer(prompt_decoded)) except Exception as e: @@ -139,6 +141,7 @@ def _generator_chat(self, prompt, temperature): top_p=1.0, # TODO: this should be controllable like temp (from the grammar) temperature=temperature, stream=True, + logit_bias=self.logit_bias, ) self.metrics.engine_input_tokens += input_token_count @@ -162,6 +165,9 @@ def _generator(self, prompt, temperature): # Otherwise we are in a chat context return self._generator_chat(prompt, temperature) + def set_logit_bias(self, logit_bias): + self.logit_bias = logit_bias + class OpenAI(Grammarless): def __init__( diff --git a/notebooks/gpt4_select.ipynb b/notebooks/gpt4_select.ipynb new file mode 100644 index 000000000..a172cf764 --- /dev/null +++ b/notebooks/gpt4_select.ipynb @@ -0,0 +1,142 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from dotenv import load_dotenv\n", + "load_dotenv()\n", + "\n", + "model = os.getenv(\"AZUREAI_CHAT_MODEL\")\n", + "azure_endpoint = os.getenv(\"AZUREAI_CHAT_BASE_ENDPOINT\")\n", + "azure_deployment = os.getenv(\"AZUREAI_CHAT_DEPLOYMENT\")\n", + "azure_api_version = os.getenv(\"AZUREAI_CHAT_API_VERSION\")\n", + "api_key = os.getenv(\"AZUREAI_CHAT_KEY\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from guidance import models, gen\n", + "from guidance import user, assistant\n", + "import tiktoken\n", + "\n", + "def select(azureai_model, choices):\n", + " tokenizer = tiktoken.encoding_for_model(model)\n", + " encoded_choices = tokenizer.encode_batch(choices)\n", + " encoded_choices_tokens = [token for choice in encoded_choices for token in choice]\n", + " logit_bias = {str(token): +98 for token in encoded_choices_tokens}\n", + " logit_bias[\"100257\"] = 100 # \"<|endoftext|>\"\n", + " azureai_model.set_logit_bias(logit_bias)\n", + "\n", + " max_tokens = max(len(choice) for choice in encoded_choices)\n", + " return max_tokens\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "userWhich food does a rabbit prefer?assistantcarrots
" + ], + "text/plain": [ + "userWhat is the meaning of life?assistant42