From 8143037d12498bfac3f8ffa8e4684e5ce979b77b Mon Sep 17 00:00:00 2001 From: Ronny Gunawan <3048897+ronnygunawan@users.noreply.github.com> Date: Tue, 5 Dec 2023 23:47:19 +0700 Subject: [PATCH] Allow continuing conversation by replying to generated image --- BotNet.Services/BotCommands/Art.cs | 17 ++++++++++--- BotNet.Services/BotCommands/OpenAI.cs | 34 ++++++++++++++++++++----- BotNet.Services/OpenAI/ThreadTracker.cs | 20 +++++++++++---- 3 files changed, 56 insertions(+), 15 deletions(-) diff --git a/BotNet.Services/BotCommands/Art.cs b/BotNet.Services/BotCommands/Art.cs index 68b8740..b92efa6 100644 --- a/BotNet.Services/BotCommands/Art.cs +++ b/BotNet.Services/BotCommands/Art.cs @@ -3,6 +3,8 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using BotNet.Services.OpenAI.Models; +using BotNet.Services.OpenAI; using BotNet.Services.RateLimit; using BotNet.Services.Stability.Models; using BotNet.Services.ThisXDoesNotExist; @@ -33,8 +35,8 @@ public static async Task GetRandomArtAsync(ITelegramBotClient botClient, IServic ); try { - byte[] image = await serviceProvider.GetRequiredService().GenerateImageAsync(commandArgument, CancellationToken.None); - using MemoryStream imageStream = new(image); + byte[] generatedImage = await serviceProvider.GetRequiredService().GenerateImageAsync(commandArgument, CancellationToken.None); + using MemoryStream imageStream = new(generatedImage); try { await botClient.DeleteMessageAsync( @@ -46,11 +48,20 @@ await botClient.DeleteMessageAsync( throw; } - await botClient.SendPhotoAsync( + Message generatedImageMessage = await botClient.SendPhotoAsync( chatId: message.Chat.Id, photo: new InputFileStream(imageStream, "art.png"), replyToMessageId: message.MessageId, cancellationToken: cancellationToken); + + // Track generated image for continuation + serviceProvider.GetRequiredService().TrackMessage( + messageId: generatedImageMessage.MessageId, + sender: "AI", + text: null, + imageBase64: Convert.ToBase64String(generatedImage), + replyToMessageId: message.MessageId + ); } catch (ContentFilteredException exc) { await botClient.EditMessageTextAsync( chatId: message.Chat.Id, diff --git a/BotNet.Services/BotCommands/OpenAI.cs b/BotNet.Services/BotCommands/OpenAI.cs index 0256f38..be2ee57 100644 --- a/BotNet.Services/BotCommands/OpenAI.cs +++ b/BotNet.Services/BotCommands/OpenAI.cs @@ -865,12 +865,22 @@ await botClient.DeleteMessageAsync( } catch (OperationCanceledException) { throw; } - await botClient.SendPhotoAsync( + + Message generatedImageMessage = await botClient.SendPhotoAsync( chatId: message.Chat.Id, photo: new InputFileStream(generatedImageStream, "art.png"), replyToMessageId: message.MessageId, cancellationToken: cancellationToken ); + + // Track generated image for continuation + serviceProvider.GetRequiredService().TrackMessage( + messageId: generatedImageMessage.MessageId, + sender: "AI", + text: null, + imageBase64: Convert.ToBase64String(generatedImage), + replyToMessageId: message.MessageId + ); } catch (ContentFilteredException exc) { await botClient.EditMessageTextAsync( chatId: busyMessage.Chat.Id, @@ -939,12 +949,22 @@ public static async Task StreamChatWithFriendlyBotAsync(ITelegramBotClient botCl ? CHAT_PRIVATE_RATE_LIMITER : CHAT_GROUP_RATE_LIMITER ).ValidateActionRate(message.Chat.Id, message.From!.Id); - await serviceProvider.GetRequiredService().StreamChatAsync( - message: message.Text!, - thread: thread, - chatId: message.Chat.Id, - replyToMessageId: message.MessageId - ); + + if (thread.FirstOrDefault().ImageBase64 is { } imageBase64) { + await serviceProvider.GetRequiredService().StreamChatAsync( + message: message.Text!, + imageBase64: imageBase64, + chatId: message.Chat.Id, + replyToMessageId: message.MessageId + ); + } else { + await serviceProvider.GetRequiredService().StreamChatAsync( + message: message.Text!, + thread: thread, + chatId: message.Chat.Id, + replyToMessageId: message.MessageId + ); + } } catch (RateLimitExceededException exc) when (exc is { Cooldown: var cooldown }) { if (message.Chat.Type == ChatType.Private) { await botClient.SendTextMessageAsync( diff --git a/BotNet.Services/OpenAI/ThreadTracker.cs b/BotNet.Services/OpenAI/ThreadTracker.cs index 569249b..03cf046 100644 --- a/BotNet.Services/OpenAI/ThreadTracker.cs +++ b/BotNet.Services/OpenAI/ThreadTracker.cs @@ -33,15 +33,25 @@ public void TrackMessage( long messageId, int maxLines ) { + bool firstLine = true; while (_memoryCache.TryGetValue( key: new MessageId(messageId), value: out Message? message ) && message != null && maxLines-- > 0) { - yield return ( - Sender: message.Sender, - Text: message.Text, - ImageBase64: message.ImageBase64 - ); + if (firstLine) { + yield return ( + Sender: message.Sender, + Text: message.Text, + ImageBase64: message.ImageBase64 + ); + firstLine = false; + } else { + yield return ( + Sender: message.Sender, + Text: message.Text, + ImageBase64: null // Strip images from the rest of thread + ); + } if (message.ReplyToMessageId == null) { yield break;