Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed SampleNavigationParameters from generated samples. #9

Merged
merged 2 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions AIDevGallery/Models/MultiModelSampleNavigationParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ internal class MultiModelSampleNavigationParameters(
{
public string[] ModelPaths { get; } = modelPaths;
public HardwareAccelerator[] HardwareAccelerators { get; } = hardwareAccelerators;
public LlmPromptTemplate?[] PromptTemplates { get; } = promptTemplates;

protected override string ChatClientModelPath => ModelPaths[0];
protected override LlmPromptTemplate? ChatClientPromptTemplate => PromptTemplates[0];
protected override LlmPromptTemplate? ChatClientPromptTemplate => promptTemplates[0];
}
}
3 changes: 1 addition & 2 deletions AIDevGallery/Models/SampleNavigationParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ internal class SampleNavigationParameters(
{
public string ModelPath { get; } = modelPath;
public HardwareAccelerator HardwareAccelerator { get; } = hardwareAccelerator;
public LlmPromptTemplate? PromptTemplate { get; } = promptTemplate;

protected override string ChatClientModelPath => ModelPath;
protected override LlmPromptTemplate? ChatClientPromptTemplate => PromptTemplate;
protected override LlmPromptTemplate? ChatClientPromptTemplate => promptTemplate;
}
}
164 changes: 74 additions & 90 deletions AIDevGallery/ProjectGenerator/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,6 @@ private async Task<string> GenerateAsyncInternal(Sample sample, Dictionary<Model
{ "ProjectTemplate.csproj", $"{safeProjectName}.csproj" }
};

string modelTemplateString = GetPromptTemplateString(modelInfos.Values.Select(m => m.ModelPromptTemplate).ToList());

var className = await AddFilesFromSampleAsync(sample, packageReferences, safeProjectName, outputPath, addLllmTypes, modelInfos, cancellationToken);

foreach (var file in files)
Expand Down Expand Up @@ -263,7 +261,6 @@ private async Task<string> GenerateAsyncInternal(Sample sample, Dictionary<Model
content = content.Replace("$XmlEscapedPublisher$", xmlEscapedPublisher);
content = content.Replace("$DotNetVersion$", DotNetVersion);
content = content.Replace("$MainSamplePage$", className);
content = content.Replace("$promptTemplate$", modelTemplateString);

// Write the file
await File.WriteAllTextAsync(outputPathFile, content, cancellationToken);
Expand Down Expand Up @@ -359,21 +356,14 @@ private async Task<string> GenerateAsyncInternal(Sample sample, Dictionary<Model
return outputPath;
}

private string GetChatClientLoaderString(Sample sample, bool isMultiModel, string modelPath)
private string GetChatClientLoaderString(Sample sample, string modelPath, string promptTemplate)
{
if (!sample.SharedCode.Contains(SharedCodeEnum.GenAIModel))
{
return string.Empty;
}

if (isMultiModel)
{
return $"GenAIModel.CreateAsync({modelPath}, sampleParams.PromptTemplates[0], System.Threading.CancellationToken.None)";
}
else
{
return $"GenAIModel.CreateAsync({modelPath}, sampleParams.PromptTemplate, System.Threading.CancellationToken.None)";
}
return $"GenAIModel.CreateAsync({modelPath}, {promptTemplate})";
}

private static async Task CopyFileAsync(string sourceFile, string destinationFile, CancellationToken cancellationToken)
Expand All @@ -391,95 +381,69 @@ private static string EscapeNewLines(string str)
return str;
}

