Skip to content

Commit

Permalink
Merge pull request #79 from teknologi-umum/ai-overhaul
Browse files Browse the repository at this point in the history
Use Stability SDXL 1.0 for image generation
  • Loading branch information
ronnygunawan authored Dec 4, 2023
2 parents 6cede0c + a5ad180 commit 6a665f7
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 690 deletions.
103 changes: 60 additions & 43 deletions BotNet.Services/BotCommands/Art.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
using System.Threading;
using System.Threading.Tasks;
using BotNet.Services.RateLimit;
using BotNet.Services.Stability;
using BotNet.Services.ThisXDoesNotExist;
using Microsoft.Extensions.DependencyInjection;
using Telegram.Bot;
Expand All @@ -24,13 +23,31 @@ public static async Task GetRandomArtAsync(ITelegramBotClient botClient, IServic
try {
GENERATED_ART_RATE_LIMITER.ValidateActionRate(message.Chat.Id, message.From!.Id);

Message busyMessage = await botClient.SendTextMessageAsync(
chatId: message.Chat.Id,
text: "Generating image… ⏳",
parseMode: ParseMode.Markdown,
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken
);

try {
byte[] image = await serviceProvider.GetRequiredService<StabilityClient>().GenerateImageAsync(commandArgument, CancellationToken.None);
byte[] image = await serviceProvider.GetRequiredService<Stability.Skills.ImageGenerationBot>().GenerateImageAsync(commandArgument, CancellationToken.None);
using MemoryStream imageStream = new(image);

try {
await botClient.DeleteMessageAsync(
chatId: busyMessage.Chat.Id,
messageId: busyMessage.MessageId,
cancellationToken: cancellationToken
);
} catch (OperationCanceledException) {
throw;
}

await botClient.SendPhotoAsync(
chatId: message.Chat.Id,
photo: new InputFileStream(imageStream, "art.jpg"),
photo: new InputFileStream(imageStream, "art.png"),
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken);
} catch {
Expand Down Expand Up @@ -72,48 +89,48 @@ await botClient.SendTextMessageAsync(
}
}

public static async Task ModifyArtAsync(ITelegramBotClient botClient, IServiceProvider serviceProvider, Message message, string textPrompt, CancellationToken cancellationToken) {
if (message.ReplyToMessage is { } replyToMessage) {
using MemoryStream originalImageStream = new();
Telegram.Bot.Types.File fileInfo = message.ReplyToMessage.Photo?.Length > 0
? await botClient.GetInfoAndDownloadFileAsync(
fileId: message.ReplyToMessage.Photo.OrderByDescending(photoSize => photoSize.Width).First().FileId,
destination: originalImageStream,
cancellationToken: cancellationToken)
: await botClient.GetInfoAndDownloadFileAsync(
fileId: message.ReplyToMessage.Sticker!.FileId,
destination: originalImageStream,
cancellationToken: cancellationToken);
//public static async Task ModifyArtAsync(ITelegramBotClient botClient, IServiceProvider serviceProvider, Message message, string textPrompt, CancellationToken cancellationToken) {
// if (message.ReplyToMessage is { } replyToMessage) {
// using MemoryStream originalImageStream = new();
// Telegram.Bot.Types.File fileInfo = message.ReplyToMessage.Photo?.Length > 0
// ? await botClient.GetInfoAndDownloadFileAsync(
// fileId: message.ReplyToMessage.Photo.OrderByDescending(photoSize => photoSize.Width).First().FileId,
// destination: originalImageStream,
// cancellationToken: cancellationToken)
// : await botClient.GetInfoAndDownloadFileAsync(
// fileId: message.ReplyToMessage.Sticker!.FileId,
// destination: originalImageStream,
// cancellationToken: cancellationToken);

try {
MODIFY_ART_RATE_LIMITER.ValidateActionRate(message.Chat.Id, message.From!.Id);
// try {
// MODIFY_ART_RATE_LIMITER.ValidateActionRate(message.Chat.Id, message.From!.Id);

try {
byte[] image = await serviceProvider.GetRequiredService<StabilityClient>().ModifyImageAsync(originalImageStream.ToArray(), textPrompt, CancellationToken.None);
using MemoryStream imageStream = new(image);
// try {
// byte[] image = await serviceProvider.GetRequiredService<StabilityClient>().ModifyImageAsync(originalImageStream.ToArray(), textPrompt, CancellationToken.None);
// using MemoryStream imageStream = new(image);

await botClient.SendPhotoAsync(
chatId: message.Chat.Id,
photo: new InputFileStream(imageStream, "art.jpg"),
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken);
} catch {
await botClient.SendTextMessageAsync(
chatId: message.Chat.Id,
text: "<code>Could not generate art</code>",
parseMode: ParseMode.Html,
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken);
}
} catch (RateLimitExceededException exc) when (exc is { Cooldown: var cooldown }) {
await botClient.SendTextMessageAsync(
chatId: message.Chat.Id,
text: $"Anda belum mendapat giliran. Coba lagi {cooldown}.",
parseMode: ParseMode.Html,
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken);
}
}
}
// await botClient.SendPhotoAsync(
// chatId: message.Chat.Id,
// photo: new InputFileStream(imageStream, "art.jpg"),
// replyToMessageId: message.MessageId,
// cancellationToken: cancellationToken);
// } catch {
// await botClient.SendTextMessageAsync(
// chatId: message.Chat.Id,
// text: "<code>Could not generate art</code>",
// parseMode: ParseMode.Html,
// replyToMessageId: message.MessageId,
// cancellationToken: cancellationToken);
// }
// } catch (RateLimitExceededException exc) when (exc is { Cooldown: var cooldown }) {
// await botClient.SendTextMessageAsync(
// chatId: message.Chat.Id,
// text: $"Anda belum mendapat giliran. Coba lagi {cooldown}.",
// parseMode: ParseMode.Html,
// replyToMessageId: message.MessageId,
// cancellationToken: cancellationToken);
// }
// }
//}
}
}
82 changes: 42 additions & 40 deletions BotNet.Services/BotCommands/OpenAI.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using BotNet.Services.OpenAI.Models;
using BotNet.Services.OpenAI.Skills;
using BotNet.Services.RateLimit;
using BotNet.Services.Stability.Skills;
using Microsoft.Extensions.DependencyInjection;
using RG.Ninja;
using SkiaSharp;
Expand Down Expand Up @@ -766,9 +767,8 @@ await botClient.SendTextMessageAsync(
}
}

