From 9a85746502f93c6a95f3869e49902b45a85688b7 Mon Sep 17 00:00:00 2001 From: Chris Bartholomew Date: Thu, 21 Mar 2024 15:42:33 -0400 Subject: [PATCH] Allow resource size and parallelism to be specified as variables; fixes for unit tests --- .../compute-openai-embeddings/gateways.yaml | 16 ++++ .../python-source-exclamation/pipeline.yaml | 3 +- langstream-api/pom.xml | 10 +++ .../langstream/api/model/ResourcesSpec.java | 46 +++++++++-- .../api/model/ResourcesSpecDeserializer.java | 77 +++++++++++++++++++ .../docker/LocalRunApplicationCmdTest.java | 2 +- ...ComposableAgentExecutionPlanOptimiser.java | 18 +++-- .../ApplicationPlaceholderResolver.java | 50 ++++++++++-- .../ApplicationPlaceholderResolverTest.java | 8 ++ .../ApplicationResourceLimitsChecker.java | 25 +++++- .../impl/k8s/KubernetesClusterRuntime.java | 26 ++++++- .../application/ApplicationService.java | 24 +++++- 12 files changed, 281 insertions(+), 24 deletions(-) create mode 100644 langstream-api/src/main/java/ai/langstream/api/model/ResourcesSpecDeserializer.java diff --git a/examples/applications/compute-openai-embeddings/gateways.yaml b/examples/applications/compute-openai-embeddings/gateways.yaml index 4802932be..edc0acef3 100644 --- a/examples/applications/compute-openai-embeddings/gateways.yaml +++ b/examples/applications/compute-openai-embeddings/gateways.yaml @@ -1,3 +1,19 @@ +# +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + gateways: - id: "input" type: produce diff --git a/examples/applications/python/python-source-exclamation/pipeline.yaml b/examples/applications/python/python-source-exclamation/pipeline.yaml index 661c174c2..8aacf71ca 100644 --- a/examples/applications/python/python-source-exclamation/pipeline.yaml +++ b/examples/applications/python/python-source-exclamation/pipeline.yaml @@ -23,4 +23,5 @@ pipeline: type: "python-source" output: "output-topic" configuration: - className: example.Exclamation \ No newline at end of file + className: example.Exclamation + dbname: "example" \ No newline at end of file diff --git a/langstream-api/pom.xml b/langstream-api/pom.xml index 13d379e3f..7975ee7cd 100644 --- a/langstream-api/pom.xml +++ b/langstream-api/pom.xml @@ -45,5 +45,15 @@ jackson-annotations provided + + com.fasterxml.jackson.core + jackson-databind + provided + + + com.fasterxml.jackson.core + jackson-core + provided + diff --git a/langstream-api/src/main/java/ai/langstream/api/model/ResourcesSpec.java b/langstream-api/src/main/java/ai/langstream/api/model/ResourcesSpec.java index 8ee35c949..72d368821 100644 --- a/langstream-api/src/main/java/ai/langstream/api/model/ResourcesSpec.java +++ b/langstream-api/src/main/java/ai/langstream/api/model/ResourcesSpec.java @@ -16,21 +16,51 @@ package ai.langstream.api.model; import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import java.util.Map; @JsonInclude(JsonInclude.Include.NON_NULL) -/** Definition of the resources required by the agent. */ -public record ResourcesSpec(Integer parallelism, Integer size, DiskSpec disk) { +@JsonDeserialize(using = ResourcesSpecDeserializer.class) // Use custom deserializer +public record ResourcesSpec(Object parallelism, Object size, DiskSpec disk) { - public static ResourcesSpec DEFAULT = new ResourcesSpec(1, 1, null); + public static final ResourcesSpec DEFAULT = new ResourcesSpec(1, 1, null); + // withDefaultsFrom method public ResourcesSpec withDefaultsFrom(ResourcesSpec higherLevel) { if (higherLevel == null) { return this; } - Integer newParallelism = parallelism == null ? higherLevel.parallelism() : parallelism; - Integer newUnits = size == null ? higherLevel.size() : size; - DiskSpec newDisk = - disk == null ? higherLevel.disk() : disk.withDefaultsFrom(higherLevel.disk); - return new ResourcesSpec(newParallelism, newUnits, newDisk); + Object newParallelism = parallelism == null ? higherLevel.parallelism() : parallelism; + Object newSize = size == null ? higherLevel.size() : size; + DiskSpec newDisk = disk == null ? higherLevel.disk() : disk; + + return new ResourcesSpec(newParallelism, newSize, newDisk); + } + + // resolveVariables method + public ResourcesSpec resolveVariables(Map variableMap) { + Integer resolvedParallelism = resolveToObject(parallelism, variableMap); + Integer resolvedSize = resolveToObject(size, variableMap); + return new ResourcesSpec(resolvedParallelism, resolvedSize, disk); + } + + // Helper method to resolve objects or strings to Integer + private static Integer resolveToObject(Object value, Map variableMap) { + if (value instanceof String) { + String stringValue = (String) value; + if (stringValue.startsWith("${") && stringValue.endsWith("}")) { + String variableName = stringValue.substring(2, stringValue.length() - 1); + return variableMap.getOrDefault(variableName, null); + } + try { + return Integer.parseInt(stringValue); + } catch (NumberFormatException e) { + System.err.println("Error parsing integer: " + e.getMessage()); + return null; + } + } else if (value instanceof Integer) { + return (Integer) value; + } + return null; } } diff --git a/langstream-api/src/main/java/ai/langstream/api/model/ResourcesSpecDeserializer.java b/langstream-api/src/main/java/ai/langstream/api/model/ResourcesSpecDeserializer.java new file mode 100644 index 000000000..39b0600c5 --- /dev/null +++ b/langstream-api/src/main/java/ai/langstream/api/model/ResourcesSpecDeserializer.java @@ -0,0 +1,77 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.api.model; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.deser.std.StdDeserializer; +import java.io.IOException; + +public class ResourcesSpecDeserializer extends StdDeserializer { + + public ResourcesSpecDeserializer() { + super(ResourcesSpec.class); + } + + @Override + public ResourcesSpec deserialize(JsonParser p, DeserializationContext ctxt) + throws IOException, JsonProcessingException { + JsonNode node = p.getCodec().readTree(p); + Object parallelism = parseField(node, "parallelism"); + Object size = parseField(node, "size"); + DiskSpec disk = null; // Allow disk to be null + + if (node.has("disk") && !node.get("disk").isNull()) { + disk = parseDiskSpec(node.get("disk"), ctxt, p); + } + + return new ResourcesSpec(parallelism, size, disk); + } + + private Object parseField(JsonNode node, String fieldName) { + if (node == null || !node.has(fieldName)) { + return null; + } + JsonNode fieldNode = node.get(fieldName); + if (fieldNode.isInt()) { + return fieldNode.asInt(); + } else if (fieldNode.isTextual()) { + String textValue = fieldNode.asText(); + if (textValue.matches("\\$\\{.*\\}")) { + return textValue; + } + try { + return Integer.parseInt(textValue); + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + "Value for '" + + fieldName + + "' must be an integer or a string representing an integer."); + } + } + throw new IllegalArgumentException("Unsupported JSON type for field '" + fieldName + "'."); + } + + private DiskSpec parseDiskSpec(JsonNode diskNode, DeserializationContext ctxt, JsonParser p) + throws IOException { + Boolean enabled = diskNode.has("enabled") ? diskNode.get("enabled").asBoolean() : null; + String type = diskNode.has("type") ? diskNode.get("type").asText() : null; + String size = diskNode.has("size") ? diskNode.get("size").asText() : null; + return new DiskSpec(enabled, type, size); + } +} diff --git a/langstream-cli/src/test/java/ai/langstream/cli/commands/docker/LocalRunApplicationCmdTest.java b/langstream-cli/src/test/java/ai/langstream/cli/commands/docker/LocalRunApplicationCmdTest.java index 6cd374f95..4c64fc619 100644 --- a/langstream-cli/src/test/java/ai/langstream/cli/commands/docker/LocalRunApplicationCmdTest.java +++ b/langstream-cli/src/test/java/ai/langstream/cli/commands/docker/LocalRunApplicationCmdTest.java @@ -65,7 +65,7 @@ void testArgs() throws Exception { log.info("Last line: {}", lastLine); assertTrue( lastLine.contains( - "run --rm -i -e START_BROKER=true -e START_MINIO=true -e START_HERDDB=true " + "run --rm -i -e START_BROKER=true -e USE_PULSAR=false -e START_MINIO=true -e START_HERDDB=true " + "-e LANSGSTREAM_TESTER_TENANT=default -e LANSGSTREAM_TESTER_APPLICATIONID=my-app " + "-e LANSGSTREAM_TESTER_STARTWEBSERVICES=true -e LANSGSTREAM_TESTER_DRYRUN=false ")); assertTrue( diff --git a/langstream-core/src/main/java/ai/langstream/impl/agents/ComposableAgentExecutionPlanOptimiser.java b/langstream-core/src/main/java/ai/langstream/impl/agents/ComposableAgentExecutionPlanOptimiser.java index 2e7880f2f..a949db3a1 100644 --- a/langstream-core/src/main/java/ai/langstream/impl/agents/ComposableAgentExecutionPlanOptimiser.java +++ b/langstream-core/src/main/java/ai/langstream/impl/agents/ComposableAgentExecutionPlanOptimiser.java @@ -65,11 +65,19 @@ && compareResourcesNoDisk( } private static boolean compareResourcesNoDisk(ResourcesSpec a, ResourcesSpec b) { - Integer parallismA = a != null ? a.parallelism() : null; - Integer parallismB = b != null ? b.parallelism() : null; - Integer sizeA = a != null ? a.size() : null; - Integer sizeB = b != null ? b.size() : null; - return Objects.equals(parallismA, parallismB) && Objects.equals(sizeA, sizeB); + Object parallelismA = a != null ? a.parallelism() : null; + Object parallelismB = b != null ? b.parallelism() : null; + Object sizeA = a != null ? a.size() : null; + Object sizeB = b != null ? b.size() : null; + + // Convert to String for comparison to handle both Integer and String types + String parallelismAStr = parallelismA != null ? parallelismA.toString() : null; + String parallelismBStr = parallelismB != null ? parallelismB.toString() : null; + String sizeAStr = sizeA != null ? sizeA.toString() : null; + String sizeBStr = sizeB != null ? sizeB.toString() : null; + + return Objects.equals(parallelismAStr, parallelismBStr) + && Objects.equals(sizeAStr, sizeBStr); } @Override diff --git a/langstream-core/src/main/java/ai/langstream/impl/common/ApplicationPlaceholderResolver.java b/langstream-core/src/main/java/ai/langstream/impl/common/ApplicationPlaceholderResolver.java index a897e26ae..5e208b992 100644 --- a/langstream-core/src/main/java/ai/langstream/impl/common/ApplicationPlaceholderResolver.java +++ b/langstream-core/src/main/java/ai/langstream/impl/common/ApplicationPlaceholderResolver.java @@ -25,6 +25,7 @@ import ai.langstream.api.model.Module; import ai.langstream.api.model.Pipeline; import ai.langstream.api.model.Resource; +import ai.langstream.api.model.ResourcesSpec; import ai.langstream.api.model.StreamingCluster; import ai.langstream.api.model.TopicDefinition; import com.fasterxml.jackson.databind.ObjectMapper; @@ -58,6 +59,7 @@ private ApplicationPlaceholderResolver() {} @SneakyThrows public static Application resolvePlaceholders(Application instance) { instance = deepCopy(instance); + log.debug("Resolving placeholders in application: {}", instance); final Map context = createContext(instance); if (log.isDebugEnabled()) { log.debug( @@ -66,8 +68,14 @@ public static Application resolvePlaceholders(Application instance) { if (log.isDebugEnabled()) { log.debug("Resolve context: {}", context); } + Instance resolvedInstance = resolveInstance(instance, context); + log.debug("Resolved instance: {}", resolvedInstance); + + Map resolvedModule = resolveModules(instance, context); + log.debug("Resolved modules: {}", resolvedModule); instance.setInstance(resolveInstance(instance, context)); + instance.setResources(resolveResources(instance, context)); instance.setModules(resolveModules(instance, context)); instance.setGateways(resolveGateways(instance, context)); @@ -119,11 +127,17 @@ private static Map resolveModules( for (Map.Entry pipelineEntry : module.getPipelines().entrySet()) { final Pipeline pipeline = pipelineEntry.getValue(); List newAgents = new ArrayList<>(); - for (AgentConfiguration value : pipeline.getAgents()) { - value.setConfiguration(resolveMap(context, value.getConfiguration())); - value.setInput(resolveConnection(context, value.getInput())); - value.setOutput(resolveConnection(context, value.getOutput())); - newAgents.add(value); + for (AgentConfiguration agent : pipeline.getAgents()) { + agent.setConfiguration(resolveMap(context, agent.getConfiguration())); + agent.setInput(resolveConnection(context, agent.getInput())); + agent.setOutput(resolveConnection(context, agent.getOutput())); + + // Resolve ResourcesSpec for each agent + ResourcesSpec resolvedResources = + resolveResourcesSpec(context, agent.getResources()); + agent.setResources(resolvedResources); + + newAgents.add(agent); } pipeline.setAgents(newAgents); } @@ -132,6 +146,32 @@ private static Map resolveModules( return newModules; } + private static ResourcesSpec resolveResourcesSpec( + Map context, ResourcesSpec resourcesSpec) { + if (resourcesSpec == null) { + return null; + } + + // Resolve parallelism and size. + Integer parallelism = resolveValueAsInteger(context, resourcesSpec.parallelism()); + Integer size = resolveValueAsInteger(context, resourcesSpec.size()); + + return new ResourcesSpec(parallelism, size, resourcesSpec.disk()); + } + + static Integer resolveValueAsInteger(Map context, Object template) { + // If the template is a string, assume we need to resolve it using resolveValueAsString and + // then parse it as an integer + // If it is an integer already, just return it + if (template instanceof String) { + return Integer.parseInt(resolveValueAsString(context, (String) template)); + } else if (template instanceof Integer) { + return (Integer) template; + } else { + return null; + } + } + private static Instance resolveInstance( Application applicationInstance, Map context) { final StreamingCluster newCluster; diff --git a/langstream-core/src/test/java/ai/langstream/impl/common/ApplicationPlaceholderResolverTest.java b/langstream-core/src/test/java/ai/langstream/impl/common/ApplicationPlaceholderResolverTest.java index 7f3eadd34..741caca52 100644 --- a/langstream-core/src/test/java/ai/langstream/impl/common/ApplicationPlaceholderResolverTest.java +++ b/langstream-core/src/test/java/ai/langstream/impl/common/ApplicationPlaceholderResolverTest.java @@ -119,6 +119,9 @@ void testResolveInAgentConfiguration() throws Exception { - name: "${globals.input-topic}" - name: "${globals.output-topic}" - name: "${globals.stream-response-topic}" + resources: + size: "${globals.size}" + parallelism: "${globals.scaled-agents.parallelism}" pipeline: - name: "agent1" id: "agent1" @@ -137,6 +140,9 @@ void testResolveInAgentConfiguration() throws Exception { input-topic: my-input-topic output-topic: my-output-topic stream-response-topic: my-stream-topic + size: 10 + scaled-agents: + parallelism: 5 """, """ secrets: @@ -176,6 +182,8 @@ void testResolveInAgentConfiguration() throws Exception { assertEquals( "my-output-topic", resolved.getModule("module-1").getTopics().get("my-output-topic").getName()); + assertEquals(10, agentConfiguration.getResources().size()); + assertEquals(5, agentConfiguration.getResources().parallelism()); } @Test diff --git a/langstream-k8s-deployer/langstream-k8s-deployer-core/src/main/java/ai/langstream/deployer/k8s/limits/ApplicationResourceLimitsChecker.java b/langstream-k8s-deployer/langstream-k8s-deployer-core/src/main/java/ai/langstream/deployer/k8s/limits/ApplicationResourceLimitsChecker.java index 89ddf814b..ba0e403be 100644 --- a/langstream-k8s-deployer/langstream-k8s-deployer-core/src/main/java/ai/langstream/deployer/k8s/limits/ApplicationResourceLimitsChecker.java +++ b/langstream-k8s-deployer/langstream-k8s-deployer-core/src/main/java/ai/langstream/deployer/k8s/limits/ApplicationResourceLimitsChecker.java @@ -175,8 +175,31 @@ public static int computeRequestedUnits(ApplicationCustomResource applicationCus agent.getResources()); return -1; } - totalUnits += agent.getResources().parallelism() * agent.getResources().size(); + // Resolve parallelism and size to integers + Integer parallelism = resolveObjectToInteger(agent.getResources().parallelism()); + Integer size = resolveObjectToInteger(agent.getResources().size()); + totalUnits += parallelism * size; } return totalUnits; } + + private static Integer resolveObjectToInteger(Object value) { + if (value instanceof Integer) { + return (Integer) value; + } else if (value instanceof String) { + // Let's assume it's always correctly formatted as an integer + // by the time it reaches this point + try { + return Integer.parseInt((String) value); + } catch (NumberFormatException e) { + log.error("Error parsing string to integer: {}", value, e); + return null; + } + } else { + log.error( + "Unsupported type for resource spec value: {}", + value.getClass().getSimpleName()); + return null; + } + } } diff --git a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/KubernetesClusterRuntime.java b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/KubernetesClusterRuntime.java index 09cf10d8d..754244f47 100644 --- a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/KubernetesClusterRuntime.java +++ b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/main/java/ai/langstream/runtime/impl/k8s/KubernetesClusterRuntime.java @@ -338,8 +338,10 @@ private void collectAgentCustomResourceAndSecret( disks = List.of(); } - agentSpec.setResources( - new AgentSpec.Resources(resourcesSpec.parallelism(), resourcesSpec.size())); + Integer resolvedParallelism = resolveObjectToInteger(resourcesSpec.parallelism()); + Integer resolvedSize = resolveObjectToInteger(resourcesSpec.size()); + + agentSpec.setResources(new AgentSpec.Resources(resolvedParallelism, resolvedSize)); agentSpec.serializeAndSetOptions(new AgentSpec.Options(disks)); agentSpec.setAgentConfigSecretRef(secretName); agentSpec.setCodeArchiveId(codeStorageArchiveId); @@ -354,6 +356,26 @@ private void collectAgentCustomResourceAndSecret( secrets.add(secret); } + private static Integer resolveObjectToInteger(Object value) { + if (value instanceof Integer) { + return (Integer) value; + } else if (value instanceof String) { + // Let's assume it's always correctly formatted as an integer + // by the time it reaches this point + try { + return Integer.parseInt((String) value); + } catch (NumberFormatException e) { + log.error("Error parsing string to integer: {}", value, e); + return null; + } + } else { + log.error( + "Unsupported type for resource spec value: {}", + value.getClass().getSimpleName()); + return null; + } + } + private static String bytesToHex(byte[] hash) { StringBuilder hexString = new StringBuilder(2 * hash.length); for (byte b : hash) { diff --git a/langstream-webservice/src/main/java/ai/langstream/webservice/application/ApplicationService.java b/langstream-webservice/src/main/java/ai/langstream/webservice/application/ApplicationService.java index 9b69bb15e..0d9d8fde5 100644 --- a/langstream-webservice/src/main/java/ai/langstream/webservice/application/ApplicationService.java +++ b/langstream-webservice/src/main/java/ai/langstream/webservice/application/ApplicationService.java @@ -123,11 +123,33 @@ private int countRequestedUnits(ExecutionPlan executionPlan) { for (Map.Entry agent : executionPlan.getAgents().entrySet()) { final ResourcesSpec resources = agent.getValue().getResources(); - requestedUnits += resources.size() * resources.parallelism(); + Integer resolvedParallelism = resolveObjectToInteger(resources.parallelism()); + Integer resolvedSize = resolveObjectToInteger(resources.size()); + requestedUnits += resolvedSize * resolvedParallelism; } return requestedUnits; } + private static Integer resolveObjectToInteger(Object value) { + if (value instanceof Integer) { + return (Integer) value; + } else if (value instanceof String) { + // Let's assume it's always correctly formatted as an integer + // by the time it reaches this point + try { + return Integer.parseInt((String) value); + } catch (NumberFormatException e) { + log.error("Error parsing string to integer: {}", value, e); + return null; + } + } else { + log.error( + "Unsupported type for resource spec value: {}", + value.getClass().getSimpleName()); + return null; + } + } + private void validateApplicationModel(Application application) { validateGateways(application); }