private string GetPromptTemplateString(List<PromptTemplate?> promptTemplates)
private string GetPromptTemplateString(PromptTemplate? promptTemplate, int spaceCount)
{
if (promptTemplates.Count == 0 ||
promptTemplates.All(pt => pt == null))
if (promptTemplate == null)
{
return string.Empty;
return "null";
}

StringBuilder modelPromptTemplateSb = new();
if (promptTemplates.Count == 1)
var spaces = new string(' ', spaceCount);
modelPromptTemplateSb.AppendLine("new LlmPromptTemplate");
modelPromptTemplateSb.Append(spaces);
modelPromptTemplateSb.AppendLine("{");
if (!string.IsNullOrEmpty(promptTemplate.System))
{
modelPromptTemplateSb.AppendLine("public LlmPromptTemplate? PromptTemplate { get; } =");
modelPromptTemplateSb.Append(spaces);
modelPromptTemplateSb.AppendLine(
string.Format(
CultureInfo.InvariantCulture,
"""
System = "{0}",
""",
EscapeNewLines(promptTemplate.System)));
}
else

if (!string.IsNullOrEmpty(promptTemplate.User))
{
modelPromptTemplateSb.AppendLine("public LlmPromptTemplate?[] PromptTemplates { get; } = [");
modelPromptTemplateSb.Append(spaces);
modelPromptTemplateSb.AppendLine(string.Format(
CultureInfo.InvariantCulture,
"""
User = "{0}",
""",
EscapeNewLines(promptTemplate.User)));
}

foreach (var promptTemplate in promptTemplates)
if (!string.IsNullOrEmpty(promptTemplate.Assistant))
{
if (promptTemplate == null)
{
modelPromptTemplateSb.AppendLine(" null,");
continue;
}

modelPromptTemplateSb.AppendLine(
$$"""
new LlmPromptTemplate
{
""");
if (!string.IsNullOrEmpty(promptTemplate.System))
{
modelPromptTemplateSb.AppendLine(
string.Format(
CultureInfo.InvariantCulture,
"""
System = "{0}",
""",
EscapeNewLines(promptTemplate.System)));
}

if (!string.IsNullOrEmpty(promptTemplate.User))
{
modelPromptTemplateSb.AppendLine(string.Format(
CultureInfo.InvariantCulture,
"""
User = "{0}",
""",
EscapeNewLines(promptTemplate.User)));
}
modelPromptTemplateSb.Append(spaces);
modelPromptTemplateSb.AppendLine(string.Format(
CultureInfo.InvariantCulture,
"""
Assistant = "{0}",
""",
EscapeNewLines(promptTemplate.Assistant)));
}

if (!string.IsNullOrEmpty(promptTemplate.Assistant))
{
modelPromptTemplateSb.AppendLine(string.Format(
if (promptTemplate.Stop != null && promptTemplate.Stop.Length > 0)
{
modelPromptTemplateSb.Append(spaces);
var stopStr = string.Join(", ", promptTemplate.Stop.Select(s =>
string.Format(
CultureInfo.InvariantCulture,
"""
Assistant = "{0}",
"{0}"
""",
EscapeNewLines(promptTemplate.Assistant)));
}

if (promptTemplate.Stop != null && promptTemplate.Stop.Length > 0)
{
var stopStr = string.Join(", ", promptTemplate.Stop.Select(s =>
string.Format(
CultureInfo.InvariantCulture,
"""
"{0}"
""",
EscapeNewLines(s))));
modelPromptTemplateSb.Append(" Stop = [ ");
modelPromptTemplateSb.Append(stopStr);
modelPromptTemplateSb.AppendLine("]");
}

modelPromptTemplateSb.Append(" }");
if (promptTemplates.Count > 1)
{
modelPromptTemplateSb.AppendLine(",");
}
}

if (promptTemplates.Count > 1)
{
modelPromptTemplateSb.Append(" ]");
EscapeNewLines(s))));
modelPromptTemplateSb.Append(" Stop = [ ");
modelPromptTemplateSb.Append(stopStr);
modelPromptTemplateSb.AppendLine("]");
}

modelPromptTemplateSb.Append(';');
modelPromptTemplateSb.Append(spaces);
modelPromptTemplateSb.Append('}');

return modelPromptTemplateSb.ToString();
}
Expand All @@ -497,7 +461,7 @@ private async Task<string> AddFilesFromSampleAsync(
if (!sharedCode.Contains(SharedCodeEnum.LlmPromptTemplate) &&
(addLllmTypes || sample.SharedCode.Contains(SharedCodeEnum.GenAIModel)))
{
// Always used inside SampleNavigationParameters.cs and GenAIModel.cs
// Always used inside GenAIModel.cs
sharedCode.Add(SharedCodeEnum.LlmPromptTemplate);
}

Expand Down Expand Up @@ -531,19 +495,29 @@ private async Task<string> AddFilesFromSampleAsync(
if (!string.IsNullOrEmpty(sample.XAMLCode))
{
var xamlSource = CleanXamlSource(sample.XAMLCode, safeProjectName, out className);
xamlSource = xamlSource.Replace($"{Environment.NewLine} xmlns:samples=\"using:AIDevGallery.Samples\"", string.Empty);
xamlSource = xamlSource.Replace("<samples:BaseSamplePage", "<Page");
xamlSource = xamlSource.Replace("</samples:BaseSamplePage>", "</Page>");

await File.WriteAllTextAsync(Path.Join(outputPath, $"{className}.xaml"), xamlSource, cancellationToken);
}

if (!string.IsNullOrEmpty(sample.CSCode))
{
var cleanCsSource = CleanCsSource(sample.CSCode, safeProjectName, true);
cleanCsSource = cleanCsSource.Replace("sampleParams.NotifyCompletion();", "App.Window?.ModelLoaded();");
cleanCsSource = cleanCsSource.Replace(": BaseSamplePage", ": Microsoft.UI.Xaml.Controls.Page");
cleanCsSource = cleanCsSource.Replace(
"Task LoadModelAsync(SampleNavigationParameters sampleParams)",
"void OnNavigatedTo(Microsoft.UI.Xaml.Navigation.NavigationEventArgs e)");
cleanCsSource = cleanCsSource.Replace(
"Task LoadModelAsync(MultiModelSampleNavigationParameters sampleParams)",
"void OnNavigatedTo(Microsoft.UI.Xaml.Navigation.NavigationEventArgs e)");
cleanCsSource = cleanCsSource.Replace($"{Environment.NewLine} return Task.CompletedTask;", string.Empty);

string modelPath;
if (modelInfos.Count > 1)
{
cleanCsSource = cleanCsSource.Replace("MultiModelSampleNavigationParameters", "SampleNavigationParameters");

int i = 0;
foreach (var modelInfo in modelInfos)
{
Expand All @@ -564,10 +538,20 @@ private async Task<string> AddFilesFromSampleAsync(

cleanCsSource = cleanCsSource.Replace("sampleParams.CancellationToken", "CancellationToken.None");

var chatClientLoader = GetChatClientLoaderString(sample, modelInfos.Count > 1, modelPath);
if (chatClientLoader != null)
var search = "sampleParams.GetIChatClientAsync()";
int index = cleanCsSource.IndexOf(search, StringComparison.OrdinalIgnoreCase);
if (index > 0)
{
cleanCsSource = cleanCsSource.Replace("sampleParams.GetIChatClientAsync()", chatClientLoader);
int newLineIndex = cleanCsSource[..index].LastIndexOf(Environment.NewLine, StringComparison.OrdinalIgnoreCase);
var subStr = cleanCsSource[(newLineIndex + Environment.NewLine.Length)..];
var subStrWithoutSpaces = subStr.TrimStart();
var spaceCount = subStr.Length - subStrWithoutSpaces.Length;
var promptTemplate = GetPromptTemplateString(modelInfos.Values.First().ModelPromptTemplate, spaceCount);
var chatClientLoader = GetChatClientLoaderString(sample, modelPath, promptTemplate);
if (chatClientLoader != null)
{
cleanCsSource = cleanCsSource.Replace(search, chatClientLoader);
}
}

await File.WriteAllTextAsync(Path.Join(outputPath, $"{className}.xaml.cs"), cleanCsSource, cancellationToken);
Expand Down
4 changes: 1 addition & 3 deletions AIDevGallery/ProjectGenerator/Template/MainWindow.xaml.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using System.Threading;
using Microsoft.UI.Xaml;
using $safeprojectname$.SharedCode;

namespace $safeprojectname$
{
Expand All @@ -11,7 +9,7 @@ public MainWindow()
this.InitializeComponent();
this.RootFrame.Loaded += (sender, args) =>
{
RootFrame.Navigate(typeof($MainSamplePage$), new SampleNavigationParameters());
RootFrame.Navigate(typeof($MainSamplePage$));
};
}

Expand Down

This file was deleted.

37 changes: 37 additions & 0 deletions AIDevGallery/Samples/BaseSamplePage.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using AIDevGallery.Models;
using Microsoft.UI.Xaml.Controls;
using Microsoft.UI.Xaml.Navigation;
using System.Threading.Tasks;

namespace AIDevGallery.Samples
{
internal partial class BaseSamplePage : Page
{
protected override async void OnNavigatedTo(NavigationEventArgs e)
{
base.OnNavigatedTo(e);

if (e.Parameter is SampleNavigationParameters sampleParams)
{
await LoadModelAsync(sampleParams);
}
else if (e.Parameter is MultiModelSampleNavigationParameters sampleParams2)
{
await LoadModelAsync(sampleParams2);
}
}

protected virtual Task LoadModelAsync(SampleNavigationParameters sampleParams)
{
return Task.CompletedTask;
}

protected virtual Task LoadModelAsync(MultiModelSampleNavigationParameters sampleParams)
{
return Task.CompletedTask;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
<?xml version="1.0" encoding="utf-8" ?>
<Page
<samples:BaseSamplePage
xmlns:samples="using:AIDevGallery.Samples"
x:Class="AIDevGallery.Samples.OpenSourceModels.SentenceEmbeddings.Embeddings.RetrievalAugmentedGeneration"
xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation"
xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
Expand Down Expand Up @@ -180,4 +181,4 @@
</Button>
</Grid>
</Grid>
</Page>
</samples:BaseSamplePage>
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace AIDevGallery.Samples.OpenSourceModels.SentenceEmbeddings.Embeddings
],
Id = "9C1FB14D-4841-449C-9563-4551106BB693",
Icon = "\uE8D4")]
internal sealed partial class RetrievalAugmentedGeneration : Page
internal sealed partial class RetrievalAugmentedGeneration : BaseSamplePage
{
private readonly ChatOptions _chatOptions = GenAIModel.GetDefaultChatOptions();
private EmbeddingGenerator? _embeddings;
Expand Down Expand Up @@ -89,20 +89,16 @@ public RetrievalAugmentedGeneration()
this.Loaded += (s, e) => Page_Loaded(); // <exclude-line>
}

protected override async void OnNavigatedTo(NavigationEventArgs e)
protected override async Task LoadModelAsync(MultiModelSampleNavigationParameters sampleParams)
{
base.OnNavigatedTo(e);
if (e.Parameter is MultiModelSampleNavigationParameters sampleParams)
{
_embeddings = new EmbeddingGenerator(sampleParams.ModelPaths[1], sampleParams.HardwareAccelerators[1]);
_chatClient = await sampleParams.GetIChatClientAsync();
_chatOptions.MaxOutputTokens = 2048;
_embeddings = new EmbeddingGenerator(sampleParams.ModelPaths[1], sampleParams.HardwareAccelerators[1]);
_chatClient = await sampleParams.GetIChatClientAsync();
_chatOptions.MaxOutputTokens = 2048;

sampleParams.NotifyCompletion();
sampleParams.NotifyCompletion();

IndexPDFButton.IsEnabled = true;
IndexPDFText.Text = "Select a PDF";
}
IndexPDFButton.IsEnabled = true;
IndexPDFText.Text = "Select a PDF";
}

// <exclude>
Expand Down
Loading