Skip to content

Commit

Permalink
#5211 - AI assistant prototype
Browse files Browse the repository at this point in the history
- Fix displaying indexing progress on the bulk processing page
- Added re-index button to assistant sidebar
- Normalize embeddings and use dot product instead of cosine when indexing/searching
- Allow suspending background tasks for a project
- During project import and project initialization from a template, suspend background tasks
- During indexing, embed in batches
- Fix timing logging for embeddings
- Allow auto-detecting the embedding dimensions
- Allow using embeddings with a dimension higher than 1024
  • Loading branch information
reckart committed Jan 2, 2025
1 parent f9eb347 commit 7c483b4
Show file tree
Hide file tree
Showing 56 changed files with 948 additions and 235 deletions.
4 changes: 4 additions & 0 deletions inception/inception-assistant/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-collections4</artifactId>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

public interface AssistantService
{
List<MAssistantMessage> listMessages(String aSessionOwner, Project aProject);
List<MAssistantMessage> getConversationMessages(String aSessionOwner, Project aProject);

void processUserMessage(String aSessionOwner, Project aProject,
MAssistantMessage aMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@
import org.springframework.security.core.session.SessionDestroyedEvent;
import org.springframework.security.core.session.SessionRegistry;

import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.EncodingRegistry;
import com.knuddels.jtokkit.api.EncodingType;

import de.tudarmstadt.ukp.clarin.webanno.model.Project;
import de.tudarmstadt.ukp.clarin.webanno.security.model.User;
Expand Down Expand Up @@ -77,13 +75,13 @@ public class AssistantServiceImpl
private final AssistantProperties properties;
private final UserGuideQueryService documentationIndexingService;
private final DocumentQueryService documentQueryService;

private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry();
private final EncodingRegistry encodingRegistry;

public AssistantServiceImpl(SessionRegistry aSessionRegistry,
SimpMessagingTemplate aMsgTemplate, OllamaClient aOllamaClient,
AssistantProperties aProperties, UserGuideQueryService aDocumentationIndexingService,
DocumentQueryService aDocumentQueryService)
DocumentQueryService aDocumentQueryService,
EncodingRegistry aEncodingRegistry)
{
sessionRegistry = aSessionRegistry;
msgTemplate = aMsgTemplate;
Expand All @@ -92,6 +90,7 @@ public AssistantServiceImpl(SessionRegistry aSessionRegistry,
properties = aProperties;
documentationIndexingService = aDocumentationIndexingService;
documentQueryService = aDocumentQueryService;
encodingRegistry = aEncodingRegistry;
}

// Set order so this is handled before session info is removed from sessionRegistry
Expand Down Expand Up @@ -132,7 +131,7 @@ public void onAfterProjectRemoved(AfterProjectRemovedEvent aEvent)
}

@Override
public List<MAssistantMessage> listMessages(String aSessionOwner, Project aProject)
public List<MAssistantMessage> getConversationMessages(String aSessionOwner, Project aProject)
{
var state = getState(aSessionOwner, aProject);
return state.getMessages();
Expand Down Expand Up @@ -161,14 +160,17 @@ public void processUserMessage(String aSessionOwner, Project aProject,
try {
var systemMessages = generateSystemMessages(aSessionOwner, aProject, aMessage);
var transientMessages = generateTransientMessages(aSessionOwner, aProject, aMessage);
var recentMessages = listMessages(aSessionOwner, aProject);
var recentMessages = getConversationMessages(aSessionOwner, aProject);

// We record the message only now to ensure it is not included in the listMessages above
recordMessage(aSessionOwner, aProject, aMessage);

// For testing purposes we send this message to the UI
for (var msg : transientMessages) {
dispatchMessage(aSessionOwner, aProject, msg);
if (properties.isDevMode()) {
// For testing purposes we send this message to the UI but do not record it as
// part of the conversation
for (var msg : transientMessages) {
dispatchMessage(aSessionOwner, aProject, msg);
}
}

var conversation = limitConversationToContextLength(systemMessages, transientMessages,
Expand Down Expand Up @@ -324,7 +326,9 @@ private List<MAssistantMessage> limitConversationToContextLength(
// the tokenizer we use counts fewer tokens than the one user by
// the model and also to cover for message encoding JSON overhead,
// we try to use only 90% of the context window.
var encoding = registry.getEncoding(EncodingType.CL100K_BASE);
var encoding = encodingRegistry.getEncoding(properties.getChat().getEncoding())
.orElseThrow(() -> new IllegalStateException(
"Unknown encoding: " + properties.getChat().getEncoding()));
var limit = floorDiv(aContextLength * 90, 100);

var headMessages = new ArrayList<MAssistantMessage>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public List<MAssistantMessage> onSubscribeToAssistantMessages(SimpMessageHeaderA
throws IOException
{
var project = projectService.getProject(aProjectId);
return assistantService.listMessages(aPrincipal.getName(), project);
return assistantService.getConversationMessages(aPrincipal.getName(), project);
}

@MessageMapping(PROJECT_ASSISTANT_TOPIC_TEMPLATE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.security.core.session.SessionRegistry;

import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.EncodingRegistry;

import de.tudarmstadt.ukp.inception.assistant.AssistantService;
import de.tudarmstadt.ukp.inception.assistant.AssistantServiceImpl;
import de.tudarmstadt.ukp.inception.assistant.index.DocumentQueryService;
import de.tudarmstadt.ukp.inception.assistant.index.DocumentQueryServiceImpl;
import de.tudarmstadt.ukp.inception.assistant.index.EmbeddingService;
import de.tudarmstadt.ukp.inception.assistant.index.EmbeddingServiceImpl;
import de.tudarmstadt.ukp.inception.assistant.sidebar.AssistantSidebarFactory;
import de.tudarmstadt.ukp.inception.assistant.userguide.UserGuideQueryService;
import de.tudarmstadt.ukp.inception.assistant.userguide.UserGuideQueryServiceImpl;
Expand All @@ -47,10 +52,10 @@ public class AssistantAutoConfiguration
public AssistantService assistantService(SessionRegistry aSessionRegistry,
SimpMessagingTemplate aMsgTemplate, OllamaClient aOllamaClient,
AssistantProperties aProperties, UserGuideQueryService aDocumentationIndexingService,
DocumentQueryService aDocumentQueryService)
DocumentQueryService aDocumentQueryService, EncodingRegistry aEncodingRegistry)
{
return new AssistantServiceImpl(aSessionRegistry, aMsgTemplate, aOllamaClient, aProperties,
aDocumentationIndexingService, aDocumentQueryService);
aDocumentationIndexingService, aDocumentQueryService, aEncodingRegistry);
}

@Bean
Expand All @@ -61,18 +66,28 @@ public AssistantSidebarFactory assistantSidebarFactory()

@Bean
public UserGuideQueryService userManualQueryService(AssistantProperties aProperties,
SchedulingService aSchedulingService, OllamaClient aOllamaClient)
SchedulingService aSchedulingService, EmbeddingService aEmbeddingService)
{
return new UserGuideQueryServiceImpl(aProperties, aSchedulingService, aOllamaClient);
return new UserGuideQueryServiceImpl(aProperties, aSchedulingService, aEmbeddingService);
}

@Bean
public EncodingRegistry encodingRegistry() {
return Encodings.newLazyEncodingRegistry();
}

@Bean
public EmbeddingService EmbeddingService(AssistantProperties aProperties, OllamaClient aOllamaClient) {
return new EmbeddingServiceImpl(aProperties, aOllamaClient);
}

@Bean
public DocumentQueryService documentQueryService(AssistantProperties aProperties,
RepositoryProperties aRepositoryProperties,
AssistantDocumentIndexProperties aIndexProperties, SchedulingService aSchedulingService,
OllamaClient aOllamaClient)
OllamaClient aOllamaClient, EmbeddingService aEmbeddingService)
{
return new DocumentQueryServiceImpl(aProperties, aRepositoryProperties, aIndexProperties,
aSchedulingService, aOllamaClient);
aSchedulingService, aEmbeddingService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ public interface AssistantChatProperties
double getTemperature();

int getContextLength();

String getEncoding();
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

public interface AssistantEmbeddingProperties
{
public static final int AUTO_DETECT_DIMENSION = -1;

String getModel();

double getTopP();
Expand All @@ -30,4 +32,14 @@ public interface AssistantEmbeddingProperties
double getTemperature();

int getSeed();

int getContextLength();

int getBatchSize();

String getEncoding();

int getDimension();

void setDimension(int aI);
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ public interface AssistantProperties
AssistantChatProperties getChat();

AssistantEmbeddingProperties getEmbedding();

boolean isDevMode();
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package de.tudarmstadt.ukp.inception.assistant.config;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.context.properties.ConfigurationProperties;

@ConfigurationProperties("assistant")
Expand All @@ -32,6 +33,20 @@ public class AssistantPropertiesImpl
private AssistantChatProperties chat = new AssistantChatPropertiesImpl();
private AssistantEmbeddingProperties embedding = new AssistantEmbeddingPropertiesImpl();

@Value("${inception.dev:false}") // Inject system property or use default if not provided
private boolean devMode;

@Override
public boolean isDevMode()
{
return devMode;
}

public void setDevMode(boolean aDevMode)
{
devMode = aDevMode;
}

@Override
public String getUrl()
{
Expand Down Expand Up @@ -108,6 +123,7 @@ public static class AssistantChatPropertiesImpl
private double repeatPenalty = 1.1;
private double temperature = 0.1;
private int contextLength = 4096;
private String encoding = "cl100k_base";

@Override
public String getModel()
Expand Down Expand Up @@ -174,6 +190,17 @@ public void setContextLength(int aContextLength)
{
contextLength = aContextLength;
}

@Override
public String getEncoding()
{
return encoding;
}

public void setEncoding(String aEncoding)
{
encoding = aEncoding;
}
}

public static class AssistantEmbeddingPropertiesImpl
Expand All @@ -186,6 +213,10 @@ public static class AssistantEmbeddingPropertiesImpl
private int topK = 1000;
private double repeatPenalty = 1.0;
private double temperature = 0.0;
private int contextLength = 768;
private int batchSize = 16;
private String encoding = "cl100k_base";
private int dimension = AUTO_DETECT_DIMENSION;

@Override
public String getModel()
Expand Down Expand Up @@ -252,5 +283,49 @@ public void setSeed(int aSeed)
{
seed = aSeed;
}

@Override
public int getContextLength()
{
return contextLength;
}

public void setContextLength(int aContextLength)
{
contextLength = aContextLength;
}

@Override
public int getBatchSize()
{
return batchSize;
}

public void setBatchSize(int aBatchSize)
{
batchSize = aBatchSize;
}

@Override
public String getEncoding()
{
return encoding;
}

public void setEncoding(String aEncoding)
{
encoding = aEncoding;
}

@Override
public int getDimension()
{
return dimension;
}

public void setDimension(int aDimension)
{
dimension = aDimension;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ public interface DocumentQueryService
PooledIndex borrowIndex(Project aProject) throws Exception;

List<String> query(Project aProject, String aQuery, int aTopN, double aScoreThreshold);

void rebuildIndexAsync(Project aProject);
}
Loading

0 comments on commit 7c483b4

Please sign in to comment.