Skip to content

Commit

Permalink
custom api endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
fabio-garavini committed Dec 26, 2024
1 parent 995dea1 commit e436c15
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 83 deletions.
120 changes: 111 additions & 9 deletions custom_components/openai_whisper_cloud/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
ConfigFlowResult,
OptionsFlowWithConfigEntry,
)
from homeassistant.const import CONF_API_KEY, CONF_MODEL, CONF_NAME, CONF_SOURCE
from homeassistant.const import (
CONF_API_KEY,
CONF_MODEL,
CONF_NAME,
CONF_SOURCE,
CONF_URL,
)
from homeassistant.core import callback
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.selector import (
Expand All @@ -27,6 +33,7 @@

from .const import (
_LOGGER,
CONF_CUSTOM_PROVIDER,
CONF_PROMPT,
CONF_TEMPERATURE,
DEFAULT_PROMPT,
Expand Down Expand Up @@ -111,7 +118,7 @@ async def async_step_init(
for x in whisper_providers[
self.config_entry.data[CONF_SOURCE]
].models
].index(user_input[CONF_MODEL]),
].index(user_input[CONF_MODEL]) if not self.config_entry.data.get(CONF_CUSTOM_PROVIDER) else user_input[CONF_MODEL],
CONF_TEMPERATURE: user_input[CONF_TEMPERATURE],
CONF_PROMPT: user_input.get(CONF_PROMPT, ""),
},
Expand All @@ -129,7 +136,7 @@ async def async_step_init(
self.config_entry.data[CONF_SOURCE]
].models
]
),
) if not self.config_entry.data.get(CONF_CUSTOM_PROVIDER) else cv.string,
vol.Optional(CONF_TEMPERATURE): vol.All(
vol.Coerce(float), vol.Range(min=0, max=1)
),
Expand All @@ -139,7 +146,7 @@ async def async_step_init(
suggested_values={
CONF_MODEL: whisper_providers[self.config_entry.data[CONF_SOURCE]]
.models[self.config_entry.options[CONF_MODEL]]
.name,
.name if not self.config_entry.data.get(CONF_CUSTOM_PROVIDER) else self.config_entry.options[CONF_MODEL],
CONF_TEMPERATURE: self.config_entry.options[CONF_TEMPERATURE],
CONF_PROMPT: self.config_entry.options.get(CONF_PROMPT, ""),
},
Expand Down Expand Up @@ -172,7 +179,9 @@ async def async_step_user(
"""Handle initial step."""
errors = {}
if user_input is not None:
self._provider = whisper_providers[int(user_input[CONF_SOURCE])]
if int(user_input[CONF_SOURCE]) != 2:
self._provider = whisper_providers[int(user_input[CONF_SOURCE])]

return await self.async_step_whisper()

return self.async_show_form(
Expand All @@ -187,6 +196,23 @@ async def async_step_whisper(
"""Handle initial step."""
errors = {}
if user_input is not None:

if self._provider is None:
return self.async_create_entry(
title=user_input.get(CONF_NAME),
data={
CONF_CUSTOM_PROVIDER: True,
CONF_NAME: user_input[CONF_NAME],
CONF_URL: user_input[CONF_URL],
CONF_API_KEY: user_input.get(CONF_API_KEY),
},
options={
CONF_MODEL: user_input[CONF_MODEL],
CONF_TEMPERATURE: user_input[CONF_TEMPERATURE],
CONF_PROMPT: user_input.get(CONF_PROMPT, ""),
},
)

try:
await validate_input(user_input, self._provider)

Expand Down Expand Up @@ -218,6 +244,26 @@ async def async_step_whisper(
except UnknownError:
errors["base"] = "unknown"

if self._provider is None:
return self.async_show_form(
step_id="whisper",
data_schema=vol.Schema(
{
vol.Required(
CONF_NAME, default="Custom Whisper"
): cv.string,
vol.Required(CONF_URL): cv.string,
vol.Optional(CONF_API_KEY): cv.string,
vol.Required(CONF_MODEL): cv.string,
vol.Optional(
CONF_TEMPERATURE, default=DEFAULT_TEMPERATURE
): vol.All(vol.Coerce(float), vol.Range(min=0, max=1)),
vol.Optional(CONF_PROMPT): cv.string,
}
),
errors=errors,
)

return self.async_show_form(
step_id="whisper",
data_schema=vol.Schema(
Expand Down Expand Up @@ -251,23 +297,79 @@ async def async_step_reconfigure(

entry = self.hass.config_entries.async_get_entry(self.context["entry_id"])

if entry.data.get(CONF_CUSTOM_PROVIDER, False):

if user_input is not None:

self.hass.config_entries.async_update_entry(
entry=entry,
title=user_input.get(CONF_NAME),
data={
CONF_CUSTOM_PROVIDER: True,
CONF_NAME: user_input[CONF_NAME],
CONF_URL: user_input[CONF_URL],
CONF_API_KEY: user_input.get(CONF_API_KEY, entry.data.get(CONF_API_KEY, "")),
},
options={
CONF_MODEL: user_input[CONF_MODEL],
CONF_TEMPERATURE: user_input[CONF_TEMPERATURE],
CONF_PROMPT: user_input.get(CONF_PROMPT, ""),
},
)

await self.hass.config_entries.async_reload(self.context["entry_id"])
return self.async_abort(reason="reconfigure_successful")


return self.async_show_form(
step_id="reconfigure",
data_schema=self.add_suggested_values_to_schema(
data_schema=vol.Schema(
{
vol.Required(
CONF_NAME, default="Custom Whisper"
): cv.string,
vol.Required(CONF_URL): cv.string,
vol.Optional(CONF_API_KEY): cv.string,
vol.Required(CONF_MODEL): cv.string,
vol.Optional(
CONF_TEMPERATURE, default=DEFAULT_TEMPERATURE
): vol.All(vol.Coerce(float), vol.Range(min=0, max=1)),
vol.Optional(CONF_PROMPT): cv.string,
}
),
suggested_values={
CONF_NAME: entry.data.get(CONF_NAME),
CONF_URL: entry.data.get(CONF_URL),
CONF_MODEL: entry.options.get(CONF_MODEL),
CONF_TEMPERATURE: entry.options.get(CONF_TEMPERATURE),
CONF_PROMPT: entry.options.get(CONF_PROMPT),
},
),
errors=errors,
)

provider: WhisperProvider = whisper_providers[entry.data.get(CONF_SOURCE)]
whisper: WhisperModel = provider.models[entry.options.get(CONF_MODEL)]

if user_input is not None:
if CONF_API_KEY not in user_input:
user_input[CONF_API_KEY] = entry.data.get(CONF_API_KEY)

try:
await validate_input(user_input, provider)
await validate_input(
{
**user_input,
CONF_API_KEY: user_input.get(CONF_API_KEY, entry.data.get(CONF_API_KEY, ""))
},
provider
)

self.hass.config_entries.async_update_entry(
entry=entry,
title=user_input[CONF_NAME],
data={
CONF_SOURCE: entry.data.get(CONF_SOURCE),
CONF_NAME: user_input[CONF_NAME],
CONF_API_KEY: user_input.get(CONF_API_KEY),
CONF_API_KEY: user_input.get(CONF_API_KEY, entry.data.get(CONF_API_KEY, "")),
},
options={
CONF_MODEL: [x.name for x in provider.models].index(
Expand Down
1 change: 1 addition & 0 deletions custom_components/openai_whisper_cloud/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

CONF_PROMPT = "prompt"
CONF_TEMPERATURE = "temperature"
CONF_CUSTOM_PROVIDER = "custom_provider"

SUPPORTED_LANGUAGES = [
"af",
Expand Down
2 changes: 1 addition & 1 deletion custom_components/openai_whisper_cloud/manifest.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"domain": "openai_whisper_cloud",
"name": "Whisper Cloud",
"version": "1.1.1",
"version": "1.2.0",
"codeowners": [
"@fabio-garavini"
],
Expand Down
30 changes: 22 additions & 8 deletions custom_components/openai_whisper_cloud/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,23 @@
SpeechToTextEntity,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY, CONF_MODEL, CONF_NAME, CONF_SOURCE
from homeassistant.const import (
CONF_API_KEY,
CONF_MODEL,
CONF_NAME,
CONF_SOURCE,
CONF_URL,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback

from . import _LOGGER
from .const import CONF_PROMPT, CONF_TEMPERATURE
from .const import (
CONF_CUSTOM_PROVIDER,
CONF_PROMPT,
CONF_TEMPERATURE,
SUPPORTED_LANGUAGES,
)
from .whisper_provider import WhisperModel, whisper_providers


Expand All @@ -35,14 +46,15 @@ async def async_setup_entry(
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Demo speech platform via config entry."""
"""Set up Whisper speech platform via config entry."""
_LOGGER.debug(f"STT setup Entry {config_entry.entry_id}")

async_add_entities([
OpenAIWhisperCloudEntity(
api_url=whisper_providers[config_entry.data[CONF_SOURCE]].url,
api_key=config_entry.data[CONF_API_KEY],
model=whisper_providers[config_entry.data[CONF_SOURCE]].models[config_entry.options[CONF_MODEL]],
custom=config_entry.data.get(CONF_CUSTOM_PROVIDER, False),
api_url=config_entry.data[CONF_URL] if config_entry.data.get(CONF_CUSTOM_PROVIDER) else whisper_providers[config_entry.data[CONF_SOURCE]].url,
api_key=config_entry.data.get(CONF_API_KEY, ""),
model= WhisperModel(config_entry.options[CONF_MODEL], SUPPORTED_LANGUAGES) if config_entry.data.get(CONF_CUSTOM_PROVIDER) else whisper_providers[config_entry.data[CONF_SOURCE]].models[config_entry.options[CONF_MODEL]],
temperature=config_entry.options[CONF_TEMPERATURE],
prompt=config_entry.options[CONF_PROMPT],
name=config_entry.data[CONF_NAME],
Expand All @@ -51,11 +63,13 @@ async def async_setup_entry(
])



class OpenAIWhisperCloudEntity(SpeechToTextEntity):
"""OpenAI Whisper API provider entity."""

def __init__(self, api_url: str, api_key: str, model: WhisperModel, temperature, prompt, name, unique_id) -> None:
def __init__(self, custom: bool, api_url: str, api_key: str, model: WhisperModel, temperature, prompt, name, unique_id) -> None:
"""Init STT service."""
self.custom = custom
self.api_url = api_url
self.api_key = api_key
self.model = model
Expand Down Expand Up @@ -152,7 +166,7 @@ async def async_process_audio_stream(
# Make the request in a separate thread
response = await asyncio.to_thread(
requests.post,
f"{self.api_url}/v1/audio/transcriptions",
f"{self.api_url}/v1/audio/transcriptions" if not self.custom else self.api_url,
headers={
"Authorization": f"Bearer {self.api_key}",
},
Expand Down
10 changes: 7 additions & 3 deletions custom_components/openai_whisper_cloud/translations/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,29 @@
"whisper": {
"data": {
"name": "Name",
"url": "Url",
"api_key": "API key",
"model": "Model",
"temperature": "Temperature",
"prompt": "Prompt (Optional)"
},
"data_description": {
"prompt": "Prompt can be used to improve speech recognition of words or even names.\nou have to provide a list of words or names separated by a comma \", \".\nExample: \"open, close, Chat GPT-3, DALL·E\""
"url": "Full whisper api url. Example: `https://api.openai.com/v1/audio/transcriptions`",
"prompt": "Prompt can be used to improve speech recognition of words or even names.\nou have to provide a list of words or names separated by a comma \", \".\nExample: `open, close, Chat GPT-3, DALL·E`"
}
},
"reconfigure": {
"data": {
"name": "Name",
"url": "Url",
"api_key": "API key (Optional)",
"model": "Model",
"temperature": "Temperature",
"prompt": "Prompt (Optional)"
},
"data_description": {
"prompt": "Prompt can be used to improve speech recognition of words or even names.\nou have to provide a list of words or names separated by a comma \", \".\nExample: \"open, close, Chat GPT-3, DALL·E\""
"url": "Full whisper api url. Example: `https://api.openai.com/v1/audio/transcriptions`",
"prompt": "Prompt can be used to improve speech recognition of words or even names.\nou have to provide a list of words or names separated by a comma \", \".\nExample: `open, close, Chat GPT-3, DALL·E`"
}
}
}
Expand All @@ -63,7 +67,7 @@
"prompt": "Prompt (Optional)"
},
"data_description": {
"prompt": "Prompt can be used to improve speech recognition of words or even names.\nYou have to provide a list of words or names separated by a comma \", \".\nExample: \"open, close, Chat GPT-3, DALL·E\""
"prompt": "Prompt can be used to improve speech recognition of words or even names.\nYou have to provide a list of words or names separated by a comma \", \".\nExample: `open, close, Chat GPT-3, DALL·E`"
}
}
}
Expand Down
Loading

0 comments on commit e436c15

Please sign in to comment.