Skip to content

Commit

Permalink
Merge pull request #20 from fumito-ito/fix/aws-bedrock-invalid-request
Browse files Browse the repository at this point in the history
Fix issue with incorrect parameters being included in requests when using AWS Bedrock
  • Loading branch information
fumito-ito authored Jul 3, 2024
2 parents 52faa83 + 9686789 commit b68c2d7
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ extension InvokeModelInput {
/// - request: Claude API request. It will be converted to `Data` and contained in bedrock request.
/// - contentType: acceptable request content type
init(accept: String, request: MessagesRequest, contentType: String) throws {
let data = try request.encode(with: ["anthropic_version": AnthropicBedrockClient.anthropicVersion])
let data = try request.encode(
with: ["anthropic_version": AnthropicBedrockClient.anthropicVersion],
without: UnnecessaryParameter.allCases.map { $0.rawValue }
)

self.init(
accept: accept,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ extension InvokeModelWithResponseStreamInput {
/// - request: Claude API request. It will be converted to `Data` and contained in bedrock request.
/// - contentType: acceptable request content type
init(accept: String, request: MessagesRequest, contentType: String) throws {
let data = try request.encode(with: ["anthropic_version": AnthropicBedrockClient.anthropicVersion])
let data = try request.encode(
with: ["anthropic_version": AnthropicBedrockClient.anthropicVersion],
without: UnnecessaryParameter.allCases.map { $0.rawValue }
)

self.init(
accept: accept,
Expand Down
8 changes: 3 additions & 5 deletions Sources/AnthropicSwiftSDK-Bedrock/Messages.swift
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public struct Messages {
topP: Double? = nil,
topK: Int? = nil
) async throws -> AsyncThrowingStream<StreamingResponse, Error> {
// In the inference call, fill the body field with a JSON object that conforms the type call you want to make [Anthropic Claude Messages API](https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html).
// In the inference call, fill the body field with a JSON object that conforms the type call you want to make [Anthropic Claude Messages API](https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html ).
let requestBody = MessagesRequest(
model: model,
messages: messages,
Expand Down Expand Up @@ -135,12 +135,10 @@ extension BedrockRuntimeClientTypes.ResponseStream {
throw AnthropicBedrockClientError.bedrockRuntimeClientGetsUnknownPayload(self)
}

guard
let data = payload.bytes,
let line = String(data: data, encoding: .utf8) else {
guard let data = payload.bytes else {
throw AnthropicBedrockClientError.cannotGetDataFromBedrockClientPayload(payload)
}

return line
return String(decoding: data, as: UTF8.self)
}
}
17 changes: 17 additions & 0 deletions Sources/AnthropicSwiftSDK-Bedrock/UnnecessaryParameter.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//
// UnnecessaryParameter.swift
//
//
// Created by 伊藤史 on 2024/07/01.
//

import Foundation

/// Unnecessary parameters to use Anthropic claude through AWS Bedrock
///
/// When using the Anthropic API through AWS Bedrock, some of the properties required in a normal Anthropic API request were causing errors as invalid properties.
enum UnnecessaryParameter: String, CaseIterable {
case model
case stream
case metadata
}
6 changes: 5 additions & 1 deletion Sources/AnthropicSwiftSDK/Network/MessagesRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public struct MessagesRequest: Encodable {
}

extension MessagesRequest {
public func encode(with appendingObject: [String: Any]) throws -> Data {
public func encode(with appendingObject: [String: Any], without removingObjectKeys: [String] = []) throws -> Data {
let encoded = try anthropicJSONEncoder.encode(self)
guard var dictionary = try JSONSerialization.jsonObject(with: encoded, options: []) as? [String: Any] else {
return encoded
Expand All @@ -84,6 +84,10 @@ extension MessagesRequest {
dictionary[key] = value
}

removingObjectKeys.forEach { key in
dictionary.removeValue(forKey: key)
}

return try JSONSerialization.data(withJSONObject: dictionary, options: [])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,39 +37,58 @@ final class AnthropicBedrockClientTests: XCTestCase {
let requestData = try XCTUnwrap(invokeModel.body)
let decodedRequestData = try anthropicJSONDecoder.decode(MessagesRequest.self, from: requestData)

XCTAssertEqual(request.model.description, decodedRequestData.model.description)
XCTAssertEqual(request.messages.first?.role, decodedRequestData.messages.first?.role)
XCTAssertEqual(request.messages.first?.content.first, decodedRequestData.messages.first?.content.first)
XCTAssertEqual(request.system, decodedRequestData.system)
XCTAssertEqual(request.maxTokens, decodedRequestData.maxTokens)
XCTAssertEqual(request.metaData?.userId, decodedRequestData.metaData?.userId)
XCTAssertEqual(request.stopSequences, decodedRequestData.stopSequences)
XCTAssertEqual(request.stream, decodedRequestData.stream)
XCTAssertEqual(request.temperature, decodedRequestData.temperature)
XCTAssertEqual(request.topP, decodedRequestData.topP)
XCTAssertEqual(request.topK, decodedRequestData.topK)
}

func testInvokeModelNotContainUnnecessaryParameters() throws {
let request = MessagesRequest(model: .claude_3_Haiku, messages: [Message(role: .user, content: [.text("Hello! Claude!")])], system: nil, maxTokens: 1024, metaData: MetaData(userId: "112234"), stopSequences: ["stop sequence"], stream: false, temperature: 0.4, topP: 1, topK: 2)
let invokeModel = try InvokeModelInput(accept: "application/json", request: request, contentType: "application/json")

let requestData = try XCTUnwrap(invokeModel.body)
let decodedJSON = try JSONSerialization.jsonObject(with: requestData, options: []) as! [String: Any]

UnnecessaryParameter.allCases.map { $0.rawValue }.forEach { key in
XCTAssertFalse(decodedJSON.keys.contains(key))
}
}

func testInvokeModelWithResponseStreamContainEncodedMessageRequest() throws {
let request = MessagesRequest(model: .claude_3_Haiku, messages: [Message(role: .user, content: [.text("Hello! Claude!")])], system: nil, maxTokens: 1024, metaData: MetaData(userId: "112234"), stopSequences: ["stop sequence"], stream: false, temperature: 0.4, topP: 1, topK: 2)
let invokeModel = try InvokeModelWithResponseStreamInput(accept: "application/json", request: request, contentType: "application/json")

let requestData = try XCTUnwrap(invokeModel.body)
let decodedRequestData = try anthropicJSONDecoder.decode(MessagesRequest.self, from: requestData)

XCTAssertEqual(request.model.description, decodedRequestData.model.description)
XCTAssertEqual(request.messages.first?.role, decodedRequestData.messages.first?.role)
XCTAssertEqual(request.messages.first?.content.first, decodedRequestData.messages.first?.content.first)
XCTAssertEqual(request.system, decodedRequestData.system)
XCTAssertEqual(request.maxTokens, decodedRequestData.maxTokens)
XCTAssertEqual(request.metaData?.userId, decodedRequestData.metaData?.userId)
XCTAssertEqual(request.stopSequences, decodedRequestData.stopSequences)
XCTAssertEqual(request.stream, decodedRequestData.stream)
XCTAssertEqual(request.temperature, decodedRequestData.temperature)
XCTAssertEqual(request.topP, decodedRequestData.topP)
XCTAssertEqual(request.topK, decodedRequestData.topK)
}

func testInvokeModelWithResponseStreamNotContainUnnecessaryParameters() throws {
let request = MessagesRequest(model: .claude_3_Haiku, messages: [Message(role: .user, content: [.text("Hello! Claude!")])], system: nil, maxTokens: 1024, metaData: MetaData(userId: "112234"), stopSequences: ["stop sequence"], stream: false, temperature: 0.4, topP: 1, topK: 2)
let invokeModel = try InvokeModelWithResponseStreamInput(accept: "application/json", request: request, contentType: "application/json")

let requestData = try XCTUnwrap(invokeModel.body)
let decodedJSON = try JSONSerialization.jsonObject(with: requestData, options: []) as! [String: Any]

UnnecessaryParameter.allCases.map { $0.rawValue }.forEach { key in
XCTAssertFalse(decodedJSON.keys.contains(key))
}
}


func testInvokeModelOutputShouldBeConvertToMessageResponse() throws {
let json = """
{
Expand Down Expand Up @@ -142,9 +161,7 @@ extension MessagesRequest: Decodable {
case messages
case system
case maxTokens
case metaData
case stopSequences
case stream
case temperature
case topP
case topK
Expand All @@ -153,13 +170,10 @@ extension MessagesRequest: Decodable {
public init(from decoder: any Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
self.init(
model: try container.decode(Model.self, forKey: .model),
messages: try container.decode([Message].self, forKey: .messages),
system: try? container.decode(String.self, forKey: .system),
maxTokens: try container.decode(Int.self, forKey: .maxTokens),
metaData: try container.decode(MetaData.self, forKey: .metaData),
stopSequences: try container.decode([String].self, forKey: .stopSequences),
stream: try container.decode(Bool.self, forKey: .stream),
temperature: try container.decode(Double.self, forKey: .temperature),
topP: try container.decode(Double.self, forKey: .topP),
topK: try container.decode(Int.self, forKey: .topK)
Expand Down

0 comments on commit b68c2d7

Please sign in to comment.