Skip to content

Commit

Permalink
Allow continuing conversation by replying to generated image
Browse files Browse the repository at this point in the history
  • Loading branch information
ronnygunawan committed Dec 5, 2023
1 parent ed23b81 commit 8143037
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 15 deletions.
17 changes: 14 additions & 3 deletions BotNet.Services/BotCommands/Art.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using BotNet.Services.OpenAI.Models;
using BotNet.Services.OpenAI;
using BotNet.Services.RateLimit;
using BotNet.Services.Stability.Models;
using BotNet.Services.ThisXDoesNotExist;
Expand Down Expand Up @@ -33,8 +35,8 @@ public static async Task GetRandomArtAsync(ITelegramBotClient botClient, IServic
);

try {
byte[] image = await serviceProvider.GetRequiredService<Stability.Skills.ImageGenerationBot>().GenerateImageAsync(commandArgument, CancellationToken.None);
using MemoryStream imageStream = new(image);
byte[] generatedImage = await serviceProvider.GetRequiredService<Stability.Skills.ImageGenerationBot>().GenerateImageAsync(commandArgument, CancellationToken.None);
using MemoryStream imageStream = new(generatedImage);

try {
await botClient.DeleteMessageAsync(
Expand All @@ -46,11 +48,20 @@ await botClient.DeleteMessageAsync(
throw;
}

await botClient.SendPhotoAsync(
Message generatedImageMessage = await botClient.SendPhotoAsync(
chatId: message.Chat.Id,
photo: new InputFileStream(imageStream, "art.png"),
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken);

// Track generated image for continuation
serviceProvider.GetRequiredService<ThreadTracker>().TrackMessage(
messageId: generatedImageMessage.MessageId,
sender: "AI",
text: null,
imageBase64: Convert.ToBase64String(generatedImage),
replyToMessageId: message.MessageId
);
} catch (ContentFilteredException exc) {
await botClient.EditMessageTextAsync(
chatId: message.Chat.Id,
Expand Down
34 changes: 27 additions & 7 deletions BotNet.Services/BotCommands/OpenAI.cs
Original file line number Diff line number Diff line change
Expand Up @@ -865,12 +865,22 @@ await botClient.DeleteMessageAsync(
} catch (OperationCanceledException) {
throw;
}
await botClient.SendPhotoAsync(

Message generatedImageMessage = await botClient.SendPhotoAsync(
chatId: message.Chat.Id,
photo: new InputFileStream(generatedImageStream, "art.png"),
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken
);

// Track generated image for continuation
serviceProvider.GetRequiredService<ThreadTracker>().TrackMessage(
messageId: generatedImageMessage.MessageId,
sender: "AI",
text: null,
imageBase64: Convert.ToBase64String(generatedImage),
replyToMessageId: message.MessageId
);
} catch (ContentFilteredException exc) {
await botClient.EditMessageTextAsync(
chatId: busyMessage.Chat.Id,
Expand Down Expand Up @@ -939,12 +949,22 @@ public static async Task StreamChatWithFriendlyBotAsync(ITelegramBotClient botCl
? CHAT_PRIVATE_RATE_LIMITER
: CHAT_GROUP_RATE_LIMITER
).ValidateActionRate(message.Chat.Id, message.From!.Id);
await serviceProvider.GetRequiredService<FriendlyBot>().StreamChatAsync(
message: message.Text!,
thread: thread,
chatId: message.Chat.Id,
replyToMessageId: message.MessageId
);

if (thread.FirstOrDefault().ImageBase64 is { } imageBase64) {
await serviceProvider.GetRequiredService<VisionBot>().StreamChatAsync(
message: message.Text!,
imageBase64: imageBase64,
chatId: message.Chat.Id,
replyToMessageId: message.MessageId
);
} else {
await serviceProvider.GetRequiredService<FriendlyBot>().StreamChatAsync(
message: message.Text!,
thread: thread,
chatId: message.Chat.Id,
replyToMessageId: message.MessageId
);
}
} catch (RateLimitExceededException exc) when (exc is { Cooldown: var cooldown }) {
if (message.Chat.Type == ChatType.Private) {
await botClient.SendTextMessageAsync(
Expand Down
20 changes: 15 additions & 5 deletions BotNet.Services/OpenAI/ThreadTracker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,25 @@ public void TrackMessage(
long messageId,
int maxLines
) {
bool firstLine = true;
while (_memoryCache.TryGetValue<Message>(
key: new MessageId(messageId),
value: out Message? message
) && message != null && maxLines-- > 0) {
yield return (
Sender: message.Sender,
Text: message.Text,
ImageBase64: message.ImageBase64
);
if (firstLine) {
yield return (
Sender: message.Sender,
Text: message.Text,
ImageBase64: message.ImageBase64
);
firstLine = false;
} else {
yield return (
Sender: message.Sender,
Text: message.Text,
ImageBase64: null // Strip images from the rest of thread
);
}

if (message.ReplyToMessageId == null) {
yield break;
Expand Down

0 comments on commit 8143037

Please sign in to comment.