private static readonly RateLimiter IMAGE_GENERATION_PER_USER_RATE_LIMITER = RateLimiter.PerUser(1, TimeSpan.FromMinutes(10));
private static readonly RateLimiter IMAGE_GENERATION_PER_CHAT_RATE_LIMITER = RateLimiter.PerChat(2, TimeSpan.FromMinutes(5));
private static readonly RateLimiter IMAGE_GENERATION_GLOBAL_RATE_LIMITER = RateLimiter.PerChat(1, TimeSpan.FromMinutes(1));
private static readonly RateLimiter IMAGE_GENERATION_PER_USER_RATE_LIMITER = RateLimiter.PerUser(1, TimeSpan.FromMinutes(5));
private static readonly RateLimiter IMAGE_GENERATION_PER_CHAT_RATE_LIMITER = RateLimiter.PerChat(2, TimeSpan.FromMinutes(3));
public static async Task StreamChatWithFriendlyBotAsync(
ITelegramBotClient botClient,
IServiceProvider serviceProvider,
Expand Down Expand Up @@ -831,46 +831,48 @@ await serviceProvider.GetRequiredService<FriendlyBot>().StreamChatAsync(
replyToMessageId: message.MessageId
);
break;
case ChatIntent.ImageGeneration:
IMAGE_GENERATION_PER_USER_RATE_LIMITER.ValidateActionRate(
chatId: message.Chat.Id,
userId: message.From.Id
);
IMAGE_GENERATION_PER_CHAT_RATE_LIMITER.ValidateActionRate(
chatId: message.Chat.Id,
userId: message.From.Id
);
IMAGE_GENERATION_GLOBAL_RATE_LIMITER.ValidateActionRate(
chatId: 0,
userId: 0
);
Message busyMessage = await botClient.SendTextMessageAsync(
chatId: message.Chat.Id,
text: "Generating image… ⏳",
parseMode: ParseMode.Markdown,
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken
);
Uri generatedImageUrl = await serviceProvider.GetRequiredService<ImageGenerationBot>().GenerateImageAsync(
prompt: message.Text!,
cancellationToken: cancellationToken
);
try {
await botClient.DeleteMessageAsync(
chatId: busyMessage.Chat.Id,
messageId: busyMessage.MessageId,
case ChatIntent.ImageGeneration: {
IMAGE_GENERATION_PER_USER_RATE_LIMITER.ValidateActionRate(
chatId: message.Chat.Id,
userId: message.From.Id
);
IMAGE_GENERATION_PER_CHAT_RATE_LIMITER.ValidateActionRate(
chatId: message.Chat.Id,
userId: message.From.Id
);
Message busyMessage = await botClient.SendTextMessageAsync(
chatId: message.Chat.Id,
text: "Generating image… ⏳",
parseMode: ParseMode.Markdown,
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken
);
//Uri generatedImageUrl = await serviceProvider.GetRequiredService<ImageGenerationBot>().GenerateImageAsync(
// prompt: message.Text!,
// cancellationToken: cancellationToken
//);
byte[] generatedImage = await serviceProvider.GetRequiredService<Stability.Skills.ImageGenerationBot>().GenerateImageAsync(
prompt: message.Text!,
cancellationToken: cancellationToken
);
using MemoryStream generatedImageStream = new(generatedImage);
try {
await botClient.DeleteMessageAsync(
chatId: busyMessage.Chat.Id,
messageId: busyMessage.MessageId,
cancellationToken: cancellationToken
);
} catch (OperationCanceledException) {
throw;
}
await botClient.SendPhotoAsync(
chatId: message.Chat.Id,
photo: new InputFileStream(generatedImageStream, "art.png"),
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken
);
} catch (OperationCanceledException) {
throw;
break;
}
await botClient.SendPhotoAsync(
chatId: message.Chat.Id,
photo: new InputFileUrl(generatedImageUrl),
replyToMessageId: message.MessageId,
cancellationToken: cancellationToken
);
break;
}
}
} catch (RateLimitExceededException exc) when (exc is { Cooldown: var cooldown }) {
Expand Down
5 changes: 0 additions & 5 deletions BotNet.Services/BotNet.Services.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
<ItemGroup>
<None Remove="CopyPasta\Pasta.json" />
<None Remove="Meme\Images\ramad.jpg" />
<None Remove="Stability\generation.proto" />
<None Remove="FancyText\CharMaps\Bold.json" />
<None Remove="FancyText\CharMaps\BoldItalic.json" />
<None Remove="FancyText\CharMaps\Cursive.json" />
Expand Down Expand Up @@ -100,8 +99,4 @@
<ProjectReference Include="..\pehape\csharp\Pehape\Pehape.csproj" />
</ItemGroup>

<ItemGroup>
<Protobuf Include="Stability\generation.proto" GrpcServices="Client" />
</ItemGroup>

</Project>
13 changes: 13 additions & 0 deletions BotNet.Services/Stability/Models/TextToImageResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using System.Collections.Generic;

namespace BotNet.Services.Stability.Models {
internal sealed record TextToImageResponse(
List<Artifact> Artifacts
);

internal sealed record Artifact(
string Base64,
string FinishReason,
int Seed
);
}
4 changes: 3 additions & 1 deletion BotNet.Services/Stability/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using Microsoft.Extensions.DependencyInjection;
using BotNet.Services.Stability.Skills;
using Microsoft.Extensions.DependencyInjection;

namespace BotNet.Services.Stability {
public static class ServiceCollectionExtensions {
public static IServiceCollection AddStabilityClient(this IServiceCollection services) {
services.AddSingleton<StabilityClient>();
services.AddSingleton<ImageGenerationBot>();
return services;
}
}
Expand Down
21 changes: 21 additions & 0 deletions BotNet.Services/Stability/Skills/ImageGenerationBot.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using System.Threading;
using System.Threading.Tasks;

namespace BotNet.Services.Stability.Skills {
public sealed class ImageGenerationBot(
StabilityClient stabilityClient
) {
private readonly StabilityClient _stabilityClient = stabilityClient;

public async Task<byte[]> GenerateImageAsync(
string prompt,
CancellationToken cancellationToken
) {
return await _stabilityClient.GenerateImageAsync(
engine: "stable-diffusion-xl-1024-v1-0",
promptText: prompt,
cancellationToken: cancellationToken
);
}
}
}
Loading

0 comments on commit 6a665f7

Please sign in to comment.