From c9e1a9ce35275a08076628e7a1b67c81abfbed72 Mon Sep 17 00:00:00 2001 From: TKS <32640296+bigsk1@users.noreply.github.com> Date: Tue, 5 Nov 2024 19:32:59 -0800 Subject: [PATCH] add xai --- .env.sample | 10 +++- README.md | 9 ++- app/app.py | 119 ++++++++++++++++++++++++++++++++++++--- app/app_logic.py | 3 + app/static/js/scripts.js | 6 ++ app/templates/index.html | 7 +++ cli.py | 89 ++++++++++++++++++++++++++--- 7 files changed, 221 insertions(+), 22 deletions(-) diff --git a/.env.sample b/.env.sample index 5cf9ab7..d43f64e 100644 --- a/.env.sample +++ b/.env.sample @@ -2,7 +2,7 @@ # Depending on the value of MODEL_PROVIDER, the corresponding service will be used when run. # You can mix and match; use local Ollama with OpenAI speech or use OpenAI model with local XTTS, etc. -# Model Provider: openai or ollama +# Model Provider: openai or ollama or xai MODEL_PROVIDER=ollama # Character to use - Options: alien_scientist, anarchist, bigfoot, chatgpt, clumsyhero, conandoyle, conspiracy, cyberpunk, @@ -32,7 +32,7 @@ XTTS_SPEED=1.2 # OpenAI API Key for models and speech (replace with your actual API key) OPENAI_API_KEY=sk-proj-1111111 # Models to use - OPTIONAL: For screen analysis, if MODEL_PROVIDER is ollama, llava will be used by default. -# Ensure you have llava downloaded with Ollama. If OpenAI is used, gpt-4o-mini works well. +# Ensure you have llava downloaded with Ollama. If OpenAI is used, gpt-4o-mini works well. xai not support yet falls back to openai is xai is selected. OPENAI_MODEL=gpt-4o-mini # Endpoints: @@ -45,6 +45,12 @@ OLLAMA_BASE_URL=http://localhost:11434 # Model to use - llama3 or llama3.1 or 3.2 works well for local usage. In the UI you will have a list of popular models to choose from so the model here is just a starting point OLLAMA_MODEL=llama3 +# XAI Configuration +XAI_MODEL=grok-beta +XAI_API_KEY=your_api_key_here +XAI_BASE_URL=https://api.x.ai/v1 + + # NOTES: # List of trigger phrases to have the model view your desktop (desktop, browser, images, etc.). # It will describe what it sees, and you can ask questions about it: diff --git a/README.md b/README.md index a19df2a..428bca9 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ You can run all locally, you can use openai for chat and voice, you can mix betw ## Features -- **Supports both OpenAI and Ollama language models**: Choose the model that best fits your needs. +- **Supports OpenAI, xAI or Ollama language models**: Choose the model that best fits your needs. - **Provides text-to-speech synthesis using XTTS or OpenAI TTS or ElevenLabs**: Enjoy natural and expressive voices. - **No typing needed, just speak**: Hands-free interaction makes conversations smooth and effortless. - **Analyzes user mood and adjusts AI responses accordingly**: Get personalized responses based on your mood. @@ -28,7 +28,7 @@ You can run all locally, you can use openai for chat and voice, you can mix betw - Python 3.10 - CUDA-enabled GPU -- Ollama models or Openai API for chat +- Ollama models or Openai API or xAI for chat - Local XTTS or Openai API or ElevenLabs API for speech - Microsoft C++ Build Tools on windows - Microphone @@ -201,6 +201,11 @@ OLLAMA_BASE_URL=http://localhost:11434 # Models to use - llama3 works well for local usage. OLLAMA_MODEL=llama3 +# xAI Configuration +XAI_MODEL=grok-beta +XAI_API_KEY=your_api_key_here +XAI_BASE_URL=https://api.x.ai/v1 + # NOTES: # List of trigger phrases to have the model view your desktop (desktop, browser, images, etc.). # It will describe what it sees, and you can ask questions about it: diff --git a/app/app.py b/app/app.py index 19e2f7d..06a3155 100644 --- a/app/app.py +++ b/app/app.py @@ -34,6 +34,9 @@ OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') OPENAI_MODEL = os.getenv('OPENAI_MODEL') OPENAI_BASE_URL = os.getenv('OPENAI_BASE_URL') +XAI_API_KEY = os.getenv('XAI_API_KEY') +XAI_MODEL = os.getenv('XAI_MODEL') +XAI_BASE_URL = os.getenv('XAI_BASE_URL') OLLAMA_MODEL = os.getenv('OLLAMA_MODEL') OLLAMA_BASE_URL = os.getenv('OLLAMA_BASE_URL') ELEVENLABS_API_KEY = os.getenv('ELEVENLABS_API_KEY') @@ -85,6 +88,11 @@ def init_openai_model(model_name): global OPENAI_MODEL OPENAI_MODEL = model_name print(f"Switched to OpenAI model: {model_name}") + +def init_xai_model(model_name): + global XAI_MODEL + XAI_MODEL = model_name + print(f"Switched to XAI model: {model_name}") def init_openai_tts_voice(voice_name): global OPENAI_TTS_VOICE @@ -175,7 +183,7 @@ def sync_play_audio(file_path): print(f"Using device: {device}") print(f"Model provider: {MODEL_PROVIDER}") -print(f"Model: {OPENAI_MODEL if MODEL_PROVIDER == 'openai' else OLLAMA_MODEL}") +print(f"Model: {OPENAI_MODEL if MODEL_PROVIDER == 'openai' else XAI_MODEL if MODEL_PROVIDER == 'xai' else OLLAMA_MODEL}") print(f"Character: {character_display_name}") print(f"Text-to-Speech provider: {TTS_PROVIDER}") print("To stop chatting say Quit, Leave or Exit. Say, what's on my screen, to have AI view screen. One moment please loading...") @@ -441,6 +449,50 @@ def chatgpt_streamed(user_input, system_message, mood_prompt, conversation_histo except requests.exceptions.RequestException as e: full_response = f"Error connecting to Ollama model: {e}" print(f"Debug: Ollama error - {e}") + + elif MODEL_PROVIDER == 'xai': + messages = [{"role": "system", "content": system_message + "\n" + mood_prompt}] + conversation_history + [{"role": "user", "content": user_input}] + headers = { + 'Authorization': f'Bearer {XAI_API_KEY}', + 'Content-Type': 'application/json' + } + payload = { + "model": XAI_MODEL, + "messages": messages, + "stream": True + } + try: + print(f"Debug: Sending request to XAI: {XAI_BASE_URL}") + response = requests.post(f"{XAI_BASE_URL}/chat/completions", headers=headers, json=payload, stream=True, timeout=30) + response.raise_for_status() + + print("Starting XAI stream...") + line_buffer = "" + for line in response.iter_lines(decode_unicode=True): + if line.startswith("data:"): + line = line[5:].strip() + if line: + try: + chunk = json.loads(line) + delta_content = chunk['choices'][0]['delta'].get('content', '') + if delta_content: + line_buffer += delta_content + if '\n' in line_buffer: + lines = line_buffer.split('\n') + for line in lines[:-1]: + print(NEON_GREEN + line + RESET_COLOR) + full_response += line + '\n' + line_buffer = lines[-1] + except json.JSONDecodeError: + continue + if line_buffer: + print(NEON_GREEN + line_buffer + RESET_COLOR) + full_response += line_buffer + print("\nXAI stream complete.") + + except requests.exceptions.RequestException as e: + full_response = f"Error connecting to XAI model: {e}" + print(f"Debug: XAI error - {e}") elif MODEL_PROVIDER == 'openai': messages = [{"role": "system", "content": system_message + "\n" + mood_prompt}] + conversation_history + [{"role": "user", "content": user_input}] @@ -590,6 +642,7 @@ async def encode_image(image_path): # Analyze Image async def analyze_image(image_path, question_prompt): encoded_image = await encode_image(image_path) + if MODEL_PROVIDER == 'ollama': headers = {'Content-Type': 'application/json'} payload = { @@ -612,8 +665,13 @@ async def analyze_image(image_path, question_prompt): except aiohttp.ClientError as e: print(f"Request failed: {e}") return {"choices": [{"message": {"content": "Failed to process the image with the llava model."}}]} - else: - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {OPENAI_API_KEY}"} + + elif MODEL_PROVIDER == 'xai': + # First, try XAI's image analysis if it's supported + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {XAI_API_KEY}" + } message = { "role": "user", "content": [ @@ -621,15 +679,58 @@ async def analyze_image(image_path, question_prompt): {"type": "image_url", "image_url": {"url": f"data:image/jpg;base64,{encoded_image}", "detail": "low"}} ] } - payload = {"model": OPENAI_MODEL, "temperature": 0.5, "messages": [message], "max_tokens": 1000} + payload = { + "model": XAI_MODEL, + "temperature": 0.5, + "messages": [message], + "max_tokens": 1000 + } + try: async with aiohttp.ClientSession() as session: - async with session.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=30) as response: - response.raise_for_status() - return await response.json() + async with session.post(f"{XAI_BASE_URL}/chat/completions", headers=headers, json=payload, timeout=30) as response: + if response.status == 200: + return await response.json() + else: + # If XAI doesn't support image analysis or returns an error, + # fall back to OpenAI's image analysis + print("XAI image analysis failed or not supported, falling back to OpenAI") + return await fallback_to_openai_image_analysis(encoded_image, question_prompt) except aiohttp.ClientError as e: - print(f"Request failed: {e}") - return {"choices": [{"message": {"content": "Failed to process the image with the OpenAI model."}}]} + print(f"XAI image analysis failed: {e}, falling back to OpenAI") + return await fallback_to_openai_image_analysis(encoded_image, question_prompt) + + else: # OpenAI as default + return await fallback_to_openai_image_analysis(encoded_image, question_prompt) + +async def fallback_to_openai_image_analysis(encoded_image, question_prompt): + """Helper function for OpenAI image analysis fallback""" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {OPENAI_API_KEY}" + } + message = { + "role": "user", + "content": [ + {"type": "text", "text": question_prompt}, + {"type": "image_url", "image_url": {"url": f"data:image/jpg;base64,{encoded_image}", "detail": "low"}} + ] + } + payload = { + "model": OPENAI_MODEL, + "temperature": 0.5, + "messages": [message], + "max_tokens": 1000 + } + + try: + async with aiohttp.ClientSession() as session: + async with session.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload, timeout=30) as response: + response.raise_for_status() + return await response.json() + except aiohttp.ClientError as e: + print(f"OpenAI fallback request failed: {e}") + return {"choices": [{"message": {"content": "Failed to process the image with both XAI and OpenAI models."}}]} async def generate_speech(text, temp_audio_path): diff --git a/app/app_logic.py b/app/app_logic.py index 87e9e83..5a376a3 100644 --- a/app/app_logic.py +++ b/app/app_logic.py @@ -15,6 +15,7 @@ open_file, init_ollama_model, init_openai_model, + init_xai_model, init_openai_tts_voice, init_elevenlabs_tts_voice, init_xtts_speed, @@ -117,6 +118,8 @@ def set_env_variable(key: str, value: str): init_ollama_model(value) # Reinitialize Ollama model if key == "OPENAI_MODEL": init_openai_model(value) # Reinitialize OpenAI model + if key == "XAI_MODEL": + init_xai_model(value) # Reinitialize XAI model if key == "OPENAI_TTS_VOICE": init_openai_tts_voice(value) # Reinitialize OpenAI TTS voice if key == "ELEVENLABS_TTS_VOICE": diff --git a/app/static/js/scripts.js b/app/static/js/scripts.js index 59eba0a..ad8fe71 100644 --- a/app/static/js/scripts.js +++ b/app/static/js/scripts.js @@ -161,6 +161,11 @@ document.addEventListener("DOMContentLoaded", function() { websocket.send(JSON.stringify({ action: "set_ollama_model", model: selectedModel })); } + function setXAIModel() { + const selectedModel = document.getElementById('xai-model-select').value; + websocket.send(JSON.stringify({ action: "set_xai_model", model: selectedModel })); + } + function setXTTSSpeed() { const selectedSpeed = document.getElementById('xtts-speed-select').value; websocket.send(JSON.stringify({ action: "set_xtts_speed", speed: selectedSpeed })); @@ -181,6 +186,7 @@ document.addEventListener("DOMContentLoaded", function() { document.getElementById('openai-voice-select').addEventListener('change', setOpenAIVoice); document.getElementById('openai-model-select').addEventListener('change', setOpenAIModel); document.getElementById('ollama-model-select').addEventListener('change', setOllamaModel); + document.getElementById('xai-model-select').addEventListener('change', setXAIModel); document.getElementById('xtts-speed-select').addEventListener('change', setXTTSSpeed); document.getElementById('elevenlabs-voice-select').addEventListener('change', setElevenLabsVoice); diff --git a/app/templates/index.html b/app/templates/index.html index 8241380..5b0ceee 100644 --- a/app/templates/index.html +++ b/app/templates/index.html @@ -60,6 +60,7 @@

@@ -110,6 +111,12 @@

+
+ + +