Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
Allow resource size and parallelism to be specified as variables; fix…
Browse files Browse the repository at this point in the history
…es for unit tests
  • Loading branch information
cdbartholomew committed Mar 21, 2024
1 parent 2d693d3 commit 9a85746
Show file tree
Hide file tree
Showing 12 changed files with 281 additions and 24 deletions.
16 changes: 16 additions & 0 deletions examples/applications/compute-openai-embeddings/gateways.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
#
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

gateways:
- id: "input"
type: produce
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ pipeline:
type: "python-source"
output: "output-topic"
configuration:
className: example.Exclamation
className: example.Exclamation
dbname: "example"
10 changes: 10 additions & 0 deletions langstream-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,15 @@
<artifactId>jackson-annotations</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<scope>provided</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,51 @@
package ai.langstream.api.model;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import java.util.Map;

@JsonInclude(JsonInclude.Include.NON_NULL)
/** Definition of the resources required by the agent. */
public record ResourcesSpec(Integer parallelism, Integer size, DiskSpec disk) {
@JsonDeserialize(using = ResourcesSpecDeserializer.class) // Use custom deserializer
public record ResourcesSpec(Object parallelism, Object size, DiskSpec disk) {

public static ResourcesSpec DEFAULT = new ResourcesSpec(1, 1, null);
public static final ResourcesSpec DEFAULT = new ResourcesSpec(1, 1, null);

// withDefaultsFrom method
public ResourcesSpec withDefaultsFrom(ResourcesSpec higherLevel) {
if (higherLevel == null) {
return this;
}
Integer newParallelism = parallelism == null ? higherLevel.parallelism() : parallelism;
Integer newUnits = size == null ? higherLevel.size() : size;
DiskSpec newDisk =
disk == null ? higherLevel.disk() : disk.withDefaultsFrom(higherLevel.disk);
return new ResourcesSpec(newParallelism, newUnits, newDisk);
Object newParallelism = parallelism == null ? higherLevel.parallelism() : parallelism;
Object newSize = size == null ? higherLevel.size() : size;
DiskSpec newDisk = disk == null ? higherLevel.disk() : disk;

return new ResourcesSpec(newParallelism, newSize, newDisk);
}

// resolveVariables method
public ResourcesSpec resolveVariables(Map<String, Integer> variableMap) {
Integer resolvedParallelism = resolveToObject(parallelism, variableMap);
Integer resolvedSize = resolveToObject(size, variableMap);
return new ResourcesSpec(resolvedParallelism, resolvedSize, disk);
}

// Helper method to resolve objects or strings to Integer
private static Integer resolveToObject(Object value, Map<String, Integer> variableMap) {
if (value instanceof String) {
String stringValue = (String) value;
if (stringValue.startsWith("${") && stringValue.endsWith("}")) {
String variableName = stringValue.substring(2, stringValue.length() - 1);
return variableMap.getOrDefault(variableName, null);
}
try {
return Integer.parseInt(stringValue);
} catch (NumberFormatException e) {
System.err.println("Error parsing integer: " + e.getMessage());
return null;
}
} else if (value instanceof Integer) {
return (Integer) value;
}
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.api.model;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import java.io.IOException;

public class ResourcesSpecDeserializer extends StdDeserializer<ResourcesSpec> {

public ResourcesSpecDeserializer() {
super(ResourcesSpec.class);
}

@Override
public ResourcesSpec deserialize(JsonParser p, DeserializationContext ctxt)
throws IOException, JsonProcessingException {
JsonNode node = p.getCodec().readTree(p);
Object parallelism = parseField(node, "parallelism");
Object size = parseField(node, "size");
DiskSpec disk = null; // Allow disk to be null

if (node.has("disk") && !node.get("disk").isNull()) {
disk = parseDiskSpec(node.get("disk"), ctxt, p);
}

return new ResourcesSpec(parallelism, size, disk);
}

private Object parseField(JsonNode node, String fieldName) {
if (node == null || !node.has(fieldName)) {
return null;
}
JsonNode fieldNode = node.get(fieldName);
if (fieldNode.isInt()) {
return fieldNode.asInt();
} else if (fieldNode.isTextual()) {
String textValue = fieldNode.asText();
if (textValue.matches("\\$\\{.*\\}")) {
return textValue;
}
try {
return Integer.parseInt(textValue);
} catch (NumberFormatException e) {
throw new IllegalArgumentException(
"Value for '"
+ fieldName
+ "' must be an integer or a string representing an integer.");
}
}
throw new IllegalArgumentException("Unsupported JSON type for field '" + fieldName + "'.");
}

private DiskSpec parseDiskSpec(JsonNode diskNode, DeserializationContext ctxt, JsonParser p)
throws IOException {
Boolean enabled = diskNode.has("enabled") ? diskNode.get("enabled").asBoolean() : null;
String type = diskNode.has("type") ? diskNode.get("type").asText() : null;
String size = diskNode.has("size") ? diskNode.get("size").asText() : null;
return new DiskSpec(enabled, type, size);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void testArgs() throws Exception {
log.info("Last line: {}", lastLine);
assertTrue(
lastLine.contains(
"run --rm -i -e START_BROKER=true -e START_MINIO=true -e START_HERDDB=true "
"run --rm -i -e START_BROKER=true -e USE_PULSAR=false -e START_MINIO=true -e START_HERDDB=true "
+ "-e LANSGSTREAM_TESTER_TENANT=default -e LANSGSTREAM_TESTER_APPLICATIONID=my-app "
+ "-e LANSGSTREAM_TESTER_STARTWEBSERVICES=true -e LANSGSTREAM_TESTER_DRYRUN=false "));
assertTrue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,19 @@ && compareResourcesNoDisk(
}

private static boolean compareResourcesNoDisk(ResourcesSpec a, ResourcesSpec b) {
Integer parallismA = a != null ? a.parallelism() : null;
Integer parallismB = b != null ? b.parallelism() : null;
Integer sizeA = a != null ? a.size() : null;
Integer sizeB = b != null ? b.size() : null;
return Objects.equals(parallismA, parallismB) && Objects.equals(sizeA, sizeB);
Object parallelismA = a != null ? a.parallelism() : null;
Object parallelismB = b != null ? b.parallelism() : null;
Object sizeA = a != null ? a.size() : null;
Object sizeB = b != null ? b.size() : null;

// Convert to String for comparison to handle both Integer and String types
String parallelismAStr = parallelismA != null ? parallelismA.toString() : null;
String parallelismBStr = parallelismB != null ? parallelismB.toString() : null;
String sizeAStr = sizeA != null ? sizeA.toString() : null;
String sizeBStr = sizeB != null ? sizeB.toString() : null;

return Objects.equals(parallelismAStr, parallelismBStr)
&& Objects.equals(sizeAStr, sizeBStr);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import ai.langstream.api.model.Module;
import ai.langstream.api.model.Pipeline;
import ai.langstream.api.model.Resource;
import ai.langstream.api.model.ResourcesSpec;
import ai.langstream.api.model.StreamingCluster;
import ai.langstream.api.model.TopicDefinition;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -58,6 +59,7 @@ private ApplicationPlaceholderResolver() {}
@SneakyThrows
public static Application resolvePlaceholders(Application instance) {
instance = deepCopy(instance);
log.debug("Resolving placeholders in application: {}", instance);
final Map<String, Object> context = createContext(instance);
if (log.isDebugEnabled()) {
log.debug(
Expand All @@ -66,8 +68,14 @@ public static Application resolvePlaceholders(Application instance) {
if (log.isDebugEnabled()) {
log.debug("Resolve context: {}", context);
}
Instance resolvedInstance = resolveInstance(instance, context);
log.debug("Resolved instance: {}", resolvedInstance);

Map<String, Module> resolvedModule = resolveModules(instance, context);
log.debug("Resolved modules: {}", resolvedModule);

instance.setInstance(resolveInstance(instance, context));

instance.setResources(resolveResources(instance, context));
instance.setModules(resolveModules(instance, context));
instance.setGateways(resolveGateways(instance, context));
Expand Down Expand Up @@ -119,11 +127,17 @@ private static Map<String, Module> resolveModules(
for (Map.Entry<String, Pipeline> pipelineEntry : module.getPipelines().entrySet()) {
final Pipeline pipeline = pipelineEntry.getValue();
List<AgentConfiguration> newAgents = new ArrayList<>();
for (AgentConfiguration value : pipeline.getAgents()) {
value.setConfiguration(resolveMap(context, value.getConfiguration()));
value.setInput(resolveConnection(context, value.getInput()));
value.setOutput(resolveConnection(context, value.getOutput()));
newAgents.add(value);
for (AgentConfiguration agent : pipeline.getAgents()) {
agent.setConfiguration(resolveMap(context, agent.getConfiguration()));
agent.setInput(resolveConnection(context, agent.getInput()));
agent.setOutput(resolveConnection(context, agent.getOutput()));

// Resolve ResourcesSpec for each agent
ResourcesSpec resolvedResources =
resolveResourcesSpec(context, agent.getResources());
agent.setResources(resolvedResources);

newAgents.add(agent);
}
pipeline.setAgents(newAgents);
}
Expand All @@ -132,6 +146,32 @@ private static Map<String, Module> resolveModules(
return newModules;
}

private static ResourcesSpec resolveResourcesSpec(
Map<String, Object> context, ResourcesSpec resourcesSpec) {
if (resourcesSpec == null) {
return null;
}

// Resolve parallelism and size.
Integer parallelism = resolveValueAsInteger(context, resourcesSpec.parallelism());
Integer size = resolveValueAsInteger(context, resourcesSpec.size());

return new ResourcesSpec(parallelism, size, resourcesSpec.disk());
}

static Integer resolveValueAsInteger(Map<String, Object> context, Object template) {
// If the template is a string, assume we need to resolve it using resolveValueAsString and
// then parse it as an integer
// If it is an integer already, just return it
if (template instanceof String) {
return Integer.parseInt(resolveValueAsString(context, (String) template));
} else if (template instanceof Integer) {
return (Integer) template;
} else {
return null;
}
}

private static Instance resolveInstance(
Application applicationInstance, Map<String, Object> context) {
final StreamingCluster newCluster;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ void testResolveInAgentConfiguration() throws Exception {
- name: "${globals.input-topic}"
- name: "${globals.output-topic}"
- name: "${globals.stream-response-topic}"
resources:
size: "${globals.size}"
parallelism: "${globals.scaled-agents.parallelism}"
pipeline:
- name: "agent1"
id: "agent1"
Expand All @@ -137,6 +140,9 @@ void testResolveInAgentConfiguration() throws Exception {
input-topic: my-input-topic
output-topic: my-output-topic
stream-response-topic: my-stream-topic
size: 10
scaled-agents:
parallelism: 5
""",
"""
secrets:
Expand Down Expand Up @@ -176,6 +182,8 @@ void testResolveInAgentConfiguration() throws Exception {
assertEquals(
"my-output-topic",
resolved.getModule("module-1").getTopics().get("my-output-topic").getName());
assertEquals(10, agentConfiguration.getResources().size());
assertEquals(5, agentConfiguration.getResources().parallelism());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,31 @@ public static int computeRequestedUnits(ApplicationCustomResource applicationCus
agent.getResources());
return -1;
}
totalUnits += agent.getResources().parallelism() * agent.getResources().size();
// Resolve parallelism and size to integers
Integer parallelism = resolveObjectToInteger(agent.getResources().parallelism());
Integer size = resolveObjectToInteger(agent.getResources().size());
totalUnits += parallelism * size;
}
return totalUnits;
}

private static Integer resolveObjectToInteger(Object value) {
if (value instanceof Integer) {
return (Integer) value;
} else if (value instanceof String) {
// Let's assume it's always correctly formatted as an integer
// by the time it reaches this point
try {
return Integer.parseInt((String) value);
} catch (NumberFormatException e) {
log.error("Error parsing string to integer: {}", value, e);
return null;
}
} else {
log.error(
"Unsupported type for resource spec value: {}",
value.getClass().getSimpleName());
return null;
}
}
}
Loading

0 comments on commit 9a85746

Please sign in to comment.