diff --git a/changelog/unreleased/kong/ai-gemini-blocks-content-safety.yml b/changelog/unreleased/kong/ai-gemini-blocks-content-safety.yml new file mode 100644 index 000000000000..3cdd2e3a2842 --- /dev/null +++ b/changelog/unreleased/kong/ai-gemini-blocks-content-safety.yml @@ -0,0 +1,3 @@ +message: "**ai-proxy**: Fixed a bug where Gemini provider would return an error if content safety failed in AI Proxy." +type: bugfix +scope: Plugin diff --git a/kong/llm/drivers/gemini.lua b/kong/llm/drivers/gemini.lua index bfab0b15743e..5d783e0da97b 100644 --- a/kong/llm/drivers/gemini.lua +++ b/kong/llm/drivers/gemini.lua @@ -32,6 +32,14 @@ local function to_gemini_generation_config(request_table) } end +local function is_content_safety_failure(content) + return content + and content.candidates + and #content.candidates > 0 + and content.candidates[1].finishReason + and content.candidates[1].finishReason == "SAFETY" +end + local function is_response_content(content) return content and content.candidates @@ -210,8 +218,6 @@ local function to_gemini_chat_openai(request_table, model_info, route_type) new_r.tools = request_table.tools and to_tools(request_table.tools) end - kong.log.warn(cjson.encode(new_r)) - return new_r, "application/json", nil end @@ -229,7 +235,16 @@ local function from_gemini_chat_openai(response, model_info, route_type) messages.choices = {} if response.candidates and #response.candidates > 0 then - if is_response_content(response) then + -- for transformer plugins only + if model_info.source + and (model_info.source == "ai-request-transformer" or model_info.source == "ai-response-transformer") + and is_content_safety_failure(response) then + local err = "transformation generation candidate breached Gemini content safety" + ngx.log(ngx.ERR, err) + + return nil, err + + elseif is_response_content(response) then messages.choices[1] = { index = 0, message = { @@ -270,14 +285,6 @@ local function from_gemini_chat_openai(response, model_info, route_type) } end - elseif response.candidates - and #response.candidates > 0 - and response.candidates[1].finishReason - and response.candidates[1].finishReason == "SAFETY" then - local err = "transformation generation candidate breached Gemini content safety" - ngx.log(ngx.ERR, err) - return nil, err - else -- probably a server fault or other unexpected response local err = "no generation candidates received from Gemini, or max_tokens too short" ngx.log(ngx.ERR, err) diff --git a/kong/plugins/ai-request-transformer/handler.lua b/kong/plugins/ai-request-transformer/handler.lua index 6a22a6d8297e..b47399f62bc6 100644 --- a/kong/plugins/ai-request-transformer/handler.lua +++ b/kong/plugins/ai-request-transformer/handler.lua @@ -65,6 +65,7 @@ function _M:access(conf) local http_opts = create_http_opts(conf) conf.llm.__plugin_id = conf.__plugin_id conf.llm.__key__ = conf.__key__ + conf.llm.model.source = "ai-request-transformer" local ai_driver, err = llm.new_driver(conf.llm, http_opts, identity_interface) if not ai_driver then diff --git a/kong/plugins/ai-response-transformer/handler.lua b/kong/plugins/ai-response-transformer/handler.lua index d119f98610c5..d409148c8e1f 100644 --- a/kong/plugins/ai-response-transformer/handler.lua +++ b/kong/plugins/ai-response-transformer/handler.lua @@ -124,6 +124,7 @@ function _M:access(conf) local http_opts = create_http_opts(conf) conf.llm.__plugin_id = conf.__plugin_id conf.llm.__key__ = conf.__key__ + conf.llm.model.source = "ai-response-transformer" local ai_driver, err = llm.new_driver(conf.llm, http_opts, identity_interface) if not ai_driver then