diff --git a/.editorconfig b/.editorconfig index 562f7ad..677b6b8 100644 --- a/.editorconfig +++ b/.editorconfig @@ -179,7 +179,7 @@ csharp_style_prefer_readonly_struct_member = true # Code-block preferences csharp_prefer_braces = true:silent csharp_prefer_simple_using_statement = true:suggestion -csharp_style_namespace_declarations = block_scoped:warning +csharp_style_namespace_declarations = file_scoped:warning csharp_style_prefer_method_group_conversion = true:silent csharp_style_prefer_primary_constructors = true:suggestion csharp_style_prefer_top_level_statements = true:silent diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000..19d50e6 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# Switch to file-scoped namespaces +52b04efcd96584a68740f17ea34b3430b288285c \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/DependencyVersionsSourceGenerator.cs b/AIDevGallery.SourceGenerator/DependencyVersionsSourceGenerator.cs index fec9833..ed044df 100644 --- a/AIDevGallery.SourceGenerator/DependencyVersionsSourceGenerator.cs +++ b/AIDevGallery.SourceGenerator/DependencyVersionsSourceGenerator.cs @@ -6,62 +6,61 @@ using System.Collections.Generic; using System.Text; -namespace AIDevGallery.SourceGenerator +namespace AIDevGallery.SourceGenerator; + +[Generator(LanguageNames.CSharp)] +internal class DependencyVersionsSourceGenerator : IIncrementalGenerator { - [Generator(LanguageNames.CSharp)] - internal class DependencyVersionsSourceGenerator : IIncrementalGenerator - { - private Dictionary? packageVersions = null; + private Dictionary? packageVersions = null; - public void Initialize(IncrementalGeneratorInitializationContext context) - { - packageVersions = Helpers.GetPackageVersions(); + public void Initialize(IncrementalGeneratorInitializationContext context) + { + packageVersions = Helpers.GetPackageVersions(); - context.RegisterPostInitializationOutput(Execute); - } + context.RegisterPostInitializationOutput(Execute); + } - public void Execute(IncrementalGeneratorPostInitializationContext context) + public void Execute(IncrementalGeneratorPostInitializationContext context) + { + if (packageVersions == null) { - if (packageVersions == null) - { - return; - } - - GeneratePackageVersionsFile(context, packageVersions); + return; } - private static void GeneratePackageVersionsFile(IncrementalGeneratorPostInitializationContext context, Dictionary packageVersions) - { - var sourceBuilder = new StringBuilder(); + GeneratePackageVersionsFile(context, packageVersions); + } - sourceBuilder.AppendLine( - $$"""" - #nullable enable + private static void GeneratePackageVersionsFile(IncrementalGeneratorPostInitializationContext context, Dictionary packageVersions) + { + var sourceBuilder = new StringBuilder(); - using System.Collections.Generic; - using AIDevGallery.Models; + sourceBuilder.AppendLine( + $$"""" + #nullable enable - namespace AIDevGallery.Samples; + using System.Collections.Generic; + using AIDevGallery.Models; - internal static partial class PackageVersionHelpers - { - """"); + namespace AIDevGallery.Samples; - sourceBuilder.AppendLine(" internal static Dictionary PackageVersions { get; } = new ()"); - sourceBuilder.AppendLine(" {"); - foreach (var packageVersion in packageVersions) + internal static partial class PackageVersionHelpers { - sourceBuilder.AppendLine( - $$"""" - { "{{packageVersion.Key}}", "{{packageVersion.Value}}" }, - """"); - } + """"); - sourceBuilder.AppendLine(" };"); + sourceBuilder.AppendLine(" internal static Dictionary PackageVersions { get; } = new ()"); + sourceBuilder.AppendLine(" {"); + foreach (var packageVersion in packageVersions) + { + sourceBuilder.AppendLine( + $$"""" + { "{{packageVersion.Key}}", "{{packageVersion.Value}}" }, + """"); + } - sourceBuilder.AppendLine("}"); + sourceBuilder.AppendLine(" };"); - context.AddSource($"PackageVersionHelpers.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); - } + sourceBuilder.AppendLine("}"); + + context.AddSource($"PackageVersionHelpers.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Diagnostics/Analyzers/NonUniqueIdAnalyzer.cs b/AIDevGallery.SourceGenerator/Diagnostics/Analyzers/NonUniqueIdAnalyzer.cs index b75d579..ccdcd91 100644 --- a/AIDevGallery.SourceGenerator/Diagnostics/Analyzers/NonUniqueIdAnalyzer.cs +++ b/AIDevGallery.SourceGenerator/Diagnostics/Analyzers/NonUniqueIdAnalyzer.cs @@ -8,60 +8,59 @@ using System.Collections.Immutable; using System.Linq; -namespace AIDevGallery.SourceGenerator.Diagnostics.Analyzers +namespace AIDevGallery.SourceGenerator.Diagnostics.Analyzers; + +[DiagnosticAnalyzer(LanguageNames.CSharp)] +internal class NonUniqueIdAnalyzer : DiagnosticAnalyzer { - [DiagnosticAnalyzer(LanguageNames.CSharp)] - internal class NonUniqueIdAnalyzer : DiagnosticAnalyzer + public override ImmutableArray SupportedDiagnostics { get; } = [DiagnosticDescriptors.NonUniqueId]; + + public override void Initialize(AnalysisContext context) { - public override ImmutableArray SupportedDiagnostics { get; } = [DiagnosticDescriptors.NonUniqueId]; + context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics); + context.EnableConcurrentExecution(); - public override void Initialize(AnalysisContext context) + context.RegisterCompilationStartAction(static context => { - context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.Analyze | GeneratedCodeAnalysisFlags.ReportDiagnostics); - context.EnableConcurrentExecution(); - - context.RegisterCompilationStartAction(static context => + // Get the [GallerySample] attribute type symbol + if (context.Compilation.GetTypeByMetadataName(WellKnownTypeNames.GallerySampleAttribute) is not INamedTypeSymbol gallerySampleAttributeSymbol) { - // Get the [GallerySample] attribute type symbol - if (context.Compilation.GetTypeByMetadataName(WellKnownTypeNames.GallerySampleAttribute) is not INamedTypeSymbol gallerySampleAttributeSymbol) - { - return; - } + return; + } - var locations = new ConcurrentDictionary(); + var locations = new ConcurrentDictionary(); - context.RegisterSymbolAction( - context => + context.RegisterSymbolAction( + context => + { + if (context.Symbol is not INamedTypeSymbol { TypeKind: TypeKind.Class, IsImplicitlyDeclared: false }) + { + return; + } + + if (context.Symbol.TryGetAttributeWithType(gallerySampleAttributeSymbol, out AttributeData? attribute) && + attribute != null && + attribute.NamedArguments.FirstOrDefault(a => a.Key == "Id").Value.Value is string id && + !string.IsNullOrEmpty(id) && + attribute.GetLocation() is Location location) { - if (context.Symbol is not INamedTypeSymbol { TypeKind: TypeKind.Class, IsImplicitlyDeclared: false }) + // Check if the id is unique + if (locations.TryAdd(id, location)) { + // ID is unique so far, do nothing return; } - - if (context.Symbol.TryGetAttributeWithType(gallerySampleAttributeSymbol, out AttributeData? attribute) && - attribute != null && - attribute.NamedArguments.FirstOrDefault(a => a.Key == "Id").Value.Value is string id && - !string.IsNullOrEmpty(id) && - attribute.GetLocation() is Location location) + else { - // Check if the id is unique - if (locations.TryAdd(id, location)) - { - // ID is unique so far, do nothing - return; - } - else - { - context.ReportDiagnostic(Diagnostic.Create( - DiagnosticDescriptors.NonUniqueId, - location, - [locations[id]], - id)); - } + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.NonUniqueId, + location, + [locations[id]], + id)); } - }, - SymbolKind.NamedType); - }); - } + } + }, + SymbolKind.NamedType); + }); } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Diagnostics/DiagnosticDescriptors.cs b/AIDevGallery.SourceGenerator/Diagnostics/DiagnosticDescriptors.cs index aa299fc..17f0dd2 100644 --- a/AIDevGallery.SourceGenerator/Diagnostics/DiagnosticDescriptors.cs +++ b/AIDevGallery.SourceGenerator/Diagnostics/DiagnosticDescriptors.cs @@ -3,26 +3,25 @@ using Microsoft.CodeAnalysis; -namespace AIDevGallery.SourceGenerator.Diagnostics +namespace AIDevGallery.SourceGenerator.Diagnostics; + +internal static class DiagnosticDescriptors { - internal static class DiagnosticDescriptors - { - public static readonly DiagnosticDescriptor NonUniqueId = new( - id: "AIDevGallery0001", - title: "Duplicate Id for [GallerySample] sample", - messageFormat: "Id '{0}' is used more than once", - category: nameof(SamplesSourceGenerator), - defaultSeverity: DiagnosticSeverity.Error, - isEnabledByDefault: true, - description: "All gallery samples must have unique ids."); + public static readonly DiagnosticDescriptor NonUniqueId = new( + id: "AIDevGallery0001", + title: "Duplicate Id for [GallerySample] sample", + messageFormat: "Id '{0}' is used more than once", + category: nameof(SamplesSourceGenerator), + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: "All gallery samples must have unique ids."); - public static readonly DiagnosticDescriptor NugetPackageNotUsed = new( - id: "AIDevGallery0002", - title: "Nuget package not used", - messageFormat: "Nuget package '{0}' is not used in the Gallery app", - category: nameof(SamplesSourceGenerator), - defaultSeverity: DiagnosticSeverity.Error, - isEnabledByDefault: true, - description: "Nuget package references must be used in the sample app to be used by a sample."); - } + public static readonly DiagnosticDescriptor NugetPackageNotUsed = new( + id: "AIDevGallery0002", + title: "Nuget package not used", + messageFormat: "Nuget package '{0}' is not used in the Gallery app", + category: nameof(SamplesSourceGenerator), + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: "Nuget package references must be used in the sample app to be used by a sample."); } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Extensions/AttributeDataExtensions.cs b/AIDevGallery.SourceGenerator/Extensions/AttributeDataExtensions.cs index 1e52cd3..d916a66 100644 --- a/AIDevGallery.SourceGenerator/Extensions/AttributeDataExtensions.cs +++ b/AIDevGallery.SourceGenerator/Extensions/AttributeDataExtensions.cs @@ -3,18 +3,17 @@ using Microsoft.CodeAnalysis; -namespace AIDevGallery.SourceGenerator.Extensions +namespace AIDevGallery.SourceGenerator.Extensions; + +internal static class AttributeDataExtensions { - internal static class AttributeDataExtensions + public static Location? GetLocation(this AttributeData attributeData) { - public static Location? GetLocation(this AttributeData attributeData) + if (attributeData.ApplicationSyntaxReference is { } syntaxReference) { - if (attributeData.ApplicationSyntaxReference is { } syntaxReference) - { - return syntaxReference.SyntaxTree.GetLocation(syntaxReference.Span); - } - - return null; + return syntaxReference.SyntaxTree.GetLocation(syntaxReference.Span); } + + return null; } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Extensions/ISymbolExtensions.cs b/AIDevGallery.SourceGenerator/Extensions/ISymbolExtensions.cs index 13ae25e..d7299b0 100644 --- a/AIDevGallery.SourceGenerator/Extensions/ISymbolExtensions.cs +++ b/AIDevGallery.SourceGenerator/Extensions/ISymbolExtensions.cs @@ -3,45 +3,44 @@ using Microsoft.CodeAnalysis; -namespace AIDevGallery.SourceGenerator.Extensions +namespace AIDevGallery.SourceGenerator.Extensions; + +/// +/// Extension methods for the type. +/// +internal static class ISymbolExtensions { /// - /// Extension methods for the type. + /// Gets the fully qualified name for a given symbol. /// - internal static class ISymbolExtensions + /// The input instance. + /// The fully qualified name for . + public static string GetFullyQualifiedName(this ISymbol symbol) { - /// - /// Gets the fully qualified name for a given symbol. - /// - /// The input instance. - /// The fully qualified name for . - public static string GetFullyQualifiedName(this ISymbol symbol) - { - return symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - } + return symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + } - /// - /// Tries to get an attribute with the specified type. - /// - /// The input instance to check. - /// The instance for the attribute type to look for. - /// The resulting attribute, if it was found. - /// Whether or not has an attribute with the specified name. - public static bool TryGetAttributeWithType(this ISymbol symbol, ITypeSymbol typeSymbol, out AttributeData? attributeData) + /// + /// Tries to get an attribute with the specified type. + /// + /// The input instance to check. + /// The instance for the attribute type to look for. + /// The resulting attribute, if it was found. + /// Whether or not has an attribute with the specified name. + public static bool TryGetAttributeWithType(this ISymbol symbol, ITypeSymbol typeSymbol, out AttributeData? attributeData) + { + foreach (AttributeData attribute in symbol.GetAttributes()) { - foreach (AttributeData attribute in symbol.GetAttributes()) + if (SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, typeSymbol)) { - if (SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, typeSymbol)) - { - attributeData = attribute; + attributeData = attribute; - return true; - } + return true; } + } - attributeData = null; + attributeData = null; - return false; - } + return false; } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Helpers.cs b/AIDevGallery.SourceGenerator/Helpers.cs index 4c5a517..3ad403d 100644 --- a/AIDevGallery.SourceGenerator/Helpers.cs +++ b/AIDevGallery.SourceGenerator/Helpers.cs @@ -14,132 +14,113 @@ using System.Threading.Tasks; using System.Xml; -namespace AIDevGallery.SourceGenerator +namespace AIDevGallery.SourceGenerator; + +internal static class Helpers { - internal static class Helpers + internal static string EscapeUnicodeString(string unicodeString) { - internal static string EscapeUnicodeString(string unicodeString) - { - return JsonSerializer.Serialize(unicodeString, SourceGenerationContext.Default.String); - } + return JsonSerializer.Serialize(unicodeString, SourceGenerationContext.Default.String); + } - private static ModelFamily Fix(ModelFamily modelFamily) + private static ModelFamily Fix(ModelFamily modelFamily) + { + string? id = modelFamily.Id; + if (string.IsNullOrWhiteSpace(id)) { - string? id = modelFamily.Id; - if (string.IsNullOrWhiteSpace(id)) + id = Guid.NewGuid().ToString(); + return new ModelFamily { - id = Guid.NewGuid().ToString(); - return new ModelFamily - { - Id = id, - Name = modelFamily.Name, - Description = modelFamily.Description, - DocsUrl = modelFamily.DocsUrl, - Models = modelFamily.Models, - ReadmeUrl = modelFamily.ReadmeUrl, - }; - } - - return modelFamily; + Id = id, + Name = modelFamily.Name, + Description = modelFamily.Description, + DocsUrl = modelFamily.DocsUrl, + Models = modelFamily.Models, + ReadmeUrl = modelFamily.ReadmeUrl, + }; } - private static ModelGroup Fix(ModelGroup modelGroup) + return modelFamily; + } + + private static ModelGroup Fix(ModelGroup modelGroup) + { + string? id = modelGroup.Id; + if (string.IsNullOrWhiteSpace(id)) { - string? id = modelGroup.Id; - if (string.IsNullOrWhiteSpace(id)) + id = Guid.NewGuid().ToString(); + return new ModelGroup { - id = Guid.NewGuid().ToString(); - return new ModelGroup - { - Id = id, - Name = modelGroup.Name, - Icon = modelGroup.Icon, - Models = modelGroup.Models - }; - } - - return modelGroup; + Id = id, + Name = modelGroup.Name, + Icon = modelGroup.Icon, + Models = modelGroup.Models + }; } - private static async Task FixAsync(Model model, CancellationToken cancellationToken) - { - long? size = model.Size; - - if (size is null or 0) - { - List filesToDownload; - if (model.Url.StartsWith("https://github.com", StringComparison.InvariantCulture)) - { - var ghUrl = new GitHubUrl(model.Url); - filesToDownload = await ModelInformationHelper.GetDownloadFilesFromGitHub(ghUrl, cancellationToken); - } - else - { - var hfUrl = new HuggingFaceUrl(model.Url); - using var httpClientHandler = new HttpClientHandler(); - filesToDownload = await ModelInformationHelper.GetDownloadFilesFromHuggingFace(hfUrl, httpClientHandler, cancellationToken); - } + return modelGroup; + } - filesToDownload = ModelInformationHelper.FilterFiles(filesToDownload, model.FileFilters); + private static async Task FixAsync(Model model, CancellationToken cancellationToken) + { + long? size = model.Size; - size = filesToDownload.Sum(f => f.Size); + if (size is null or 0) + { + List filesToDownload; + if (model.Url.StartsWith("https://github.com", StringComparison.InvariantCulture)) + { + var ghUrl = new GitHubUrl(model.Url); + filesToDownload = await ModelInformationHelper.GetDownloadFilesFromGitHub(ghUrl, cancellationToken); } - - string? id = model.Id; - if (string.IsNullOrWhiteSpace(id)) + else { - id = Guid.NewGuid().ToString(); + var hfUrl = new HuggingFaceUrl(model.Url); + using var httpClientHandler = new HttpClientHandler(); + filesToDownload = await ModelInformationHelper.GetDownloadFilesFromHuggingFace(hfUrl, httpClientHandler, cancellationToken); } - return new Model - { - Id = id, - Name = model.Name, - Url = model.Url, - Description = model.Description, - HardwareAccelerators = model.HardwareAccelerators, - SupportedOnQualcomm = model.SupportedOnQualcomm, - Size = size, - ParameterSize = model.ParameterSize, - Icon = model.Icon, - PromptTemplate = model.PromptTemplate, - License = model.License, - FileFilters = model.FileFilters - }; + filesToDownload = ModelInformationHelper.FilterFiles(filesToDownload, model.FileFilters); + + size = filesToDownload.Sum(f => f.Size); } - internal static async Task FixModelGroupAsync(Dictionary modelGroups, CancellationToken cancellationToken) + string? id = model.Id; + if (string.IsNullOrWhiteSpace(id)) { - for (int k = 0; k < modelGroups.Values.Count; k++) - { - var modelGroup = modelGroups.ElementAt(k); - modelGroups[modelGroup.Key] = Fix(modelGroup.Value); - modelGroup = modelGroups.ElementAt(k); - - for (int j = 0; j < modelGroup.Value.Models.Count; j++) - { - var modelFamily = modelGroup.Value.Models.ElementAt(j); - modelGroup.Value.Models[modelFamily.Key] = Fix(modelFamily.Value); - modelFamily = modelGroup.Value.Models.ElementAt(j); - - for (int i = 0; i < modelFamily.Value.Models.Count; i++) - { - var model = modelFamily.Value.Models.ElementAt(i); - modelFamily.Value.Models[model.Key] = await FixAsync(model.Value, cancellationToken); - } - } - } - - return JsonSerializer.Serialize(modelGroups, SourceGenerationContext.Default.DictionaryStringModelGroup); + id = Guid.NewGuid().ToString(); } - internal static async Task FixModelFamiliesAsync(Dictionary modelFamilies, CancellationToken cancellationToken) + return new Model { - for (int j = 0; j < modelFamilies.Values.Count; j++) + Id = id, + Name = model.Name, + Url = model.Url, + Description = model.Description, + HardwareAccelerators = model.HardwareAccelerators, + SupportedOnQualcomm = model.SupportedOnQualcomm, + Size = size, + ParameterSize = model.ParameterSize, + Icon = model.Icon, + PromptTemplate = model.PromptTemplate, + License = model.License, + FileFilters = model.FileFilters + }; + } + + internal static async Task FixModelGroupAsync(Dictionary modelGroups, CancellationToken cancellationToken) + { + for (int k = 0; k < modelGroups.Values.Count; k++) + { + var modelGroup = modelGroups.ElementAt(k); + modelGroups[modelGroup.Key] = Fix(modelGroup.Value); + modelGroup = modelGroups.ElementAt(k); + + for (int j = 0; j < modelGroup.Value.Models.Count; j++) { - var modelFamily = modelFamilies.ElementAt(j); - modelFamilies[modelFamily.Key] = Fix(modelFamily.Value); - modelFamily = modelFamilies.ElementAt(j); + var modelFamily = modelGroup.Value.Models.ElementAt(j); + modelGroup.Value.Models[modelFamily.Key] = Fix(modelFamily.Value); + modelFamily = modelGroup.Value.Models.ElementAt(j); for (int i = 0; i < modelFamily.Value.Models.Count; i++) { @@ -147,26 +128,44 @@ internal static async Task FixModelFamiliesAsync(Dictionary GetPackageVersions() + return JsonSerializer.Serialize(modelGroups, SourceGenerationContext.Default.DictionaryStringModelGroup); + } + + internal static async Task FixModelFamiliesAsync(Dictionary modelFamilies, CancellationToken cancellationToken) + { + for (int j = 0; j < modelFamilies.Values.Count; j++) { - var assembly = Assembly.GetExecutingAssembly(); - var packageVersions = new Dictionary(); + var modelFamily = modelFamilies.ElementAt(j); + modelFamilies[modelFamily.Key] = Fix(modelFamily.Value); + modelFamily = modelFamilies.ElementAt(j); - using (Stream stream = assembly.GetManifestResourceStream("AIDevGallery.SourceGenerator.Directory.Packages.props")) + for (int i = 0; i < modelFamily.Value.Models.Count; i++) { - using (XmlTextReader xmlReader = new(stream)) - { - while (xmlReader.ReadToFollowing("PackageVersion")) - { - packageVersions.Add(xmlReader.GetAttribute("Include"), xmlReader.GetAttribute("Version")); - } + var model = modelFamily.Value.Models.ElementAt(i); + modelFamily.Value.Models[model.Key] = await FixAsync(model.Value, cancellationToken); + } + } - return packageVersions; + return JsonSerializer.Serialize(modelFamilies, SourceGenerationContext.Default.DictionaryStringModelFamily); + } + + internal static Dictionary GetPackageVersions() + { + var assembly = Assembly.GetExecutingAssembly(); + var packageVersions = new Dictionary(); + + using (Stream stream = assembly.GetManifestResourceStream("AIDevGallery.SourceGenerator.Directory.Packages.props")) + { + using (XmlTextReader xmlReader = new(stream)) + { + while (xmlReader.ReadToFollowing("PackageVersion")) + { + packageVersions.Add(xmlReader.GetAttribute("Include"), xmlReader.GetAttribute("Version")); } + + return packageVersions; } } } diff --git a/AIDevGallery.SourceGenerator/Models/ApiDefinition.cs b/AIDevGallery.SourceGenerator/Models/ApiDefinition.cs index de87bc6..bf0e903 100644 --- a/AIDevGallery.SourceGenerator/Models/ApiDefinition.cs +++ b/AIDevGallery.SourceGenerator/Models/ApiDefinition.cs @@ -1,14 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -namespace AIDevGallery.SourceGenerator.Models +namespace AIDevGallery.SourceGenerator.Models; + +internal class ApiDefinition { - internal class ApiDefinition - { - public required string Id { get; init; } - public required string Name { get; init; } - public required string Icon { get; init; } - public required string ReadmeUrl { get; init; } - public required string License { get; init; } - } + public required string Id { get; init; } + public required string Name { get; init; } + public required string Icon { get; init; } + public required string ReadmeUrl { get; init; } + public required string License { get; init; } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Models/ApiGroup.cs b/AIDevGallery.SourceGenerator/Models/ApiGroup.cs index 2e3732b..299000d 100644 --- a/AIDevGallery.SourceGenerator/Models/ApiGroup.cs +++ b/AIDevGallery.SourceGenerator/Models/ApiGroup.cs @@ -3,14 +3,13 @@ using System.Collections.Generic; -namespace AIDevGallery.SourceGenerator.Models +namespace AIDevGallery.SourceGenerator.Models; + +internal class ApiGroup : IModelGroup { - internal class ApiGroup : IModelGroup - { - public required string Id { get; init; } - public required string Name { get; init; } - public required string Icon { get; init; } - public int? Order { get; init; } - public required Dictionary Apis { get; init; } - } + public required string Id { get; init; } + public required string Name { get; init; } + public required string Icon { get; init; } + public int? Order { get; init; } + public required Dictionary Apis { get; init; } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Models/HardwareAccelerator.cs b/AIDevGallery.SourceGenerator/Models/HardwareAccelerator.cs index 3f2c81d..64da1c3 100644 --- a/AIDevGallery.SourceGenerator/Models/HardwareAccelerator.cs +++ b/AIDevGallery.SourceGenerator/Models/HardwareAccelerator.cs @@ -3,13 +3,12 @@ using System.Text.Json.Serialization; -namespace AIDevGallery.SourceGenerator.Models +namespace AIDevGallery.SourceGenerator.Models; + +[JsonConverter(typeof(JsonStringEnumConverter))] +internal enum HardwareAccelerator { - [JsonConverter(typeof(JsonStringEnumConverter))] - internal enum HardwareAccelerator - { - CPU, - DML, - QNN - } + CPU, + DML, + QNN } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Models/IModelGroup.cs b/AIDevGallery.SourceGenerator/Models/IModelGroup.cs index 4970057..59e99ce 100644 --- a/AIDevGallery.SourceGenerator/Models/IModelGroup.cs +++ b/AIDevGallery.SourceGenerator/Models/IModelGroup.cs @@ -1,13 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -namespace AIDevGallery.SourceGenerator.Models +namespace AIDevGallery.SourceGenerator.Models; + +internal interface IModelGroup { - internal interface IModelGroup - { - public string Id { get; init; } - public string Name { get; init; } - public string Icon { get; init; } - public int? Order { get; init; } - } + public string Id { get; init; } + public string Name { get; init; } + public string Icon { get; init; } + public int? Order { get; init; } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Models/Model.cs b/AIDevGallery.SourceGenerator/Models/Model.cs index 2710434..f6c56ca 100644 --- a/AIDevGallery.SourceGenerator/Models/Model.cs +++ b/AIDevGallery.SourceGenerator/Models/Model.cs @@ -4,25 +4,24 @@ using System.Collections.Generic; using System.Text.Json.Serialization; -namespace AIDevGallery.SourceGenerator.Models +namespace AIDevGallery.SourceGenerator.Models; + +internal class Model { - internal class Model - { - public string? Id { get; init; } - public required string Name { get; init; } - public required string Url { get; init; } - public required string Description { get; init; } - [JsonConverter(typeof(SingleOrListOfHardwareAcceleratorConverter))] - [JsonPropertyName("HardwareAccelerator")] - public required List HardwareAccelerators { get; init; } - public bool? SupportedOnQualcomm { get; init; } - public long? Size { get; init; } - public string? Icon { get; init; } - public string? ParameterSize { get; init; } - public string? PromptTemplate { get; init; } - public required string License { get; init; } - [JsonConverter(typeof(SingleOrListOfStringConverter))] - [JsonPropertyName("FileFilter")] - public List? FileFilters { get; init; } - } + public string? Id { get; init; } + public required string Name { get; init; } + public required string Url { get; init; } + public required string Description { get; init; } + [JsonConverter(typeof(SingleOrListOfHardwareAcceleratorConverter))] + [JsonPropertyName("HardwareAccelerator")] + public required List HardwareAccelerators { get; init; } + public bool? SupportedOnQualcomm { get; init; } + public long? Size { get; init; } + public string? Icon { get; init; } + public string? ParameterSize { get; init; } + public string? PromptTemplate { get; init; } + public required string License { get; init; } + [JsonConverter(typeof(SingleOrListOfStringConverter))] + [JsonPropertyName("FileFilter")] + public List? FileFilters { get; init; } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Models/ModelFamily.cs b/AIDevGallery.SourceGenerator/Models/ModelFamily.cs index 41a7e22..5bb699e 100644 --- a/AIDevGallery.SourceGenerator/Models/ModelFamily.cs +++ b/AIDevGallery.SourceGenerator/Models/ModelFamily.cs @@ -3,16 +3,15 @@ using System.Collections.Generic; -namespace AIDevGallery.SourceGenerator.Models +namespace AIDevGallery.SourceGenerator.Models; + +internal class ModelFamily { - internal class ModelFamily - { - public string? Id { get; init; } - public required string Name { get; init; } - public required string Description { get; init; } - public string? DocsUrl { get; init; } - public int? Order { get; init; } - public required Dictionary Models { get; init; } - public string ReadmeUrl { get; init; } = null!; - } + public string? Id { get; init; } + public required string Name { get; init; } + public required string Description { get; init; } + public string? DocsUrl { get; init; } + public int? Order { get; init; } + public required Dictionary Models { get; init; } + public string ReadmeUrl { get; init; } = null!; } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Models/ModelGroup.cs b/AIDevGallery.SourceGenerator/Models/ModelGroup.cs index 89611f7..4baf770 100644 --- a/AIDevGallery.SourceGenerator/Models/ModelGroup.cs +++ b/AIDevGallery.SourceGenerator/Models/ModelGroup.cs @@ -3,14 +3,13 @@ using System.Collections.Generic; -namespace AIDevGallery.SourceGenerator.Models +namespace AIDevGallery.SourceGenerator.Models; + +internal class ModelGroup : IModelGroup { - internal class ModelGroup : IModelGroup - { - public required string Id { get; init; } - public required string Name { get; init; } - public required string Icon { get; init; } - public int? Order { get; init; } - public required Dictionary Models { get; init; } - } + public required string Id { get; init; } + public required string Name { get; init; } + public required string Icon { get; init; } + public int? Order { get; init; } + public required Dictionary Models { get; init; } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Models/PromptTemplate.cs b/AIDevGallery.SourceGenerator/Models/PromptTemplate.cs index fd471e2..67cb7cf 100644 --- a/AIDevGallery.SourceGenerator/Models/PromptTemplate.cs +++ b/AIDevGallery.SourceGenerator/Models/PromptTemplate.cs @@ -3,17 +3,16 @@ using System.Text.Json.Serialization; -namespace AIDevGallery.SourceGenerator.Models +namespace AIDevGallery.SourceGenerator.Models; + +internal class PromptTemplate { - internal class PromptTemplate - { - [JsonPropertyName("system")] - public string? System { get; init; } - [JsonPropertyName("user")] - public required string User { get; init; } - [JsonPropertyName("assistant")] - public string? Assistant { get; init; } - [JsonPropertyName("stop")] - public required string[] Stop { get; init; } - } + [JsonPropertyName("system")] + public string? System { get; init; } + [JsonPropertyName("user")] + public required string User { get; init; } + [JsonPropertyName("assistant")] + public string? Assistant { get; init; } + [JsonPropertyName("stop")] + public required string[] Stop { get; init; } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Models/Scenario.cs b/AIDevGallery.SourceGenerator/Models/Scenario.cs index 09c86c9..c56ee10 100644 --- a/AIDevGallery.SourceGenerator/Models/Scenario.cs +++ b/AIDevGallery.SourceGenerator/Models/Scenario.cs @@ -1,12 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -namespace AIDevGallery.SourceGenerator.Models +namespace AIDevGallery.SourceGenerator.Models; + +internal class Scenario { - internal class Scenario - { - public required string Name { get; init; } - public string? Description { get; init; } - public required string Id { get; init; } - } + public required string Name { get; init; } + public string? Description { get; init; } + public required string Id { get; init; } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Models/ScenarioCategory.cs b/AIDevGallery.SourceGenerator/Models/ScenarioCategory.cs index 058fb0b..4e6205b 100644 --- a/AIDevGallery.SourceGenerator/Models/ScenarioCategory.cs +++ b/AIDevGallery.SourceGenerator/Models/ScenarioCategory.cs @@ -3,12 +3,11 @@ using System.Collections.Generic; -namespace AIDevGallery.SourceGenerator.Models +namespace AIDevGallery.SourceGenerator.Models; + +internal class ScenarioCategory { - internal class ScenarioCategory - { - public required string Name { get; init; } - public required string Icon { get; init; } - public required Dictionary Scenarios { get; init; } - } + public required string Name { get; init; } + public required string Icon { get; init; } + public required Dictionary Scenarios { get; init; } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Models/SingleOrListOfHardwareAcceleratorConverter.cs b/AIDevGallery.SourceGenerator/Models/SingleOrListOfHardwareAcceleratorConverter.cs index 9bd5155..2381b7d 100644 --- a/AIDevGallery.SourceGenerator/Models/SingleOrListOfHardwareAcceleratorConverter.cs +++ b/AIDevGallery.SourceGenerator/Models/SingleOrListOfHardwareAcceleratorConverter.cs @@ -7,50 +7,49 @@ using System.Text.Json; using System.Text.Json.Serialization; -namespace AIDevGallery.SourceGenerator.Models +namespace AIDevGallery.SourceGenerator.Models; + +internal class SingleOrListOfHardwareAcceleratorConverter : JsonConverter> { - internal class SingleOrListOfHardwareAcceleratorConverter : JsonConverter> + public override List Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { - public override List Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + var list = new List(); + if (reader.TokenType == JsonTokenType.StartArray) { - var list = new List(); - if (reader.TokenType == JsonTokenType.StartArray) + while (reader.Read()) { - while (reader.Read()) + if (reader.TokenType == JsonTokenType.EndArray) { - if (reader.TokenType == JsonTokenType.EndArray) - { - break; - } - - list.Add(JsonSerializer.Deserialize(ref reader, SourceGenerationContext.Default.HardwareAccelerator)); + break; } - } - else if (reader.TokenType != JsonTokenType.Null) - { - var singleValue = JsonSerializer.Deserialize(ref reader, SourceGenerationContext.Default.HardwareAccelerator); - list.Add(singleValue); - } - return list; + list.Add(JsonSerializer.Deserialize(ref reader, SourceGenerationContext.Default.HardwareAccelerator)); + } + } + else if (reader.TokenType != JsonTokenType.Null) + { + var singleValue = JsonSerializer.Deserialize(ref reader, SourceGenerationContext.Default.HardwareAccelerator); + list.Add(singleValue); } - public override void Write(Utf8JsonWriter writer, List value, JsonSerializerOptions options) + return list; + } + + public override void Write(Utf8JsonWriter writer, List value, JsonSerializerOptions options) + { + if (value.Count == 1) + { + JsonSerializer.Serialize(writer, value.First(), SourceGenerationContext.Default.HardwareAccelerator); + } + else { - if (value.Count == 1) + writer.WriteStartArray(); + foreach (var item in value) { - JsonSerializer.Serialize(writer, value.First(), SourceGenerationContext.Default.HardwareAccelerator); + JsonSerializer.Serialize(writer, item, SourceGenerationContext.Default.HardwareAccelerator); } - else - { - writer.WriteStartArray(); - foreach (var item in value) - { - JsonSerializer.Serialize(writer, item, SourceGenerationContext.Default.HardwareAccelerator); - } - writer.WriteEndArray(); - } + writer.WriteEndArray(); } } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Models/SingleOrListOfStringConverter.cs b/AIDevGallery.SourceGenerator/Models/SingleOrListOfStringConverter.cs index 4013bcb..ea72a0e 100644 --- a/AIDevGallery.SourceGenerator/Models/SingleOrListOfStringConverter.cs +++ b/AIDevGallery.SourceGenerator/Models/SingleOrListOfStringConverter.cs @@ -7,50 +7,49 @@ using System.Text.Json; using System.Text.Json.Serialization; -namespace AIDevGallery.SourceGenerator.Models +namespace AIDevGallery.SourceGenerator.Models; + +internal class SingleOrListOfStringConverter : JsonConverter> { - internal class SingleOrListOfStringConverter : JsonConverter> + public override List Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { - public override List Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + var list = new List(); + if (reader.TokenType == JsonTokenType.StartArray) { - var list = new List(); - if (reader.TokenType == JsonTokenType.StartArray) + while (reader.Read()) { - while (reader.Read()) + if (reader.TokenType == JsonTokenType.EndArray) { - if (reader.TokenType == JsonTokenType.EndArray) - { - break; - } - - list.Add(JsonSerializer.Deserialize(ref reader, SourceGenerationContext.Default.String) ?? string.Empty); + break; } - } - else if (reader.TokenType != JsonTokenType.Null) - { - var singleValue = JsonSerializer.Deserialize(ref reader, SourceGenerationContext.Default.String); - list.Add(singleValue ?? string.Empty); - } - return list; + list.Add(JsonSerializer.Deserialize(ref reader, SourceGenerationContext.Default.String) ?? string.Empty); + } + } + else if (reader.TokenType != JsonTokenType.Null) + { + var singleValue = JsonSerializer.Deserialize(ref reader, SourceGenerationContext.Default.String); + list.Add(singleValue ?? string.Empty); } - public override void Write(Utf8JsonWriter writer, List value, JsonSerializerOptions options) + return list; + } + + public override void Write(Utf8JsonWriter writer, List value, JsonSerializerOptions options) + { + if (value.Count == 1) + { + JsonSerializer.Serialize(writer, value.First(), SourceGenerationContext.Default.String); + } + else { - if (value.Count == 1) + writer.WriteStartArray(); + foreach (var item in value) { - JsonSerializer.Serialize(writer, value.First(), SourceGenerationContext.Default.String); + JsonSerializer.Serialize(writer, item, SourceGenerationContext.Default.String); } - else - { - writer.WriteStartArray(); - foreach (var item in value) - { - JsonSerializer.Serialize(writer, item, SourceGenerationContext.Default.String); - } - writer.WriteEndArray(); - } + writer.WriteEndArray(); } } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/Models/SourceGenerationContext.cs b/AIDevGallery.SourceGenerator/Models/SourceGenerationContext.cs index 0eb639b..a246001 100644 --- a/AIDevGallery.SourceGenerator/Models/SourceGenerationContext.cs +++ b/AIDevGallery.SourceGenerator/Models/SourceGenerationContext.cs @@ -4,15 +4,14 @@ using System.Collections.Generic; using System.Text.Json.Serialization; -namespace AIDevGallery.SourceGenerator.Models +namespace AIDevGallery.SourceGenerator.Models; + +[JsonSourceGenerationOptions(WriteIndented = true, AllowTrailingCommas = true, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(Dictionary))] +internal partial class SourceGenerationContext : JsonSerializerContext { - [JsonSourceGenerationOptions(WriteIndented = true, AllowTrailingCommas = true, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] - [JsonSerializable(typeof(Dictionary))] - [JsonSerializable(typeof(Dictionary))] - [JsonSerializable(typeof(Dictionary))] - [JsonSerializable(typeof(Dictionary))] - [JsonSerializable(typeof(Dictionary))] - internal partial class SourceGenerationContext : JsonSerializerContext - { - } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/ModelsSourceGenerator.cs b/AIDevGallery.SourceGenerator/ModelsSourceGenerator.cs index d72aac9..d21a185 100644 --- a/AIDevGallery.SourceGenerator/ModelsSourceGenerator.cs +++ b/AIDevGallery.SourceGenerator/ModelsSourceGenerator.cs @@ -16,444 +16,443 @@ #pragma warning disable RS1035 // Do not use APIs banned for analyzers -namespace AIDevGallery.SourceGenerator +namespace AIDevGallery.SourceGenerator; + +[Generator(LanguageNames.CSharp)] +internal class ModelSourceGenerator : IIncrementalGenerator { - [Generator(LanguageNames.CSharp)] - internal class ModelSourceGenerator : IIncrementalGenerator + public void Initialize(IncrementalGeneratorInitializationContext context) { - public void Initialize(IncrementalGeneratorInitializationContext context) - { - IncrementalValuesProvider modelJsons = context.AdditionalTextsProvider.Where( - static file => file.Path.EndsWith(".json") && - Path.GetFileName(Path.GetDirectoryName(file.Path)).Equals("ModelsDefinitions", StringComparison.OrdinalIgnoreCase)); + IncrementalValuesProvider modelJsons = context.AdditionalTextsProvider.Where( + static file => file.Path.EndsWith(".json") && + Path.GetFileName(Path.GetDirectoryName(file.Path)).Equals("ModelsDefinitions", StringComparison.OrdinalIgnoreCase)); - var pathsAndContents = modelJsons.Select((text, cancellationToken) => - (text.Path, Content: text.GetText(cancellationToken)!.ToString(), CancellationToken: cancellationToken)) - .Collect(); + var pathsAndContents = modelJsons.Select((text, cancellationToken) => + (text.Path, Content: text.GetText(cancellationToken)!.ToString(), CancellationToken: cancellationToken)) + .Collect(); - context.RegisterSourceOutput(pathsAndContents, Execute); - } + context.RegisterSourceOutput(pathsAndContents, Execute); + } - public void Execute(SourceProductionContext context, ImmutableArray<(string Path, string Content, CancellationToken CancellationToken)> modelJsons) - { - Dictionary modelTypes = []; - var sourceBuilder = new StringBuilder(); + public void Execute(SourceProductionContext context, ImmutableArray<(string Path, string Content, CancellationToken CancellationToken)> modelJsons) + { + Dictionary modelTypes = []; + var sourceBuilder = new StringBuilder(); - sourceBuilder.AppendLine( - $$"""" - #nullable enable + sourceBuilder.AppendLine( + $$"""" + #nullable enable - using System.Collections.Generic; + using System.Collections.Generic; - namespace AIDevGallery.Models; - - internal enum ModelType - { - """"); + namespace AIDevGallery.Models; + + internal enum ModelType + { + """"); - foreach (var modelJson in modelJsons) + foreach (var modelJson in modelJsons) + { + try { - try + var success = true; + switch (modelJson.Path) { - var success = true; - switch (modelJson.Path) - { - case var path when path.EndsWith("apis.json"): - var apiGroups = JsonSerializer.Deserialize(modelJson.Content, SourceGenerationContext.Default.DictionaryStringApiGroup); - if (apiGroups == null) - { - throw new InvalidOperationException("Failed to deserialize api.json"); - } + case var path when path.EndsWith("apis.json"): + var apiGroups = JsonSerializer.Deserialize(modelJson.Content, SourceGenerationContext.Default.DictionaryStringApiGroup); + if (apiGroups == null) + { + throw new InvalidOperationException("Failed to deserialize api.json"); + } - AddApis(sourceBuilder, apiGroups); + AddApis(sourceBuilder, apiGroups); - break; + break; - case var path when path.EndsWith(".model.json"): - var modelFamilies = JsonSerializer.Deserialize(modelJson.Content, SourceGenerationContext.Default.DictionaryStringModelFamily); - if (modelFamilies == null) + case var path when path.EndsWith(".model.json"): + var modelFamilies = JsonSerializer.Deserialize(modelJson.Content, SourceGenerationContext.Default.DictionaryStringModelFamily); + if (modelFamilies == null) + { + throw new InvalidOperationException("Failed to deserialize model.json"); + } + + foreach (var modelFamily in modelFamilies) + { + if (!AddEnumValue(sourceBuilder, modelFamily.Key, modelFamily)) { - throw new InvalidOperationException("Failed to deserialize model.json"); + success = false; } - foreach (var modelFamily in modelFamilies) + foreach (var model in modelFamily.Value.Models) { - if (!AddEnumValue(sourceBuilder, modelFamily.Key, modelFamily)) + if (!AddEnumValue(sourceBuilder, $"{modelFamily.Key}{model.Key}", model)) { success = false; } - - foreach (var model in modelFamily.Value.Models) - { - if (!AddEnumValue(sourceBuilder, $"{modelFamily.Key}{model.Key}", model)) - { - success = false; - } - } } + } - if (!success) - { - File.WriteAllText(modelJson.Path, Helpers.FixModelFamiliesAsync(modelFamilies, modelJson.CancellationToken).Result); - } + if (!success) + { + File.WriteAllText(modelJson.Path, Helpers.FixModelFamiliesAsync(modelFamilies, modelJson.CancellationToken).Result); + } - break; + break; - case var path when path.EndsWith(".modelgroup.json"): - var modelGroups = JsonSerializer.Deserialize(modelJson.Content, SourceGenerationContext.Default.DictionaryStringModelGroup); - if (modelGroups == null) + case var path when path.EndsWith(".modelgroup.json"): + var modelGroups = JsonSerializer.Deserialize(modelJson.Content, SourceGenerationContext.Default.DictionaryStringModelGroup); + if (modelGroups == null) + { + throw new InvalidOperationException("Failed to deserialize modelgroup.json"); + } + + foreach (var modelGroup in modelGroups) + { + if (!AddEnumValue(sourceBuilder, modelGroup.Key, modelGroup)) { - throw new InvalidOperationException("Failed to deserialize modelgroup.json"); + success = false; } - foreach (var modelGroup in modelGroups) + foreach (var modelFamily in modelGroup.Value.Models) { - if (!AddEnumValue(sourceBuilder, modelGroup.Key, modelGroup)) + if (!AddEnumValue(sourceBuilder, modelFamily.Key, modelFamily)) { success = false; } - foreach (var modelFamily in modelGroup.Value.Models) + foreach (var model in modelFamily.Value.Models) { - if (!AddEnumValue(sourceBuilder, modelFamily.Key, modelFamily)) + if (!AddEnumValue(sourceBuilder, $"{modelFamily.Key}{model.Key}", model)) { success = false; } - - foreach (var model in modelFamily.Value.Models) - { - if (!AddEnumValue(sourceBuilder, $"{modelFamily.Key}{model.Key}", model)) - { - success = false; - } - } } } + } - if (!success) - { - File.WriteAllText(modelJson.Path, Helpers.FixModelGroupAsync(modelGroups, modelJson.CancellationToken).Result); - } + if (!success) + { + File.WriteAllText(modelJson.Path, Helpers.FixModelGroupAsync(modelGroups, modelJson.CancellationToken).Result); + } - break; + break; - default: - break; - } - } - catch (Exception e) - { - Debug.WriteLine(e); - throw new Exception($"Error when processing '{modelJson.Path}' - Internal error: '{e.Message}'", e); + default: + break; } } + catch (Exception e) + { + Debug.WriteLine(e); + throw new Exception($"Error when processing '{modelJson.Path}' - Internal error: '{e.Message}'", e); + } + } - sourceBuilder.AppendLine( - """ - } - """); + sourceBuilder.AppendLine( + """ + } + """); - context.AddSource("ModelType.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); + context.AddSource("ModelType.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); - bool AddEnumValue(StringBuilder sourceBuilder, string enumValueName, KeyValuePair dict) + bool AddEnumValue(StringBuilder sourceBuilder, string enumValueName, KeyValuePair dict) + { + if (dict.Value != null) { - if (dict.Value != null) + sourceBuilder.AppendLine($" {enumValueName},"); + modelTypes.Add(enumValueName, dict.Value); + if (dict.Value is Model model && + (model.Size == null || model.Size == 0 || + string.IsNullOrWhiteSpace(model.Id))) { - sourceBuilder.AppendLine($" {enumValueName},"); - modelTypes.Add(enumValueName, dict.Value); - if (dict.Value is Model model && - (model.Size == null || model.Size == 0 || - string.IsNullOrWhiteSpace(model.Id))) - { - return false; - } + return false; } - - return true; } - void AddApis(StringBuilder sourceBuilder, Dictionary apiGroups) + return true; + } + + void AddApis(StringBuilder sourceBuilder, Dictionary apiGroups) + { + foreach (var apiGroup in apiGroups) { - foreach (var apiGroup in apiGroups) + AddEnumValue(sourceBuilder, apiGroup.Key, apiGroup); + if (apiGroup.Value.Apis != null) { - AddEnumValue(sourceBuilder, apiGroup.Key, apiGroup); - if (apiGroup.Value.Apis != null) + foreach (var apiDefinition in apiGroup.Value.Apis) { - foreach (var apiDefinition in apiGroup.Value.Apis) - { - AddEnumValue(sourceBuilder, apiDefinition.Key, apiDefinition); - } + AddEnumValue(sourceBuilder, apiDefinition.Key, apiDefinition); } } } - - GenerateModelTypeHelpersFile(context, modelTypes); } - private void GenerateModelTypeHelpersFile(SourceProductionContext context, Dictionary modelTypes) - { - var sourceBuilder = new StringBuilder(); + GenerateModelTypeHelpersFile(context, modelTypes); + } - sourceBuilder.AppendLine( - $$"""" - #nullable enable + private void GenerateModelTypeHelpersFile(SourceProductionContext context, Dictionary modelTypes) + { + var sourceBuilder = new StringBuilder(); - using System.Collections.Generic; - using AIDevGallery.Models; + sourceBuilder.AppendLine( + $$"""" + #nullable enable - namespace AIDevGallery.Samples; + using System.Collections.Generic; + using AIDevGallery.Models; - internal static class ModelTypeHelpers - { - """"); + namespace AIDevGallery.Samples; - GenerateModelDetails( - sourceBuilder, - modelTypes - .Where(kp => kp.Value is Model) - .ToDictionary(kp => kp.Key, kp => (Model)kp.Value)); + internal static class ModelTypeHelpers + { + """"); - var modelFamilies = modelTypes - .Where(kp => kp.Value is ModelFamily) - .ToDictionary(kp => kp.Key, kp => (ModelFamily)kp.Value); - GenerateModelFamilyDetails(sourceBuilder, modelFamilies); + GenerateModelDetails( + sourceBuilder, + modelTypes + .Where(kp => kp.Value is Model) + .ToDictionary(kp => kp.Key, kp => (Model)kp.Value)); - var apiDefinitions = modelTypes - .Where(kp => kp.Value is ApiDefinition) - .ToDictionary(kp => kp.Key, kp => (ApiDefinition)kp.Value); - GenerateApiDefinitionDetails(sourceBuilder, apiDefinitions); + var modelFamilies = modelTypes + .Where(kp => kp.Value is ModelFamily) + .ToDictionary(kp => kp.Key, kp => (ModelFamily)kp.Value); + GenerateModelFamilyDetails(sourceBuilder, modelFamilies); - var modelGroups = modelTypes - .Where(kp => kp.Value is IModelGroup) - .ToDictionary(kp => kp.Key, kp => (IModelGroup)kp.Value); - GenerateModelGroupDetails(sourceBuilder, modelGroups); + var apiDefinitions = modelTypes + .Where(kp => kp.Value is ApiDefinition) + .ToDictionary(kp => kp.Key, kp => (ApiDefinition)kp.Value); + GenerateApiDefinitionDetails(sourceBuilder, apiDefinitions); - GenerateModelParentMapping(sourceBuilder, modelGroups, modelFamilies, apiDefinitions); + var modelGroups = modelTypes + .Where(kp => kp.Value is IModelGroup) + .ToDictionary(kp => kp.Key, kp => (IModelGroup)kp.Value); + GenerateModelGroupDetails(sourceBuilder, modelGroups); - GenerateGetModelOrder(sourceBuilder, modelGroups, modelFamilies); + GenerateModelParentMapping(sourceBuilder, modelGroups, modelFamilies, apiDefinitions); - sourceBuilder.AppendLine("}"); + GenerateGetModelOrder(sourceBuilder, modelGroups, modelFamilies); - context.AddSource($"ModelTypeHelpers.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); - } + sourceBuilder.AppendLine("}"); - private void GenerateModelDetails(StringBuilder sourceBuilder, Dictionary modelTypes) - { - sourceBuilder.AppendLine(" internal static Dictionary ModelDetails { get; } = new ()"); - sourceBuilder.AppendLine(" {"); - foreach (var modelType in modelTypes) - { - var modelDefinition = modelType.Value; - var promptTemplate = modelDefinition.PromptTemplate != null ? - $"PromptTemplateHelpers.PromptTemplates[PromptTemplateType.{modelDefinition.PromptTemplate}]" : - "null"; - var hardwareAccelerator = string.Join(", ", modelDefinition.HardwareAccelerators.Select(ha => $"HardwareAccelerator.{ha}")); - var supportedOnQualcomm = modelDefinition.SupportedOnQualcomm.HasValue ? modelDefinition.SupportedOnQualcomm.Value.ToString().ToLower() : "null"; - var icon = !string.IsNullOrEmpty(modelDefinition.Icon) ? $"\"{modelDefinition.Icon}\"" : "string.Empty"; - var fileFilters = modelDefinition.FileFilters != null ? string.Join(", ", modelDefinition.FileFilters.Select(ff => $"\"{ff}\"")) : string.Empty; - - sourceBuilder.AppendLine( - $$"""" - { - ModelType.{{modelType.Key}}, - new ModelDetails - { - Name = "{{modelDefinition.Name}}", - Id = "{{modelDefinition.Id}}", - Description = "{{modelDefinition.Description}}", - Url = "{{modelDefinition.Url}}", - HardwareAccelerators = [ {{hardwareAccelerator}} ], - SupportedOnQualcomm = {{supportedOnQualcomm}}, - Size = {{modelDefinition.Size}}, - ParameterSize = "{{modelDefinition.ParameterSize}}", - PromptTemplate = {{promptTemplate}}, - Icon = {{icon}}, - License = "{{modelDefinition.License}}", - FileFilters = [ {{fileFilters}} ] - } - }, - """"); - } - - sourceBuilder.AppendLine(" };"); - } + context.AddSource($"ModelTypeHelpers.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); + } - private void GenerateModelFamilyDetails(StringBuilder sourceBuilder, Dictionary modelFamily) + private void GenerateModelDetails(StringBuilder sourceBuilder, Dictionary modelTypes) + { + sourceBuilder.AppendLine(" internal static Dictionary ModelDetails { get; } = new ()"); + sourceBuilder.AppendLine(" {"); + foreach (var modelType in modelTypes) { - sourceBuilder.AppendLine(" internal static Dictionary ModelFamilyDetails { get; } = new ()"); - sourceBuilder.AppendLine(" {"); - foreach (var modelFamilyType in modelFamily) - { - var modelFamilyDefinition = modelFamilyType.Value; - sourceBuilder.AppendLine( - $$"""" - { - ModelType.{{modelFamilyType.Key}}, - new ModelFamily - { - Id = "{{modelFamilyDefinition.Id}}", - Name = "{{modelFamilyDefinition.Name}}", - Description = "{{modelFamilyDefinition.Description}}", - DocsUrl = "{{modelFamilyDefinition.DocsUrl}}", - ReadmeUrl = "{{modelFamilyDefinition.ReadmeUrl}}", - } - }, - """"); - } - - sourceBuilder.AppendLine(" };"); + var modelDefinition = modelType.Value; + var promptTemplate = modelDefinition.PromptTemplate != null ? + $"PromptTemplateHelpers.PromptTemplates[PromptTemplateType.{modelDefinition.PromptTemplate}]" : + "null"; + var hardwareAccelerator = string.Join(", ", modelDefinition.HardwareAccelerators.Select(ha => $"HardwareAccelerator.{ha}")); + var supportedOnQualcomm = modelDefinition.SupportedOnQualcomm.HasValue ? modelDefinition.SupportedOnQualcomm.Value.ToString().ToLower() : "null"; + var icon = !string.IsNullOrEmpty(modelDefinition.Icon) ? $"\"{modelDefinition.Icon}\"" : "string.Empty"; + var fileFilters = modelDefinition.FileFilters != null ? string.Join(", ", modelDefinition.FileFilters.Select(ff => $"\"{ff}\"")) : string.Empty; - sourceBuilder.AppendLine(); + sourceBuilder.AppendLine( + $$"""" + { + ModelType.{{modelType.Key}}, + new ModelDetails + { + Name = "{{modelDefinition.Name}}", + Id = "{{modelDefinition.Id}}", + Description = "{{modelDefinition.Description}}", + Url = "{{modelDefinition.Url}}", + HardwareAccelerators = [ {{hardwareAccelerator}} ], + SupportedOnQualcomm = {{supportedOnQualcomm}}, + Size = {{modelDefinition.Size}}, + ParameterSize = "{{modelDefinition.ParameterSize}}", + PromptTemplate = {{promptTemplate}}, + Icon = {{icon}}, + License = "{{modelDefinition.License}}", + FileFilters = [ {{fileFilters}} ] + } + }, + """"); } - private void GenerateApiDefinitionDetails(StringBuilder sourceBuilder, Dictionary apiDefinitions) + sourceBuilder.AppendLine(" };"); + } + + private void GenerateModelFamilyDetails(StringBuilder sourceBuilder, Dictionary modelFamily) + { + sourceBuilder.AppendLine(" internal static Dictionary ModelFamilyDetails { get; } = new ()"); + sourceBuilder.AppendLine(" {"); + foreach (var modelFamilyType in modelFamily) { - sourceBuilder.AppendLine(" internal static Dictionary ApiDefinitionDetails { get; } = new ()"); - sourceBuilder.AppendLine(" {"); - foreach (var apiDefinitionType in apiDefinitions) - { - var apiDefinition = apiDefinitionType.Value; - sourceBuilder.AppendLine( + var modelFamilyDefinition = modelFamilyType.Value; + sourceBuilder.AppendLine( $$"""" + { + ModelType.{{modelFamilyType.Key}}, + new ModelFamily { - ModelType.{{apiDefinitionType.Key}}, - new ApiDefinition - { - Id = "{{apiDefinition.Id}}", - Name = "{{apiDefinition.Name}}", - Icon = "{{apiDefinition.Icon}}", - ReadmeUrl = "{{apiDefinition.ReadmeUrl}}", - License = "{{apiDefinition.License}}", - } - }, - """"); - } + Id = "{{modelFamilyDefinition.Id}}", + Name = "{{modelFamilyDefinition.Name}}", + Description = "{{modelFamilyDefinition.Description}}", + DocsUrl = "{{modelFamilyDefinition.DocsUrl}}", + ReadmeUrl = "{{modelFamilyDefinition.ReadmeUrl}}", + } + }, + """"); + } - sourceBuilder.AppendLine(" };"); + sourceBuilder.AppendLine(" };"); - sourceBuilder.AppendLine(); - } + sourceBuilder.AppendLine(); + } - private void GenerateModelGroupDetails(StringBuilder sourceBuilder, Dictionary modelGroups) + private void GenerateApiDefinitionDetails(StringBuilder sourceBuilder, Dictionary apiDefinitions) + { + sourceBuilder.AppendLine(" internal static Dictionary ApiDefinitionDetails { get; } = new ()"); + sourceBuilder.AppendLine(" {"); + foreach (var apiDefinitionType in apiDefinitions) { - sourceBuilder.AppendLine(" internal static Dictionary ModelGroupDetails { get; } = new ()"); - sourceBuilder.AppendLine(" {"); - - foreach (var modelGroupType in modelGroups) - { - var modelGroupDefinition = modelGroupType.Value; - sourceBuilder.AppendLine( - $$"""" + var apiDefinition = apiDefinitionType.Value; + sourceBuilder.AppendLine( + $$"""" + { + ModelType.{{apiDefinitionType.Key}}, + new ApiDefinition { - ModelType.{{modelGroupType.Key}}, - new ModelGroup - { - Id = "{{modelGroupDefinition.Id}}", - Name = "{{modelGroupDefinition.Name}}", - Icon = {{Helpers.EscapeUnicodeString(modelGroupDefinition.Icon)}}, - IsApi = {{(modelGroupDefinition is ApiGroup).ToString().ToLower()}} - } - }, - """"); - } + Id = "{{apiDefinition.Id}}", + Name = "{{apiDefinition.Name}}", + Icon = "{{apiDefinition.Icon}}", + ReadmeUrl = "{{apiDefinition.ReadmeUrl}}", + License = "{{apiDefinition.License}}", + } + }, + """"); + } - sourceBuilder.AppendLine(" };"); + sourceBuilder.AppendLine(" };"); - sourceBuilder.AppendLine(); - } + sourceBuilder.AppendLine(); + } - private void GenerateModelParentMapping(StringBuilder sourceBuilder, Dictionary modelGroups, Dictionary modelFamilies, Dictionary apiDefinitions) + private void GenerateModelGroupDetails(StringBuilder sourceBuilder, Dictionary modelGroups) + { + sourceBuilder.AppendLine(" internal static Dictionary ModelGroupDetails { get; } = new ()"); + sourceBuilder.AppendLine(" {"); + + foreach (var modelGroupType in modelGroups) { - sourceBuilder.AppendLine(" internal static Dictionary> ParentMapping { get; } = new ()"); - sourceBuilder.AppendLine(" {"); + var modelGroupDefinition = modelGroupType.Value; + sourceBuilder.AppendLine( + $$"""" + { + ModelType.{{modelGroupType.Key}}, + new ModelGroup + { + Id = "{{modelGroupDefinition.Id}}", + Name = "{{modelGroupDefinition.Name}}", + Icon = {{Helpers.EscapeUnicodeString(modelGroupDefinition.Icon)}}, + IsApi = {{(modelGroupDefinition is ApiGroup).ToString().ToLower()}} + } + }, + """"); + } - var addedKeys = new HashSet(); + sourceBuilder.AppendLine(" };"); - void Print(string key, IEnumerable values) - { - if (addedKeys.Contains(key)) - { - return; - } - - addedKeys.Add(key); + sourceBuilder.AppendLine(); + } - sourceBuilder.AppendLine($$""" { ModelType.{{key}}, ["""); - foreach (var value in values) - { - sourceBuilder.AppendLine($$""" ModelType.{{value}}, """); - } + private void GenerateModelParentMapping(StringBuilder sourceBuilder, Dictionary modelGroups, Dictionary modelFamilies, Dictionary apiDefinitions) + { + sourceBuilder.AppendLine(" internal static Dictionary> ParentMapping { get; } = new ()"); + sourceBuilder.AppendLine(" {"); - sourceBuilder.AppendLine($$""" ] },"""); - } + var addedKeys = new HashSet(); - foreach (var modelGroupType in modelGroups.Where(kp => kp.Value is ModelGroup).ToDictionary(kp => kp.Key, kp => (ModelGroup)kp.Value)) + void Print(string key, IEnumerable values) + { + if (addedKeys.Contains(key)) { - Print(modelGroupType.Key, modelGroupType.Value.Models.Select(m => m.Key)); - - foreach (var modelFamilyType in modelGroupType.Value.Models) - { - Print(modelFamilyType.Key, modelFamilyType.Value.Models.Select(m => $"{modelFamilyType.Key}{m.Key}")); - } + return; } - foreach (var apiGroup in modelGroups.Where(kp => kp.Value is ApiGroup).ToDictionary(kp => kp.Key, kp => (ApiGroup)kp.Value)) + addedKeys.Add(key); + + sourceBuilder.AppendLine($$""" { ModelType.{{key}}, ["""); + foreach (var value in values) { - Print(apiGroup.Key, apiGroup.Value.Apis.Select(m => m.Key)); + sourceBuilder.AppendLine($$""" ModelType.{{value}}, """); } - foreach (var modelFamilyType in modelFamilies) + sourceBuilder.AppendLine($$""" ] },"""); + } + + foreach (var modelGroupType in modelGroups.Where(kp => kp.Value is ModelGroup).ToDictionary(kp => kp.Key, kp => (ModelGroup)kp.Value)) + { + Print(modelGroupType.Key, modelGroupType.Value.Models.Select(m => m.Key)); + + foreach (var modelFamilyType in modelGroupType.Value.Models) { Print(modelFamilyType.Key, modelFamilyType.Value.Models.Select(m => $"{modelFamilyType.Key}{m.Key}")); } + } - foreach (var apiDefinition in apiDefinitions) - { - Print(apiDefinition.Key, []); - } + foreach (var apiGroup in modelGroups.Where(kp => kp.Value is ApiGroup).ToDictionary(kp => kp.Key, kp => (ApiGroup)kp.Value)) + { + Print(apiGroup.Key, apiGroup.Value.Apis.Select(m => m.Key)); + } - sourceBuilder.AppendLine(" };"); + foreach (var modelFamilyType in modelFamilies) + { + Print(modelFamilyType.Key, modelFamilyType.Value.Models.Select(m => $"{modelFamilyType.Key}{m.Key}")); } - private void GenerateGetModelOrder(StringBuilder sourceBuilder, Dictionary modelGroups, Dictionary modelFamilies) + foreach (var apiDefinition in apiDefinitions) { - var addedOrders = new Dictionary(); + Print(apiDefinition.Key, []); + } - foreach (var modelGroupType in modelGroups) - { - addedOrders[modelGroupType.Key] = modelGroupType.Value.Order ?? int.MaxValue; + sourceBuilder.AppendLine(" };"); + } - if (modelGroupType.Value is ModelGroup modelGroup) - { - foreach (var modelFamilyType in modelGroup.Models) - { - addedOrders[modelFamilyType.Key] = modelFamilyType.Value.Order ?? int.MaxValue; - } - } - } + private void GenerateGetModelOrder(StringBuilder sourceBuilder, Dictionary modelGroups, Dictionary modelFamilies) + { + var addedOrders = new Dictionary(); - foreach (var modelFamilyType in modelFamilies) - { - addedOrders[modelFamilyType.Key] = modelFamilyType.Value.Order ?? int.MaxValue; - } + foreach (var modelGroupType in modelGroups) + { + addedOrders[modelGroupType.Key] = modelGroupType.Value.Order ?? int.MaxValue; - sourceBuilder.AppendLine(" internal static int GetModelOrder(ModelType modelType)"); - sourceBuilder.AppendLine(" {"); - sourceBuilder.AppendLine(" return modelType switch"); - sourceBuilder.AppendLine(" {"); - foreach (var keyOrder in addedOrders.OrderBy(kvp => kvp.Value)) + if (modelGroupType.Value is ModelGroup modelGroup) { - sourceBuilder.AppendLine( - $$"""" - ModelType.{{keyOrder.Key}} => {{keyOrder.Value}}, - """"); + foreach (var modelFamilyType in modelGroup.Models) + { + addedOrders[modelFamilyType.Key] = modelFamilyType.Value.Order ?? int.MaxValue; + } } + } - sourceBuilder.AppendLine(" _ => int.MaxValue,"); - sourceBuilder.AppendLine(" };"); + foreach (var modelFamilyType in modelFamilies) + { + addedOrders[modelFamilyType.Key] = modelFamilyType.Value.Order ?? int.MaxValue; + } - sourceBuilder.AppendLine(" }"); + sourceBuilder.AppendLine(" internal static int GetModelOrder(ModelType modelType)"); + sourceBuilder.AppendLine(" {"); + sourceBuilder.AppendLine(" return modelType switch"); + sourceBuilder.AppendLine(" {"); + foreach (var keyOrder in addedOrders.OrderBy(kvp => kvp.Value)) + { + sourceBuilder.AppendLine( + $$"""" + ModelType.{{keyOrder.Key}} => {{keyOrder.Value}}, + """"); } + + sourceBuilder.AppendLine(" _ => int.MaxValue,"); + sourceBuilder.AppendLine(" };"); + + sourceBuilder.AppendLine(" }"); } } diff --git a/AIDevGallery.SourceGenerator/PromptTemplatesSourceGenerator.cs b/AIDevGallery.SourceGenerator/PromptTemplatesSourceGenerator.cs index f33bb3b..aad1056 100644 --- a/AIDevGallery.SourceGenerator/PromptTemplatesSourceGenerator.cs +++ b/AIDevGallery.SourceGenerator/PromptTemplatesSourceGenerator.cs @@ -11,126 +11,125 @@ using System.Text; using System.Text.Json; -namespace AIDevGallery.SourceGenerator +namespace AIDevGallery.SourceGenerator; + +[Generator(LanguageNames.CSharp)] +internal class PromptTemplatesSourceGenerator : IIncrementalGenerator { - [Generator(LanguageNames.CSharp)] - internal class PromptTemplatesSourceGenerator : IIncrementalGenerator - { - private Dictionary? promptTemplates = null; + private Dictionary? promptTemplates = null; - public void Initialize(IncrementalGeneratorInitializationContext context) + public void Initialize(IncrementalGeneratorInitializationContext context) + { + string promptTemplateJson; + var assembly = Assembly.GetExecutingAssembly(); + using (Stream stream = assembly.GetManifestResourceStream("AIDevGallery.SourceGenerator.promptTemplates.json")) { - string promptTemplateJson; - var assembly = Assembly.GetExecutingAssembly(); - using (Stream stream = assembly.GetManifestResourceStream("AIDevGallery.SourceGenerator.promptTemplates.json")) + using (StreamReader reader = new(stream)) { - using (StreamReader reader = new(stream)) - { - promptTemplateJson = reader.ReadToEnd().Trim(); - } + promptTemplateJson = reader.ReadToEnd().Trim(); } - - promptTemplates = JsonSerializer.Deserialize(promptTemplateJson, SourceGenerationContext.Default.DictionaryStringPromptTemplate); - context.RegisterPostInitializationOutput(Execute); } - public void Execute(IncrementalGeneratorPostInitializationContext context) - { - if (promptTemplates == null) - { - return; - } - - GeneratePromptTemplatesTypeFile(context, promptTemplates); + promptTemplates = JsonSerializer.Deserialize(promptTemplateJson, SourceGenerationContext.Default.DictionaryStringPromptTemplate); + context.RegisterPostInitializationOutput(Execute); + } - GeneratePromptTemplatesHelpersFile(context, promptTemplates); + public void Execute(IncrementalGeneratorPostInitializationContext context) + { + if (promptTemplates == null) + { + return; } - private void GeneratePromptTemplatesHelpersFile(IncrementalGeneratorPostInitializationContext context, Dictionary promptTemplates) - { - var sourceBuilder = new StringBuilder(); + GeneratePromptTemplatesTypeFile(context, promptTemplates); - sourceBuilder.AppendLine( - $$"""" - #nullable enable + GeneratePromptTemplatesHelpersFile(context, promptTemplates); + } - using System.Collections.Generic; - using AIDevGallery.Models; + private void GeneratePromptTemplatesHelpersFile(IncrementalGeneratorPostInitializationContext context, Dictionary promptTemplates) + { + var sourceBuilder = new StringBuilder(); - namespace AIDevGallery.Samples; + sourceBuilder.AppendLine( + $$"""" + #nullable enable - internal static partial class PromptTemplateHelpers - { - """"); + using System.Collections.Generic; + using AIDevGallery.Models; - sourceBuilder.AppendLine(" internal static Dictionary PromptTemplates { get; } = new ()"); - sourceBuilder.AppendLine(" {"); - foreach (var promptTemplate in promptTemplates) + namespace AIDevGallery.Samples; + + internal static partial class PromptTemplateHelpers { - sourceBuilder.AppendLine( - $$"""" + """"); + + sourceBuilder.AppendLine(" internal static Dictionary PromptTemplates { get; } = new ()"); + sourceBuilder.AppendLine(" {"); + foreach (var promptTemplate in promptTemplates) + { + sourceBuilder.AppendLine( + $$"""" + { + PromptTemplateType.{{promptTemplate.Key}}, + new PromptTemplate { - PromptTemplateType.{{promptTemplate.Key}}, - new PromptTemplate - { - """"); - if (promptTemplate.Value.System != null) - { - sourceBuilder.AppendLine($$""" System = {{Helpers.EscapeUnicodeString(promptTemplate.Value.System)}},"""); - } - - sourceBuilder.AppendLine( - $$"""" - User = {{Helpers.EscapeUnicodeString(promptTemplate.Value.User)}}, - """"); - if (promptTemplate.Value.Assistant != null) - { - sourceBuilder.AppendLine($$""" Assistant = {{Helpers.EscapeUnicodeString(promptTemplate.Value.Assistant)}},"""); - } - - sourceBuilder.AppendLine( - $$"""" - Stop = [{{string.Join(", ", promptTemplate.Value.Stop.Select(Helpers.EscapeUnicodeString))}}] - } - }, - """"); + """"); + if (promptTemplate.Value.System != null) + { + sourceBuilder.AppendLine($$""" System = {{Helpers.EscapeUnicodeString(promptTemplate.Value.System)}},"""); } - sourceBuilder.AppendLine(" };"); - - sourceBuilder.AppendLine("}"); + sourceBuilder.AppendLine( + $$"""" + User = {{Helpers.EscapeUnicodeString(promptTemplate.Value.User)}}, + """"); + if (promptTemplate.Value.Assistant != null) + { + sourceBuilder.AppendLine($$""" Assistant = {{Helpers.EscapeUnicodeString(promptTemplate.Value.Assistant)}},"""); + } - context.AddSource($"PromptTemplateHelpers.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); + sourceBuilder.AppendLine( + $$"""" + Stop = [{{string.Join(", ", promptTemplate.Value.Stop.Select(Helpers.EscapeUnicodeString))}}] + } + }, + """"); } - private static void GeneratePromptTemplatesTypeFile(IncrementalGeneratorPostInitializationContext context, Dictionary promptTemplates) - { - var sourceBuilder = new StringBuilder(); + sourceBuilder.AppendLine(" };"); - sourceBuilder.AppendLine( - $$"""" - #nullable enable + sourceBuilder.AppendLine("}"); - using System.Collections.Generic; - using AIDevGallery.Models; + context.AddSource($"PromptTemplateHelpers.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); + } - namespace AIDevGallery.Models; + private static void GeneratePromptTemplatesTypeFile(IncrementalGeneratorPostInitializationContext context, Dictionary promptTemplates) + { + var sourceBuilder = new StringBuilder(); - internal enum PromptTemplateType - { - """"); + sourceBuilder.AppendLine( + $$"""" + #nullable enable - foreach (var promptTemplate in promptTemplates) - { - sourceBuilder.AppendLine($" {promptTemplate.Key},"); - } + using System.Collections.Generic; + using AIDevGallery.Models; - sourceBuilder.AppendLine( - $$"""" - } - """"); + namespace AIDevGallery.Models; + + internal enum PromptTemplateType + { + """"); - context.AddSource("PromptTemplateType.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); + foreach (var promptTemplate in promptTemplates) + { + sourceBuilder.AppendLine($" {promptTemplate.Key},"); } + + sourceBuilder.AppendLine( + $$"""" + } + """"); + + context.AddSource("PromptTemplateType.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/SamplesSourceGenerator.cs b/AIDevGallery.SourceGenerator/SamplesSourceGenerator.cs index 561a8ab..089a495 100644 --- a/AIDevGallery.SourceGenerator/SamplesSourceGenerator.cs +++ b/AIDevGallery.SourceGenerator/SamplesSourceGenerator.cs @@ -18,426 +18,425 @@ #pragma warning disable RS1035 // Do not use APIs banned for analyzers -namespace AIDevGallery.SourceGenerator +namespace AIDevGallery.SourceGenerator; + +[Generator(LanguageNames.CSharp)] +internal class SamplesSourceGenerator : IIncrementalGenerator { - [Generator(LanguageNames.CSharp)] - internal class SamplesSourceGenerator : IIncrementalGenerator + private void ExecuteSharedCodeEnumGeneration(SourceProductionContext context, ImmutableArray typeSymbols) { - private void ExecuteSharedCodeEnumGeneration(SourceProductionContext context, ImmutableArray typeSymbols) + // Filter types by the target namespace + var typesInNamespace = typeSymbols + .Where(typeSymbol => typeSymbol != null && + (typeSymbol.ContainingNamespace.ToDisplayString().StartsWith("AIDevGallery.Samples.SharedCode", StringComparison.Ordinal) || + typeSymbol.GetFullyQualifiedName().Equals("global::AIDevGallery.Utils.DeviceUtils", StringComparison.Ordinal))) + .ToList(); + + if (!typesInNamespace.Any()) { - // Filter types by the target namespace - var typesInNamespace = typeSymbols - .Where(typeSymbol => typeSymbol != null && - (typeSymbol.ContainingNamespace.ToDisplayString().StartsWith("AIDevGallery.Samples.SharedCode", StringComparison.Ordinal) || - typeSymbol.GetFullyQualifiedName().Equals("global::AIDevGallery.Utils.DeviceUtils", StringComparison.Ordinal))) - .ToList(); - - if (!typesInNamespace.Any()) - { - return; - } + return; + } - // Generate the enum source - var sourceBuilder = new StringBuilder(); - sourceBuilder.AppendLine("#nullable enable"); - sourceBuilder.AppendLine(); - sourceBuilder.AppendLine($"namespace AIDevGallery.Samples"); - sourceBuilder.AppendLine("{"); - sourceBuilder.AppendLine(" internal enum SharedCodeEnum"); - sourceBuilder.AppendLine(" {"); + // Generate the enum source + var sourceBuilder = new StringBuilder(); + sourceBuilder.AppendLine("#nullable enable"); + sourceBuilder.AppendLine(); + sourceBuilder.AppendLine($"namespace AIDevGallery.Samples"); + sourceBuilder.AppendLine("{"); + sourceBuilder.AppendLine(" internal enum SharedCodeEnum"); + sourceBuilder.AppendLine(" {"); - List filePaths = []; + List filePaths = []; - foreach (var type in typesInNamespace) - { - var filePath = type!.Locations[0].SourceTree?.FilePath; + foreach (var type in typesInNamespace) + { + var filePath = type!.Locations[0].SourceTree?.FilePath; - if (filePath != null) + if (filePath != null) + { + if (!filePaths.Contains(filePath)) { - if (!filePaths.Contains(filePath)) - { - filePaths.Add(filePath); - } + filePaths.Add(filePath); } } + } - filePaths.Add("NativeMethods.txt"); + filePaths.Add("NativeMethods.txt"); - foreach (var filePath in filePaths) + foreach (var filePath in filePaths) + { + var fileName = Path.GetFileNameWithoutExtension(filePath); + if (File.Exists(Path.ChangeExtension(filePath, ".xaml"))) { - var fileName = Path.GetFileNameWithoutExtension(filePath); - if (File.Exists(Path.ChangeExtension(filePath, ".xaml"))) - { - sourceBuilder.AppendLine($" {fileName}Cs,"); - sourceBuilder.AppendLine($" {fileName}Xaml,"); - } - else - { - sourceBuilder.AppendLine($" {fileName},"); - } + sourceBuilder.AppendLine($" {fileName}Cs,"); + sourceBuilder.AppendLine($" {fileName}Xaml,"); } + else + { + sourceBuilder.AppendLine($" {fileName},"); + } + } - sourceBuilder.AppendLine(" }"); - sourceBuilder.AppendLine(); - sourceBuilder.AppendLine(" internal static class SharedCodeHelpers"); - sourceBuilder.AppendLine(" {"); - sourceBuilder.AppendLine(" internal static string GetName(SharedCodeEnum sharedCode)"); - sourceBuilder.AppendLine(" {"); - sourceBuilder.AppendLine(" return sharedCode switch"); - sourceBuilder.AppendLine(" {"); - foreach (var filePath in filePaths) + sourceBuilder.AppendLine(" }"); + sourceBuilder.AppendLine(); + sourceBuilder.AppendLine(" internal static class SharedCodeHelpers"); + sourceBuilder.AppendLine(" {"); + sourceBuilder.AppendLine(" internal static string GetName(SharedCodeEnum sharedCode)"); + sourceBuilder.AppendLine(" {"); + sourceBuilder.AppendLine(" return sharedCode switch"); + sourceBuilder.AppendLine(" {"); + foreach (var filePath in filePaths) + { + var fileName = Path.GetFileNameWithoutExtension(filePath); + var filePathXaml = Path.ChangeExtension(filePath, ".xaml"); + if (File.Exists(filePathXaml)) { - var fileName = Path.GetFileNameWithoutExtension(filePath); - var filePathXaml = Path.ChangeExtension(filePath, ".xaml"); - if (File.Exists(filePathXaml)) - { - sourceBuilder.AppendLine($" SharedCodeEnum.{Path.GetFileNameWithoutExtension(filePath)}Xaml => \"{Path.GetFileName(filePathXaml)}\","); - sourceBuilder.AppendLine($" SharedCodeEnum.{Path.GetFileNameWithoutExtension(filePath)}Cs => \"{Path.GetFileName(filePath)}\","); - } - else - { - sourceBuilder.AppendLine($" SharedCodeEnum.{Path.GetFileNameWithoutExtension(filePath)} => \"{Path.GetFileName(filePath)}\","); - } + sourceBuilder.AppendLine($" SharedCodeEnum.{Path.GetFileNameWithoutExtension(filePath)}Xaml => \"{Path.GetFileName(filePathXaml)}\","); + sourceBuilder.AppendLine($" SharedCodeEnum.{Path.GetFileNameWithoutExtension(filePath)}Cs => \"{Path.GetFileName(filePath)}\","); + } + else + { + sourceBuilder.AppendLine($" SharedCodeEnum.{Path.GetFileNameWithoutExtension(filePath)} => \"{Path.GetFileName(filePath)}\","); } + } - sourceBuilder.AppendLine(" _ => string.Empty,"); - sourceBuilder.AppendLine(" };"); - sourceBuilder.AppendLine(" }"); - sourceBuilder.AppendLine(" internal static string GetSource(SharedCodeEnum sharedCode)"); - sourceBuilder.AppendLine(" {"); - sourceBuilder.AppendLine(" return sharedCode switch"); - sourceBuilder.AppendLine(" {"); - foreach (var filePath in filePaths) + sourceBuilder.AppendLine(" _ => string.Empty,"); + sourceBuilder.AppendLine(" };"); + sourceBuilder.AppendLine(" }"); + sourceBuilder.AppendLine(" internal static string GetSource(SharedCodeEnum sharedCode)"); + sourceBuilder.AppendLine(" {"); + sourceBuilder.AppendLine(" return sharedCode switch"); + sourceBuilder.AppendLine(" {"); + foreach (var filePath in filePaths) + { + var fileName = Path.GetFileNameWithoutExtension(filePath); + var filePathXaml = Path.ChangeExtension(filePath, ".xaml"); + if (File.Exists(filePathXaml)) { - var fileName = Path.GetFileNameWithoutExtension(filePath); - var filePathXaml = Path.ChangeExtension(filePath, ".xaml"); - if (File.Exists(filePathXaml)) - { - var fileContentXaml = XamlSourceCleanUp(File.ReadAllText(filePathXaml)); - sourceBuilder.AppendLine( - $$"""""" - SharedCodeEnum.{{fileName}}Xaml => - """ - {{fileContentXaml}} - """, - """"""); + var fileContentXaml = XamlSourceCleanUp(File.ReadAllText(filePathXaml)); + sourceBuilder.AppendLine( + $$"""""" + SharedCodeEnum.{{fileName}}Xaml => + """ + {{fileContentXaml}} + """, + """"""); - fileName = $"{fileName}Cs"; - } + fileName = $"{fileName}Cs"; + } - string fileContent; - if (fileName == "NativeMethods") + string fileContent; + if (fileName == "NativeMethods") + { + var assembly = Assembly.GetExecutingAssembly(); + using (Stream stream = assembly.GetManifestResourceStream("AIDevGallery.SourceGenerator.NativeMethods.txt")) { - var assembly = Assembly.GetExecutingAssembly(); - using (Stream stream = assembly.GetManifestResourceStream("AIDevGallery.SourceGenerator.NativeMethods.txt")) + using (StreamReader reader = new(stream)) { - using (StreamReader reader = new(stream)) - { - fileContent = reader.ReadToEnd().Trim(); - } + fileContent = reader.ReadToEnd().Trim(); } } - else - { - fileContent = SampleSourceCleanUp(File.ReadAllText(filePath), filePath); - } - - sourceBuilder.AppendLine( - $$"""""" - SharedCodeEnum.{{fileName}} => - """ - {{fileContent}} - """, - """"""); + } + else + { + fileContent = SampleSourceCleanUp(File.ReadAllText(filePath), filePath); } - sourceBuilder.AppendLine(" _ => string.Empty,"); - sourceBuilder.AppendLine(" };"); - sourceBuilder.AppendLine(" }"); - sourceBuilder.AppendLine(" }"); + sourceBuilder.AppendLine( + $$"""""" + SharedCodeEnum.{{fileName}} => + """ + {{fileContent}} + """, + """"""); + } - sourceBuilder.AppendLine("}"); + sourceBuilder.AppendLine(" _ => string.Empty,"); + sourceBuilder.AppendLine(" };"); + sourceBuilder.AppendLine(" }"); + sourceBuilder.AppendLine(" }"); - context.AddSource("SharedCodeEnum.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); - } + sourceBuilder.AppendLine("}"); - private static readonly Regex GallerySampleAttributeRemovalRegex = new(@"\n(\s)*\[GallerySample\((?>[^()]+|\((?)|\)(?<-DEPTH>))*(?(DEPTH)(?!))\)\]", RegexOptions.Compiled); - private static readonly Regex ExcludedElementXamlRemovalRegex = new(@")|(.*<\/EXCLUDE:[a-zA-Z]*>))", RegexOptions.Singleline | RegexOptions.Compiled); - private static readonly Regex ExcludedAttrbitueXamlRemovalRegex = new(@"EXCLUDE:[^""]*""[^""]*""", RegexOptions.Singleline | RegexOptions.Compiled); + context.AddSource("SharedCodeEnum.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); + } - private static string SampleSourceCleanUp(string input, string filePath) - { - var header = @"// Copyright (c) Microsoft Corporation. All rights reserved. + private static readonly Regex GallerySampleAttributeRemovalRegex = new(@"\n(\s)*\[GallerySample\((?>[^()]+|\((?)|\)(?<-DEPTH>))*(?(DEPTH)(?!))\)\]", RegexOptions.Compiled); + private static readonly Regex ExcludedElementXamlRemovalRegex = new(@")|(.*<\/EXCLUDE:[a-zA-Z]*>))", RegexOptions.Singleline | RegexOptions.Compiled); + private static readonly Regex ExcludedAttrbitueXamlRemovalRegex = new(@"EXCLUDE:[^""]*""[^""]*""", RegexOptions.Singleline | RegexOptions.Compiled); + + private static string SampleSourceCleanUp(string input, string filePath) + { + var header = @"// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License."; - if (input.StartsWith(header, StringComparison.Ordinal)) - { - input = input.Substring(header.Length) - .TrimStart(Environment.NewLine.ToCharArray()) - .TrimStart(); - } + if (input.StartsWith(header, StringComparison.Ordinal)) + { + input = input.Substring(header.Length) + .TrimStart(Environment.NewLine.ToCharArray()) + .TrimStart(); + } - input = GallerySampleAttributeRemovalRegex.Replace(input, string.Empty); - input = RemoveExcludedLinesCs(input, filePath); + input = GallerySampleAttributeRemovalRegex.Replace(input, string.Empty); + input = RemoveExcludedLinesCs(input, filePath); - return input; - } + return input; + } - private static string XamlSourceCleanUp(string input) + private static string XamlSourceCleanUp(string input) + { + if (input.Contains("xmlns:EXCLUDE")) { - if (input.Contains("xmlns:EXCLUDE")) - { - input = ExcludedElementXamlRemovalRegex.Replace(input, string.Empty); - input = ExcludedAttrbitueXamlRemovalRegex.Replace(input, string.Empty); - input = RemoveEmptyLines(input); - } - - return input; + input = ExcludedElementXamlRemovalRegex.Replace(input, string.Empty); + input = ExcludedAttrbitueXamlRemovalRegex.Replace(input, string.Empty); + input = RemoveEmptyLines(input); } - private static string RemoveEmptyLines(string input) - { - var lines = input.Split([Environment.NewLine], StringSplitOptions.None); - var nonEmptyLines = lines.Where(line => !string.IsNullOrWhiteSpace(line)); + return input; + } - return string.Join(Environment.NewLine, nonEmptyLines); - } + private static string RemoveEmptyLines(string input) + { + var lines = input.Split([Environment.NewLine], StringSplitOptions.None); + var nonEmptyLines = lines.Where(line => !string.IsNullOrWhiteSpace(line)); - private static string RemoveExcludedLinesCs(string input, string filePath) - { - List lines = new(input.Split([Environment.NewLine], StringSplitOptions.None)); + return string.Join(Environment.NewLine, nonEmptyLines); + } + + private static string RemoveExcludedLinesCs(string input, string filePath) + { + List lines = new(input.Split([Environment.NewLine], StringSplitOptions.None)); - for (int i = 0; i < lines.Count;) + for (int i = 0; i < lines.Count;) + { + if (lines[i].Contains("//") || lines[i].Contains("// ")) { - if (lines[i].Contains("//") || lines[i].Contains("// ")) + while (!lines[i].Contains("//") && !lines[i].Contains("// ")) { - while (!lines[i].Contains("//") && !lines[i].Contains("// ")) + lines.RemoveAt(i); + if (i >= lines.Count) { - lines.RemoveAt(i); - if (i >= lines.Count) - { - throw new InvalidOperationException($" block is never closed in file '{filePath}'"); - } + throw new InvalidOperationException($" block is never closed in file '{filePath}'"); } - - lines.RemoveAt(i); - } - else if (lines[i].Contains("// ") || lines[i].Contains("//")) - { - lines.RemoveAt(i); } - else - { - i++; - } - } - return string.Join(Environment.NewLine, lines); + lines.RemoveAt(i); + } + else if (lines[i].Contains("// ") || lines[i].Contains("//")) + { + lines.RemoveAt(i); + } + else + { + i++; + } } - private static INamedTypeSymbol? GetTypeSymbol(GeneratorSyntaxContext context) - { - var typeDeclaration = (TypeDeclarationSyntax)context.Node; - return context.SemanticModel.GetDeclaredSymbol(typeDeclaration) as INamedTypeSymbol; - } + return string.Join(Environment.NewLine, lines); + } - public void Initialize(IncrementalGeneratorInitializationContext context) - { - var typeDeclarations = context.SyntaxProvider - .CreateSyntaxProvider( - predicate: static (s, _) => s is TypeDeclarationSyntax, - transform: static (ctx, _) => GetTypeSymbol(ctx)) - .Where(static m => m != null) - .Collect(); - - // Combine the results of the syntax provider and add the source - context.RegisterSourceOutput(typeDeclarations, ExecuteSharedCodeEnumGeneration); - - var gallerySamplePipeline = context.SyntaxProvider - .ForAttributeWithMetadataName( - WellKnownTypeNames.GallerySampleAttribute, - predicate: (node, _) => node is ClassDeclarationSyntax, - transform: static (context, cancellationToken) => + private static INamedTypeSymbol? GetTypeSymbol(GeneratorSyntaxContext context) + { + var typeDeclaration = (TypeDeclarationSyntax)context.Node; + return context.SemanticModel.GetDeclaredSymbol(typeDeclaration) as INamedTypeSymbol; + } + + public void Initialize(IncrementalGeneratorInitializationContext context) + { + var typeDeclarations = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: static (s, _) => s is TypeDeclarationSyntax, + transform: static (ctx, _) => GetTypeSymbol(ctx)) + .Where(static m => m != null) + .Collect(); + + // Combine the results of the syntax provider and add the source + context.RegisterSourceOutput(typeDeclarations, ExecuteSharedCodeEnumGeneration); + + var gallerySamplePipeline = context.SyntaxProvider + .ForAttributeWithMetadataName( + WellKnownTypeNames.GallerySampleAttribute, + predicate: (node, _) => node is ClassDeclarationSyntax, + transform: static (context, cancellationToken) => + { + INamedTypeSymbol typeSymbol = (INamedTypeSymbol)context.TargetSymbol; + + if (context.TargetSymbol!.Locations[0].SourceTree == null) { - INamedTypeSymbol typeSymbol = (INamedTypeSymbol)context.TargetSymbol; + return null; + } - if (context.TargetSymbol!.Locations[0].SourceTree == null) - { - return null; - } + WellKnownTypeSymbols typeSymbols = new(context.SemanticModel.Compilation); - WellKnownTypeSymbols typeSymbols = new(context.SemanticModel.Compilation); + if (!typeSymbol.TryGetAttributeWithType(typeSymbols.GallerySampleAttribute, out AttributeData? attributeData)) + { + return null; + } - if (!typeSymbol.TryGetAttributeWithType(typeSymbols.GallerySampleAttribute, out AttributeData? attributeData)) - { - return null; - } + var filePath = context.TargetSymbol!.Locations[0].SourceTree!.FilePath; + var folder = Path.GetDirectoryName(filePath); + var fileName = Path.GetFileName(filePath); + var fileNameWithoutExtension = Path.GetFileNameWithoutExtension(filePath); + if (fileNameWithoutExtension.EndsWith(".xaml")) + { + fileNameWithoutExtension = Path.GetFileNameWithoutExtension(fileNameWithoutExtension); + } - var filePath = context.TargetSymbol!.Locations[0].SourceTree!.FilePath; - var folder = Path.GetDirectoryName(filePath); - var fileName = Path.GetFileName(filePath); - var fileNameWithoutExtension = Path.GetFileNameWithoutExtension(filePath); - if (fileNameWithoutExtension.EndsWith(".xaml")) - { - fileNameWithoutExtension = Path.GetFileNameWithoutExtension(fileNameWithoutExtension); - } + var sampleXamlFile = Directory.GetFiles(folder).Where(f => f.EndsWith($"\\{fileNameWithoutExtension}.xaml", StringComparison.InvariantCultureIgnoreCase)).FirstOrDefault(); + var sampleXamlFileContent = XamlSourceCleanUp(File.ReadAllText(sampleXamlFile)); - var sampleXamlFile = Directory.GetFiles(folder).Where(f => f.EndsWith($"\\{fileNameWithoutExtension}.xaml", StringComparison.InvariantCultureIgnoreCase)).FirstOrDefault(); - var sampleXamlFileContent = XamlSourceCleanUp(File.ReadAllText(sampleXamlFile)); + var pageType = context.TargetSymbol.GetFullyQualifiedName(); - var pageType = context.TargetSymbol.GetFullyQualifiedName(); + var sampleXamlCsFile = Directory.GetFiles(folder).Where(f => f.EndsWith($"\\{fileNameWithoutExtension}.xaml.cs", StringComparison.InvariantCultureIgnoreCase)).FirstOrDefault(); + var sampleXamlCsFileContent = SampleSourceCleanUp(File.ReadAllText(sampleXamlCsFile), sampleXamlCsFile); - var sampleXamlCsFile = Directory.GetFiles(folder).Where(f => f.EndsWith($"\\{fileNameWithoutExtension}.xaml.cs", StringComparison.InvariantCultureIgnoreCase)).FirstOrDefault(); - var sampleXamlCsFileContent = SampleSourceCleanUp(File.ReadAllText(sampleXamlCsFile), sampleXamlCsFile); + if (attributeData == null) + { + return null; + } - if (attributeData == null) + try + { + string name = (string)attributeData.NamedArguments.First(a => a.Key == "Name").Value.Value!; + string id = attributeData.NamedArguments.FirstOrDefault(a => a.Key == "Id").Value.Value as string ?? string.Empty; + string icon = attributeData.NamedArguments.FirstOrDefault(a => a.Key == "Icon").Value.Value as string ?? string.Empty; + string? scenario = attributeData.NamedArguments.FirstOrDefault(a => a.Key == "Scenario").Value.Value?.ToString(); + string[]? nugetPackageReferences = null; + var nugetPackageReferencesRef = attributeData.NamedArguments.FirstOrDefault(a => a.Key == "NugetPackageReferences"); + if (!nugetPackageReferencesRef.Value.IsNull) { - return null; + nugetPackageReferences = nugetPackageReferencesRef.Value.Values.Select(v => (string)v.Value!).ToArray(); } - try - { - string name = (string)attributeData.NamedArguments.First(a => a.Key == "Name").Value.Value!; - string id = attributeData.NamedArguments.FirstOrDefault(a => a.Key == "Id").Value.Value as string ?? string.Empty; - string icon = attributeData.NamedArguments.FirstOrDefault(a => a.Key == "Icon").Value.Value as string ?? string.Empty; - string? scenario = attributeData.NamedArguments.FirstOrDefault(a => a.Key == "Scenario").Value.Value?.ToString(); - string[]? nugetPackageReferences = null; - var nugetPackageReferencesRef = attributeData.NamedArguments.FirstOrDefault(a => a.Key == "NugetPackageReferences"); - if (!nugetPackageReferencesRef.Value.IsNull) - { - nugetPackageReferences = nugetPackageReferencesRef.Value.Values.Select(v => (string)v.Value!).ToArray(); - } + return new SampleModel( + Owner: typeSymbol.GetFullyQualifiedName(), + Name: name, + PageType: pageType, + XAMLCode: sampleXamlFileContent, + CSCode: sampleXamlCsFileContent, + Id: id, + Icon: icon, + Scenario: scenario, + NugetPackageReferences: nugetPackageReferences, + attributeData.GetLocation()); + } + catch (Exception) + { + throw new InvalidOperationException($"Error when processing {typeSymbol.GetFullyQualifiedName()} - GallerySampleAttribute: {attributeData}"); + } + }) + .Where(static m => m != null) + .Select(static (m, _) => m!) + .Collect(); - return new SampleModel( - Owner: typeSymbol.GetFullyQualifiedName(), - Name: name, - PageType: pageType, - XAMLCode: sampleXamlFileContent, - CSCode: sampleXamlCsFileContent, - Id: id, - Icon: icon, - Scenario: scenario, - NugetPackageReferences: nugetPackageReferences, - attributeData.GetLocation()); - } - catch (Exception) - { - throw new InvalidOperationException($"Error when processing {typeSymbol.GetFullyQualifiedName()} - GallerySampleAttribute: {attributeData}"); - } - }) - .Where(static m => m != null) - .Select(static (m, _) => m!) - .Collect(); + context.RegisterImplementationSourceOutput(gallerySamplePipeline, static (context, samples) => + { + var sourceBuilder = new StringBuilder(); - context.RegisterImplementationSourceOutput(gallerySamplePipeline, static (context, samples) => - { - var sourceBuilder = new StringBuilder(); + sourceBuilder.AppendLine( + $$"""" + #nullable enable - sourceBuilder.AppendLine( - $$"""" - #nullable enable + using System.Collections.Generic; + using System.Linq; + using AIDevGallery.Models; + using AIDevGallery.Samples.Attributes; - using System.Collections.Generic; - using System.Linq; - using AIDevGallery.Models; - using AIDevGallery.Samples.Attributes; + namespace AIDevGallery.Samples; - namespace AIDevGallery.Samples; + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + internal static class SampleDetails + { + private static List GetSharedCodeFrom(System.Type type) + { + return type.GetCustomAttributes(typeof(GallerySampleAttribute), false) + .Cast() + .First().SharedCode?.ToList() ?? new (); + } - [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] - internal static class SampleDetails + private static List? GetModelTypesFrom(int index, System.Type type) { - private static List GetSharedCodeFrom(System.Type type) + if (index == 2) { return type.GetCustomAttributes(typeof(GallerySampleAttribute), false) .Cast() - .First().SharedCode?.ToList() ?? new (); + .First().Model2Types?.ToList(); } - private static List? GetModelTypesFrom(int index, System.Type type) - { - if (index == 2) - { - return type.GetCustomAttributes(typeof(GallerySampleAttribute), false) - .Cast() - .First().Model2Types?.ToList(); - } - - return type.GetCustomAttributes(typeof(GallerySampleAttribute), false) - .Cast() - .First().Model1Types.ToList(); - } + return type.GetCustomAttributes(typeof(GallerySampleAttribute), false) + .Cast() + .First().Model1Types.ToList(); + } - internal static List Samples = [ - """"); + internal static List Samples = [ + """"); - var packageVersions = Helpers.GetPackageVersions(); + var packageVersions = Helpers.GetPackageVersions(); - foreach (var sample in samples) + foreach (var sample in samples) + { + if (sample.Scenario == null) { - if (sample.Scenario == null) - { - // TODO: Remove when APIs are added, and mark scenario as required on GallerySampleAttribute - Debug.WriteLine($"Scenario is null for {sample.Name}"); - } - else - { - var nugetPackageReferences = sample.NugetPackageReferences != null && sample.NugetPackageReferences.Length > 0 - ? string.Join(", ", sample.NugetPackageReferences.Select(r => $"\"{r}\"")) - : string.Empty; + // TODO: Remove when APIs are added, and mark scenario as required on GallerySampleAttribute + Debug.WriteLine($"Scenario is null for {sample.Name}"); + } + else + { + var nugetPackageReferences = sample.NugetPackageReferences != null && sample.NugetPackageReferences.Length > 0 + ? string.Join(", ", sample.NugetPackageReferences.Select(r => $"\"{r}\"")) + : string.Empty; - if (sample.NugetPackageReferences != null && sample.NugetPackageReferences.Length > 0) + if (sample.NugetPackageReferences != null && sample.NugetPackageReferences.Length > 0) + { + foreach (var packageReference in sample.NugetPackageReferences) { - foreach (var packageReference in sample.NugetPackageReferences) + if (!packageVersions.ContainsKey(packageReference)) { - if (!packageVersions.ContainsKey(packageReference)) - { - context.ReportDiagnostic(Diagnostic.Create( - DiagnosticDescriptors.NugetPackageNotUsed, - sample.Location, - packageReference)); - } + context.ReportDiagnostic(Diagnostic.Create( + DiagnosticDescriptors.NugetPackageNotUsed, + sample.Location, + packageReference)); } } - - sourceBuilder.AppendLine( - $$"""""" - new Sample - { - Name = "{{sample.Name}}", - PageType = typeof({{sample.PageType}}), - XAMLCode = - """ - {{sample.XAMLCode}} - """, - CSCode = - """" - {{sample.CSCode}} - """", - Id = "{{sample.Id}}", - Icon = {{Helpers.EscapeUnicodeString(sample.Icon)}}, - Scenario = (ScenarioType){{sample.Scenario}}, - Model1Types = GetModelTypesFrom(1, typeof({{sample.Owner}}))!, - Model2Types = GetModelTypesFrom(2, typeof({{sample.Owner}})), - SharedCode = GetSharedCodeFrom(typeof({{sample.Owner}})), - NugetPackageReferences = [ {{nugetPackageReferences}} ] - }, - """"""); } - } - sourceBuilder.AppendLine( - """ - ]; - } - """); + sourceBuilder.AppendLine( + $$"""""" + new Sample + { + Name = "{{sample.Name}}", + PageType = typeof({{sample.PageType}}), + XAMLCode = + """ + {{sample.XAMLCode}} + """, + CSCode = + """" + {{sample.CSCode}} + """", + Id = "{{sample.Id}}", + Icon = {{Helpers.EscapeUnicodeString(sample.Icon)}}, + Scenario = (ScenarioType){{sample.Scenario}}, + Model1Types = GetModelTypesFrom(1, typeof({{sample.Owner}}))!, + Model2Types = GetModelTypesFrom(2, typeof({{sample.Owner}})), + SharedCode = GetSharedCodeFrom(typeof({{sample.Owner}})), + NugetPackageReferences = [ {{nugetPackageReferences}} ] + }, + """"""); + } + } - context.AddSource($"SampleDetails.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); - }); - } + sourceBuilder.AppendLine( + """ + ]; + } + """); - private record ScenarioModel(string EnumName, string ScenarioCategoryType, string Id, string Name, string Description); - private record SampleModel(string Owner, string Name, string PageType, string XAMLCode, string CSCode, string Id, string Icon, string? Scenario, string[]? NugetPackageReferences, Location? Location); - private record ModelDefinitionModel(string EnumName, string Parent, string Name, string Id, string Description, string Url, string HardwareAccelerator, long Size, string ParameterSize); + context.AddSource($"SampleDetails.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); + }); } + + private record ScenarioModel(string EnumName, string ScenarioCategoryType, string Id, string Name, string Description); + private record SampleModel(string Owner, string Name, string PageType, string XAMLCode, string CSCode, string Id, string Icon, string? Scenario, string[]? NugetPackageReferences, Location? Location); + private record ModelDefinitionModel(string EnumName, string Parent, string Name, string Id, string Description, string Url, string HardwareAccelerator, long Size, string ParameterSize); } #pragma warning restore RS1035 // Do not use APIs banned for analyzers \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/ScenariosSourceGenerator.cs b/AIDevGallery.SourceGenerator/ScenariosSourceGenerator.cs index 2acb2c1..ae2cb54 100644 --- a/AIDevGallery.SourceGenerator/ScenariosSourceGenerator.cs +++ b/AIDevGallery.SourceGenerator/ScenariosSourceGenerator.cs @@ -10,162 +10,161 @@ using System.Text; using System.Text.Json; -namespace AIDevGallery.SourceGenerator +namespace AIDevGallery.SourceGenerator; + +[Generator(LanguageNames.CSharp)] +internal class ScenariosSourceGenerator : IIncrementalGenerator { - [Generator(LanguageNames.CSharp)] - internal class ScenariosSourceGenerator : IIncrementalGenerator - { - private Dictionary? scenarioCategories = null; + private Dictionary? scenarioCategories = null; - public void Initialize(IncrementalGeneratorInitializationContext context) + public void Initialize(IncrementalGeneratorInitializationContext context) + { + string scenarioJson; + var assembly = Assembly.GetExecutingAssembly(); + using (Stream stream = assembly.GetManifestResourceStream("AIDevGallery.SourceGenerator.scenarios.json")) { - string scenarioJson; - var assembly = Assembly.GetExecutingAssembly(); - using (Stream stream = assembly.GetManifestResourceStream("AIDevGallery.SourceGenerator.scenarios.json")) + using (StreamReader reader = new(stream)) { - using (StreamReader reader = new(stream)) - { - scenarioJson = reader.ReadToEnd().Trim(); - } + scenarioJson = reader.ReadToEnd().Trim(); } - - scenarioCategories = JsonSerializer.Deserialize(scenarioJson, SourceGenerationContext.Default.DictionaryStringScenarioCategory); - context.RegisterPostInitializationOutput(Execute); } - public void Execute(IncrementalGeneratorPostInitializationContext context) - { - if (scenarioCategories == null) - { - return; - } + scenarioCategories = JsonSerializer.Deserialize(scenarioJson, SourceGenerationContext.Default.DictionaryStringScenarioCategory); + context.RegisterPostInitializationOutput(Execute); + } - GenerateScenarioCategoryTypeFile(context, scenarioCategories); + public void Execute(IncrementalGeneratorPostInitializationContext context) + { + if (scenarioCategories == null) + { + return; + } - GenerateScenariosTypeFile(context, scenarioCategories); + GenerateScenarioCategoryTypeFile(context, scenarioCategories); - GenerateScenarioHelpersFile(context, scenarioCategories); - } + GenerateScenariosTypeFile(context, scenarioCategories); - private void GenerateScenarioHelpersFile(IncrementalGeneratorPostInitializationContext context, Dictionary scenarioCategories) - { - var sourceBuilder = new StringBuilder(); + GenerateScenarioHelpersFile(context, scenarioCategories); + } - sourceBuilder.AppendLine( - $$"""" - #nullable enable + private void GenerateScenarioHelpersFile(IncrementalGeneratorPostInitializationContext context, Dictionary scenarioCategories) + { + var sourceBuilder = new StringBuilder(); - using System.Collections.Generic; - using AIDevGallery.Models; + sourceBuilder.AppendLine( + $$"""" + #nullable enable - namespace AIDevGallery.Samples; + using System.Collections.Generic; + using AIDevGallery.Models; - internal static partial class ScenarioCategoryHelpers - { - """"); + namespace AIDevGallery.Samples; - sourceBuilder.AppendLine(" internal static List AllScenarioCategories { get; } = ["); - foreach (var scenarioCategory in scenarioCategories) + internal static partial class ScenarioCategoryHelpers { - string icon = Helpers.EscapeUnicodeString(scenarioCategory.Value.Icon); + """"); - sourceBuilder.AppendLine( - $$"""" - new ScenarioCategory - { - Name = "{{scenarioCategory.Value.Name}}", - Icon = {{icon}}, - Scenarios = new List - { - """"); - foreach (var scenario in scenarioCategory.Value.Scenarios) - { - sourceBuilder.AppendLine( - $$"""""" - new Scenario - { - ScenarioType = ScenarioType.{{scenarioCategory.Key}}{{scenario.Key}}, - Name = "{{scenario.Value.Name}}", - Description = "{{scenario.Value.Description}}", - Id = "{{scenario.Value.Id}}", - Icon = {{icon}} - }, - """"""); - } + sourceBuilder.AppendLine(" internal static List AllScenarioCategories { get; } = ["); + foreach (var scenarioCategory in scenarioCategories) + { + string icon = Helpers.EscapeUnicodeString(scenarioCategory.Value.Icon); + sourceBuilder.AppendLine( + $$"""" + new ScenarioCategory + { + Name = "{{scenarioCategory.Value.Name}}", + Icon = {{icon}}, + Scenarios = new List + { + """"); + foreach (var scenario in scenarioCategory.Value.Scenarios) + { sourceBuilder.AppendLine( - $$"""" - } - }, - """"); + $$"""""" + new Scenario + { + ScenarioType = ScenarioType.{{scenarioCategory.Key}}{{scenario.Key}}, + Name = "{{scenario.Value.Name}}", + Description = "{{scenario.Value.Description}}", + Id = "{{scenario.Value.Id}}", + Icon = {{icon}} + }, + """"""); } - sourceBuilder.AppendLine(" ];"); + sourceBuilder.AppendLine( + $$"""" + } + }, + """"); + } + + sourceBuilder.AppendLine(" ];"); - sourceBuilder.AppendLine("}"); + sourceBuilder.AppendLine("}"); - context.AddSource($"ScenarioCategoryHelpers.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); - } + context.AddSource($"ScenarioCategoryHelpers.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); + } - private void GenerateScenariosTypeFile(IncrementalGeneratorPostInitializationContext context, Dictionary scenarioCategories) - { - var sourceBuilder = new StringBuilder(); + private void GenerateScenariosTypeFile(IncrementalGeneratorPostInitializationContext context, Dictionary scenarioCategories) + { + var sourceBuilder = new StringBuilder(); - sourceBuilder.AppendLine( - $$"""" - #nullable enable + sourceBuilder.AppendLine( + $$"""" + #nullable enable - using System.Collections.Generic; + using System.Collections.Generic; - namespace AIDevGallery.Models; - - internal enum ScenarioType - { - """"); - foreach (var scenarioCategory in scenarioCategories) + namespace AIDevGallery.Models; + + internal enum ScenarioType { - foreach (var scenario in scenarioCategory.Value.Scenarios) - { - sourceBuilder.AppendLine($" {scenarioCategory.Key}{scenario.Key},"); - } + """"); + foreach (var scenarioCategory in scenarioCategories) + { + foreach (var scenario in scenarioCategory.Value.Scenarios) + { + sourceBuilder.AppendLine($" {scenarioCategory.Key}{scenario.Key},"); } - - sourceBuilder.AppendLine( - """ - } - """); - - context.AddSource("ScenarioType.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); } - private static void GenerateScenarioCategoryTypeFile(IncrementalGeneratorPostInitializationContext context, Dictionary scenarioCategories) - { - var sourceBuilder = new StringBuilder(); + sourceBuilder.AppendLine( + """ + } + """); - sourceBuilder.AppendLine( - $$"""" - #nullable enable + context.AddSource("ScenarioType.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); + } - using System.Collections.Generic; - using AIDevGallery.Models; + private static void GenerateScenarioCategoryTypeFile(IncrementalGeneratorPostInitializationContext context, Dictionary scenarioCategories) + { + var sourceBuilder = new StringBuilder(); - namespace AIDevGallery.Models; + sourceBuilder.AppendLine( + $$"""" + #nullable enable - internal enum ScenarioCategoryType - { - """"); + using System.Collections.Generic; + using AIDevGallery.Models; - foreach (var scenarioCategory in scenarioCategories) - { - sourceBuilder.AppendLine($" {scenarioCategory.Key},"); - } + namespace AIDevGallery.Models; - sourceBuilder.AppendLine( - $$"""" - } - """"); + internal enum ScenarioCategoryType + { + """"); - context.AddSource("ScenarioCategoryType.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); + foreach (var scenarioCategory in scenarioCategories) + { + sourceBuilder.AppendLine($" {scenarioCategory.Key},"); } + + sourceBuilder.AppendLine( + $$"""" + } + """"); + + context.AddSource("ScenarioCategoryType.g.cs", SourceText.From(sourceBuilder.ToString(), Encoding.UTF8)); } } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/WellKnownTypeNames.cs b/AIDevGallery.SourceGenerator/WellKnownTypeNames.cs index 2d475d0..5e970ac 100644 --- a/AIDevGallery.SourceGenerator/WellKnownTypeNames.cs +++ b/AIDevGallery.SourceGenerator/WellKnownTypeNames.cs @@ -1,16 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -namespace AIDevGallery.SourceGenerator +namespace AIDevGallery.SourceGenerator; + +/// +/// A container for well known type names. +/// +internal sealed class WellKnownTypeNames { /// - /// A container for well known type names. + /// The AIDevGallery.Samples.Attributes.GallerySampleAttribute type name. /// - internal sealed class WellKnownTypeNames - { - /// - /// The AIDevGallery.Samples.Attributes.GallerySampleAttribute type name. - /// - public const string GallerySampleAttribute = "AIDevGallery.Samples.Attributes.GallerySampleAttribute"; - } + public const string GallerySampleAttribute = "AIDevGallery.Samples.Attributes.GallerySampleAttribute"; } \ No newline at end of file diff --git a/AIDevGallery.SourceGenerator/WellKnownTypeSymbols.cs b/AIDevGallery.SourceGenerator/WellKnownTypeSymbols.cs index 44ebf51..2e64f89 100644 --- a/AIDevGallery.SourceGenerator/WellKnownTypeSymbols.cs +++ b/AIDevGallery.SourceGenerator/WellKnownTypeSymbols.cs @@ -3,43 +3,42 @@ using Microsoft.CodeAnalysis; -namespace AIDevGallery.SourceGenerator +namespace AIDevGallery.SourceGenerator; + +/// +/// A simple helper providing quick access to known type symbols. +/// +internal sealed class WellKnownTypeSymbols { /// - /// A simple helper providing quick access to known type symbols. + /// The input instance. /// - internal sealed class WellKnownTypeSymbols - { - /// - /// The input instance. - /// - private readonly Compilation compilation; + private readonly Compilation compilation; - private INamedTypeSymbol? gallerySampleAttribute; + private INamedTypeSymbol? gallerySampleAttribute; - /// - /// Initializes a new instance of the class. - /// - /// The input instance. - public WellKnownTypeSymbols(Compilation compilation) - { - this.compilation = compilation; - } + /// + /// Initializes a new instance of the class. + /// + /// The input instance. + public WellKnownTypeSymbols(Compilation compilation) + { + this.compilation = compilation; + } - /// - /// Gets the for AIDevGallery.Samples.Attributes.GallerySampleAttribute. - /// - public INamedTypeSymbol GallerySampleAttribute => Get(ref this.gallerySampleAttribute, WellKnownTypeNames.GallerySampleAttribute); + /// + /// Gets the for AIDevGallery.Samples.Attributes.GallerySampleAttribute. + /// + public INamedTypeSymbol GallerySampleAttribute => Get(ref this.gallerySampleAttribute, WellKnownTypeNames.GallerySampleAttribute); - /// - /// Gets an instance with a specified fully qualified metadata name. - /// - /// The backing storage to save the result. - /// The fully qualified metadata name of the instance to get. - /// The resulting instance. - private INamedTypeSymbol Get(ref INamedTypeSymbol? storage, string fullyQualifiedMetadataName) - { - return storage ??= this.compilation.GetTypeByMetadataName(fullyQualifiedMetadataName)!; - } + /// + /// Gets an instance with a specified fully qualified metadata name. + /// + /// The backing storage to save the result. + /// The fully qualified metadata name of the instance to get. + /// The resulting instance. + private INamedTypeSymbol Get(ref INamedTypeSymbol? storage, string fullyQualifiedMetadataName) + { + return storage ??= this.compilation.GetTypeByMetadataName(fullyQualifiedMetadataName)!; } } \ No newline at end of file diff --git a/AIDevGallery.UnitTests/ProjectGeneratorUnitTests.cs b/AIDevGallery.UnitTests/ProjectGeneratorUnitTests.cs index 71c9161..c6a2ee3 100644 --- a/AIDevGallery.UnitTests/ProjectGeneratorUnitTests.cs +++ b/AIDevGallery.UnitTests/ProjectGeneratorUnitTests.cs @@ -21,191 +21,190 @@ using System.Threading; using System.Threading.Tasks; -namespace AIDevGallery.UnitTests +namespace AIDevGallery.UnitTests; + +[TestClass] +public class ProjectGenerator { - [TestClass] - public class ProjectGenerator + private readonly Generator generator = new(); + private static readonly string TmpPath = Path.Combine(Path.GetTempPath(), "AIDevGalleryTests"); + private static readonly string TmpPathProjectGenerator = Path.Combine(TmpPath, "ProjectGenerator"); + private static readonly string TmpPathLogs = Path.Combine(TmpPath, "Logs"); + + public TestContext TestContext { get; set; } = null!; + + [ClassInitialize] + public static void Initialize(TestContext context) { - private readonly Generator generator = new(); - private static readonly string TmpPath = Path.Combine(Path.GetTempPath(), "AIDevGalleryTests"); - private static readonly string TmpPathProjectGenerator = Path.Combine(TmpPath, "ProjectGenerator"); - private static readonly string TmpPathLogs = Path.Combine(TmpPath, "Logs"); + ArgumentNullException.ThrowIfNull(context); - public TestContext TestContext { get; set; } = null!; + if (Directory.Exists(TmpPath)) + { + Directory.Delete(TmpPath, true); + } - [ClassInitialize] - public static void Initialize(TestContext context) + if (Directory.Exists(TmpPathProjectGenerator)) { - ArgumentNullException.ThrowIfNull(context); + Directory.Delete(TmpPathProjectGenerator, true); + } - if (Directory.Exists(TmpPath)) - { - Directory.Delete(TmpPath, true); - } + if (Directory.Exists(TmpPathLogs)) + { + Directory.Delete(TmpPathLogs, true); + } - if (Directory.Exists(TmpPathProjectGenerator)) - { - Directory.Delete(TmpPathProjectGenerator, true); - } + Directory.CreateDirectory(TmpPath); + Directory.CreateDirectory(TmpPathProjectGenerator); + Directory.CreateDirectory(TmpPathLogs); + } - if (Directory.Exists(TmpPathLogs)) - { - Directory.Delete(TmpPathLogs, true); - } + private class SampleUIData : INotifyPropertyChanged + { + private Brush? statusColor; - Directory.CreateDirectory(TmpPath); - Directory.CreateDirectory(TmpPathProjectGenerator); - Directory.CreateDirectory(TmpPathLogs); + public required Sample Sample { get; init; } + public Brush? StatusColor + { + get => statusColor; + set => SetProperty(ref statusColor, value); } - private class SampleUIData : INotifyPropertyChanged + private void SetProperty(ref Brush? field, Brush? value, [System.Runtime.CompilerServices.CallerMemberName] string? propertyName = null) { - private Brush? statusColor; - - public required Sample Sample { get; init; } - public Brush? StatusColor - { - get => statusColor; - set => SetProperty(ref statusColor, value); - } - - private void SetProperty(ref Brush? field, Brush? value, [System.Runtime.CompilerServices.CallerMemberName] string? propertyName = null) + if (field != value) { - if (field != value) - { - field = value; - PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(propertyName)); - } + field = value; + PropertyChanged?.Invoke(this, new PropertyChangedEventArgs(propertyName)); } - - public event PropertyChangedEventHandler? PropertyChanged; } - [TestMethod] - public async Task GenerateForAllSamples() + public event PropertyChangedEventHandler? PropertyChanged; + } + + [TestMethod] + public async Task GenerateForAllSamples() + { + List source = null!; + ListView listView = null!; + SolidColorBrush green = null!; + SolidColorBrush red = null!; + SolidColorBrush yellow = null!; + TaskCompletionSource taskCompletionSource = new(); + + UITestMethodAttribute.DispatcherQueue?.TryEnqueue(() => { - List source = null!; - ListView listView = null!; - SolidColorBrush green = null!; - SolidColorBrush red = null!; - SolidColorBrush yellow = null!; - TaskCompletionSource taskCompletionSource = new(); - - UITestMethodAttribute.DispatcherQueue?.TryEnqueue(() => + source = SampleDetails.Samples.Select(s => new SampleUIData { - source = SampleDetails.Samples.Select(s => new SampleUIData - { - Sample = s, - StatusColor = new SolidColorBrush(Colors.LightGray) - }).ToList(); - - green = new SolidColorBrush(Colors.Green); - red = new SolidColorBrush(Colors.Red); - yellow = new SolidColorBrush(Colors.Yellow); - - listView = new ListView - { - ItemsSource = source, - ItemTemplate = Microsoft.UI.Xaml.Application.Current.Resources["SampleItemTemplate"] as Microsoft.UI.Xaml.DataTemplate - }; - UnitTestApp.SetWindowContent(listView); - - taskCompletionSource.SetResult(); - }); + Sample = s, + StatusColor = new SolidColorBrush(Colors.LightGray) + }).ToList(); - await taskCompletionSource.Task; + green = new SolidColorBrush(Colors.Green); + red = new SolidColorBrush(Colors.Red); + yellow = new SolidColorBrush(Colors.Yellow); - Dictionary successDict = []; + listView = new ListView + { + ItemsSource = source, + ItemTemplate = Microsoft.UI.Xaml.Application.Current.Resources["SampleItemTemplate"] as Microsoft.UI.Xaml.DataTemplate + }; + UnitTestApp.SetWindowContent(listView); - // write test count - TestContext.WriteLine($"Running {source.Count} tests"); + taskCompletionSource.SetResult(); + }); - await Parallel.ForEachAsync(source, new ParallelOptions { MaxDegreeOfParallelism = 4 }, async (item, ct) => - { - listView.DispatcherQueue.TryEnqueue(() => - { - item.StatusColor = yellow; - }); + await taskCompletionSource.Task; - var success = await GenerateForSample(item.Sample, ct); + Dictionary successDict = []; - TestContext.WriteLine($"Built {item.Sample.Name} with status {success}"); - Debug.WriteLine($"Built {item.Sample.Name} with status {success}"); + // write test count + TestContext.WriteLine($"Running {source.Count} tests"); - listView.DispatcherQueue.TryEnqueue(() => - { - item.StatusColor = success ? green : red; - }); - successDict.Add(item.Sample.Name, success); + await Parallel.ForEachAsync(source, new ParallelOptions { MaxDegreeOfParallelism = 4 }, async (item, ct) => + { + listView.DispatcherQueue.TryEnqueue(() => + { + item.StatusColor = yellow; }); - successDict.Should().AllSatisfy(kvp => kvp.Value.Should().BeTrue($"{kvp.Key} should build successfully")); - } + var success = await GenerateForSample(item.Sample, ct); - private async Task GenerateForSample(Sample sample, CancellationToken cancellationToken) - { - var modelsDetails = ModelDetailsHelper.GetModelDetails(sample); + TestContext.WriteLine($"Built {item.Sample.Name} with status {success}"); + Debug.WriteLine($"Built {item.Sample.Name} with status {success}"); - ModelDetails modelDetails1 = modelsDetails[0].Values.First().First(); - Dictionary cachedModelsToGenerator = new() + listView.DispatcherQueue.TryEnqueue(() => { - [sample.Model1Types.First()] = ("FakePath", modelsDetails[0].Values.First().First().Url, modelDetails1.HardwareAccelerators.First()) - }; + item.StatusColor = success ? green : red; + }); + successDict.Add(item.Sample.Name, success); + }); - if (sample.Model2Types != null && modelsDetails.Count > 1) - { - ModelDetails modelDetails2 = modelsDetails[1].Values.First().First(); - cachedModelsToGenerator[sample.Model2Types.First()] = ("FakePath", modelDetails2.Url, modelDetails2.HardwareAccelerators.First()); - } + successDict.Should().AllSatisfy(kvp => kvp.Value.Should().BeTrue($"{kvp.Key} should build successfully")); + } - var projectPath = await generator.GenerateAsync(sample, cachedModelsToGenerator, false, TmpPathProjectGenerator, cancellationToken); + private async Task GenerateForSample(Sample sample, CancellationToken cancellationToken) + { + var modelsDetails = ModelDetailsHelper.GetModelDetails(sample); - var safeProjectName = Path.GetFileName(projectPath); - string logFileName = $"build_{safeProjectName}.log"; + ModelDetails modelDetails1 = modelsDetails[0].Values.First().First(); + Dictionary cachedModelsToGenerator = new() + { + [sample.Model1Types.First()] = ("FakePath", modelsDetails[0].Values.First().First().Url, modelDetails1.HardwareAccelerators.First()) + }; - var arch = DeviceUtils.IsArm64() ? "arm64" : "x64"; + if (sample.Model2Types != null && modelsDetails.Count > 1) + { + ModelDetails modelDetails2 = modelsDetails[1].Values.First().First(); + cachedModelsToGenerator[sample.Model2Types.First()] = ("FakePath", modelDetails2.Url, modelDetails2.HardwareAccelerators.First()); + } - var process = Process.Start(new ProcessStartInfo - { - FileName = @"C:\Program Files\dotnet\dotnet", - WorkingDirectory = projectPath, - Arguments = $"build -r win-{arch} -f {Generator.DotNetVersion}-windows10.0.22621.0 /p:Configuration=Release /p:Platform={arch} /flp:logfile={logFileName}", - RedirectStandardOutput = true, - RedirectStandardError = true, - UseShellExecute = false, - CreateNoWindow = true - }); + var projectPath = await generator.GenerateAsync(sample, cachedModelsToGenerator, false, TmpPathProjectGenerator, cancellationToken); - if (process == null) - { - return false; - } + var safeProjectName = Path.GetFileName(projectPath); + string logFileName = $"build_{safeProjectName}.log"; - var console = process.StandardOutput.ReadToEnd(); - var error = process.StandardError.ReadToEnd(); + var arch = DeviceUtils.IsArm64() ? "arm64" : "x64"; - process.WaitForExit(); + var process = Process.Start(new ProcessStartInfo + { + FileName = @"C:\Program Files\dotnet\dotnet", + WorkingDirectory = projectPath, + Arguments = $"build -r win-{arch} -f {Generator.DotNetVersion}-windows10.0.22621.0 /p:Configuration=Release /p:Platform={arch} /flp:logfile={logFileName}", + RedirectStandardOutput = true, + RedirectStandardError = true, + UseShellExecute = false, + CreateNoWindow = true + }); + + if (process == null) + { + return false; + } - var logFilePath = Path.Combine(TmpPathLogs, logFileName); - File.Move(Path.Combine(projectPath, logFileName), logFilePath, true); + var console = process.StandardOutput.ReadToEnd(); + var error = process.StandardError.ReadToEnd(); - TestContext.AddResultFile(logFilePath); + process.WaitForExit(); - if (process.ExitCode != 0) - { - Debug.Write(console); - Debug.WriteLine(string.Empty); - Debug.Write(error); - Debug.WriteLine(string.Empty); - } + var logFilePath = Path.Combine(TmpPathLogs, logFileName); + File.Move(Path.Combine(projectPath, logFileName), logFilePath, true); - return process.ExitCode == 0; - } + TestContext.AddResultFile(logFilePath); - [ClassCleanup] - public static void Cleanup() + if (process.ExitCode != 0) { - Directory.Delete(TmpPathProjectGenerator, true); + Debug.Write(console); + Debug.WriteLine(string.Empty); + Debug.Write(error); + Debug.WriteLine(string.Empty); } + + return process.ExitCode == 0; + } + + [ClassCleanup] + public static void Cleanup() + { + Directory.Delete(TmpPathProjectGenerator, true); } } \ No newline at end of file diff --git a/AIDevGallery.UnitTests/UnitTestApp.xaml.cs b/AIDevGallery.UnitTests/UnitTestApp.xaml.cs index e509fad..1b0cfee 100644 --- a/AIDevGallery.UnitTests/UnitTestApp.xaml.cs +++ b/AIDevGallery.UnitTests/UnitTestApp.xaml.cs @@ -5,44 +5,43 @@ using Microsoft.VisualStudio.TestTools.UnitTesting.AppContainer; using System; -namespace AIDevGallery.UnitTests +namespace AIDevGallery.UnitTests; + +/// +/// Provides application-specific behavior to supplement the default Application class. +/// +public partial class UnitTestApp : Application { /// - /// Provides application-specific behavior to supplement the default Application class. + /// Initializes a new instance of the class. + /// Initializes the singleton application object. This is the first line of authored code + /// executed, and as such is the logical equivalent of main() or WinMain(). /// - public partial class UnitTestApp : Application + public UnitTestApp() + { + this.InitializeComponent(); + } + + /// + /// Invoked when the application is launched. + /// + /// Details about the launch request and process. + protected override void OnLaunched(LaunchActivatedEventArgs args) + { + Microsoft.VisualStudio.TestPlatform.TestExecutor.UnitTestClient.CreateDefaultUI(); + + window = new UnitTestAppWindow(); + window.Activate(); + + UITestMethodAttribute.DispatcherQueue = window.DispatcherQueue; + + Microsoft.VisualStudio.TestPlatform.TestExecutor.UnitTestClient.Run(Environment.CommandLine); + } + + private static UnitTestAppWindow? window; + + public static void SetWindowContent(UIElement content) { - /// - /// Initializes a new instance of the class. - /// Initializes the singleton application object. This is the first line of authored code - /// executed, and as such is the logical equivalent of main() or WinMain(). - /// - public UnitTestApp() - { - this.InitializeComponent(); - } - - /// - /// Invoked when the application is launched. - /// - /// Details about the launch request and process. - protected override void OnLaunched(LaunchActivatedEventArgs args) - { - Microsoft.VisualStudio.TestPlatform.TestExecutor.UnitTestClient.CreateDefaultUI(); - - window = new UnitTestAppWindow(); - window.Activate(); - - UITestMethodAttribute.DispatcherQueue = window.DispatcherQueue; - - Microsoft.VisualStudio.TestPlatform.TestExecutor.UnitTestClient.Run(Environment.CommandLine); - } - - private static UnitTestAppWindow? window; - - public static void SetWindowContent(UIElement content) - { - window?.SetRootGridContent(content); - } + window?.SetRootGridContent(content); } } \ No newline at end of file diff --git a/AIDevGallery.UnitTests/UnitTestAppWindow.xaml.cs b/AIDevGallery.UnitTests/UnitTestAppWindow.xaml.cs index 35ed627..739ddc5 100644 --- a/AIDevGallery.UnitTests/UnitTestAppWindow.xaml.cs +++ b/AIDevGallery.UnitTests/UnitTestAppWindow.xaml.cs @@ -3,19 +3,18 @@ using Microsoft.UI.Xaml; -namespace AIDevGallery.UnitTests +namespace AIDevGallery.UnitTests; + +internal sealed partial class UnitTestAppWindow : Window { - internal sealed partial class UnitTestAppWindow : Window + public UnitTestAppWindow() { - public UnitTestAppWindow() - { - this.InitializeComponent(); - } + this.InitializeComponent(); + } - public void SetRootGridContent(UIElement content) - { - this.RootGrid.Children.Clear(); - this.RootGrid.Children.Add(content); - } + public void SetRootGridContent(UIElement content) + { + this.RootGrid.Children.Clear(); + this.RootGrid.Children.Add(content); } } \ No newline at end of file diff --git a/AIDevGallery.Utils/GitHubModelFileDetails.cs b/AIDevGallery.Utils/GitHubModelFileDetails.cs index e81426a..fca956e 100644 --- a/AIDevGallery.Utils/GitHubModelFileDetails.cs +++ b/AIDevGallery.Utils/GitHubModelFileDetails.cs @@ -3,47 +3,46 @@ using System.Text.Json.Serialization; -namespace AIDevGallery.Utils +namespace AIDevGallery.Utils; + +/// +/// GitHub model file details +/// +public class GitHubModelFileDetails { /// - /// GitHub model file details + /// Gets the name of the file + /// + [JsonPropertyName("name")] + public string? Name { get; init; } + + /// + /// Gets the relative path to the file + /// + [JsonPropertyName("path")] + public string? Path { get; init; } + + /// + /// Gets the SHA of the file + /// + [JsonPropertyName("sha")] + public string? Sha { get; init; } + + /// + /// Gets the size of the file + /// + [JsonPropertyName("size")] + public int Size { get; init; } + + /// + /// Gets the URL to download the file from + /// + [JsonPropertyName("download_url")] + public string? DownloadUrl { get; init; } + + /// + /// Gets the type of the file /// - public class GitHubModelFileDetails - { - /// - /// Gets the name of the file - /// - [JsonPropertyName("name")] - public string? Name { get; init; } - - /// - /// Gets the relative path to the file - /// - [JsonPropertyName("path")] - public string? Path { get; init; } - - /// - /// Gets the SHA of the file - /// - [JsonPropertyName("sha")] - public string? Sha { get; init; } - - /// - /// Gets the size of the file - /// - [JsonPropertyName("size")] - public int Size { get; init; } - - /// - /// Gets the URL to download the file from - /// - [JsonPropertyName("download_url")] - public string? DownloadUrl { get; init; } - - /// - /// Gets the type of the file - /// - [JsonPropertyName("type")] - public string? Type { get; init; } - } + [JsonPropertyName("type")] + public string? Type { get; init; } } \ No newline at end of file diff --git a/AIDevGallery.Utils/HuggingFaceModelFileDetails.cs b/AIDevGallery.Utils/HuggingFaceModelFileDetails.cs index bd570d3..9a98282 100644 --- a/AIDevGallery.Utils/HuggingFaceModelFileDetails.cs +++ b/AIDevGallery.Utils/HuggingFaceModelFileDetails.cs @@ -3,29 +3,28 @@ using System.Text.Json.Serialization; -namespace AIDevGallery.Utils +namespace AIDevGallery.Utils; + +/// +/// Details of a file in a Hugging Face model. +/// +public class HuggingFaceModelFileDetails { /// - /// Details of a file in a Hugging Face model. + /// Gets the Type of the file. /// - public class HuggingFaceModelFileDetails - { - /// - /// Gets the Type of the file. - /// - [JsonPropertyName("type")] - public string? Type { get; init; } + [JsonPropertyName("type")] + public string? Type { get; init; } - /// - /// Gets the size of the file. - /// - [JsonPropertyName("size")] - public long Size { get; init; } + /// + /// Gets the size of the file. + /// + [JsonPropertyName("size")] + public long Size { get; init; } - /// - /// Gets the path of the file. - /// - [JsonPropertyName("path")] - public string? Path { get; init; } - } + /// + /// Gets the path of the file. + /// + [JsonPropertyName("path")] + public string? Path { get; init; } } \ No newline at end of file diff --git a/AIDevGallery.Utils/ModelFileDetails.cs b/AIDevGallery.Utils/ModelFileDetails.cs index b20c3ed..f5e717d 100644 --- a/AIDevGallery.Utils/ModelFileDetails.cs +++ b/AIDevGallery.Utils/ModelFileDetails.cs @@ -1,31 +1,30 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -namespace AIDevGallery.Utils +namespace AIDevGallery.Utils; + +/// +/// Model file details +/// +public class ModelFileDetails { /// - /// Model file details + /// Gets the URL to download the model from /// - public class ModelFileDetails - { - /// - /// Gets the URL to download the model from - /// - public string? DownloadUrl { get; init; } + public string? DownloadUrl { get; init; } - /// - /// Gets the size of the file - /// - public long Size { get; init; } + /// + /// Gets the size of the file + /// + public long Size { get; init; } - /// - /// Gets the name of the file - /// - public string? Name { get; init; } + /// + /// Gets the name of the file + /// + public string? Name { get; init; } - /// - /// Gets the relative path to the file - /// - public string? Path { get; init; } - } + /// + /// Gets the relative path to the file + /// + public string? Path { get; init; } } \ No newline at end of file diff --git a/AIDevGallery.Utils/ModelInformationHelper.cs b/AIDevGallery.Utils/ModelInformationHelper.cs index 5cde2bf..08293a9 100644 --- a/AIDevGallery.Utils/ModelInformationHelper.cs +++ b/AIDevGallery.Utils/ModelInformationHelper.cs @@ -11,194 +11,193 @@ using System.Threading.Tasks; using System.Threading.Tasks.Dataflow; -namespace AIDevGallery.Utils +namespace AIDevGallery.Utils; + +/// +/// Provides helper methods to retrieve model file details from GitHub and Hugging Face. +/// +public static class ModelInformationHelper { /// - /// Provides helper methods to retrieve model file details from GitHub and Hugging Face. + /// Retrieves a list of model file details from a specified GitHub repository. /// - public static class ModelInformationHelper + /// The GitHub URL containing the organization, repository, path, and reference. + /// A token to monitor for cancellation requests. + /// A list of model file details. + public static async Task> GetDownloadFilesFromGitHub(GitHubUrl url, CancellationToken cancellationToken) { - /// - /// Retrieves a list of model file details from a specified GitHub repository. - /// - /// The GitHub URL containing the organization, repository, path, and reference. - /// A token to monitor for cancellation requests. - /// A list of model file details. - public static async Task> GetDownloadFilesFromGitHub(GitHubUrl url, CancellationToken cancellationToken) - { - string getModelDetailsUrl = $"https://api.github.com/repos/{url.Organization}/{url.Repo}/contents/{url.Path}?ref={url.Ref}"; + string getModelDetailsUrl = $"https://api.github.com/repos/{url.Organization}/{url.Repo}/contents/{url.Path}?ref={url.Ref}"; - // call api and get json - using var client = new HttpClient(); - client.DefaultRequestHeaders.Add("User-Agent", "AIDevGallery"); - var response = await client.GetAsync(getModelDetailsUrl, cancellationToken); + // call api and get json + using var client = new HttpClient(); + client.DefaultRequestHeaders.Add("User-Agent", "AIDevGallery"); + var response = await client.GetAsync(getModelDetailsUrl, cancellationToken); #if NET8_0_OR_GREATER - var responseContent = await response.Content.ReadAsStringAsync(cancellationToken); + var responseContent = await response.Content.ReadAsStringAsync(cancellationToken); #else - var responseContent = await response.Content.ReadAsStringAsync(); + var responseContent = await response.Content.ReadAsStringAsync(); #endif - // make it a list if it isn't already - responseContent = responseContent.Trim(); + // make it a list if it isn't already + responseContent = responseContent.Trim(); #if NET8_0_OR_GREATER - if (!responseContent.StartsWith('[')) + if (!responseContent.StartsWith('[')) #else - if (!responseContent.StartsWith("[")) + if (!responseContent.StartsWith("[")) #endif - { - responseContent = $"[{responseContent}]"; - } + { + responseContent = $"[{responseContent}]"; + } - var files = JsonSerializer.Deserialize(responseContent, SourceGenerationContext.Default.ListGitHubModelFileDetails); + var files = JsonSerializer.Deserialize(responseContent, SourceGenerationContext.Default.ListGitHubModelFileDetails); - if (files == null) + if (files == null) + { + Debug.WriteLine("Failed to get model details from GitHub"); + return []; + } + + return files.Select(f => + new ModelFileDetails() { - Debug.WriteLine("Failed to get model details from GitHub"); - return []; - } + DownloadUrl = f.DownloadUrl, + Size = f.Size, + Name = (f.Path ?? string.Empty).Split(["/"], StringSplitOptions.RemoveEmptyEntries).LastOrDefault(), + Path = f.Path + }).ToList(); + } - return files.Select(f => - new ModelFileDetails() - { - DownloadUrl = f.DownloadUrl, - Size = f.Size, - Name = (f.Path ?? string.Empty).Split(["/"], StringSplitOptions.RemoveEmptyEntries).LastOrDefault(), - Path = f.Path - }).ToList(); - } + /// + /// Retrieves a list of model file details from a specified Hugging Face repository. + /// + /// The Hugging Face URL containing the organization, repository, path, and reference. + /// The HTTP message handler used to configure the HTTP client. + /// A token to monitor for cancellation requests. + /// A list of model file details. + public static async Task> GetDownloadFilesFromHuggingFace(HuggingFaceUrl hfUrl, HttpMessageHandler? httpMessageHandler = null, CancellationToken cancellationToken = default) + { + string getModelDetailsUrl; - /// - /// Retrieves a list of model file details from a specified Hugging Face repository. - /// - /// The Hugging Face URL containing the organization, repository, path, and reference. - /// The HTTP message handler used to configure the HTTP client. - /// A token to monitor for cancellation requests. - /// A list of model file details. - public static async Task> GetDownloadFilesFromHuggingFace(HuggingFaceUrl hfUrl, HttpMessageHandler? httpMessageHandler = null, CancellationToken cancellationToken = default) + if (hfUrl.IsFile) { - string getModelDetailsUrl; - - if (hfUrl.IsFile) + getModelDetailsUrl = $"https://huggingface.co/api/models/{hfUrl.Organization}/{hfUrl.Repo}/tree/{hfUrl.Ref}"; + if (hfUrl.Path != null) { - getModelDetailsUrl = $"https://huggingface.co/api/models/{hfUrl.Organization}/{hfUrl.Repo}/tree/{hfUrl.Ref}"; - if (hfUrl.Path != null) - { - var filePath = hfUrl.Path.Split('/'); - filePath = filePath.Take(filePath.Length - 1).ToArray(); + var filePath = hfUrl.Path.Split('/'); + filePath = filePath.Take(filePath.Length - 1).ToArray(); - if (filePath.Length > 0) - { - getModelDetailsUrl = $"{getModelDetailsUrl}/{string.Join("/", filePath)}"; - } + if (filePath.Length > 0) + { + getModelDetailsUrl = $"{getModelDetailsUrl}/{string.Join("/", filePath)}"; } } - else - { - getModelDetailsUrl = $"https://huggingface.co/api/models/{hfUrl.PartialUrl}"; - } + } + else + { + getModelDetailsUrl = $"https://huggingface.co/api/models/{hfUrl.PartialUrl}"; + } - // call api and get json - using var client = new HttpClient(); - var response = await client.GetAsync(getModelDetailsUrl, cancellationToken); + // call api and get json + using var client = new HttpClient(); + var response = await client.GetAsync(getModelDetailsUrl, cancellationToken); #if NET8_0_OR_GREATER - var responseContent = await response.Content.ReadAsStringAsync(cancellationToken); + var responseContent = await response.Content.ReadAsStringAsync(cancellationToken); #else - var responseContent = await response.Content.ReadAsStringAsync(); + var responseContent = await response.Content.ReadAsStringAsync(); #endif - var hfFiles = JsonSerializer.Deserialize(responseContent, SourceGenerationContext.Default.ListHuggingFaceModelFileDetails); + var hfFiles = JsonSerializer.Deserialize(responseContent, SourceGenerationContext.Default.ListHuggingFaceModelFileDetails); - if (hfFiles == null) - { - Debug.WriteLine("Failed to get model details from Hugging Face"); - return []; - } + if (hfFiles == null) + { + Debug.WriteLine("Failed to get model details from Hugging Face"); + return []; + } - if (hfUrl.IsFile) - { - hfFiles = hfFiles.Where(f => f.Path == hfUrl.Path).ToList(); - } + if (hfUrl.IsFile) + { + hfFiles = hfFiles.Where(f => f.Path == hfUrl.Path).ToList(); + } - if (hfFiles.Any(f => f.Type == "directory")) - { - var baseUrl = $"https://huggingface.co/api/models/{hfUrl.Organization}/{hfUrl.Repo}/tree/{hfUrl.Ref}"; + if (hfFiles.Any(f => f.Type == "directory")) + { + var baseUrl = $"https://huggingface.co/api/models/{hfUrl.Organization}/{hfUrl.Repo}/tree/{hfUrl.Ref}"; - using var httpClient = httpMessageHandler != null ? new HttpClient(httpMessageHandler) : new HttpClient(); + using var httpClient = httpMessageHandler != null ? new HttpClient(httpMessageHandler) : new HttpClient(); - ActionBlock actionBlock = null!; - actionBlock = new ActionBlock( - async (string path) => - { - var response = await httpClient.GetAsync($"{baseUrl}/{path}", cancellationToken); + ActionBlock actionBlock = null!; + actionBlock = new ActionBlock( + async (string path) => + { + var response = await httpClient.GetAsync($"{baseUrl}/{path}", cancellationToken); #if NET8_0_OR_GREATER - var stream = await response.Content.ReadAsStreamAsync(cancellationToken); + var stream = await response.Content.ReadAsStreamAsync(cancellationToken); #else - var stream = await response.Content.ReadAsStreamAsync(); + var stream = await response.Content.ReadAsStreamAsync(); #endif - var files = await JsonSerializer.DeserializeAsync(stream, SourceGenerationContext.Default.ListHuggingFaceModelFileDetails, cancellationToken); - if (files != null) + var files = await JsonSerializer.DeserializeAsync(stream, SourceGenerationContext.Default.ListHuggingFaceModelFileDetails, cancellationToken); + if (files != null) + { + lock (hfFiles) { - lock (hfFiles) + foreach (var file in files.Where(f => f.Type != "directory")) { - foreach (var file in files.Where(f => f.Type != "directory")) - { - hfFiles.Add(file); - } - } - - foreach (var folder in files.Where(f => f.Type == "directory" && f.Path != null)) - { - actionBlock.Post(folder.Path!); + hfFiles.Add(file); } } - if (actionBlock.InputCount == 0) + foreach (var folder in files.Where(f => f.Type == "directory" && f.Path != null)) { - actionBlock.Complete(); + actionBlock.Post(folder.Path!); } - }, - new ExecutionDataflowBlockOptions - { - MaxDegreeOfParallelism = 4, - CancellationToken = cancellationToken - }); + } - foreach (var folder in hfFiles.Where(f => f.Type == "directory" && f.Path != null)) + if (actionBlock.InputCount == 0) + { + actionBlock.Complete(); + } + }, + new ExecutionDataflowBlockOptions { - actionBlock.Post(folder.Path!); - } + MaxDegreeOfParallelism = 4, + CancellationToken = cancellationToken + }); - await actionBlock.Completion; + foreach (var folder in hfFiles.Where(f => f.Type == "directory" && f.Path != null)) + { + actionBlock.Post(folder.Path!); } - return hfFiles.Where(f => f.Type != "directory").Select(f => - new ModelFileDetails() - { - DownloadUrl = $"https://huggingface.co/{hfUrl.Organization}/{hfUrl.Repo}/resolve/{hfUrl.Ref}/{f.Path}", - Size = f.Size, - Name = (f.Path ?? string.Empty).Split(["/"], StringSplitOptions.RemoveEmptyEntries).LastOrDefault(), - Path = f.Path - }).ToList(); + await actionBlock.Completion; } - /// - /// Filters the list of files to download based on the specified file filters. - /// - /// The list of files to download. - /// The list of file filters (wildcards) to apply. - /// The filtered list of files to download. - public static List FilterFiles(List filesToDownload, List? fileFilters) - { - if (fileFilters == null || fileFilters.Count == 0) + return hfFiles.Where(f => f.Type != "directory").Select(f => + new ModelFileDetails() { - return filesToDownload; - } + DownloadUrl = $"https://huggingface.co/{hfUrl.Organization}/{hfUrl.Repo}/resolve/{hfUrl.Ref}/{f.Path}", + Size = f.Size, + Name = (f.Path ?? string.Empty).Split(["/"], StringSplitOptions.RemoveEmptyEntries).LastOrDefault(), + Path = f.Path + }).ToList(); + } - return filesToDownload - .Where(f => fileFilters.Any(filter => - f.Path != null && - f.Path.EndsWith(filter, StringComparison.InvariantCultureIgnoreCase))) - .ToList(); + /// + /// Filters the list of files to download based on the specified file filters. + /// + /// The list of files to download. + /// The list of file filters (wildcards) to apply. + /// The filtered list of files to download. + public static List FilterFiles(List filesToDownload, List? fileFilters) + { + if (fileFilters == null || fileFilters.Count == 0) + { + return filesToDownload; } + + return filesToDownload + .Where(f => fileFilters.Any(filter => + f.Path != null && + f.Path.EndsWith(filter, StringComparison.InvariantCultureIgnoreCase))) + .ToList(); } } \ No newline at end of file diff --git a/AIDevGallery.Utils/SourceGenerationContext.cs b/AIDevGallery.Utils/SourceGenerationContext.cs index 86679e0..2fa31fd 100644 --- a/AIDevGallery.Utils/SourceGenerationContext.cs +++ b/AIDevGallery.Utils/SourceGenerationContext.cs @@ -4,12 +4,11 @@ using System.Collections.Generic; using System.Text.Json.Serialization; -namespace AIDevGallery.Utils +namespace AIDevGallery.Utils; + +[JsonSourceGenerationOptions(WriteIndented = true, AllowTrailingCommas = true)] +[JsonSerializable(typeof(List))] +[JsonSerializable(typeof(List))] +internal partial class SourceGenerationContext : JsonSerializerContext { - [JsonSourceGenerationOptions(WriteIndented = true, AllowTrailingCommas = true)] - [JsonSerializable(typeof(List))] - [JsonSerializable(typeof(List))] - internal partial class SourceGenerationContext : JsonSerializerContext - { - } } \ No newline at end of file diff --git a/AIDevGallery/App.xaml.cs b/AIDevGallery/App.xaml.cs index 51dac25..5b6558a 100644 --- a/AIDevGallery/App.xaml.cs +++ b/AIDevGallery/App.xaml.cs @@ -12,142 +12,141 @@ using System.Linq; using System.Threading.Tasks; -namespace AIDevGallery +namespace AIDevGallery; + +/// +/// Provides application-specific behavior to supplement the default Application class. +/// +public partial class App : Application { /// - /// Provides application-specific behavior to supplement the default Application class. + /// Gets, or initializes, the singleton application object. This is the first line of authored code + /// executed, and as such is the logical equivalent of main() or WinMain(). /// - public partial class App : Application + internal static MainWindow MainWindow { get; private set; } = null!; + internal static ModelCache ModelCache { get; private set; } = null!; + internal static AppData AppData { get; private set; } = null!; + internal static List SearchIndex { get; private set; } = null!; + + internal App() { - /// - /// Gets, or initializes, the singleton application object. This is the first line of authored code - /// executed, and as such is the logical equivalent of main() or WinMain(). - /// - internal static MainWindow MainWindow { get; private set; } = null!; - internal static ModelCache ModelCache { get; private set; } = null!; - internal static AppData AppData { get; private set; } = null!; - internal static List SearchIndex { get; private set; } = null!; - - internal App() - { - this.InitializeComponent(); - } + this.InitializeComponent(); + } - /// - /// Invoked when the application is launched. - /// - /// Details about the launch request and process. - protected override async void OnLaunched(LaunchActivatedEventArgs args) - { - await LoadSamples(); - AppActivationArguments appActivationArguments = AppInstance.GetCurrent().GetActivatedEventArgs(); - var activationParam = ActivationHelper.GetActivationParam(appActivationArguments); - MainWindow = new MainWindow(activationParam); + /// + /// Invoked when the application is launched. + /// + /// Details about the launch request and process. + protected override async void OnLaunched(LaunchActivatedEventArgs args) + { + await LoadSamples(); + AppActivationArguments appActivationArguments = AppInstance.GetCurrent().GetActivatedEventArgs(); + var activationParam = ActivationHelper.GetActivationParam(appActivationArguments); + MainWindow = new MainWindow(activationParam); - MainWindow.Activate(); - } + MainWindow.Activate(); + } - internal static List FindSampleItemById(string id) + internal static List FindSampleItemById(string id) + { + foreach (var sample in SampleDetails.Samples) { - foreach (var sample in SampleDetails.Samples) + if (sample.Id == id) { - if (sample.Id == id) - { - return sample.Model1Types; - } + return sample.Model1Types; } + } - foreach (var modelFamily in ModelTypeHelpers.ModelFamilyDetails) + foreach (var modelFamily in ModelTypeHelpers.ModelFamilyDetails) + { + if (modelFamily.Value.Id == id) { - if (modelFamily.Value.Id == id) - { - return [modelFamily.Key]; - } + return [modelFamily.Key]; } + } - foreach (var modelGroup in ModelTypeHelpers.ModelGroupDetails) + foreach (var modelGroup in ModelTypeHelpers.ModelGroupDetails) + { + if (modelGroup.Value.Id == id) { - if (modelGroup.Value.Id == id) - { - return [modelGroup.Key]; - } + return [modelGroup.Key]; } + } - foreach (var modelDetails in ModelTypeHelpers.ModelDetails) + foreach (var modelDetails in ModelTypeHelpers.ModelDetails) + { + if (modelDetails.Value.Id == id) { - if (modelDetails.Value.Id == id) - { - return [modelDetails.Key]; - } + return [modelDetails.Key]; } + } - foreach (var apiDefinition in ModelTypeHelpers.ApiDefinitionDetails) + foreach (var apiDefinition in ModelTypeHelpers.ApiDefinitionDetails) + { + if (apiDefinition.Value.Id == id) { - if (apiDefinition.Value.Id == id) - { - return [apiDefinition.Key]; - } + return [apiDefinition.Key]; } - - return []; } - internal static Scenario? FindScenarioById(string id) + return []; + } + + internal static Scenario? FindScenarioById(string id) + { + foreach (var category in ScenarioCategoryHelpers.AllScenarioCategories) { - foreach (var category in ScenarioCategoryHelpers.AllScenarioCategories) + var foundScenario = category.Scenarios.FirstOrDefault(scenario => scenario.Id == id); + if (foundScenario != null) { - var foundScenario = category.Scenarios.FirstOrDefault(scenario => scenario.Id == id); - if (foundScenario != null) - { - return foundScenario; - } + return foundScenario; } - - return null; } - private async Task LoadSamples() - { - AppData = await AppData.GetForApp(); - TelemetryFactory.Get().IsDiagnosticTelemetryOn = false; // AppData.IsDiagnosticDataEnabled; - ModelCache = await ModelCache.CreateForApp(AppData); - GenerateSearchIndex(); - } + return null; + } + + private async Task LoadSamples() + { + AppData = await AppData.GetForApp(); + TelemetryFactory.Get().IsDiagnosticTelemetryOn = false; // AppData.IsDiagnosticDataEnabled; + ModelCache = await ModelCache.CreateForApp(AppData); + GenerateSearchIndex(); + } - private void GenerateSearchIndex() + private void GenerateSearchIndex() + { + SearchIndex = []; + foreach (ScenarioCategory category in ScenarioCategoryHelpers.AllScenarioCategories) { - SearchIndex = []; - foreach (ScenarioCategory category in ScenarioCategoryHelpers.AllScenarioCategories) + foreach (Scenario scenario in category.Scenarios) { - foreach (Scenario scenario in category.Scenarios) - { - SearchIndex.Add(new SearchResult() { Label = scenario.Name, Icon = scenario.Icon!, Description = scenario.Description!, Tag = scenario }); - } + SearchIndex.Add(new SearchResult() { Label = scenario.Name, Icon = scenario.Icon!, Description = scenario.Description!, Tag = scenario }); } + } - List rootModels = [.. ModelTypeHelpers.ModelGroupDetails.Keys]; - rootModels.AddRange(ModelTypeHelpers.ModelFamilyDetails.Keys); + List rootModels = [.. ModelTypeHelpers.ModelGroupDetails.Keys]; + rootModels.AddRange(ModelTypeHelpers.ModelFamilyDetails.Keys); - foreach (var key in rootModels) + foreach (var key in rootModels) + { + if (ModelTypeHelpers.ParentMapping.TryGetValue(key, out List? innerItems)) { - if (ModelTypeHelpers.ParentMapping.TryGetValue(key, out List? innerItems)) + if (innerItems?.Count > 0) { - if (innerItems?.Count > 0) + foreach (var childNavigationItem in innerItems) { - foreach (var childNavigationItem in innerItems) + if (ModelTypeHelpers.ModelGroupDetails.TryGetValue(childNavigationItem, out var modelGroup)) + { + SearchIndex.Add(new SearchResult() { Label = modelGroup.Name, Icon = modelGroup.Icon, Description = modelGroup.Name!, Tag = childNavigationItem }); + } + else if (ModelTypeHelpers.ModelFamilyDetails.TryGetValue(childNavigationItem, out var modelFamily)) + { + SearchIndex.Add(new SearchResult() { Label = modelFamily.Name, Description = modelFamily.Description, Tag = childNavigationItem }); + } + else if (ModelTypeHelpers.ApiDefinitionDetails.TryGetValue(childNavigationItem, out var apiDefinition)) { - if (ModelTypeHelpers.ModelGroupDetails.TryGetValue(childNavigationItem, out var modelGroup)) - { - SearchIndex.Add(new SearchResult() { Label = modelGroup.Name, Icon = modelGroup.Icon, Description = modelGroup.Name!, Tag = childNavigationItem }); - } - else if (ModelTypeHelpers.ModelFamilyDetails.TryGetValue(childNavigationItem, out var modelFamily)) - { - SearchIndex.Add(new SearchResult() { Label = modelFamily.Name, Description = modelFamily.Description, Tag = childNavigationItem }); - } - else if (ModelTypeHelpers.ApiDefinitionDetails.TryGetValue(childNavigationItem, out var apiDefinition)) - { - SearchIndex.Add(new SearchResult() { Label = apiDefinition.Name, Icon = apiDefinition.Icon, Description = apiDefinition.Name!, Tag = childNavigationItem }); - } + SearchIndex.Add(new SearchResult() { Label = apiDefinition.Name, Icon = apiDefinition.Icon, Description = apiDefinition.Name!, Tag = childNavigationItem }); } } } diff --git a/AIDevGallery/Controls/CopyButton/CopyButton.cs b/AIDevGallery/Controls/CopyButton/CopyButton.cs index 3179e1f..f416b63 100644 --- a/AIDevGallery/Controls/CopyButton/CopyButton.cs +++ b/AIDevGallery/Controls/CopyButton/CopyButton.cs @@ -6,30 +6,29 @@ using Microsoft.UI.Xaml.Controls; using Microsoft.UI.Xaml.Media.Animation; -namespace AIDevGallery.Controls +namespace AIDevGallery.Controls; + +internal sealed partial class CopyButton : Button { - internal sealed partial class CopyButton : Button + public CopyButton() { - public CopyButton() - { - this.DefaultStyleKey = typeof(CopyButton); - } + this.DefaultStyleKey = typeof(CopyButton); + } - private void CopyButton_Click(object sender, RoutedEventArgs e) + private void CopyButton_Click(object sender, RoutedEventArgs e) + { + if (GetTemplateChild("CopyToClipboardSuccessAnimation") is Storyboard storyBoard) { - if (GetTemplateChild("CopyToClipboardSuccessAnimation") is Storyboard storyBoard) - { - storyBoard.Begin(); - } - - NarratorHelper.Announce(this, "Copied to clipboard", "CopiedToClipboardActivityId"); + storyBoard.Begin(); } - protected override void OnApplyTemplate() - { - Click -= CopyButton_Click; - base.OnApplyTemplate(); - Click += CopyButton_Click; - } + NarratorHelper.Announce(this, "Copied to clipboard", "CopiedToClipboardActivityId"); + } + + protected override void OnApplyTemplate() + { + Click -= CopyButton_Click; + base.OnApplyTemplate(); + Click += CopyButton_Click; } } \ No newline at end of file diff --git a/AIDevGallery/Controls/DownloadProgressList.xaml.cs b/AIDevGallery/Controls/DownloadProgressList.xaml.cs index 242614b..43cef58 100644 --- a/AIDevGallery/Controls/DownloadProgressList.xaml.cs +++ b/AIDevGallery/Controls/DownloadProgressList.xaml.cs @@ -8,82 +8,81 @@ using System.Collections.ObjectModel; using System.Linq; -namespace AIDevGallery.Controls +namespace AIDevGallery.Controls; + +internal sealed partial class DownloadProgressList : UserControl { - internal sealed partial class DownloadProgressList : UserControl + private readonly ObservableCollection downloadProgresses = []; + public DownloadProgressList() + { + this.InitializeComponent(); + App.ModelCache.DownloadQueue.ModelsChanged += DownloadQueue_ModelsChanged; + PopulateModels(); + } + + private void PopulateModels() { - private readonly ObservableCollection downloadProgresses = []; - public DownloadProgressList() + downloadProgresses.Clear(); + foreach (var model in App.ModelCache.DownloadQueue.GetDownloads()) { - this.InitializeComponent(); - App.ModelCache.DownloadQueue.ModelsChanged += DownloadQueue_ModelsChanged; - PopulateModels(); + downloadProgresses.Add(new DownloadableModel(model)); } + } - private void PopulateModels() + private void DownloadQueue_ModelsChanged(ModelDownloadQueue sender) + { + foreach (var model in sender.GetDownloads()) { - downloadProgresses.Clear(); - foreach (var model in App.ModelCache.DownloadQueue.GetDownloads()) + var existingDownload = downloadProgresses.FirstOrDefault(x => x.ModelDetails.Url == model.Details.Url); + if (existingDownload != null && existingDownload.Status == DownloadStatus.Canceled) { - downloadProgresses.Add(new DownloadableModel(model)); + downloadProgresses.Remove(existingDownload); } - } - private void DownloadQueue_ModelsChanged(ModelDownloadQueue sender) - { - foreach (var model in sender.GetDownloads()) + if (existingDownload == null || existingDownload.Status == DownloadStatus.Canceled) { - var existingDownload = downloadProgresses.FirstOrDefault(x => x.ModelDetails.Url == model.Details.Url); - if (existingDownload != null && existingDownload.Status == DownloadStatus.Canceled) - { - downloadProgresses.Remove(existingDownload); - } - - if (existingDownload == null || existingDownload.Status == DownloadStatus.Canceled) - { - downloadProgresses.Add(new DownloadableModel(model)); - } + downloadProgresses.Add(new DownloadableModel(model)); } } + } - private void CancelDownloadModelButton_Click(object sender, RoutedEventArgs e) + private void CancelDownloadModelButton_Click(object sender, RoutedEventArgs e) + { + if (sender is Button button && button.Tag is DownloadableModel downloadableModel) { - if (sender is Button button && button.Tag is DownloadableModel downloadableModel) - { - downloadableModel.CancelDownload(); - } + downloadableModel.CancelDownload(); } + } - private void GoToModelPageClicked(object sender, RoutedEventArgs e) + private void GoToModelPageClicked(object sender, RoutedEventArgs e) + { + if (sender is Button button && button.Tag is DownloadableModel downloadableModel) { - if (sender is Button button && button.Tag is DownloadableModel downloadableModel) - { - var modelDetails = downloadableModel.ModelDetails; + var modelDetails = downloadableModel.ModelDetails; - if (modelDetails != null) - { - App.MainWindow.Navigate("Models", modelDetails); - } + if (modelDetails != null) + { + App.MainWindow.Navigate("Models", modelDetails); } } + } - private void RetryDownloadClicked(object sender, RoutedEventArgs e) + private void RetryDownloadClicked(object sender, RoutedEventArgs e) + { + if (sender is Button button && button.Tag is DownloadableModel downloadableModel) { - if (sender is Button button && button.Tag is DownloadableModel downloadableModel) - { - downloadProgresses.Remove(downloadableModel); - App.ModelCache.AddModelToDownloadQueue(downloadableModel.ModelDetails); - } + downloadProgresses.Remove(downloadableModel); + App.ModelCache.AddModelToDownloadQueue(downloadableModel.ModelDetails); } + } - private void ClearHistory_Click(object sender, RoutedEventArgs e) + private void ClearHistory_Click(object sender, RoutedEventArgs e) + { + foreach (DownloadableModel model in downloadProgresses.ToList()) { - foreach (DownloadableModel model in downloadProgresses.ToList()) + if (model.Status is DownloadStatus.Completed or DownloadStatus.Canceled) { - if (model.Status is DownloadStatus.Completed or DownloadStatus.Canceled) - { - downloadProgresses.Remove(model); - } + downloadProgresses.Remove(model); } } } diff --git a/AIDevGallery/Controls/HomePage/SampleRow.xaml.cs b/AIDevGallery/Controls/HomePage/SampleRow.xaml.cs index f42bfd4..ba426e8 100644 --- a/AIDevGallery/Controls/HomePage/SampleRow.xaml.cs +++ b/AIDevGallery/Controls/HomePage/SampleRow.xaml.cs @@ -6,83 +6,82 @@ using System; using System.Collections.ObjectModel; -namespace AIDevGallery.Controls -{ - internal partial class SampleRow : UserControl - { - public static readonly DependencyProperty ShowCategoryProperty = DependencyProperty.Register( - nameof(ShowCategory), - typeof(bool), - typeof(SampleRow), - new PropertyMetadata(defaultValue: true)); +namespace AIDevGallery.Controls; - public bool ShowCategory - { - get => (bool)GetValue(ShowCategoryProperty); - set => SetValue(ShowCategoryProperty, value); - } +internal partial class SampleRow : UserControl +{ + public static readonly DependencyProperty ShowCategoryProperty = DependencyProperty.Register( + nameof(ShowCategory), + typeof(bool), + typeof(SampleRow), + new PropertyMetadata(defaultValue: true)); - public static readonly DependencyProperty CategoryImageUrlProperty = DependencyProperty.Register( - nameof(CategoryImageUrl), - typeof(Uri), - typeof(SampleRow), - new PropertyMetadata(defaultValue: null)); + public bool ShowCategory + { + get => (bool)GetValue(ShowCategoryProperty); + set => SetValue(ShowCategoryProperty, value); + } - public Uri CategoryImageUrl - { - get => (Uri)GetValue(CategoryImageUrlProperty); - set => SetValue(CategoryImageUrlProperty, value); - } + public static readonly DependencyProperty CategoryImageUrlProperty = DependencyProperty.Register( + nameof(CategoryImageUrl), + typeof(Uri), + typeof(SampleRow), + new PropertyMetadata(defaultValue: null)); - public static readonly DependencyProperty CategoryHeaderProperty = DependencyProperty.Register(nameof(CategoryHeader), typeof(string), typeof(SampleRow), new PropertyMetadata(defaultValue: null)); + public Uri CategoryImageUrl + { + get => (Uri)GetValue(CategoryImageUrlProperty); + set => SetValue(CategoryImageUrlProperty, value); + } - public string CategoryHeader - { - get => (string)GetValue(CategoryHeaderProperty); - set => SetValue(CategoryHeaderProperty, value); - } + public static readonly DependencyProperty CategoryHeaderProperty = DependencyProperty.Register(nameof(CategoryHeader), typeof(string), typeof(SampleRow), new PropertyMetadata(defaultValue: null)); - public static readonly DependencyProperty CategoryDescriptionProperty = DependencyProperty.Register(nameof(CategoryDescription), typeof(string), typeof(SampleRow), new PropertyMetadata(defaultValue: null)); + public string CategoryHeader + { + get => (string)GetValue(CategoryHeaderProperty); + set => SetValue(CategoryHeaderProperty, value); + } - public string CategoryDescription - { - get => (string)GetValue(CategoryDescriptionProperty); - set => SetValue(CategoryDescriptionProperty, value); - } + public static readonly DependencyProperty CategoryDescriptionProperty = DependencyProperty.Register(nameof(CategoryDescription), typeof(string), typeof(SampleRow), new PropertyMetadata(defaultValue: null)); - public static readonly DependencyProperty SampleCardsProperty = DependencyProperty.Register(nameof(SampleCards), typeof(ObservableCollection), typeof(SampleRow), new PropertyMetadata(null)); + public string CategoryDescription + { + get => (string)GetValue(CategoryDescriptionProperty); + set => SetValue(CategoryDescriptionProperty, value); + } - public ObservableCollection SampleCards - { - get => (ObservableCollection)GetValue(SampleCardsProperty); - set => SetValue(SampleCardsProperty, value); - } + public static readonly DependencyProperty SampleCardsProperty = DependencyProperty.Register(nameof(SampleCards), typeof(ObservableCollection), typeof(SampleRow), new PropertyMetadata(null)); - public SampleRow() - { - this.InitializeComponent(); - SampleCards = []; - } + public ObservableCollection SampleCards + { + get => (ObservableCollection)GetValue(SampleCardsProperty); + set => SetValue(SampleCardsProperty, value); + } - private void ItemsView_ItemInvoked(ItemsView sender, ItemsViewItemInvokedEventArgs args) - { - if (args.InvokedItem is RowSample item) - { - App.MainWindow.NavigateToPage(App.FindScenarioById(item.Id!)); - } - } + public SampleRow() + { + this.InitializeComponent(); + SampleCards = []; + } - private void AllSamplesButton_Click(object sender, RoutedEventArgs e) + private void ItemsView_ItemInvoked(ItemsView sender, ItemsViewItemInvokedEventArgs args) + { + if (args.InvokedItem is RowSample item) { - App.MainWindow.Navigate("samples"); + App.MainWindow.NavigateToPage(App.FindScenarioById(item.Id!)); } } - internal class RowSample + private void AllSamplesButton_Click(object sender, RoutedEventArgs e) { - public string? Title { get; set; } - public string? Description { get; set; } - public IconElement? Icon { get; set; } - public string? Id { get; set; } + App.MainWindow.Navigate("samples"); } +} + +internal class RowSample +{ + public string? Title { get; set; } + public string? Description { get; set; } + public IconElement? Icon { get; set; } + public string? Id { get; set; } } \ No newline at end of file diff --git a/AIDevGallery/Controls/HomePage/SamplesCarousel.xaml.cs b/AIDevGallery/Controls/HomePage/SamplesCarousel.xaml.cs index fb8ba92..8776e01 100644 --- a/AIDevGallery/Controls/HomePage/SamplesCarousel.xaml.cs +++ b/AIDevGallery/Controls/HomePage/SamplesCarousel.xaml.cs @@ -4,46 +4,45 @@ using Microsoft.UI.Xaml; using Microsoft.UI.Xaml.Controls; -namespace AIDevGallery.Controls +namespace AIDevGallery.Controls; + +internal partial class SamplesCarousel : UserControl { - internal partial class SamplesCarousel : UserControl + public SamplesCarousel() { - public SamplesCarousel() - { - this.InitializeComponent(); - SetupSampleView(); - } + this.InitializeComponent(); + SetupSampleView(); + } - private void SetupSampleView() + private void SetupSampleView() + { + if (App.AppData.MostRecentlyUsedItems.Count > 0) { - if (App.AppData.MostRecentlyUsedItems.Count > 0) - { - RecentItem.Visibility = Visibility.Visible; + RecentItem.Visibility = Visibility.Visible; - foreach (var item in App.AppData.MostRecentlyUsedItems) + foreach (var item in App.AppData.MostRecentlyUsedItems) + { + RowSample s = new() { - RowSample s = new() - { - Title = item.DisplayName, - Icon = new FontIcon() { Glyph = item.Icon }, - Description = item.Description, - Id = item.ItemId - }; - RecentItemsRow.SampleCards.Add(s); - } + Title = item.DisplayName, + Icon = new FontIcon() { Glyph = item.Icon }, + Description = item.Description, + Id = item.ItemId + }; + RecentItemsRow.SampleCards.Add(s); } } + } - private void UserControl_Loaded(object sender, RoutedEventArgs e) + private void UserControl_Loaded(object sender, RoutedEventArgs e) + { + if (App.AppData.MostRecentlyUsedItems.Count > 0) { - if (App.AppData.MostRecentlyUsedItems.Count > 0) - { - FilterBar.SelectedItem = FilterBar.Items[0]; - } - else - { - FilterBar.SelectedItem = FilterBar.Items[1]; - } + FilterBar.SelectedItem = FilterBar.Items[0]; + } + else + { + FilterBar.SelectedItem = FilterBar.Items[1]; } } } \ No newline at end of file diff --git a/AIDevGallery/Controls/HomePage/TileGallery.xaml.cs b/AIDevGallery/Controls/HomePage/TileGallery.xaml.cs index 6acc3c9..ff15c8a 100644 --- a/AIDevGallery/Controls/HomePage/TileGallery.xaml.cs +++ b/AIDevGallery/Controls/HomePage/TileGallery.xaml.cs @@ -4,76 +4,75 @@ using Microsoft.UI.Xaml; using Microsoft.UI.Xaml.Controls; -namespace AIDevGallery.Controls +namespace AIDevGallery.Controls; + +internal sealed partial class TileGallery : UserControl { - internal sealed partial class TileGallery : UserControl + public TileGallery() + { + this.InitializeComponent(); + } + + public object Source + { + get => (object)GetValue(SourceProperty); + set => SetValue(SourceProperty, value); + } + + public static readonly DependencyProperty SourceProperty = + DependencyProperty.Register("Source", typeof(object), typeof(TileGallery), new PropertyMetadata(null)); + + private void Scroller_ViewChanging(object sender, ScrollViewerViewChangingEventArgs e) { - public TileGallery() + if (e.FinalView.HorizontalOffset < 1) { - this.InitializeComponent(); + ScrollBackBtn.Visibility = Visibility.Collapsed; } - - public object Source + else if (e.FinalView.HorizontalOffset > 1) { - get => (object)GetValue(SourceProperty); - set => SetValue(SourceProperty, value); + ScrollBackBtn.Visibility = Visibility.Visible; } - public static readonly DependencyProperty SourceProperty = - DependencyProperty.Register("Source", typeof(object), typeof(TileGallery), new PropertyMetadata(null)); - - private void Scroller_ViewChanging(object sender, ScrollViewerViewChangingEventArgs e) + if (e.FinalView.HorizontalOffset > scroller.ScrollableWidth - 1) { - if (e.FinalView.HorizontalOffset < 1) - { - ScrollBackBtn.Visibility = Visibility.Collapsed; - } - else if (e.FinalView.HorizontalOffset > 1) - { - ScrollBackBtn.Visibility = Visibility.Visible; - } - - if (e.FinalView.HorizontalOffset > scroller.ScrollableWidth - 1) - { - ScrollForwardBtn.Visibility = Visibility.Collapsed; - } - else if (e.FinalView.HorizontalOffset < scroller.ScrollableWidth - 1) - { - ScrollForwardBtn.Visibility = Visibility.Visible; - } + ScrollForwardBtn.Visibility = Visibility.Collapsed; } - - private void ScrollBackBtn_Click(object sender, RoutedEventArgs e) + else if (e.FinalView.HorizontalOffset < scroller.ScrollableWidth - 1) { - scroller.ChangeView(scroller.HorizontalOffset - scroller.ViewportWidth, null, null); - - // Manually focus to ScrollForwardBtn since this button disappears after scrolling to the end. - ScrollForwardBtn.Focus(FocusState.Programmatic); + ScrollForwardBtn.Visibility = Visibility.Visible; } + } - private void ScrollForwardBtn_Click(object sender, RoutedEventArgs e) - { - scroller.ChangeView(scroller.HorizontalOffset + scroller.ViewportWidth, null, null); + private void ScrollBackBtn_Click(object sender, RoutedEventArgs e) + { + scroller.ChangeView(scroller.HorizontalOffset - scroller.ViewportWidth, null, null); - // Manually focus to ScrollBackBtn since this button disappears after scrolling to the end. - ScrollBackBtn.Focus(FocusState.Programmatic); - } + // Manually focus to ScrollForwardBtn since this button disappears after scrolling to the end. + ScrollForwardBtn.Focus(FocusState.Programmatic); + } - private void Scroller_SizeChanged(object sender, SizeChangedEventArgs e) + private void ScrollForwardBtn_Click(object sender, RoutedEventArgs e) + { + scroller.ChangeView(scroller.HorizontalOffset + scroller.ViewportWidth, null, null); + + // Manually focus to ScrollBackBtn since this button disappears after scrolling to the end. + ScrollBackBtn.Focus(FocusState.Programmatic); + } + + private void Scroller_SizeChanged(object sender, SizeChangedEventArgs e) + { + UpdateScrollButtonsVisibility(); + } + + private void UpdateScrollButtonsVisibility() + { + if (scroller.ScrollableWidth > 0) { - UpdateScrollButtonsVisibility(); + ScrollForwardBtn.Visibility = Visibility.Visible; } - - private void UpdateScrollButtonsVisibility() + else { - if (scroller.ScrollableWidth > 0) - { - ScrollForwardBtn.Visibility = Visibility.Visible; - } - else - { - ScrollForwardBtn.Visibility = Visibility.Collapsed; - } + ScrollForwardBtn.Visibility = Visibility.Collapsed; } } } \ No newline at end of file diff --git a/AIDevGallery/Controls/ModelDropDownButton.xaml.cs b/AIDevGallery/Controls/ModelDropDownButton.xaml.cs index 58f3b92..6be6f54 100644 --- a/AIDevGallery/Controls/ModelDropDownButton.xaml.cs +++ b/AIDevGallery/Controls/ModelDropDownButton.xaml.cs @@ -5,34 +5,33 @@ using Microsoft.UI.Xaml; using Microsoft.UI.Xaml.Controls; -namespace AIDevGallery.Controls +namespace AIDevGallery.Controls; + +internal sealed partial class ModelDropDownButton : UserControl { - internal sealed partial class ModelDropDownButton : UserControl + public ModelDropDownButton() + { + this.InitializeComponent(); + } + + public static readonly DependencyProperty FlyoutContentProperty = DependencyProperty.Register(nameof(FlyoutContent), typeof(object), typeof(ModelDropDownButton), new PropertyMetadata(defaultValue: null)); + + public object FlyoutContent + { + get => (object)GetValue(FlyoutContentProperty); + set => SetValue(FlyoutContentProperty, value); + } + + public static readonly DependencyProperty ModelProperty = DependencyProperty.Register(nameof(Model), typeof(ModelDetails), typeof(ModelDropDownButton), new PropertyMetadata(defaultValue: null)); + + public ModelDetails? Model + { + get => (ModelDetails)GetValue(ModelProperty); + set => SetValue(ModelProperty, value); + } + + public void HideFlyout() { - public ModelDropDownButton() - { - this.InitializeComponent(); - } - - public static readonly DependencyProperty FlyoutContentProperty = DependencyProperty.Register(nameof(FlyoutContent), typeof(object), typeof(ModelDropDownButton), new PropertyMetadata(defaultValue: null)); - - public object FlyoutContent - { - get => (object)GetValue(FlyoutContentProperty); - set => SetValue(FlyoutContentProperty, value); - } - - public static readonly DependencyProperty ModelProperty = DependencyProperty.Register(nameof(Model), typeof(ModelDetails), typeof(ModelDropDownButton), new PropertyMetadata(defaultValue: null)); - - public ModelDetails? Model - { - get => (ModelDetails)GetValue(ModelProperty); - set => SetValue(ModelProperty, value); - } - - public void HideFlyout() - { - DropDown.Flyout.Hide(); - } + DropDown.Flyout.Hide(); } } \ No newline at end of file diff --git a/AIDevGallery/Controls/SampleContainer.xaml.cs b/AIDevGallery/Controls/SampleContainer.xaml.cs index 2ede336..1e5d040 100644 --- a/AIDevGallery/Controls/SampleContainer.xaml.cs +++ b/AIDevGallery/Controls/SampleContainer.xaml.cs @@ -18,344 +18,343 @@ using System.Threading; using System.Threading.Tasks; -namespace AIDevGallery.Controls +namespace AIDevGallery.Controls; + +internal sealed partial class SampleContainer : UserControl { - internal sealed partial class SampleContainer : UserControl - { - private Sample? _sampleCache; - private List? _modelsCache; - private CancellationTokenSource? _sampleLoadingCts; - private TaskCompletionSource? _sampleLoadedCompletionSource; - private double _codePaneWidth; + private Sample? _sampleCache; + private List? _modelsCache; + private CancellationTokenSource? _sampleLoadingCts; + private TaskCompletionSource? _sampleLoadedCompletionSource; + private double _codePaneWidth; - private static readonly List> References = []; + private static readonly List> References = []; - internal static bool AnySamplesLoading() - { - return References.Any(r => r.TryGetTarget(out var sampleContainer) && sampleContainer._sampleLoadedCompletionSource != null); - } + internal static bool AnySamplesLoading() + { + return References.Any(r => r.TryGetTarget(out var sampleContainer) && sampleContainer._sampleLoadedCompletionSource != null); + } - internal static async Task WaitUnloadAllAsync() + internal static async Task WaitUnloadAllAsync() + { + foreach (var reference in References) { - foreach (var reference in References) + if (reference.TryGetTarget(out var sampleContainer)) { - if (reference.TryGetTarget(out var sampleContainer)) + sampleContainer.CancelCTS(); + if (sampleContainer._sampleLoadedCompletionSource != null) { - sampleContainer.CancelCTS(); - if (sampleContainer._sampleLoadedCompletionSource != null) + try + { + await sampleContainer._sampleLoadedCompletionSource.Task; + } + catch (Exception) { - try - { - await sampleContainer._sampleLoadedCompletionSource.Task; - } - catch (Exception) - { - } - finally - { - sampleContainer._sampleLoadedCompletionSource = null; - } + } + finally + { + sampleContainer._sampleLoadedCompletionSource = null; } } } + } + + References.Clear(); + } - References.Clear(); + private void CancelCTS() + { + if (_sampleLoadingCts != null) + { + _sampleLoadingCts.Cancel(); + _sampleLoadingCts = null; } + } - private void CancelCTS() + public SampleContainer() + { + this.InitializeComponent(); + References.Add(new WeakReference(this)); + this.Unloaded += (sender, args) => { - if (_sampleLoadingCts != null) + CancelCTS(); + var reference = References.FirstOrDefault(r => r.TryGetTarget(out var sampleContainer) && sampleContainer == this); + if (reference != null) { - _sampleLoadingCts.Cancel(); - _sampleLoadingCts = null; + References.Remove(reference); } - } + }; + } - public SampleContainer() + public async Task LoadSampleAsync(Sample? sample, List? models) + { + if (sample == null) { - this.InitializeComponent(); - References.Add(new WeakReference(this)); - this.Unloaded += (sender, args) => - { - CancelCTS(); - var reference = References.FirstOrDefault(r => r.TryGetTarget(out var sampleContainer) && sampleContainer == this); - if (reference != null) - { - References.Remove(reference); - } - }; + this.Visibility = Visibility.Collapsed; + return; } - public async Task LoadSampleAsync(Sample? sample, List? models) + this.Visibility = Visibility.Visible; + if (!LoadSampleMetadata(sample, models)) { - if (sample == null) - { - this.Visibility = Visibility.Collapsed; - return; - } + return; + } - this.Visibility = Visibility.Visible; - if (!LoadSampleMetadata(sample, models)) - { - return; - } + CancelCTS(); - CancelCTS(); + if (models == null) + { + NavigatedToSampleEvent.Log(sample.Name ?? string.Empty); + SampleFrame.Navigate(sample.PageType); + VisualStateManager.GoToState(this, "SampleLoaded", true); + return; + } - if (models == null) - { - NavigatedToSampleEvent.Log(sample.Name ?? string.Empty); - SampleFrame.Navigate(sample.PageType); - VisualStateManager.GoToState(this, "SampleLoaded", true); - return; - } + if (models == null || models.Count == 0) + { + VisualStateManager.GoToState(this, "Disabled", true); + SampleFrame.Content = null; + return; + } - if (models == null || models.Count == 0) + var cachedModelsPaths = models.Select(m => + { + // If it is an API, use the URL just to count + if (m.Size == 0) { - VisualStateManager.GoToState(this, "Disabled", true); - SampleFrame.Content = null; - return; + return m.Url; } - var cachedModelsPaths = models.Select(m => - { - // If it is an API, use the URL just to count - if (m.Size == 0) - { - return m.Url; - } - - return App.ModelCache.GetCachedModel(m.Url)?.Path; - }) - .Where(cm => cm != null) - .Select(cm => cm!) - .ToList(); - - if (cachedModelsPaths == null || cachedModelsPaths.Count != models.Count) - { - VisualStateManager.GoToState(this, "Disabled", true); - SampleFrame.Content = null; - return; - } + return App.ModelCache.GetCachedModel(m.Url)?.Path; + }) + .Where(cm => cm != null) + .Select(cm => cm!) + .ToList(); - // model available - VisualStateManager.GoToState(this, "SampleLoading", true); + if (cachedModelsPaths == null || cachedModelsPaths.Count != models.Count) + { + VisualStateManager.GoToState(this, "Disabled", true); SampleFrame.Content = null; + return; + } - _sampleLoadingCts = new CancellationTokenSource(); - _sampleLoadedCompletionSource = new TaskCompletionSource(); - BaseSampleNavigationParameters sampleNavigationParameters; + // model available + VisualStateManager.GoToState(this, "SampleLoading", true); + SampleFrame.Content = null; - var modelPath = cachedModelsPaths.First(); - var token = _sampleLoadingCts.Token; + _sampleLoadingCts = new CancellationTokenSource(); + _sampleLoadedCompletionSource = new TaskCompletionSource(); + BaseSampleNavigationParameters sampleNavigationParameters; - if (cachedModelsPaths.Count == 1) + var modelPath = cachedModelsPaths.First(); + var token = _sampleLoadingCts.Token; + + if (cachedModelsPaths.Count == 1) + { + sampleNavigationParameters = new SampleNavigationParameters( + modelPath, + models.First().HardwareAccelerators.First(), + models.First().PromptTemplate?.ToLlmPromptTemplate(), + _sampleLoadedCompletionSource, + token); + } + else + { + var hardwareAccelerators = new List(); + var promptTemplates = new List(); + foreach (var model in models) { - sampleNavigationParameters = new SampleNavigationParameters( - modelPath, - models.First().HardwareAccelerators.First(), - models.First().PromptTemplate?.ToLlmPromptTemplate(), - _sampleLoadedCompletionSource, - token); + hardwareAccelerators.Add(model.HardwareAccelerators.First()); + promptTemplates.Add(model.PromptTemplate?.ToLlmPromptTemplate()); } - else - { - var hardwareAccelerators = new List(); - var promptTemplates = new List(); - foreach (var model in models) - { - hardwareAccelerators.Add(model.HardwareAccelerators.First()); - promptTemplates.Add(model.PromptTemplate?.ToLlmPromptTemplate()); - } - sampleNavigationParameters = new MultiModelSampleNavigationParameters( - [.. cachedModelsPaths], - [.. hardwareAccelerators], - [.. promptTemplates], - _sampleLoadedCompletionSource, - token); - } + sampleNavigationParameters = new MultiModelSampleNavigationParameters( + [.. cachedModelsPaths], + [.. hardwareAccelerators], + [.. promptTemplates], + _sampleLoadedCompletionSource, + token); + } - NavigatedToSampleEvent.Log(sample.Name ?? string.Empty); - SampleFrame.Navigate(sample.PageType, sampleNavigationParameters); + NavigatedToSampleEvent.Log(sample.Name ?? string.Empty); + SampleFrame.Navigate(sample.PageType, sampleNavigationParameters); - await _sampleLoadedCompletionSource.Task; + await _sampleLoadedCompletionSource.Task; - _sampleLoadedCompletionSource = null; - _sampleLoadingCts = null; + _sampleLoadedCompletionSource = null; + _sampleLoadingCts = null; - NavigatedToSampleLoadedEvent.Log(sample.Name ?? string.Empty); + NavigatedToSampleLoadedEvent.Log(sample.Name ?? string.Empty); - VisualStateManager.GoToState(this, "SampleLoaded", true); + VisualStateManager.GoToState(this, "SampleLoaded", true); - CodePivot.Items.Clear(); + CodePivot.Items.Clear(); - RenderCode(); - } + RenderCode(); + } - [MemberNotNull(nameof(_sampleCache))] - private bool LoadSampleMetadata(Sample sample, List? models) + [MemberNotNull(nameof(_sampleCache))] + private bool LoadSampleMetadata(Sample sample, List? models) + { + if (_sampleCache == sample && + _modelsCache != null && + models != null) { - if (_sampleCache == sample && - _modelsCache != null && - models != null) + var modelsAreEqual = true; + if (_modelsCache.Count != models.Count) { - var modelsAreEqual = true; - if (_modelsCache.Count != models.Count) - { - modelsAreEqual = false; - } - else + modelsAreEqual = false; + } + else + { + for (int i = 0; i < models.Count; i++) { - for (int i = 0; i < models.Count; i++) + ModelDetails? model = models[i]; + if (!_modelsCache[i].Id.Equals(model.Id, StringComparison.Ordinal) || + !_modelsCache[i].HardwareAccelerators.SequenceEqual(model.HardwareAccelerators)) { - ModelDetails? model = models[i]; - if (!_modelsCache[i].Id.Equals(model.Id, StringComparison.Ordinal) || - !_modelsCache[i].HardwareAccelerators.SequenceEqual(model.HardwareAccelerators)) - { - modelsAreEqual = false; - } + modelsAreEqual = false; } } - - if (modelsAreEqual) - { - return false; - } } - _sampleCache = sample; - _modelsCache = models; - - if (sample == null) + if (modelsAreEqual) { - Visibility = Visibility.Collapsed; + return false; } - - return true; } - private void RenderCode(bool force = false) - { - var codeFormatter = new RichTextBlockFormatter(GetStylesFromTheme(ActualTheme)); - - if (_sampleCache == null) - { - return; - } - - if (CodePivot.Items.Count > 0 && !force) - { - return; - } + _sampleCache = sample; + _modelsCache = models; - CodePivot.Items.Clear(); - - if (!string.IsNullOrEmpty(_sampleCache.CSCode)) - { - CodePivot.Items.Add(CreateCodeBlock(codeFormatter, "Sample.xaml.cs", _sampleCache.CSCode, Languages.CSharp)); - } + if (sample == null) + { + Visibility = Visibility.Collapsed; + } - if (!string.IsNullOrEmpty(_sampleCache.XAMLCode)) - { - CodePivot.Items.Add(CreateCodeBlock(codeFormatter, "Sample.xaml", _sampleCache.XAMLCode, Languages.FindById("xaml"))); - } + return true; + } - if (_sampleCache.SharedCode != null && _sampleCache.SharedCode.Count != 0) - { - foreach (var sharedCodeEnum in _sampleCache.SharedCode) - { - string sharedCodeName = Samples.SharedCodeHelpers.GetName(sharedCodeEnum); - string sharedCodeContent = Samples.SharedCodeHelpers.GetSource(sharedCodeEnum); + private void RenderCode(bool force = false) + { + var codeFormatter = new RichTextBlockFormatter(GetStylesFromTheme(ActualTheme)); - CodePivot.Items.Add(CreateCodeBlock(codeFormatter, sharedCodeName, sharedCodeContent, Languages.CSharp)); - } - } + if (_sampleCache == null) + { + return; } - private PivotItem CreateCodeBlock(RichTextBlockFormatter codeFormatter, string header, string code, ILanguage language) + if (CodePivot.Items.Count > 0 && !force) { - var textBlock = new RichTextBlock() - { - Margin = new Thickness(0, 12, 0, 12), - FontFamily = new Microsoft.UI.Xaml.Media.FontFamily("Consolas"), - FontSize = 14, - IsTextSelectionEnabled = true - }; + return; + } - codeFormatter.FormatRichTextBlock(code, language, textBlock); + CodePivot.Items.Clear(); - PivotItem item = new() - { - Header = header, - Content = new ScrollViewer() - { - HorizontalScrollMode = ScrollMode.Auto, - HorizontalScrollBarVisibility = ScrollBarVisibility.Visible, - VerticalScrollMode = ScrollMode.Auto, - VerticalScrollBarVisibility = ScrollBarVisibility.Visible, - Content = textBlock, - Padding = new Thickness(0, 0, 16, 16) - } - }; - AutomationProperties.SetName(item, header); - return item; + if (!string.IsNullOrEmpty(_sampleCache.CSCode)) + { + CodePivot.Items.Add(CreateCodeBlock(codeFormatter, "Sample.xaml.cs", _sampleCache.CSCode, Languages.CSharp)); } - private void UserControl_ActualThemeChanged(FrameworkElement sender, object args) + if (!string.IsNullOrEmpty(_sampleCache.XAMLCode)) { - RenderCode(true); + CodePivot.Items.Add(CreateCodeBlock(codeFormatter, "Sample.xaml", _sampleCache.XAMLCode, Languages.FindById("xaml"))); } - public void ShowCode() + if (_sampleCache.SharedCode != null && _sampleCache.SharedCode.Count != 0) { - RenderCode(); + foreach (var sharedCodeEnum in _sampleCache.SharedCode) + { + string sharedCodeName = Samples.SharedCodeHelpers.GetName(sharedCodeEnum); + string sharedCodeContent = Samples.SharedCodeHelpers.GetSource(sharedCodeEnum); - CodeColumn.Width = _codePaneWidth == 0 ? new GridLength(1, GridUnitType.Star) : new GridLength(_codePaneWidth); - VisualStateManager.GoToState(this, "ShowCodePane", true); + CodePivot.Items.Add(CreateCodeBlock(codeFormatter, sharedCodeName, sharedCodeContent, Languages.CSharp)); + } } + } - public void HideCode() + private PivotItem CreateCodeBlock(RichTextBlockFormatter codeFormatter, string header, string code, ILanguage language) + { + var textBlock = new RichTextBlock() { - _codePaneWidth = CodeColumn.ActualWidth; - VisualStateManager.GoToState(this, "HideCodePane", true); - } + Margin = new Thickness(0, 12, 0, 12), + FontFamily = new Microsoft.UI.Xaml.Media.FontFamily("Consolas"), + FontSize = 14, + IsTextSelectionEnabled = true + }; - private async void NuGetPackage_Click(object sender, RoutedEventArgs e) + codeFormatter.FormatRichTextBlock(code, language, textBlock); + + PivotItem item = new() { - if (sender is HyperlinkButton button && button.Tag is string url) + Header = header, + Content = new ScrollViewer() { - await Windows.System.Launcher.LaunchUriAsync(new Uri("https://www.nuget.org/packages/" + url)); + HorizontalScrollMode = ScrollMode.Auto, + HorizontalScrollBarVisibility = ScrollBarVisibility.Visible, + VerticalScrollMode = ScrollMode.Auto, + VerticalScrollBarVisibility = ScrollBarVisibility.Visible, + Content = textBlock, + Padding = new Thickness(0, 0, 16, 16) } + }; + AutomationProperties.SetName(item, header); + return item; + } + + private void UserControl_ActualThemeChanged(FrameworkElement sender, object args) + { + RenderCode(true); + } + + public void ShowCode() + { + RenderCode(); + + CodeColumn.Width = _codePaneWidth == 0 ? new GridLength(1, GridUnitType.Star) : new GridLength(_codePaneWidth); + VisualStateManager.GoToState(this, "ShowCodePane", true); + } + + public void HideCode() + { + _codePaneWidth = CodeColumn.ActualWidth; + VisualStateManager.GoToState(this, "HideCodePane", true); + } + + private async void NuGetPackage_Click(object sender, RoutedEventArgs e) + { + if (sender is HyperlinkButton button && button.Tag is string url) + { + await Windows.System.Launcher.LaunchUriAsync(new Uri("https://www.nuget.org/packages/" + url)); } + } - private StyleDictionary GetStylesFromTheme(ElementTheme theme) + private StyleDictionary GetStylesFromTheme(ElementTheme theme) + { + if (theme == ElementTheme.Dark) { - if (theme == ElementTheme.Dark) - { - // Adjust DefaultDark Theme to meet contrast accessibility requirements - StyleDictionary darkStyles = StyleDictionary.DefaultDark; - darkStyles[ScopeName.Comment].Foreground = StyleDictionary.BrightGreen; - darkStyles[ScopeName.XmlDocComment].Foreground = StyleDictionary.BrightGreen; - darkStyles[ScopeName.XmlDocTag].Foreground = StyleDictionary.BrightGreen; - darkStyles[ScopeName.XmlComment].Foreground = StyleDictionary.BrightGreen; - darkStyles[ScopeName.XmlDelimiter].Foreground = StyleDictionary.White; - darkStyles[ScopeName.Keyword].Foreground = "#FF41D6FF"; - darkStyles[ScopeName.String].Foreground = "#FFFFB100"; - darkStyles[ScopeName.XmlAttributeValue].Foreground = "#FF41D6FF"; - darkStyles[ScopeName.XmlAttributeQuotes].Foreground = "#FF41D6FF"; - return darkStyles; - } - else - { - StyleDictionary lightStyles = StyleDictionary.DefaultLight; - lightStyles[ScopeName.XmlDocComment].Foreground = "#FF006828"; - lightStyles[ScopeName.XmlDocTag].Foreground = "#FF006828"; - lightStyles[ScopeName.Comment].Foreground = "#FF006828"; - lightStyles[ScopeName.XmlAttribute].Foreground = "#FFB5004D"; - lightStyles[ScopeName.XmlName].Foreground = "#FF400000"; - return lightStyles; - } + // Adjust DefaultDark Theme to meet contrast accessibility requirements + StyleDictionary darkStyles = StyleDictionary.DefaultDark; + darkStyles[ScopeName.Comment].Foreground = StyleDictionary.BrightGreen; + darkStyles[ScopeName.XmlDocComment].Foreground = StyleDictionary.BrightGreen; + darkStyles[ScopeName.XmlDocTag].Foreground = StyleDictionary.BrightGreen; + darkStyles[ScopeName.XmlComment].Foreground = StyleDictionary.BrightGreen; + darkStyles[ScopeName.XmlDelimiter].Foreground = StyleDictionary.White; + darkStyles[ScopeName.Keyword].Foreground = "#FF41D6FF"; + darkStyles[ScopeName.String].Foreground = "#FFFFB100"; + darkStyles[ScopeName.XmlAttributeValue].Foreground = "#FF41D6FF"; + darkStyles[ScopeName.XmlAttributeQuotes].Foreground = "#FF41D6FF"; + return darkStyles; + } + else + { + StyleDictionary lightStyles = StyleDictionary.DefaultLight; + lightStyles[ScopeName.XmlDocComment].Foreground = "#FF006828"; + lightStyles[ScopeName.XmlDocTag].Foreground = "#FF006828"; + lightStyles[ScopeName.Comment].Foreground = "#FF006828"; + lightStyles[ScopeName.XmlAttribute].Foreground = "#FFB5004D"; + lightStyles[ScopeName.XmlName].Foreground = "#FF400000"; + return lightStyles; } } } \ No newline at end of file diff --git a/AIDevGallery/Controls/TitleBar/TitleBar.cs b/AIDevGallery/Controls/TitleBar/TitleBar.cs index 573b878..fd7ba91 100644 --- a/AIDevGallery/Controls/TitleBar/TitleBar.cs +++ b/AIDevGallery/Controls/TitleBar/TitleBar.cs @@ -4,212 +4,211 @@ using Microsoft.UI.Xaml; using Microsoft.UI.Xaml.Controls; -namespace AIDevGallery.Controls -{ - [TemplateVisualState(Name = BackButtonVisibleState, GroupName = BackButtonStates)] - [TemplateVisualState(Name = BackButtonCollapsedState, GroupName = BackButtonStates)] - [TemplateVisualState(Name = PaneButtonVisibleState, GroupName = PaneButtonStates)] - [TemplateVisualState(Name = PaneButtonCollapsedState, GroupName = PaneButtonStates)] - [TemplateVisualState(Name = WindowActivatedState, GroupName = ActivationStates)] - [TemplateVisualState(Name = WindowDeactivatedState, GroupName = ActivationStates)] - [TemplateVisualState(Name = StandardState, GroupName = DisplayModeStates)] - [TemplateVisualState(Name = TallState, GroupName = DisplayModeStates)] - [TemplateVisualState(Name = IconVisibleState, GroupName = IconStates)] - [TemplateVisualState(Name = IconCollapsedState, GroupName = IconStates)] - [TemplateVisualState(Name = ContentVisibleState, GroupName = ContentStates)] - [TemplateVisualState(Name = ContentCollapsedState, GroupName = ContentStates)] - [TemplateVisualState(Name = FooterVisibleState, GroupName = FooterStates)] - [TemplateVisualState(Name = FooterCollapsedState, GroupName = FooterStates)] - [TemplateVisualState(Name = WideState, GroupName = ReflowStates)] - [TemplateVisualState(Name = NarrowState, GroupName = ReflowStates)] - [TemplatePart(Name = PartBackButton, Type = typeof(Button))] - [TemplatePart(Name = PartPaneButton, Type = typeof(Button))] - [TemplatePart(Name = nameof(PART_LeftPaddingColumn), Type = typeof(ColumnDefinition))] - [TemplatePart(Name = nameof(PART_RightPaddingColumn), Type = typeof(ColumnDefinition))] - [TemplatePart(Name = nameof(PART_ButtonHolder), Type = typeof(StackPanel))] +namespace AIDevGallery.Controls; + +[TemplateVisualState(Name = BackButtonVisibleState, GroupName = BackButtonStates)] +[TemplateVisualState(Name = BackButtonCollapsedState, GroupName = BackButtonStates)] +[TemplateVisualState(Name = PaneButtonVisibleState, GroupName = PaneButtonStates)] +[TemplateVisualState(Name = PaneButtonCollapsedState, GroupName = PaneButtonStates)] +[TemplateVisualState(Name = WindowActivatedState, GroupName = ActivationStates)] +[TemplateVisualState(Name = WindowDeactivatedState, GroupName = ActivationStates)] +[TemplateVisualState(Name = StandardState, GroupName = DisplayModeStates)] +[TemplateVisualState(Name = TallState, GroupName = DisplayModeStates)] +[TemplateVisualState(Name = IconVisibleState, GroupName = IconStates)] +[TemplateVisualState(Name = IconCollapsedState, GroupName = IconStates)] +[TemplateVisualState(Name = ContentVisibleState, GroupName = ContentStates)] +[TemplateVisualState(Name = ContentCollapsedState, GroupName = ContentStates)] +[TemplateVisualState(Name = FooterVisibleState, GroupName = FooterStates)] +[TemplateVisualState(Name = FooterCollapsedState, GroupName = FooterStates)] +[TemplateVisualState(Name = WideState, GroupName = ReflowStates)] +[TemplateVisualState(Name = NarrowState, GroupName = ReflowStates)] +[TemplatePart(Name = PartBackButton, Type = typeof(Button))] +[TemplatePart(Name = PartPaneButton, Type = typeof(Button))] +[TemplatePart(Name = nameof(PART_LeftPaddingColumn), Type = typeof(ColumnDefinition))] +[TemplatePart(Name = nameof(PART_RightPaddingColumn), Type = typeof(ColumnDefinition))] +[TemplatePart(Name = nameof(PART_ButtonHolder), Type = typeof(StackPanel))] #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - internal partial class TitleBar : Control - { - private const string PartBackButton = "PART_BackButton"; - private const string PartPaneButton = "PART_PaneButton"; +internal partial class TitleBar : Control +{ + private const string PartBackButton = "PART_BackButton"; + private const string PartPaneButton = "PART_PaneButton"; - private const string BackButtonVisibleState = "BackButtonVisible"; - private const string BackButtonCollapsedState = "BackButtonCollapsed"; - private const string BackButtonStates = "BackButtonStates"; + private const string BackButtonVisibleState = "BackButtonVisible"; + private const string BackButtonCollapsedState = "BackButtonCollapsed"; + private const string BackButtonStates = "BackButtonStates"; - private const string PaneButtonVisibleState = "PaneButtonVisible"; - private const string PaneButtonCollapsedState = "PaneButtonCollapsed"; - private const string PaneButtonStates = "PaneButtonStates"; + private const string PaneButtonVisibleState = "PaneButtonVisible"; + private const string PaneButtonCollapsedState = "PaneButtonCollapsed"; + private const string PaneButtonStates = "PaneButtonStates"; - private const string WindowActivatedState = "Activated"; - private const string WindowDeactivatedState = "Deactivated"; - private const string ActivationStates = "WindowActivationStates"; + private const string WindowActivatedState = "Activated"; + private const string WindowDeactivatedState = "Deactivated"; + private const string ActivationStates = "WindowActivationStates"; - private const string IconVisibleState = "IconVisible"; - private const string IconCollapsedState = "IconCollapsed"; - private const string IconStates = "IconStates"; + private const string IconVisibleState = "IconVisible"; + private const string IconCollapsedState = "IconCollapsed"; + private const string IconStates = "IconStates"; - private const string StandardState = "Standard"; - private const string TallState = "Tall"; - private const string DisplayModeStates = "DisplayModeStates"; + private const string StandardState = "Standard"; + private const string TallState = "Tall"; + private const string DisplayModeStates = "DisplayModeStates"; - private const string ContentVisibleState = "ContentVisible"; - private const string ContentCollapsedState = "ContentCollapsed"; - private const string ContentStates = "ContentStates"; + private const string ContentVisibleState = "ContentVisible"; + private const string ContentCollapsedState = "ContentCollapsed"; + private const string ContentStates = "ContentStates"; - private const string FooterVisibleState = "FooterVisible"; - private const string FooterCollapsedState = "FooterCollapsed"; - private const string FooterStates = "FooterStates"; + private const string FooterVisibleState = "FooterVisible"; + private const string FooterCollapsedState = "FooterCollapsed"; + private const string FooterStates = "FooterStates"; - private const string WideState = "Wide"; - private const string NarrowState = "Narrow"; - private const string ReflowStates = "ReflowStates"; + private const string WideState = "Wide"; + private const string NarrowState = "Narrow"; + private const string ReflowStates = "ReflowStates"; #pragma warning disable SA1306 // Field names should begin with lower-case letter - private ColumnDefinition? PART_LeftPaddingColumn; - private ColumnDefinition? PART_RightPaddingColumn; - private StackPanel? PART_ButtonHolder; + private ColumnDefinition? PART_LeftPaddingColumn; + private ColumnDefinition? PART_RightPaddingColumn; + private StackPanel? PART_ButtonHolder; #pragma warning restore SA1306 // Field names should begin with lower-case letter - // Internal tracking (if AutoConfigureCustomTitleBar is on) if we've actually setup the TitleBar yet or not - // We only want to reset TitleBar configuration in app, if we're the TitleBar instance that's managing that state. - private bool _isAutoConfigCompleted; + // Internal tracking (if AutoConfigureCustomTitleBar is on) if we've actually setup the TitleBar yet or not + // We only want to reset TitleBar configuration in app, if we're the TitleBar instance that's managing that state. + private bool _isAutoConfigCompleted; + + public TitleBar() + { + this.DefaultStyleKey = typeof(TitleBar); + } - public TitleBar() + protected override void OnApplyTemplate() + { + PART_LeftPaddingColumn = GetTemplateChild(nameof(PART_LeftPaddingColumn)) as ColumnDefinition; + PART_RightPaddingColumn = GetTemplateChild(nameof(PART_RightPaddingColumn)) as ColumnDefinition; + ConfigureButtonHolder(); + Configure(); + if (GetTemplateChild(PartBackButton) is Button backButton) { - this.DefaultStyleKey = typeof(TitleBar); + backButton.Click -= BackButton_Click; + backButton.Click += BackButton_Click; } - protected override void OnApplyTemplate() + if (GetTemplateChild(PartPaneButton) is Button paneButton) { - PART_LeftPaddingColumn = GetTemplateChild(nameof(PART_LeftPaddingColumn)) as ColumnDefinition; - PART_RightPaddingColumn = GetTemplateChild(nameof(PART_RightPaddingColumn)) as ColumnDefinition; - ConfigureButtonHolder(); - Configure(); - if (GetTemplateChild(PartBackButton) is Button backButton) - { - backButton.Click -= BackButton_Click; - backButton.Click += BackButton_Click; - } - - if (GetTemplateChild(PartPaneButton) is Button paneButton) - { - paneButton.Click -= PaneButton_Click; - paneButton.Click += PaneButton_Click; - } + paneButton.Click -= PaneButton_Click; + paneButton.Click += PaneButton_Click; + } - SizeChanged -= this.TitleBar_SizeChanged; - SizeChanged += this.TitleBar_SizeChanged; + SizeChanged -= this.TitleBar_SizeChanged; + SizeChanged += this.TitleBar_SizeChanged; - Update(); - base.OnApplyTemplate(); - } + Update(); + base.OnApplyTemplate(); + } - private void TitleBar_SizeChanged(object sender, SizeChangedEventArgs e) - { - UpdateVisualStateAndDragRegion(e.NewSize); - } + private void TitleBar_SizeChanged(object sender, SizeChangedEventArgs e) + { + UpdateVisualStateAndDragRegion(e.NewSize); + } - private void UpdateVisualStateAndDragRegion(Windows.Foundation.Size size) + private void UpdateVisualStateAndDragRegion(Windows.Foundation.Size size) + { + if (size.Width <= CompactStateBreakpoint) { - if (size.Width <= CompactStateBreakpoint) + if (Content != null || Footer != null) { - if (Content != null || Footer != null) - { - VisualStateManager.GoToState(this, NarrowState, true); - } + VisualStateManager.GoToState(this, NarrowState, true); } - else - { - VisualStateManager.GoToState(this, WideState, true); - } - - SetDragRegionForCustomTitleBar(); } - - private void BackButton_Click(object sender, RoutedEventArgs e) + else { - BackButtonClick?.Invoke(this, new RoutedEventArgs()); + VisualStateManager.GoToState(this, WideState, true); } - private void PaneButton_Click(object sender, RoutedEventArgs e) + SetDragRegionForCustomTitleBar(); + } + + private void BackButton_Click(object sender, RoutedEventArgs e) + { + BackButtonClick?.Invoke(this, new RoutedEventArgs()); + } + + private void PaneButton_Click(object sender, RoutedEventArgs e) + { + PaneButtonClick?.Invoke(this, new RoutedEventArgs()); + } + + private void ConfigureButtonHolder() + { + if (PART_ButtonHolder != null) { - PaneButtonClick?.Invoke(this, new RoutedEventArgs()); + PART_ButtonHolder.SizeChanged -= PART_ButtonHolder_SizeChanged; } - private void ConfigureButtonHolder() + PART_ButtonHolder = GetTemplateChild(nameof(PART_ButtonHolder)) as StackPanel; + + if (PART_ButtonHolder != null) { - if (PART_ButtonHolder != null) - { - PART_ButtonHolder.SizeChanged -= PART_ButtonHolder_SizeChanged; - } + PART_ButtonHolder.SizeChanged += PART_ButtonHolder_SizeChanged; + } + } + + private void PART_ButtonHolder_SizeChanged(object sender, SizeChangedEventArgs e) + { + SetDragRegionForCustomTitleBar(); + } - PART_ButtonHolder = GetTemplateChild(nameof(PART_ButtonHolder)) as StackPanel; + private void Configure() + { + SetWASDKTitleBar(); + } - if (PART_ButtonHolder != null) - { - PART_ButtonHolder.SizeChanged += PART_ButtonHolder_SizeChanged; - } - } + public void Reset() + { + ResetWASDKTitleBar(); + } - private void PART_ButtonHolder_SizeChanged(object sender, SizeChangedEventArgs e) + private void Update() + { + if (Icon != null) { - SetDragRegionForCustomTitleBar(); + VisualStateManager.GoToState(this, IconVisibleState, true); } - - private void Configure() + else { - SetWASDKTitleBar(); + VisualStateManager.GoToState(this, IconCollapsedState, true); } - public void Reset() + VisualStateManager.GoToState(this, IsBackButtonVisible ? BackButtonVisibleState : BackButtonCollapsedState, true); + VisualStateManager.GoToState(this, IsPaneButtonVisible ? PaneButtonVisibleState : PaneButtonCollapsedState, true); + + if (DisplayMode == DisplayMode.Tall) { - ResetWASDKTitleBar(); + VisualStateManager.GoToState(this, TallState, true); } - - private void Update() + else { - if (Icon != null) - { - VisualStateManager.GoToState(this, IconVisibleState, true); - } - else - { - VisualStateManager.GoToState(this, IconCollapsedState, true); - } - - VisualStateManager.GoToState(this, IsBackButtonVisible ? BackButtonVisibleState : BackButtonCollapsedState, true); - VisualStateManager.GoToState(this, IsPaneButtonVisible ? PaneButtonVisibleState : PaneButtonCollapsedState, true); - - if (DisplayMode == DisplayMode.Tall) - { - VisualStateManager.GoToState(this, TallState, true); - } - else - { - VisualStateManager.GoToState(this, StandardState, true); - } - - if (Content != null) - { - VisualStateManager.GoToState(this, ContentVisibleState, true); - } - else - { - VisualStateManager.GoToState(this, ContentCollapsedState, true); - } + VisualStateManager.GoToState(this, StandardState, true); + } - if (Footer != null) - { - VisualStateManager.GoToState(this, FooterVisibleState, true); - } - else - { - VisualStateManager.GoToState(this, FooterCollapsedState, true); - } + if (Content != null) + { + VisualStateManager.GoToState(this, ContentVisibleState, true); + } + else + { + VisualStateManager.GoToState(this, ContentCollapsedState, true); + } - SetDragRegionForCustomTitleBar(); + if (Footer != null) + { + VisualStateManager.GoToState(this, FooterVisibleState, true); } + else + { + VisualStateManager.GoToState(this, FooterCollapsedState, true); + } + + SetDragRegionForCustomTitleBar(); } -#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member -} \ No newline at end of file +} +#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member \ No newline at end of file diff --git a/AIDevGallery/Controls/Token.xaml.cs b/AIDevGallery/Controls/Token.xaml.cs index dec758a..714f57e 100644 --- a/AIDevGallery/Controls/Token.xaml.cs +++ b/AIDevGallery/Controls/Token.xaml.cs @@ -4,22 +4,21 @@ using Microsoft.UI.Xaml; using Microsoft.UI.Xaml.Controls; -namespace AIDevGallery.Controls +namespace AIDevGallery.Controls; + +internal sealed partial class Token : UserControl { - internal sealed partial class Token : UserControl + public string Text { - public string Text - { - get { return (string)GetValue(TextProperty); } - set { SetValue(TextProperty, value); } - } + get { return (string)GetValue(TextProperty); } + set { SetValue(TextProperty, value); } + } - public static readonly DependencyProperty TextProperty = - DependencyProperty.Register(nameof(Text), typeof(string), typeof(Token), new PropertyMetadata(null)); + public static readonly DependencyProperty TextProperty = + DependencyProperty.Register(nameof(Text), typeof(string), typeof(Token), new PropertyMetadata(null)); - public Token() - { - this.InitializeComponent(); - } + public Token() + { + this.InitializeComponent(); } } \ No newline at end of file diff --git a/AIDevGallery/Converters/RecentUsedItemTemplateSelector.cs b/AIDevGallery/Converters/RecentUsedItemTemplateSelector.cs index 2ac2195..978a032 100644 --- a/AIDevGallery/Converters/RecentUsedItemTemplateSelector.cs +++ b/AIDevGallery/Converters/RecentUsedItemTemplateSelector.cs @@ -5,24 +5,23 @@ using Microsoft.UI.Xaml; using Microsoft.UI.Xaml.Controls; -namespace AIDevGallery.Converters +namespace AIDevGallery.Converters; + +internal partial class RecentUsedItemTemplateSelector : DataTemplateSelector { - internal partial class RecentUsedItemTemplateSelector : DataTemplateSelector - { - public DataTemplate ScenarioTemplate { get; set; } = null!; + public DataTemplate ScenarioTemplate { get; set; } = null!; - public DataTemplate ModelTemplate { get; set; } = null!; + public DataTemplate ModelTemplate { get; set; } = null!; - protected override DataTemplate SelectTemplateCore(object item, DependencyObject container) + protected override DataTemplate SelectTemplateCore(object item, DependencyObject container) + { + if (item is MostRecentlyUsedItem selectedItem && selectedItem.Type == MostRecentlyUsedItemType.Scenario) + { + return ScenarioTemplate; + } + else { - if (item is MostRecentlyUsedItem selectedItem && selectedItem.Type == MostRecentlyUsedItemType.Scenario) - { - return ScenarioTemplate; - } - else - { - return ModelTemplate; - } + return ModelTemplate; } } } \ No newline at end of file diff --git a/AIDevGallery/Converters/SearchResultTemplateSelector.cs b/AIDevGallery/Converters/SearchResultTemplateSelector.cs index 4b53c21..b2b4390 100644 --- a/AIDevGallery/Converters/SearchResultTemplateSelector.cs +++ b/AIDevGallery/Converters/SearchResultTemplateSelector.cs @@ -5,24 +5,23 @@ using Microsoft.UI.Xaml; using Microsoft.UI.Xaml.Controls; -namespace AIDevGallery.Converters +namespace AIDevGallery.Converters; + +internal partial class SearchResultTemplateSelector : DataTemplateSelector { - internal partial class SearchResultTemplateSelector : DataTemplateSelector - { - public DataTemplate ScenarioTemplate { get; set; } = null!; + public DataTemplate ScenarioTemplate { get; set; } = null!; - public DataTemplate ModelTemplate { get; set; } = null!; + public DataTemplate ModelTemplate { get; set; } = null!; - protected override DataTemplate SelectTemplateCore(object item, DependencyObject container) + protected override DataTemplate SelectTemplateCore(object item, DependencyObject container) + { + if (item is SearchResult selectedItem && selectedItem.Tag.GetType() == typeof(Scenario)) + { + return ScenarioTemplate; + } + else { - if (item is SearchResult selectedItem && selectedItem.Tag.GetType() == typeof(Scenario)) - { - return ScenarioTemplate; - } - else - { - return ModelTemplate; - } + return ModelTemplate; } } } \ No newline at end of file diff --git a/AIDevGallery/Helpers/ActivationHelper.cs b/AIDevGallery/Helpers/ActivationHelper.cs index 4042b76..74915a3 100644 --- a/AIDevGallery/Helpers/ActivationHelper.cs +++ b/AIDevGallery/Helpers/ActivationHelper.cs @@ -8,59 +8,58 @@ using System.Linq; using Windows.ApplicationModel.Activation; -namespace AIDevGallery.Helpers +namespace AIDevGallery.Helpers; + +internal static class ActivationHelper { - internal static class ActivationHelper + public static object? GetActivationParam(AppActivationArguments appActivationArguments) { - public static object? GetActivationParam(AppActivationArguments appActivationArguments) + if (appActivationArguments.Kind == ExtendedActivationKind.Protocol && appActivationArguments.Data is ProtocolActivatedEventArgs protocolArgs) { - if (appActivationArguments.Kind == ExtendedActivationKind.Protocol && appActivationArguments.Data is ProtocolActivatedEventArgs protocolArgs) + var uriComponents = protocolArgs.Uri.LocalPath.Split('/', System.StringSplitOptions.RemoveEmptyEntries); + if (uriComponents?.Length > 0) { - var uriComponents = protocolArgs.Uri.LocalPath.Split('/', System.StringSplitOptions.RemoveEmptyEntries); - if (uriComponents?.Length > 0) - { - var itemId = uriComponents[0]; - string? subItemId = uriComponents.Length > 1 ? uriComponents[1] : null; + var itemId = uriComponents[0]; + string? subItemId = uriComponents.Length > 1 ? uriComponents[1] : null; - DeepLinkActivatedEvent.Log(protocolArgs.Uri.ToString()); + DeepLinkActivatedEvent.Log(protocolArgs.Uri.ToString()); - if (protocolArgs.Uri.Host == "models") - { - var sampleModelTypes = App.FindSampleItemById(itemId); + if (protocolArgs.Uri.Host == "models") + { + var sampleModelTypes = App.FindSampleItemById(itemId); - if (sampleModelTypes.Count > 0) - { - return sampleModelTypes; - } + if (sampleModelTypes.Count > 0) + { + return sampleModelTypes; } - else if (protocolArgs.Uri.Host == "scenarios") + } + else if (protocolArgs.Uri.Host == "scenarios") + { + Scenario? selectedScenario = App.FindScenarioById(itemId); + if (selectedScenario != null) { - Scenario? selectedScenario = App.FindScenarioById(itemId); - if (selectedScenario != null) - { - return selectedScenario; - } + return selectedScenario; } } } - else if (appActivationArguments.Kind == ExtendedActivationKind.ToastNotification && appActivationArguments.Data is ToastNotificationActivatedEventArgs toastArgs) + } + else if (appActivationArguments.Kind == ExtendedActivationKind.ToastNotification && appActivationArguments.Data is ToastNotificationActivatedEventArgs toastArgs) + { + var argsSplit = toastArgs.Argument.Split('='); + if (argsSplit.Length > 0 && argsSplit[1] != null) { - var argsSplit = toastArgs.Argument.Split('='); - if (argsSplit.Length > 0 && argsSplit[1] != null) + var modelType = App.FindSampleItemById(argsSplit[1]); + if (modelType.Count > 0) { - var modelType = App.FindSampleItemById(argsSplit[1]); - if (modelType.Count > 0) + var selectedSample = ModelTypeHelpers.ParentMapping.FirstOrDefault(kv => kv.Value.Contains(modelType[0])); + if (selectedSample.Value != null) { - var selectedSample = ModelTypeHelpers.ParentMapping.FirstOrDefault(kv => kv.Value.Contains(modelType[0])); - if (selectedSample.Value != null) - { - return selectedSample.Key; - } + return selectedSample.Key; } } } - - return null; } + + return null; } } \ No newline at end of file diff --git a/AIDevGallery/Helpers/ModelDetailsHelper.cs b/AIDevGallery/Helpers/ModelDetailsHelper.cs index b4c799c..1a8afee 100644 --- a/AIDevGallery/Helpers/ModelDetailsHelper.cs +++ b/AIDevGallery/Helpers/ModelDetailsHelper.cs @@ -6,127 +6,126 @@ using System.Collections.Generic; using System.Linq; -namespace AIDevGallery.Helpers +namespace AIDevGallery.Helpers; + +internal static class ModelDetailsHelper { - internal static class ModelDetailsHelper + public static ModelFamily? GetFamily(this ModelDetails modelDetails) { - public static ModelFamily? GetFamily(this ModelDetails modelDetails) + if (ModelTypeHelpers.ModelDetails.Any(md => md.Value.Url == modelDetails.Url)) { - if (ModelTypeHelpers.ModelDetails.Any(md => md.Value.Url == modelDetails.Url)) - { - var myKey = ModelTypeHelpers.ModelDetails.FirstOrDefault(md => md.Value.Url == modelDetails.Url).Key; + var myKey = ModelTypeHelpers.ModelDetails.FirstOrDefault(md => md.Value.Url == modelDetails.Url).Key; - if (ModelTypeHelpers.ParentMapping.Values.Any(parent => parent.Contains(myKey))) - { - var parentKey = ModelTypeHelpers.ParentMapping.FirstOrDefault(parent => parent.Value.Contains(myKey)).Key; - var parent = ModelTypeHelpers.ModelFamilyDetails[parentKey]; - return parent; - } + if (ModelTypeHelpers.ParentMapping.Values.Any(parent => parent.Contains(myKey))) + { + var parentKey = ModelTypeHelpers.ParentMapping.FirstOrDefault(parent => parent.Value.Contains(myKey)).Key; + var parent = ModelTypeHelpers.ModelFamilyDetails[parentKey]; + return parent; } - - return null; } - public static ModelDetails GetModelDetailsFromApiDefinition(ModelType modelType, ApiDefinition apiDefinition) + return null; + } + + public static ModelDetails GetModelDetailsFromApiDefinition(ModelType modelType, ApiDefinition apiDefinition) + { + return new ModelDetails { - return new ModelDetails - { - Id = apiDefinition.Id, - Icon = apiDefinition.Icon, - Name = apiDefinition.Name, - HardwareAccelerators = [HardwareAccelerator.DML], - IsUserAdded = false, - SupportedOnQualcomm = true, - ReadmeUrl = apiDefinition.ReadmeUrl, - Url = $"file://{modelType}", - License = apiDefinition.License - }; + Id = apiDefinition.Id, + Icon = apiDefinition.Icon, + Name = apiDefinition.Name, + HardwareAccelerators = [HardwareAccelerator.DML], + IsUserAdded = false, + SupportedOnQualcomm = true, + ReadmeUrl = apiDefinition.ReadmeUrl, + Url = $"file://{modelType}", + License = apiDefinition.License + }; + } + + public static List>> GetModelDetails(Sample sample) + { + Dictionary> model1Details = []; + foreach (ModelType modelType in sample.Model1Types) + { + model1Details[modelType] = GetSamplesForModelType(modelType); } - public static List>> GetModelDetails(Sample sample) + List>> listModelDetails = [model1Details]; + + if (sample.Model2Types != null) { - Dictionary> model1Details = []; - foreach (ModelType modelType in sample.Model1Types) + Dictionary> model2Details = []; + foreach (ModelType modelType in sample.Model2Types) { - model1Details[modelType] = GetSamplesForModelType(modelType); + model2Details[modelType] = GetSamplesForModelType(modelType); } - List>> listModelDetails = [model1Details]; - - if (sample.Model2Types != null) - { - Dictionary> model2Details = []; - foreach (ModelType modelType in sample.Model2Types) - { - model2Details[modelType] = GetSamplesForModelType(modelType); - } + listModelDetails.Add(model2Details); + } - listModelDetails.Add(model2Details); - } + return listModelDetails; - return listModelDetails; + static List GetSamplesForModelType(ModelType initialModelType) + { + Queue leafs = new(); + leafs.Enqueue(initialModelType); + bool added = true; - static List GetSamplesForModelType(ModelType initialModelType) + do { - Queue leafs = new(); - leafs.Enqueue(initialModelType); - bool added = true; + added = false; + int initialCount = leafs.Count; - do + for (int i = 0; i < initialCount; i++) { - added = false; - int initialCount = leafs.Count; - - for (int i = 0; i < initialCount; i++) + var leaf = leafs.Dequeue(); + if (ModelTypeHelpers.ParentMapping.TryGetValue(leaf, out List? values)) { - var leaf = leafs.Dequeue(); - if (ModelTypeHelpers.ParentMapping.TryGetValue(leaf, out List? values)) + if (values.Count > 0) { - if (values.Count > 0) - { - added = true; + added = true; - foreach (var value in values) - { - leafs.Enqueue(value); - } - } - else + foreach (var value in values) { - // Is API, just add back but don't mark as added - leafs.Enqueue(leaf); + leafs.Enqueue(value); } } else { - // Re-enqueue the leaf since it's actually a leaf node + // Is API, just add back but don't mark as added leafs.Enqueue(leaf); } } - } - while (leafs.Count > 0 && added); - - var allModelDetails = new List(); - foreach (var modelType in leafs.ToList()) - { - if (ModelTypeHelpers.ModelDetails.TryGetValue(modelType, out ModelDetails? modelDetails)) - { - allModelDetails.Add(modelDetails); - } - else if (ModelTypeHelpers.ApiDefinitionDetails.TryGetValue(modelType, out ApiDefinition? apiDefinition)) + else { - allModelDetails.Add(GetModelDetailsFromApiDefinition(modelType, apiDefinition)); + // Re-enqueue the leaf since it's actually a leaf node + leafs.Enqueue(leaf); } } + } + while (leafs.Count > 0 && added); - if (initialModelType == ModelType.LanguageModels && App.ModelCache != null) + var allModelDetails = new List(); + foreach (var modelType in leafs.ToList()) + { + if (ModelTypeHelpers.ModelDetails.TryGetValue(modelType, out ModelDetails? modelDetails)) { - var userAddedModels = App.ModelCache.Models.Where(m => m.Details.IsUserAdded).ToList(); - allModelDetails.AddRange(userAddedModels.Select(c => c.Details)); + allModelDetails.Add(modelDetails); } + else if (ModelTypeHelpers.ApiDefinitionDetails.TryGetValue(modelType, out ApiDefinition? apiDefinition)) + { + allModelDetails.Add(GetModelDetailsFromApiDefinition(modelType, apiDefinition)); + } + } - return allModelDetails; + if (initialModelType == ModelType.LanguageModels && App.ModelCache != null) + { + var userAddedModels = App.ModelCache.Models.Where(m => m.Details.IsUserAdded).ToList(); + allModelDetails.AddRange(userAddedModels.Select(c => c.Details)); } + + return allModelDetails; } } } \ No newline at end of file diff --git a/AIDevGallery/Helpers/NarratorHelper.cs b/AIDevGallery/Helpers/NarratorHelper.cs index ff377b1..3af64c0 100644 --- a/AIDevGallery/Helpers/NarratorHelper.cs +++ b/AIDevGallery/Helpers/NarratorHelper.cs @@ -4,14 +4,13 @@ using Microsoft.UI.Xaml; using Microsoft.UI.Xaml.Automation.Peers; -namespace AIDevGallery.Helpers +namespace AIDevGallery.Helpers; + +internal static class NarratorHelper { - internal static class NarratorHelper + public static void Announce(UIElement ue, string annoucement, string activityID) { - public static void Announce(UIElement ue, string annoucement, string activityID) - { - var peer = FrameworkElementAutomationPeer.FromElement(ue); - peer.RaiseNotificationEvent(AutomationNotificationKind.ActionCompleted, AutomationNotificationProcessing.ImportantMostRecent, annoucement, activityID); - } + var peer = FrameworkElementAutomationPeer.FromElement(ue); + peer.RaiseNotificationEvent(AutomationNotificationKind.ActionCompleted, AutomationNotificationProcessing.ImportantMostRecent, annoucement, activityID); } } \ No newline at end of file diff --git a/AIDevGallery/Helpers/NavItemIconHelper.cs b/AIDevGallery/Helpers/NavItemIconHelper.cs index 2dfb6dd..a04c16b 100644 --- a/AIDevGallery/Helpers/NavItemIconHelper.cs +++ b/AIDevGallery/Helpers/NavItemIconHelper.cs @@ -3,83 +3,82 @@ using Microsoft.UI.Xaml; -namespace AIDevGallery.Helpers +namespace AIDevGallery.Helpers; + +internal class NavItemIconHelper { - internal class NavItemIconHelper + public static object GetSelectedIcon(DependencyObject obj) { - public static object GetSelectedIcon(DependencyObject obj) - { - return obj.GetValue(SelectedIconProperty); - } - - public static void SetSelectedIcon(DependencyObject obj, object value) - { - obj.SetValue(SelectedIconProperty, value); - } + return obj.GetValue(SelectedIconProperty); + } - public static readonly DependencyProperty SelectedIconProperty = - DependencyProperty.RegisterAttached("SelectedIcon", typeof(object), typeof(NavItemIconHelper), new PropertyMetadata(null)); + public static void SetSelectedIcon(DependencyObject obj, object value) + { + obj.SetValue(SelectedIconProperty, value); + } - /// - /// Gets the value of for a - /// - /// Returns a boolean indicating whether the notification dot should be shown. - public static bool GetShowNotificationDot(DependencyObject obj) - { - return (bool)obj.GetValue(ShowNotificationDotProperty); - } + public static readonly DependencyProperty SelectedIconProperty = + DependencyProperty.RegisterAttached("SelectedIcon", typeof(object), typeof(NavItemIconHelper), new PropertyMetadata(null)); - /// - /// Sets on a - /// - public static void SetShowNotificationDot(DependencyObject obj, bool value) - { - obj.SetValue(ShowNotificationDotProperty, value); - } + /// + /// Gets the value of for a + /// + /// Returns a boolean indicating whether the notification dot should be shown. + public static bool GetShowNotificationDot(DependencyObject obj) + { + return (bool)obj.GetValue(ShowNotificationDotProperty); + } - /// - /// An attached property that sets whether or not a notification dot should be shown on an associated - /// - public static readonly DependencyProperty ShowNotificationDotProperty = - DependencyProperty.RegisterAttached("ShowNotificationDot", typeof(bool), typeof(NavItemIconHelper), new PropertyMetadata(false)); + /// + /// Sets on a + /// + public static void SetShowNotificationDot(DependencyObject obj, bool value) + { + obj.SetValue(ShowNotificationDotProperty, value); + } - /// - /// Gets the value of for a - /// - /// Returns the unselected icon as an object. - public static object GetUnselectedIcon(DependencyObject obj) - { - return (object)obj.GetValue(UnselectedIconProperty); - } + /// + /// An attached property that sets whether or not a notification dot should be shown on an associated + /// + public static readonly DependencyProperty ShowNotificationDotProperty = + DependencyProperty.RegisterAttached("ShowNotificationDot", typeof(bool), typeof(NavItemIconHelper), new PropertyMetadata(false)); - /// - /// Sets the value of for a - /// - public static void SetUnselectedIcon(DependencyObject obj, object value) - { - obj.SetValue(UnselectedIconProperty, value); - } + /// + /// Gets the value of for a + /// + /// Returns the unselected icon as an object. + public static object GetUnselectedIcon(DependencyObject obj) + { + return (object)obj.GetValue(UnselectedIconProperty); + } - /// - /// An attached property that sets the unselected icon on an associated - /// - public static readonly DependencyProperty UnselectedIconProperty = - DependencyProperty.RegisterAttached("UnselectedIcon", typeof(object), typeof(NavItemIconHelper), new PropertyMetadata(null)); + /// + /// Sets the value of for a + /// + public static void SetUnselectedIcon(DependencyObject obj, object value) + { + obj.SetValue(UnselectedIconProperty, value); + } - public static Visibility GetStaticIconVisibility(DependencyObject obj) - { - return (Visibility)obj.GetValue(StaticIconVisibilityProperty); - } + /// + /// An attached property that sets the unselected icon on an associated + /// + public static readonly DependencyProperty UnselectedIconProperty = + DependencyProperty.RegisterAttached("UnselectedIcon", typeof(object), typeof(NavItemIconHelper), new PropertyMetadata(null)); - public static void SetStaticIconVisibility(DependencyObject obj, Visibility value) - { - obj.SetValue(StaticIconVisibilityProperty, value); - } + public static Visibility GetStaticIconVisibility(DependencyObject obj) + { + return (Visibility)obj.GetValue(StaticIconVisibilityProperty); + } - /// - /// An attached property that sets the visibility of the static icon in the associated . - /// - public static readonly DependencyProperty StaticIconVisibilityProperty = - DependencyProperty.RegisterAttached("StaticIconVisibility", typeof(Visibility), typeof(NavItemIconHelper), new PropertyMetadata(Visibility.Collapsed)); + public static void SetStaticIconVisibility(DependencyObject obj, Visibility value) + { + obj.SetValue(StaticIconVisibilityProperty, value); } + + /// + /// An attached property that sets the visibility of the static icon in the associated . + /// + public static readonly DependencyProperty StaticIconVisibilityProperty = + DependencyProperty.RegisterAttached("StaticIconVisibility", typeof(Visibility), typeof(NavItemIconHelper), new PropertyMetadata(Visibility.Collapsed)); } \ No newline at end of file diff --git a/AIDevGallery/MainWindow.xaml.cs b/AIDevGallery/MainWindow.xaml.cs index e702a2f..4fdd021 100644 --- a/AIDevGallery/MainWindow.xaml.cs +++ b/AIDevGallery/MainWindow.xaml.cs @@ -14,198 +14,197 @@ using Windows.System; using WinUIEx; -namespace AIDevGallery +namespace AIDevGallery; + +internal sealed partial class MainWindow : WindowEx { - internal sealed partial class MainWindow : WindowEx + public MainWindow(object? obj = null) { - public MainWindow(object? obj = null) + this.InitializeComponent(); + SetTitleBar(); + App.ModelCache.DownloadQueue.ModelsChanged += DownloadQueue_ModelsChanged; + + this.NavView.Loaded += (sender, args) => { - this.InitializeComponent(); - SetTitleBar(); - App.ModelCache.DownloadQueue.ModelsChanged += DownloadQueue_ModelsChanged; + NavigateToPage(obj); + }; - this.NavView.Loaded += (sender, args) => + Closed += async (sender, args) => + { + if (SampleContainer.AnySamplesLoading()) { - NavigateToPage(obj); - }; + this.Hide(); + args.Handled = true; + await SampleContainer.WaitUnloadAllAsync(); + Close(); + } + }; + } - Closed += async (sender, args) => - { - if (SampleContainer.AnySamplesLoading()) - { - this.Hide(); - args.Handled = true; - await SampleContainer.WaitUnloadAllAsync(); - Close(); - } - }; + public void NavigateToPage(object? obj) + { + if (obj is Scenario) + { + Navigate("Samples", obj); } - - public void NavigateToPage(object? obj) + else if (obj is ModelType or List) { - if (obj is Scenario) - { - Navigate("Samples", obj); - } - else if (obj is ModelType or List) - { - Navigate("Models", obj); - } - else - { - Navigate("Home"); - } + Navigate("Models", obj); } + else + { + Navigate("Home"); + } + } - private void NavView_ItemInvoked(NavigationView sender, NavigationViewItemInvokedEventArgs args) + private void NavView_ItemInvoked(NavigationView sender, NavigationViewItemInvokedEventArgs args) + { + Navigate(args.InvokedItem.ToString()!); + } + + public void Navigate(string Tag, object? obj = null) + { + Tag = Tag.ToLower(CultureInfo.CurrentCulture); + + switch (Tag) { - Navigate(args.InvokedItem.ToString()!); + case "home": + Navigate(typeof(HomePage)); + break; + case "samples": + Navigate(typeof(ScenarioSelectionPage), obj); + break; + case "models": + Navigate(typeof(ModelSelectionPage), obj); + break; + case "guides": + Navigate(typeof(GuidesPage)); + break; + case "contribute": + _ = Launcher.LaunchUriAsync(new Uri("https://aka.ms/ai-dev-gallery")); + break; + case "settings": + Navigate(typeof(SettingsPage), obj); + break; } + } - public void Navigate(string Tag, object? obj = null) + private void Navigate(Type page, object? param = null) + { + DispatcherQueue.TryEnqueue(() => { - Tag = Tag.ToLower(CultureInfo.CurrentCulture); + NavFrame.Navigate(page, param); + }); + } - switch (Tag) - { - case "home": - Navigate(typeof(HomePage)); - break; - case "samples": - Navigate(typeof(ScenarioSelectionPage), obj); - break; - case "models": - Navigate(typeof(ModelSelectionPage), obj); - break; - case "guides": - Navigate(typeof(GuidesPage)); - break; - case "contribute": - _ = Launcher.LaunchUriAsync(new Uri("https://aka.ms/ai-dev-gallery")); - break; - case "settings": - Navigate(typeof(SettingsPage), obj); - break; - } + public void Navigate(MostRecentlyUsedItem mru) + { + if (mru.Type == MostRecentlyUsedItemType.Model) + { + Navigate("models", mru); } - - private void Navigate(Type page, object? param = null) + else { - DispatcherQueue.TryEnqueue(() => - { - NavFrame.Navigate(page, param); - }); + Navigate("samples", mru); } + } + + public void Navigate(Sample sample) + { + Navigate("samples", sample); + } - public void Navigate(MostRecentlyUsedItem mru) + public void Navigate(SearchResult result) + { + if (result.Tag is Scenario scenario) { - if (mru.Type == MostRecentlyUsedItemType.Model) - { - Navigate("models", mru); - } - else - { - Navigate("samples", mru); - } + Navigate("samples", scenario); } - - public void Navigate(Sample sample) + else if (result.Tag is ModelType modelType) { - Navigate("samples", sample); + Navigate("models", modelType); } + } + + private void SetTitleBar() + { + this.ExtendsContentIntoTitleBar = true; + this.SetTitleBar(titleBar); + titleBar.Window = this; + this.AppWindow.SetIcon("Assets/AppIcon/Icon.ico"); - public void Navigate(SearchResult result) + this.Title = Windows.ApplicationModel.Package.Current.DisplayName; + + if (this.Title.EndsWith("Dev", StringComparison.InvariantCulture)) { - if (result.Tag is Scenario scenario) - { - Navigate("samples", scenario); - } - else if (result.Tag is ModelType modelType) - { - Navigate("models", modelType); - } + titleBar.Subtitle = "Dev"; } - - private void SetTitleBar() + else if (this.Title.EndsWith("Preview", StringComparison.InvariantCulture)) { - this.ExtendsContentIntoTitleBar = true; - this.SetTitleBar(titleBar); - titleBar.Window = this; - this.AppWindow.SetIcon("Assets/AppIcon/Icon.ico"); + titleBar.Subtitle = "Preview"; + } + } - this.Title = Windows.ApplicationModel.Package.Current.DisplayName; + private void DownloadQueue_ModelsChanged(ModelDownloadQueue sender) + { + DownloadProgressPanel.Visibility = Visibility.Visible; + DownloadProgressRing.IsActive = sender.GetDownloads().Count > 0; + DownloadFlyout.ShowAt(DownloadBtn); + } - if (this.Title.EndsWith("Dev", StringComparison.InvariantCulture)) - { - titleBar.Subtitle = "Dev"; - } - else if (this.Title.EndsWith("Preview", StringComparison.InvariantCulture)) - { - titleBar.Subtitle = "Preview"; - } - } + private void ManageModelsClicked(object sender, RoutedEventArgs e) + { + NavFrame.Navigate(typeof(SettingsPage), "ModelManagement"); + } - private void DownloadQueue_ModelsChanged(ModelDownloadQueue sender) + private void SearchBox_TextChanged(AutoSuggestBox sender, AutoSuggestBoxTextChangedEventArgs args) + { + if (args.Reason == AutoSuggestionBoxTextChangeReason.UserInput && !string.IsNullOrWhiteSpace(SearchBox.Text)) { - DownloadProgressPanel.Visibility = Visibility.Visible; - DownloadProgressRing.IsActive = sender.GetDownloads().Count > 0; - DownloadFlyout.ShowAt(DownloadBtn); + var filteredSearchResults = App.SearchIndex.Where(sr => sr.Label.Contains(sender.Text, StringComparison.OrdinalIgnoreCase)).ToList(); + SearchBox.ItemsSource = filteredSearchResults.OrderByDescending(i => i.Label.StartsWith(sender.Text, StringComparison.CurrentCultureIgnoreCase)).ThenBy(i => i.Label); } + } - private void ManageModelsClicked(object sender, RoutedEventArgs e) + private void SearchBox_QuerySubmitted(AutoSuggestBox sender, AutoSuggestBoxQuerySubmittedEventArgs args) + { + if (args.ChosenSuggestion is SearchResult result) { - NavFrame.Navigate(typeof(SettingsPage), "ModelManagement"); + Navigate(result); } - private void SearchBox_TextChanged(AutoSuggestBox sender, AutoSuggestBoxTextChangedEventArgs args) + SearchBox.Text = string.Empty; + } + + private void TitleBar_BackButtonClick(object sender, RoutedEventArgs e) + { + if (NavFrame.CanGoBack) { - if (args.Reason == AutoSuggestionBoxTextChangeReason.UserInput && !string.IsNullOrWhiteSpace(SearchBox.Text)) - { - var filteredSearchResults = App.SearchIndex.Where(sr => sr.Label.Contains(sender.Text, StringComparison.OrdinalIgnoreCase)).ToList(); - SearchBox.ItemsSource = filteredSearchResults.OrderByDescending(i => i.Label.StartsWith(sender.Text, StringComparison.CurrentCultureIgnoreCase)).ThenBy(i => i.Label); - } + NavFrame.GoBack(); } + } - private void SearchBox_QuerySubmitted(AutoSuggestBox sender, AutoSuggestBoxQuerySubmittedEventArgs args) + private void NavFrame_Navigating(object sender, Microsoft.UI.Xaml.Navigation.NavigatingCancelEventArgs e) + { + if (e.SourcePageType == typeof(ScenarioSelectionPage)) { - if (args.ChosenSuggestion is SearchResult result) - { - Navigate(result); - } - - SearchBox.Text = string.Empty; + NavView.SelectedItem = NavView.MenuItems[1]; } - - private void TitleBar_BackButtonClick(object sender, RoutedEventArgs e) + else if (e.SourcePageType == typeof(ModelSelectionPage)) { - if (NavFrame.CanGoBack) - { - NavFrame.GoBack(); - } + NavView.SelectedItem = NavView.MenuItems[2]; } - - private void NavFrame_Navigating(object sender, Microsoft.UI.Xaml.Navigation.NavigatingCancelEventArgs e) + else if (e.SourcePageType == typeof(GuidesPage)) { - if (e.SourcePageType == typeof(ScenarioSelectionPage)) - { - NavView.SelectedItem = NavView.MenuItems[1]; - } - else if (e.SourcePageType == typeof(ModelSelectionPage)) - { - NavView.SelectedItem = NavView.MenuItems[2]; - } - else if (e.SourcePageType == typeof(GuidesPage)) - { - NavView.SelectedItem = NavView.MenuItems[3]; - } - else if (e.SourcePageType == typeof(SettingsPage)) - { - NavView.SelectedItem = NavView.FooterMenuItems[1]; - } - else - { - NavView.SelectedItem = NavView.MenuItems[0]; - } + NavView.SelectedItem = NavView.MenuItems[3]; + } + else if (e.SourcePageType == typeof(SettingsPage)) + { + NavView.SelectedItem = NavView.FooterMenuItems[1]; + } + else + { + NavView.SelectedItem = NavView.MenuItems[0]; } } } \ No newline at end of file diff --git a/AIDevGallery/Models/BaseSampleNavigationParameters.cs b/AIDevGallery/Models/BaseSampleNavigationParameters.cs index a190d18..607ddb9 100644 --- a/AIDevGallery/Models/BaseSampleNavigationParameters.cs +++ b/AIDevGallery/Models/BaseSampleNavigationParameters.cs @@ -6,23 +6,22 @@ using System.Threading; using System.Threading.Tasks; -namespace AIDevGallery.Models +namespace AIDevGallery.Models; + +internal abstract class BaseSampleNavigationParameters(TaskCompletionSource sampleLoadedCompletionSource, CancellationToken loadingCanceledToken) { - internal abstract class BaseSampleNavigationParameters(TaskCompletionSource sampleLoadedCompletionSource, CancellationToken loadingCanceledToken) - { - public CancellationToken CancellationToken { get; private set; } = loadingCanceledToken; + public CancellationToken CancellationToken { get; private set; } = loadingCanceledToken; - protected abstract string ChatClientModelPath { get; } - protected abstract LlmPromptTemplate? ChatClientPromptTemplate { get; } + protected abstract string ChatClientModelPath { get; } + protected abstract LlmPromptTemplate? ChatClientPromptTemplate { get; } - public void NotifyCompletion() - { - sampleLoadedCompletionSource.SetResult(); - } + public void NotifyCompletion() + { + sampleLoadedCompletionSource.SetResult(); + } - public async Task GetIChatClientAsync() - { - return await GenAIModel.CreateAsync(ChatClientModelPath, ChatClientPromptTemplate, CancellationToken).ConfigureAwait(false); - } + public async Task GetIChatClientAsync() + { + return await GenAIModel.CreateAsync(ChatClientModelPath, ChatClientPromptTemplate, CancellationToken).ConfigureAwait(false); } } \ No newline at end of file diff --git a/AIDevGallery/Models/MultiModelSampleNavigationParameters.cs b/AIDevGallery/Models/MultiModelSampleNavigationParameters.cs index 7c8177e..231c707 100644 --- a/AIDevGallery/Models/MultiModelSampleNavigationParameters.cs +++ b/AIDevGallery/Models/MultiModelSampleNavigationParameters.cs @@ -5,20 +5,19 @@ using System.Threading; using System.Threading.Tasks; -namespace AIDevGallery.Models +namespace AIDevGallery.Models; + +internal class MultiModelSampleNavigationParameters( + string[] modelPaths, + HardwareAccelerator[] hardwareAccelerators, + LlmPromptTemplate?[] promptTemplates, + TaskCompletionSource sampleLoadedCompletionSource, + CancellationToken loadingCanceledToken) + : BaseSampleNavigationParameters(sampleLoadedCompletionSource, loadingCanceledToken) { - internal class MultiModelSampleNavigationParameters( - string[] modelPaths, - HardwareAccelerator[] hardwareAccelerators, - LlmPromptTemplate?[] promptTemplates, - TaskCompletionSource sampleLoadedCompletionSource, - CancellationToken loadingCanceledToken) - : BaseSampleNavigationParameters(sampleLoadedCompletionSource, loadingCanceledToken) - { - public string[] ModelPaths { get; } = modelPaths; - public HardwareAccelerator[] HardwareAccelerators { get; } = hardwareAccelerators; + public string[] ModelPaths { get; } = modelPaths; + public HardwareAccelerator[] HardwareAccelerators { get; } = hardwareAccelerators; - protected override string ChatClientModelPath => ModelPaths[0]; - protected override LlmPromptTemplate? ChatClientPromptTemplate => promptTemplates[0]; - } + protected override string ChatClientModelPath => ModelPaths[0]; + protected override LlmPromptTemplate? ChatClientPromptTemplate => promptTemplates[0]; } \ No newline at end of file diff --git a/AIDevGallery/Models/SampleNavigationArgs.cs b/AIDevGallery/Models/SampleNavigationArgs.cs index 37c00a5..20e0227 100644 --- a/AIDevGallery/Models/SampleNavigationArgs.cs +++ b/AIDevGallery/Models/SampleNavigationArgs.cs @@ -1,22 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -namespace AIDevGallery.Models +namespace AIDevGallery.Models; + +internal record SampleNavigationArgs { - internal record SampleNavigationArgs - { - public Sample Sample { get; private set; } - public ModelDetails? ModelDetails { get; private set; } + public Sample Sample { get; private set; } + public ModelDetails? ModelDetails { get; private set; } - public SampleNavigationArgs(Sample sample) - { - Sample = sample; - } + public SampleNavigationArgs(Sample sample) + { + Sample = sample; + } - public SampleNavigationArgs(Sample sample, ModelDetails? modelDetails) - { - Sample = sample; - ModelDetails = modelDetails; - } + public SampleNavigationArgs(Sample sample, ModelDetails? modelDetails) + { + Sample = sample; + ModelDetails = modelDetails; } } \ No newline at end of file diff --git a/AIDevGallery/Models/SampleNavigationParameters.cs b/AIDevGallery/Models/SampleNavigationParameters.cs index e18c10c..13f4473 100644 --- a/AIDevGallery/Models/SampleNavigationParameters.cs +++ b/AIDevGallery/Models/SampleNavigationParameters.cs @@ -5,20 +5,19 @@ using System.Threading; using System.Threading.Tasks; -namespace AIDevGallery.Models +namespace AIDevGallery.Models; + +internal class SampleNavigationParameters( + string modelPath, + HardwareAccelerator hardwareAccelerator, + LlmPromptTemplate? promptTemplate, + TaskCompletionSource sampleLoadedCompletionSource, + CancellationToken loadingCanceledToken) + : BaseSampleNavigationParameters(sampleLoadedCompletionSource, loadingCanceledToken) { - internal class SampleNavigationParameters( - string modelPath, - HardwareAccelerator hardwareAccelerator, - LlmPromptTemplate? promptTemplate, - TaskCompletionSource sampleLoadedCompletionSource, - CancellationToken loadingCanceledToken) - : BaseSampleNavigationParameters(sampleLoadedCompletionSource, loadingCanceledToken) - { - public string ModelPath { get; } = modelPath; - public HardwareAccelerator HardwareAccelerator { get; } = hardwareAccelerator; + public string ModelPath { get; } = modelPath; + public HardwareAccelerator HardwareAccelerator { get; } = hardwareAccelerator; - protected override string ChatClientModelPath => ModelPath; - protected override LlmPromptTemplate? ChatClientPromptTemplate => promptTemplate; - } + protected override string ChatClientModelPath => ModelPath; + protected override LlmPromptTemplate? ChatClientPromptTemplate => promptTemplate; } \ No newline at end of file diff --git a/AIDevGallery/Models/Samples.cs b/AIDevGallery/Models/Samples.cs index 4cfe392..ffb7d2b 100644 --- a/AIDevGallery/Models/Samples.cs +++ b/AIDevGallery/Models/Samples.cs @@ -7,154 +7,153 @@ #pragma warning disable SA1649 // File name should match first type name #pragma warning disable SA1402 // File may only contain a single type -namespace AIDevGallery.Models +namespace AIDevGallery.Models; + +internal class ModelGroup { - internal class ModelGroup - { - public required string Id { get; init; } + public required string Id { get; init; } - public required string Name { get; init; } + public required string Name { get; init; } - public required string Icon { get; init; } - public required bool IsApi { get; init; } - } + public required string Icon { get; init; } + public required bool IsApi { get; init; } +} - internal class Sample - { - public string Id { get; init; } = null!; - public string Name { get; init; } = null!; - public string Icon { get; init; } = null!; - public ScenarioType Scenario { get; init; } - public List Model1Types { get; init; } = null!; - public List? Model2Types { get; init; } - public Type PageType { get; init; } = null!; - public string CSCode { get; init; } = null!; - public string XAMLCode { get; init; } = null!; - public List SharedCode { get; init; } = null!; - public List NugetPackageReferences { get; init; } = null!; - } +internal class Sample +{ + public string Id { get; init; } = null!; + public string Name { get; init; } = null!; + public string Icon { get; init; } = null!; + public ScenarioType Scenario { get; init; } + public List Model1Types { get; init; } = null!; + public List? Model2Types { get; init; } + public Type PageType { get; init; } = null!; + public string CSCode { get; init; } = null!; + public string XAMLCode { get; init; } = null!; + public List SharedCode { get; init; } = null!; + public List NugetPackageReferences { get; init; } = null!; +} - internal class ModelFamily - { - public string Id { get; init; } = null!; - public string Name { get; init; } = null!; - public string Description { get; init; } = null!; - public string? DocsUrl { get; set; } - public string ReadmeUrl { get; init; } = null!; - } +internal class ModelFamily +{ + public string Id { get; init; } = null!; + public string Name { get; init; } = null!; + public string Description { get; init; } = null!; + public string? DocsUrl { get; set; } + public string ReadmeUrl { get; init; } = null!; +} - internal class ApiDefinition - { - public string Id { get; init; } = null!; - public string Name { get; init; } = null!; - public string Icon { get; init; } = null!; - public string ReadmeUrl { get; init; } = null!; - public string License { get; init; } = null!; - } +internal class ApiDefinition +{ + public string Id { get; init; } = null!; + public string Name { get; init; } = null!; + public string Icon { get; init; } = null!; + public string ReadmeUrl { get; init; } = null!; + public string License { get; init; } = null!; +} - internal class ModelDetails +internal class ModelDetails +{ + public string Id { get; set; } = null!; + public string Name { get; set; } = null!; + public string Url { get; set; } = null!; + + public string Description { get; set; } = null!; + [JsonConverter(typeof(SingleOrListOfHardwareAcceleratorConverter))] + [JsonPropertyName("HardwareAccelerator")] + public List HardwareAccelerators { get; set; } = null!; + public long Size { get; set; } + public bool? SupportedOnQualcomm { get; set; } + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public string? ParameterSize { get; set; } + public bool IsUserAdded { get; set; } + public PromptTemplate? PromptTemplate { get; set; } + public string? ReadmeUrl { get; set; } + public string? License { get; set; } + public List? FileFilters { get; set; } + + private ModelCompatibility? compatibility; + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public ModelCompatibility Compatibility { - public string Id { get; set; } = null!; - public string Name { get; set; } = null!; - public string Url { get; set; } = null!; - - public string Description { get; set; } = null!; - [JsonConverter(typeof(SingleOrListOfHardwareAcceleratorConverter))] - [JsonPropertyName("HardwareAccelerator")] - public List HardwareAccelerators { get; set; } = null!; - public long Size { get; set; } - public bool? SupportedOnQualcomm { get; set; } - [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] - public string? ParameterSize { get; set; } - public bool IsUserAdded { get; set; } - public PromptTemplate? PromptTemplate { get; set; } - public string? ReadmeUrl { get; set; } - public string? License { get; set; } - public List? FileFilters { get; set; } - - private ModelCompatibility? compatibility; - [JsonIgnore(Condition = JsonIgnoreCondition.Always)] - public ModelCompatibility Compatibility + get { - get - { - compatibility ??= ModelCompatibility.GetModelCompatibility(this); + compatibility ??= ModelCompatibility.GetModelCompatibility(this); - return compatibility; - } + return compatibility; } + } - private string? icon; - [JsonIgnore(Condition = JsonIgnoreCondition.Always)] - public string Icon + private string? icon; + [JsonIgnore(Condition = JsonIgnoreCondition.Always)] + public string Icon + { + get { - get + // Full path is already set + if (string.IsNullOrEmpty(icon)) { - // Full path is already set - if (string.IsNullOrEmpty(icon)) + if (Url.StartsWith("https://github", StringComparison.InvariantCultureIgnoreCase)) { - if (Url.StartsWith("https://github", StringComparison.InvariantCultureIgnoreCase)) + if (App.Current.RequestedTheme == Microsoft.UI.Xaml.ApplicationTheme.Light) { - if (App.Current.RequestedTheme == Microsoft.UI.Xaml.ApplicationTheme.Light) - { - icon = "GitHub.light.svg"; - } - else - { - icon = "GitHub.dark.svg"; - } + icon = "GitHub.light.svg"; } else { - icon = "HuggingFace.svg"; + icon = "GitHub.dark.svg"; } } - - // In some cases the full path is already set - if (!icon.StartsWith("ms-appx", StringComparison.InvariantCultureIgnoreCase)) + else { - icon = "ms-appx:///Assets/ModelIcons/" + icon; + icon = "HuggingFace.svg"; } + } - return icon; + // In some cases the full path is already set + if (!icon.StartsWith("ms-appx", StringComparison.InvariantCultureIgnoreCase)) + { + icon = "ms-appx:///Assets/ModelIcons/" + icon; } - set => icon = value; + return icon; } - } - internal class PromptTemplate - { - public string? System { get; set; } - public string? User { get; init; } - public string? Assistant { get; set; } - public string[]? Stop { get; init; } + set => icon = value; } +} - internal class ScenarioCategory - { - public required string Name { get; init; } - public required string Icon { get; init; } - public required List Scenarios { get; init; } - } +internal class PromptTemplate +{ + public string? System { get; set; } + public string? User { get; init; } + public string? Assistant { get; set; } + public string[]? Stop { get; init; } +} - internal class Scenario - { - public string Name { get; init; } = null!; - public string Description { get; init; } = null!; - public string Id { get; init; } = null!; +internal class ScenarioCategory +{ + public required string Name { get; init; } + public required string Icon { get; init; } + public required List Scenarios { get; init; } +} - public string? Icon { get; init; } - public ScenarioType ScenarioType { get; set; } - } +internal class Scenario +{ + public string Name { get; init; } = null!; + public string Description { get; init; } = null!; + public string Id { get; init; } = null!; - [JsonConverter(typeof(JsonStringEnumConverter))] - internal enum HardwareAccelerator - { - CPU, - DML, - QNN - } + public string? Icon { get; init; } + public ScenarioType ScenarioType { get; set; } +} + +[JsonConverter(typeof(JsonStringEnumConverter))] +internal enum HardwareAccelerator +{ + CPU, + DML, + QNN } #pragma warning restore SA1402 // File may only contain a single type diff --git a/AIDevGallery/Models/SearchResult.cs b/AIDevGallery/Models/SearchResult.cs index d8a7483..dd7558e 100644 --- a/AIDevGallery/Models/SearchResult.cs +++ b/AIDevGallery/Models/SearchResult.cs @@ -1,13 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -namespace AIDevGallery.Models +namespace AIDevGallery.Models; + +internal class SearchResult { - internal class SearchResult - { - public string Icon { get; set; } = null!; - public string Label { get; set; } = null!; - public string Description { get; set; } = null!; - public object Tag { get; set; } = null!; - } + public string Icon { get; set; } = null!; + public string Label { get; set; } = null!; + public string Description { get; set; } = null!; + public object Tag { get; set; } = null!; } \ No newline at end of file diff --git a/AIDevGallery/Models/SingleOrListOfHardwareAcceleratorConverter.cs b/AIDevGallery/Models/SingleOrListOfHardwareAcceleratorConverter.cs index 85a3ef7..8b2b4ae 100644 --- a/AIDevGallery/Models/SingleOrListOfHardwareAcceleratorConverter.cs +++ b/AIDevGallery/Models/SingleOrListOfHardwareAcceleratorConverter.cs @@ -8,33 +8,22 @@ using System.Text.Json; using System.Text.Json.Serialization; -namespace AIDevGallery.Models +namespace AIDevGallery.Models; + +internal class SingleOrListOfHardwareAcceleratorConverter : JsonConverter> { - internal class SingleOrListOfHardwareAcceleratorConverter : JsonConverter> + public override List Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { - public override List Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + var list = new List(); + if (reader.TokenType == JsonTokenType.StartArray) { - var list = new List(); - if (reader.TokenType == JsonTokenType.StartArray) + while (reader.Read()) { - while (reader.Read()) + if (reader.TokenType == JsonTokenType.EndArray) { - if (reader.TokenType == JsonTokenType.EndArray) - { - break; - } - - try - { - list.Add(JsonSerializer.Deserialize(ref reader, AppDataSourceGenerationContext.Default.HardwareAccelerator)); - } - catch (Exception) - { - } + break; } - } - else if (reader.TokenType != JsonTokenType.Null) - { + try { list.Add(JsonSerializer.Deserialize(ref reader, AppDataSourceGenerationContext.Default.HardwareAccelerator)); @@ -43,31 +32,41 @@ public override List Read(ref Utf8JsonReader reader, Type t { } } - - if (list.Count == 0) + } + else if (reader.TokenType != JsonTokenType.Null) + { + try { - list.Add(HardwareAccelerator.CPU); + list.Add(JsonSerializer.Deserialize(ref reader, AppDataSourceGenerationContext.Default.HardwareAccelerator)); } + catch (Exception) + { + } + } - return list; + if (list.Count == 0) + { + list.Add(HardwareAccelerator.CPU); } - public override void Write(Utf8JsonWriter writer, List value, JsonSerializerOptions options) + return list; + } + + public override void Write(Utf8JsonWriter writer, List value, JsonSerializerOptions options) + { + if (value.Count == 1) + { + JsonSerializer.Serialize(writer, value.First(), AppDataSourceGenerationContext.Default.HardwareAccelerator); + } + else { - if (value.Count == 1) + writer.WriteStartArray(); + foreach (var item in value) { - JsonSerializer.Serialize(writer, value.First(), AppDataSourceGenerationContext.Default.HardwareAccelerator); + JsonSerializer.Serialize(writer, item, AppDataSourceGenerationContext.Default.HardwareAccelerator); } - else - { - writer.WriteStartArray(); - foreach (var item in value) - { - JsonSerializer.Serialize(writer, item, AppDataSourceGenerationContext.Default.HardwareAccelerator); - } - writer.WriteEndArray(); - } + writer.WriteEndArray(); } } } \ No newline at end of file diff --git a/AIDevGallery/Pages/GuidesPage.xaml.cs b/AIDevGallery/Pages/GuidesPage.xaml.cs index 91584e3..14cec8f 100644 --- a/AIDevGallery/Pages/GuidesPage.xaml.cs +++ b/AIDevGallery/Pages/GuidesPage.xaml.cs @@ -5,23 +5,22 @@ using Microsoft.UI.Xaml.Controls; using Microsoft.UI.Xaml.Navigation; -namespace AIDevGallery.Pages +namespace AIDevGallery.Pages; + +/// +/// An empty page that can be used on its own or navigated to within a Frame. +/// +internal sealed partial class GuidesPage : Page { - /// - /// An empty page that can be used on its own or navigated to within a Frame. - /// - internal sealed partial class GuidesPage : Page + public GuidesPage() { - public GuidesPage() - { - this.InitializeComponent(); - } + this.InitializeComponent(); + } - protected override void OnNavigatedTo(NavigationEventArgs e) - { - base.OnNavigatedTo(e); + protected override void OnNavigatedTo(NavigationEventArgs e) + { + base.OnNavigatedTo(e); - NavigatedToPageEvent.Log(nameof(GuidesPage)); - } + NavigatedToPageEvent.Log(nameof(GuidesPage)); } } \ No newline at end of file diff --git a/AIDevGallery/Pages/HomePage.xaml.cs b/AIDevGallery/Pages/HomePage.xaml.cs index a2bdb25..017f1b0 100644 --- a/AIDevGallery/Pages/HomePage.xaml.cs +++ b/AIDevGallery/Pages/HomePage.xaml.cs @@ -5,19 +5,18 @@ using Microsoft.UI.Xaml.Controls; using Microsoft.UI.Xaml.Navigation; -namespace AIDevGallery.Pages +namespace AIDevGallery.Pages; + +internal sealed partial class HomePage : Page { - internal sealed partial class HomePage : Page + public HomePage() { - public HomePage() - { - this.InitializeComponent(); - } + this.InitializeComponent(); + } - protected override void OnNavigatedTo(NavigationEventArgs e) - { - base.OnNavigatedTo(e); - NavigatedToPageEvent.Log(nameof(HomePage)); - } + protected override void OnNavigatedTo(NavigationEventArgs e) + { + base.OnNavigatedTo(e); + NavigatedToPageEvent.Log(nameof(HomePage)); } } \ No newline at end of file diff --git a/AIDevGallery/Pages/ModelPage.xaml.cs b/AIDevGallery/Pages/ModelPage.xaml.cs index f67a78d..3c8ed2a 100644 --- a/AIDevGallery/Pages/ModelPage.xaml.cs +++ b/AIDevGallery/Pages/ModelPage.xaml.cs @@ -17,194 +17,193 @@ using System.Threading.Tasks; using Windows.ApplicationModel.DataTransfer; -namespace AIDevGallery.Pages +namespace AIDevGallery.Pages; + +internal sealed partial class ModelPage : Page { - internal sealed partial class ModelPage : Page + public ModelFamily? ModelFamily { get; set; } + private ModelType? modelFamilyType; + public bool IsNotApi => !modelFamilyType.HasValue || !ModelTypeHelpers.ApiDefinitionDetails.ContainsKey(modelFamilyType.Value); + + public ModelPage() { - public ModelFamily? ModelFamily { get; set; } - private ModelType? modelFamilyType; - public bool IsNotApi => !modelFamilyType.HasValue || !ModelTypeHelpers.ApiDefinitionDetails.ContainsKey(modelFamilyType.Value); + this.InitializeComponent(); + this.Unloaded += ModelPage_Unloaded; + } - public ModelPage() + protected override void OnNavigatedTo(NavigationEventArgs e) + { + base.OnNavigatedTo(e); + if (e.Parameter is MostRecentlyUsedItem mru) { - this.InitializeComponent(); - this.Unloaded += ModelPage_Unloaded; + var modelFamilyId = mru.ItemId; } - - protected override void OnNavigatedTo(NavigationEventArgs e) + else if (e.Parameter is ModelType modelType && ModelTypeHelpers.ModelFamilyDetails.TryGetValue(modelType, out var modelFamilyDetails)) { - base.OnNavigatedTo(e); - if (e.Parameter is MostRecentlyUsedItem mru) - { - var modelFamilyId = mru.ItemId; - } - else if (e.Parameter is ModelType modelType && ModelTypeHelpers.ModelFamilyDetails.TryGetValue(modelType, out var modelFamilyDetails)) - { - modelFamilyType = modelType; - ModelFamily = modelFamilyDetails; + modelFamilyType = modelType; + ModelFamily = modelFamilyDetails; - modelSelectionControl.SetModels(GetAllSampleDetails().ToList()); - } - else if (e.Parameter is ModelDetails details) - { - // this is likely user added model - modelSelectionControl.SetModels([details]); + modelSelectionControl.SetModels(GetAllSampleDetails().ToList()); + } + else if (e.Parameter is ModelDetails details) + { + // this is likely user added model + modelSelectionControl.SetModels([details]); - ModelFamily = new ModelFamily - { - ReadmeUrl = details.ReadmeUrl ?? string.Empty, - Name = details.Name - }; - } - else if (e.Parameter is ModelType apiType && ModelTypeHelpers.ApiDefinitionDetails.TryGetValue(apiType, out var apiDefinition)) + ModelFamily = new ModelFamily { - // API - modelFamilyType = apiType; - - ModelFamily = new ModelFamily - { - Id = apiDefinition.Id, - ReadmeUrl = apiDefinition.ReadmeUrl, - DocsUrl = apiDefinition.ReadmeUrl, - Name = apiDefinition.Name, - }; + ReadmeUrl = details.ReadmeUrl ?? string.Empty, + Name = details.Name + }; + } + else if (e.Parameter is ModelType apiType && ModelTypeHelpers.ApiDefinitionDetails.TryGetValue(apiType, out var apiDefinition)) + { + // API + modelFamilyType = apiType; - modelSelectionControl.SetModels(GetAllSampleDetails().ToList()); - } - else + ModelFamily = new ModelFamily { - throw new InvalidOperationException("Invalid navigation parameter"); - } + Id = apiDefinition.Id, + ReadmeUrl = apiDefinition.ReadmeUrl, + DocsUrl = apiDefinition.ReadmeUrl, + Name = apiDefinition.Name, + }; - if (ModelFamily != null && !string.IsNullOrWhiteSpace(ModelFamily.ReadmeUrl)) - { - var loadReadme = LoadReadme(ModelFamily.ReadmeUrl); - } - else - { - summaryGrid.Visibility = Visibility.Collapsed; - } - - EnableSampleListIfModelIsDownloaded(); - App.ModelCache.CacheStore.ModelsChanged += CacheStore_ModelsChanged; + modelSelectionControl.SetModels(GetAllSampleDetails().ToList()); } - - private void ModelPage_Unloaded(object sender, RoutedEventArgs e) + else { - App.ModelCache.CacheStore.ModelsChanged -= CacheStore_ModelsChanged; + throw new InvalidOperationException("Invalid navigation parameter"); } - private void CacheStore_ModelsChanged(ModelCacheStore sender) + if (ModelFamily != null && !string.IsNullOrWhiteSpace(ModelFamily.ReadmeUrl)) + { + var loadReadme = LoadReadme(ModelFamily.ReadmeUrl); + } + else { - EnableSampleListIfModelIsDownloaded(); + summaryGrid.Visibility = Visibility.Collapsed; } - private void EnableSampleListIfModelIsDownloaded() + EnableSampleListIfModelIsDownloaded(); + App.ModelCache.CacheStore.ModelsChanged += CacheStore_ModelsChanged; + } + + private void ModelPage_Unloaded(object sender, RoutedEventArgs e) + { + App.ModelCache.CacheStore.ModelsChanged -= CacheStore_ModelsChanged; + } + + private void CacheStore_ModelsChanged(ModelCacheStore sender) + { + EnableSampleListIfModelIsDownloaded(); + } + + private void EnableSampleListIfModelIsDownloaded() + { + if (modelSelectionControl.Models != null && modelSelectionControl.Models.Count > 0) { - if (modelSelectionControl.Models != null && modelSelectionControl.Models.Count > 0) + foreach (var model in modelSelectionControl.Models) { - foreach (var model in modelSelectionControl.Models) + if (App.ModelCache.GetCachedModel(model.Url) != null || model.Size == 0) { - if (App.ModelCache.GetCachedModel(model.Url) != null || model.Size == 0) - { - SampleList.IsEnabled = true; - } + SampleList.IsEnabled = true; } } } + } - private async Task LoadReadme(string url) + private async Task LoadReadme(string url) + { + string readmeContents = string.Empty; + + if (url.StartsWith("https://github.com", StringComparison.InvariantCultureIgnoreCase)) { - string readmeContents = string.Empty; + readmeContents = await GithubApi.GetContentsOfTextFile(url); + } + else if (url.StartsWith("https://huggingface.co", StringComparison.InvariantCultureIgnoreCase)) + { + readmeContents = await HuggingFaceApi.GetContentsOfTextFile(url); + } - if (url.StartsWith("https://github.com", StringComparison.InvariantCultureIgnoreCase)) - { - readmeContents = await GithubApi.GetContentsOfTextFile(url); - } - else if (url.StartsWith("https://huggingface.co", StringComparison.InvariantCultureIgnoreCase)) - { - readmeContents = await HuggingFaceApi.GetContentsOfTextFile(url); - } + if (!string.IsNullOrWhiteSpace(readmeContents)) + { + readmeContents = Regex.Replace(readmeContents, @"\A---\n[\s\S]*?---\n", string.Empty, RegexOptions.Multiline); + markdownTextBlock.Text = readmeContents; + } - if (!string.IsNullOrWhiteSpace(readmeContents)) - { - readmeContents = Regex.Replace(readmeContents, @"\A---\n[\s\S]*?---\n", string.Empty, RegexOptions.Multiline); - markdownTextBlock.Text = readmeContents; - } + readmeProgressRing.IsActive = false; + } - readmeProgressRing.IsActive = false; + private IEnumerable GetAllSampleDetails() + { + if (!modelFamilyType.HasValue || !ModelTypeHelpers.ParentMapping.TryGetValue(modelFamilyType.Value, out List? modelTypes)) + { + yield break; } - private IEnumerable GetAllSampleDetails() + if (modelTypes.Count == 0) { - if (!modelFamilyType.HasValue || !ModelTypeHelpers.ParentMapping.TryGetValue(modelFamilyType.Value, out List? modelTypes)) - { - yield break; - } + // Its an API + modelTypes = [modelFamilyType.Value]; + } - if (modelTypes.Count == 0) + foreach (var modelType in modelTypes) + { + if (ModelTypeHelpers.ModelDetails.TryGetValue(modelType, out var modelDetails)) { - // Its an API - modelTypes = [modelFamilyType.Value]; + yield return modelDetails; } - - foreach (var modelType in modelTypes) + else if (ModelTypeHelpers.ApiDefinitionDetails.TryGetValue(modelType, out var apiDefinition)) { - if (ModelTypeHelpers.ModelDetails.TryGetValue(modelType, out var modelDetails)) - { - yield return modelDetails; - } - else if (ModelTypeHelpers.ApiDefinitionDetails.TryGetValue(modelType, out var apiDefinition)) - { - yield return ModelDetailsHelper.GetModelDetailsFromApiDefinition(modelType, apiDefinition); - } + yield return ModelDetailsHelper.GetModelDetailsFromApiDefinition(modelType, apiDefinition); } } + } - private void ModelSelectionControl_SelectedModelChanged(object sender, ModelDetails? modelDetails) - { - // if we don't have a modelType, we are in a user added language model, use same samples as Phi - var modelType = modelFamilyType ?? ModelType.Phi3Mini; - - var samples = SampleDetails.Samples.Where(s => s.Model1Types.Contains(modelType) || s.Model2Types?.Contains(modelType) == true).ToList(); - if (ModelTypeHelpers.ParentMapping.Values.Any(parent => parent.Contains(modelType))) - { - var parent = ModelTypeHelpers.ParentMapping.FirstOrDefault(parent => parent.Value.Contains(modelType)).Key; - samples.AddRange(SampleDetails.Samples.Where(s => s.Model1Types.Contains(parent) || s.Model2Types?.Contains(parent) == true)); - } + private void ModelSelectionControl_SelectedModelChanged(object sender, ModelDetails? modelDetails) + { + // if we don't have a modelType, we are in a user added language model, use same samples as Phi + var modelType = modelFamilyType ?? ModelType.Phi3Mini; - SampleList.ItemsSource = samples; + var samples = SampleDetails.Samples.Where(s => s.Model1Types.Contains(modelType) || s.Model2Types?.Contains(modelType) == true).ToList(); + if (ModelTypeHelpers.ParentMapping.Values.Any(parent => parent.Contains(modelType))) + { + var parent = ModelTypeHelpers.ParentMapping.FirstOrDefault(parent => parent.Value.Contains(modelType)).Key; + samples.AddRange(SampleDetails.Samples.Where(s => s.Model1Types.Contains(parent) || s.Model2Types?.Contains(parent) == true)); } - private void CopyButton_Click(object sender, RoutedEventArgs e) - { - if (ModelFamily == null || ModelFamily.Id == null) - { - return; - } + SampleList.ItemsSource = samples; + } - var dataPackage = new DataPackage(); - dataPackage.SetText($"aidevgallery://models/{ModelFamily.Id}"); - Clipboard.SetContentWithOptions(dataPackage, null); + private void CopyButton_Click(object sender, RoutedEventArgs e) + { + if (ModelFamily == null || ModelFamily.Id == null) + { + return; } - private void MarkdownTextBlock_LinkClicked(object sender, CommunityToolkit.WinUI.UI.Controls.LinkClickedEventArgs e) + var dataPackage = new DataPackage(); + dataPackage.SetText($"aidevgallery://models/{ModelFamily.Id}"); + Clipboard.SetContentWithOptions(dataPackage, null); + } + + private void MarkdownTextBlock_LinkClicked(object sender, CommunityToolkit.WinUI.UI.Controls.LinkClickedEventArgs e) + { + ModelDetailsLinkClickedEvent.Log(e.Link); + Process.Start(new ProcessStartInfo() { - ModelDetailsLinkClickedEvent.Log(e.Link); - Process.Start(new ProcessStartInfo() - { - FileName = e.Link, - UseShellExecute = true - }); - } + FileName = e.Link, + UseShellExecute = true + }); + } - private void SampleList_ItemInvoked(ItemsView sender, ItemsViewItemInvokedEventArgs args) + private void SampleList_ItemInvoked(ItemsView sender, ItemsViewItemInvokedEventArgs args) + { + if (args.InvokedItem is Sample sample) { - if (args.InvokedItem is Sample sample) - { - var availableModel = modelSelectionControl.DownloadedModels.FirstOrDefault(); - App.MainWindow.Navigate("Samples", new SampleNavigationArgs(sample, availableModel)); - } + var availableModel = modelSelectionControl.DownloadedModels.FirstOrDefault(); + App.MainWindow.Navigate("Samples", new SampleNavigationArgs(sample, availableModel)); } } } \ No newline at end of file diff --git a/AIDevGallery/Pages/ModelSelectionPage.xaml.cs b/AIDevGallery/Pages/ModelSelectionPage.xaml.cs index 0ab2a04..7c3f107 100644 --- a/AIDevGallery/Pages/ModelSelectionPage.xaml.cs +++ b/AIDevGallery/Pages/ModelSelectionPage.xaml.cs @@ -11,265 +11,264 @@ using System.Collections.Generic; using System.Linq; -namespace AIDevGallery.Pages +namespace AIDevGallery.Pages; + +internal record LastInternalNavigation(Type PageType, object? Parameter = null); + +internal sealed partial class ModelSelectionPage : Page { - internal record LastInternalNavigation(Type PageType, object? Parameter = null); + private static LastInternalNavigation? lastInternalNavigation; - internal sealed partial class ModelSelectionPage : Page + public ModelSelectionPage() { - private static LastInternalNavigation? lastInternalNavigation; + this.InitializeComponent(); + } - public ModelSelectionPage() - { - this.InitializeComponent(); - } + protected override void OnNavigatedTo(NavigationEventArgs e) + { + NavigatedToPageEvent.Log(nameof(ModelSelectionPage)); - protected override void OnNavigatedTo(NavigationEventArgs e) + SetUpModels(); + NavView.Loaded += (sender, args) => { - NavigatedToPageEvent.Log(nameof(ModelSelectionPage)); + List? modelTypes = null; + ModelDetails? details = null; + object? parameter = e.Parameter; - SetUpModels(); - NavView.Loaded += (sender, args) => + if (e.Parameter == null && lastInternalNavigation != null) { - List? modelTypes = null; - ModelDetails? details = null; - object? parameter = e.Parameter; + parameter = lastInternalNavigation.Parameter; + } - if (e.Parameter == null && lastInternalNavigation != null) - { - parameter = lastInternalNavigation.Parameter; - } + if (parameter is ModelType sample) + { + modelTypes = [sample]; + } + else if (parameter is List samples) + { + modelTypes = samples; + } + else if (parameter is MostRecentlyUsedItem mru) + { + modelTypes = App.FindSampleItemById(mru.ItemId); + } + else if (parameter is string modelOrApiId) + { + modelTypes = GetFamilyModelType(App.FindSampleItemById(modelOrApiId)); + } + else if (parameter is ModelDetails modelDetails) + { + details = modelDetails; + modelTypes = GetFamilyModelType(App.FindSampleItemById(details.Id)); + } - if (parameter is ModelType sample) - { - modelTypes = [sample]; - } - else if (parameter is List samples) - { - modelTypes = samples; - } - else if (parameter is MostRecentlyUsedItem mru) - { - modelTypes = App.FindSampleItemById(mru.ItemId); - } - else if (parameter is string modelOrApiId) - { - modelTypes = GetFamilyModelType(App.FindSampleItemById(modelOrApiId)); - } - else if (parameter is ModelDetails modelDetails) + if (modelTypes != null || details != null) + { + foreach (NavigationViewItem item in NavView.MenuItems) { - details = modelDetails; - modelTypes = GetFamilyModelType(App.FindSampleItemById(details.Id)); + SetSelectedSampleInMenu(item, modelTypes, details); } - - if (modelTypes != null || details != null) + } + else + { + if (NavView.MenuItems.FirstOrDefault() is NavigationViewItem item) { - foreach (NavigationViewItem item in NavView.MenuItems) + if (item.MenuItems != null && item.MenuItems.Count > 0) { - SetSelectedSampleInMenu(item, modelTypes, details); + item.IsExpanded = true; + NavView.SelectedItem = item.MenuItems[0]; } - } - else - { - if (NavView.MenuItems.FirstOrDefault() is NavigationViewItem item) + else { - if (item.MenuItems != null && item.MenuItems.Count > 0) - { - item.IsExpanded = true; - NavView.SelectedItem = item.MenuItems[0]; - } - else - { - NavView.SelectedItem = item; - } + NavView.SelectedItem = item; } } + } - static List? GetFamilyModelType(List? modelTypes) + static List? GetFamilyModelType(List? modelTypes) + { + if (modelTypes != null && modelTypes.Count > 0) { - if (modelTypes != null && modelTypes.Count > 0) + var modelType = modelTypes.First(); + if (ModelTypeHelpers.ModelDetails.ContainsKey(modelType)) { - var modelType = modelTypes.First(); - if (ModelTypeHelpers.ModelDetails.ContainsKey(modelType)) - { - var parent = ModelTypeHelpers.ParentMapping.FirstOrDefault(parent => parent.Value.Contains(modelType)); - modelTypes = [parent.Key]; - } + var parent = ModelTypeHelpers.ParentMapping.FirstOrDefault(parent => parent.Value.Contains(modelType)); + modelTypes = [parent.Key]; } - - return modelTypes; } - }; - base.OnNavigatedTo(e); - } + return modelTypes; + } + }; - private void SetUpModels() + base.OnNavigatedTo(e); + } + + private void SetUpModels() + { + List rootModels = [.. ModelTypeHelpers.ModelGroupDetails.Keys]; + rootModels.AddRange(ModelTypeHelpers.ModelFamilyDetails.Keys); + foreach (var key in ModelTypeHelpers.ModelFamilyDetails) { - List rootModels = [.. ModelTypeHelpers.ModelGroupDetails.Keys]; - rootModels.AddRange(ModelTypeHelpers.ModelFamilyDetails.Keys); - foreach (var key in ModelTypeHelpers.ModelFamilyDetails) + foreach (var mapping in ModelTypeHelpers.ParentMapping) { - foreach (var mapping in ModelTypeHelpers.ParentMapping) + foreach (var key2 in mapping.Value) { - foreach (var key2 in mapping.Value) + if (key.Key == key2) { - if (key.Key == key2) - { - rootModels.Remove(key.Key); - } + rootModels.Remove(key.Key); } } } + } - NavigationViewItem? languageModelsNavItem = null; + NavigationViewItem? languageModelsNavItem = null; - foreach (var key in rootModels.OrderBy(ModelTypeHelpers.GetModelOrder)) - { - var navItem = CreateFromItem(key, ModelTypeHelpers.ModelGroupDetails.ContainsKey(key)); - NavView.MenuItems.Add(navItem); + foreach (var key in rootModels.OrderBy(ModelTypeHelpers.GetModelOrder)) + { + var navItem = CreateFromItem(key, ModelTypeHelpers.ModelGroupDetails.ContainsKey(key)); + NavView.MenuItems.Add(navItem); - if (key == ModelType.LanguageModels) - { - languageModelsNavItem = navItem; - } + if (key == ModelType.LanguageModels) + { + languageModelsNavItem = navItem; } + } - if (languageModelsNavItem != null) - { - var userAddedModels = App.ModelCache.Models.Where(m => m.Details.IsUserAdded).ToList(); + if (languageModelsNavItem != null) + { + var userAddedModels = App.ModelCache.Models.Where(m => m.Details.IsUserAdded).ToList(); - foreach (var cachedModel in userAddedModels) + foreach (var cachedModel in userAddedModels) + { + languageModelsNavItem.MenuItems.Add(new NavigationViewItem { - languageModelsNavItem.MenuItems.Add(new NavigationViewItem - { - Content = cachedModel.Details.Name.Split('/').Last(), - Tag = cachedModel.Details, - }); - } + Content = cachedModel.Details.Name.Split('/').Last(), + Tag = cachedModel.Details, + }); } } + } - private static NavigationViewItem CreateFromItem(ModelType key, bool includeChildren) + private static NavigationViewItem CreateFromItem(ModelType key, bool includeChildren) + { + string name; + string? icon = null; + if (ModelTypeHelpers.ModelGroupDetails.TryGetValue(key, out var modelGroup)) + { + name = modelGroup.Name; + icon = modelGroup.Icon; + } + else { - string name; - string? icon = null; - if (ModelTypeHelpers.ModelGroupDetails.TryGetValue(key, out var modelGroup)) + if (ModelTypeHelpers.ModelFamilyDetails.TryGetValue(key, out var modelFamily)) + { + name = modelFamily.Name ?? key.ToString(); + } + else if (ModelTypeHelpers.ApiDefinitionDetails.TryGetValue(key, out var apiDefinition)) { - name = modelGroup.Name; - icon = modelGroup.Icon; + name = apiDefinition.Name ?? key.ToString(); } else { - if (ModelTypeHelpers.ModelFamilyDetails.TryGetValue(key, out var modelFamily)) - { - name = modelFamily.Name ?? key.ToString(); - } - else if (ModelTypeHelpers.ApiDefinitionDetails.TryGetValue(key, out var apiDefinition)) - { - name = apiDefinition.Name ?? key.ToString(); - } - else - { - name = key.ToString(); - } + name = key.ToString(); } + } - NavigationViewItem item = new() - { - Content = name, - Tag = key - }; + NavigationViewItem item = new() + { + Content = name, + Tag = key + }; - if (!string.IsNullOrWhiteSpace(icon)) - { - item.Icon = new FontIcon() { Glyph = icon }; - } + if (!string.IsNullOrWhiteSpace(icon)) + { + item.Icon = new FontIcon() { Glyph = icon }; + } - if (ModelTypeHelpers.ParentMapping.TryGetValue(key, out List? innerItems)) + if (ModelTypeHelpers.ParentMapping.TryGetValue(key, out List? innerItems)) + { + if (innerItems?.Count > 0) { - if (innerItems?.Count > 0) + if (includeChildren) { - if (includeChildren) + item.SelectsOnInvoked = false; + foreach (var childNavigationItem in innerItems) { - item.SelectsOnInvoked = false; - foreach (var childNavigationItem in innerItems) - { - item.MenuItems.Add(CreateFromItem(childNavigationItem, false)); - } + item.MenuItems.Add(CreateFromItem(childNavigationItem, false)); } } } - - return item; } - private void NavView_SelectionChanged(NavigationView sender, NavigationViewSelectionChangedEventArgs args) - { - Type pageType = typeof(ModelPage); - object? parameter = null; - - if (args.SelectedItem is NavigationViewItem item) - { - if (item.Tag is ModelType modelType) - { - parameter = modelType; - } - else if (item.Tag is string tag && tag == "ModelManagement") - { - pageType = typeof(ModelSelectionPage); - } - else if (item.Tag is ModelDetails details) - { - parameter = details; - } + return item; + } - lastInternalNavigation = new LastInternalNavigation(pageType, parameter); - NavFrame.Navigate(pageType, parameter); - } - } + private void NavView_SelectionChanged(NavigationView sender, NavigationViewSelectionChangedEventArgs args) + { + Type pageType = typeof(ModelPage); + object? parameter = null; - private void SetSelectedSampleInMenu(NavigationViewItem item, List? selectedSample = null, ModelDetails? details = null) + if (args.SelectedItem is NavigationViewItem item) { - if (selectedSample == null && details == null) + if (item.Tag is ModelType modelType) { - return; + parameter = modelType; } - - if (item.Tag is ModelType mt && selectedSample != null && selectedSample.Contains(mt)) + else if (item.Tag is string tag && tag == "ModelManagement") { - NavView.SelectedItem = item; - return; + pageType = typeof(ModelSelectionPage); } - - foreach (var menuItem in item.MenuItems) + else if (item.Tag is ModelDetails details) { - if (menuItem is NavigationViewItem navItem) - { - if ((navItem.Tag is ModelType modelType && selectedSample != null && selectedSample.Contains(modelType)) || - (navItem.Tag is ModelDetails modelDetails && details != null && modelDetails.Url == details.Url)) - { - item.IsExpanded = true; - NavView.SelectedItem = navItem; - return; - } - else if (navItem.MenuItems.Count > 0) - { - SetSelectedSampleInMenu(navItem, selectedSample, details); - } - } + parameter = details; } + + lastInternalNavigation = new LastInternalNavigation(pageType, parameter); + NavFrame.Navigate(pageType, parameter); } + } - private void AddModelClicked(object sender, Microsoft.UI.Xaml.RoutedEventArgs e) + private void SetSelectedSampleInMenu(NavigationViewItem item, List? selectedSample = null, ModelDetails? details = null) + { + if (selectedSample == null && details == null) { - NavView.SelectedItem = null; - NavFrame.Navigate(typeof(AddModelPage)); + return; } - private void ManageModelsClicked(object sender, Microsoft.UI.Xaml.RoutedEventArgs e) + if (item.Tag is ModelType mt && selectedSample != null && selectedSample.Contains(mt)) { - App.MainWindow.Navigate("settings", "ModelManagement"); + NavView.SelectedItem = item; + return; } + + foreach (var menuItem in item.MenuItems) + { + if (menuItem is NavigationViewItem navItem) + { + if ((navItem.Tag is ModelType modelType && selectedSample != null && selectedSample.Contains(modelType)) || + (navItem.Tag is ModelDetails modelDetails && details != null && modelDetails.Url == details.Url)) + { + item.IsExpanded = true; + NavView.SelectedItem = navItem; + return; + } + else if (navItem.MenuItems.Count > 0) + { + SetSelectedSampleInMenu(navItem, selectedSample, details); + } + } + } + } + + private void AddModelClicked(object sender, Microsoft.UI.Xaml.RoutedEventArgs e) + { + NavView.SelectedItem = null; + NavFrame.Navigate(typeof(AddModelPage)); + } + + private void ManageModelsClicked(object sender, Microsoft.UI.Xaml.RoutedEventArgs e) + { + App.MainWindow.Navigate("settings", "ModelManagement"); } } \ No newline at end of file diff --git a/AIDevGallery/Pages/ScenarioPage.xaml.cs b/AIDevGallery/Pages/ScenarioPage.xaml.cs index 907a839..71ee7a7 100644 --- a/AIDevGallery/Pages/ScenarioPage.xaml.cs +++ b/AIDevGallery/Pages/ScenarioPage.xaml.cs @@ -21,449 +21,448 @@ using Windows.ApplicationModel.DataTransfer; using Windows.Storage.Pickers; -namespace AIDevGallery.Pages +namespace AIDevGallery.Pages; + +internal sealed partial class ScenarioPage : Page { - internal sealed partial class ScenarioPage : Page + private Scenario? scenario; + private List? samples; + private Sample? sample; + private ModelDetails? selectedModelDetails; + private ModelDetails? selectedModelDetails2; + + public ScenarioPage() { - private Scenario? scenario; - private List? samples; - private Sample? sample; - private ModelDetails? selectedModelDetails; - private ModelDetails? selectedModelDetails2; + this.InitializeComponent(); + } - public ScenarioPage() + protected override void OnNavigatedTo(NavigationEventArgs e) + { + base.OnNavigatedTo(e); + if (e.Parameter is Scenario scenario) + { + this.scenario = scenario; + PopulateModelControls(); + } + else if (e.Parameter is SampleNavigationArgs sampleArgs) { - this.InitializeComponent(); + this.scenario = ScenarioCategoryHelpers.AllScenarioCategories.SelectMany(sc => sc.Scenarios).FirstOrDefault(s => s.ScenarioType == sampleArgs.Sample.Scenario); + PopulateModelControls(sampleArgs.ModelDetails); } + } - protected override void OnNavigatedTo(NavigationEventArgs e) + private void PopulateModelControls(ModelDetails? initialModelToLoad = null) + { + if (scenario == null) { - base.OnNavigatedTo(e); - if (e.Parameter is Scenario scenario) - { - this.scenario = scenario; - PopulateModelControls(); - } - else if (e.Parameter is SampleNavigationArgs sampleArgs) - { - this.scenario = ScenarioCategoryHelpers.AllScenarioCategories.SelectMany(sc => sc.Scenarios).FirstOrDefault(s => s.ScenarioType == sampleArgs.Sample.Scenario); - PopulateModelControls(sampleArgs.ModelDetails); - } + return; } - private void PopulateModelControls(ModelDetails? initialModelToLoad = null) + samples = SampleDetails.Samples.Where(sample => sample.Scenario == scenario.ScenarioType).ToList(); + List>>? modelsDetailsDict = null; + + if (samples.Count == 0) + { + return; + } + else if (samples.Count == 1) { - if (scenario == null) + modelsDetailsDict = ModelDetailsHelper.GetModelDetails(samples.First()); + } + else + { + var sample = samples.First(); + var modelKey = sample.Model1Types.First(); // First only? + + if (ModelTypeHelpers.ParentMapping.Values.Any(parent => parent.Contains(modelKey))) { - return; - } + var parentKey = ModelTypeHelpers.ParentMapping.FirstOrDefault(parent => parent.Value.Contains(modelKey)).Key; - samples = SampleDetails.Samples.Where(sample => sample.Scenario == scenario.ScenarioType).ToList(); - List>>? modelsDetailsDict = null; + var listModelDetails = new List(); - if (samples.Count == 0) - { - return; - } - else if (samples.Count == 1) - { - modelsDetailsDict = ModelDetailsHelper.GetModelDetails(samples.First()); + foreach (var s in samples) + { + listModelDetails.AddRange(ModelDetailsHelper.GetModelDetails(s).First().First().Value); + } + + modelsDetailsDict = + [ + new Dictionary> + { + [parentKey] = listModelDetails + } + ]; } - else + + if (sample.Model2Types != null) { - var sample = samples.First(); - var modelKey = sample.Model1Types.First(); // First only? + var modelKey2 = sample.Model2Types.First(); // First only? - if (ModelTypeHelpers.ParentMapping.Values.Any(parent => parent.Contains(modelKey))) + if (ModelTypeHelpers.ParentMapping.Values.Any(parent => parent.Contains(modelKey2))) { - var parentKey = ModelTypeHelpers.ParentMapping.FirstOrDefault(parent => parent.Value.Contains(modelKey)).Key; + var parentKey2 = ModelTypeHelpers.ParentMapping.FirstOrDefault(parent => parent.Value.Contains(modelKey2)).Key; - var listModelDetails = new List(); + var listModelDetails2 = new List(); foreach (var s in samples) { - listModelDetails.AddRange(ModelDetailsHelper.GetModelDetails(s).First().First().Value); + listModelDetails2.AddRange(ModelDetailsHelper.GetModelDetails(s).ElementAt(1).First().Value); } - modelsDetailsDict = - [ + modelsDetailsDict ??= []; + + modelsDetailsDict.Add( new Dictionary> { - [parentKey] = listModelDetails - } - ]; + [parentKey2] = listModelDetails2 + }); } + } + } - if (sample.Model2Types != null) - { - var modelKey2 = sample.Model2Types.First(); // First only? - - if (ModelTypeHelpers.ParentMapping.Values.Any(parent => parent.Contains(modelKey2))) - { - var parentKey2 = ModelTypeHelpers.ParentMapping.FirstOrDefault(parent => parent.Value.Contains(modelKey2)).Key; + if (modelsDetailsDict == null) + { + return; + } - var listModelDetails2 = new List(); + var models = modelsDetailsDict.First().SelectMany(g => g.Value).ToList(); - foreach (var s in samples) - { - listModelDetails2.AddRange(ModelDetailsHelper.GetModelDetails(s).ElementAt(1).First().Value); - } + selectedModelDetails = SelectLatestOrDefault(models); - modelsDetailsDict ??= []; + if (modelsDetailsDict.Count > 1) + { + var models2 = modelsDetailsDict.ElementAt(1).SelectMany(g => g.Value).ToList(); + selectedModelDetails2 = SelectLatestOrDefault(models2); + modelSelectionControl2.SetModels(models2, initialModelToLoad); + } - modelsDetailsDict.Add( - new Dictionary> - { - [parentKey2] = listModelDetails2 - }); - } - } - } + modelSelectionControl.SetModels(models, initialModelToLoad); + UpdatePlaceholderControl(); + } - if (modelsDetailsDict == null) - { - return; - } + private static ModelDetails? SelectLatestOrDefault(List models) + { + var latestModelOrApiUsageHistory = App.AppData.UsageHistory.FirstOrDefault(id => models.Any(m => m.Id == id)); - var models = modelsDetailsDict.First().SelectMany(g => g.Value).ToList(); + if (latestModelOrApiUsageHistory != null) + { + // select most recently used if there is one + return models.First(m => m.Id == latestModelOrApiUsageHistory); + } - selectedModelDetails = SelectLatestOrDefault(models); + return models.FirstOrDefault(); + } - if (modelsDetailsDict.Count > 1) - { - var models2 = modelsDetailsDict.ElementAt(1).SelectMany(g => g.Value).ToList(); - selectedModelDetails2 = SelectLatestOrDefault(models2); - modelSelectionControl2.SetModels(models2, initialModelToLoad); - } + private async void ModelSelectionControl_SelectedModelChanged(object sender, ModelDetails? modelDetails) + { + ModelDropDown.HideFlyout(); + ModelDropDown2.HideFlyout(); - modelSelectionControl.SetModels(models, initialModelToLoad); - UpdatePlaceholderControl(); + if (samples == null) + { + return; } - private static ModelDetails? SelectLatestOrDefault(List models) + if ((ModelSelectionControl)sender == modelSelectionControl) { - var latestModelOrApiUsageHistory = App.AppData.UsageHistory.FirstOrDefault(id => models.Any(m => m.Id == id)); - - if (latestModelOrApiUsageHistory != null) - { - // select most recently used if there is one - return models.First(m => m.Id == latestModelOrApiUsageHistory); - } - - return models.FirstOrDefault(); + selectedModelDetails = modelDetails; } - - private async void ModelSelectionControl_SelectedModelChanged(object sender, ModelDetails? modelDetails) + else { - ModelDropDown.HideFlyout(); - ModelDropDown2.HideFlyout(); - - if (samples == null) - { - return; - } - - if ((ModelSelectionControl)sender == modelSelectionControl) - { - selectedModelDetails = modelDetails; - } - else - { - selectedModelDetails2 = modelDetails; - } + selectedModelDetails2 = modelDetails; + } - if (selectedModelDetails != null) + if (selectedModelDetails != null) + { + foreach (var s in samples) { - foreach (var s in samples) + var extDict = ModelDetailsHelper.GetModelDetails(s).FirstOrDefault(dict => dict.Values.Any(listOfmd => listOfmd.Any(md => md.Id == selectedModelDetails.Id)))?.Values; + if (extDict != null) { - var extDict = ModelDetailsHelper.GetModelDetails(s).FirstOrDefault(dict => dict.Values.Any(listOfmd => listOfmd.Any(md => md.Id == selectedModelDetails.Id)))?.Values; - if (extDict != null) + var dict = extDict.FirstOrDefault(listOfmd => listOfmd.Any(md => md.Id == selectedModelDetails.Id)); + if (dict != null) { - var dict = extDict.FirstOrDefault(listOfmd => listOfmd.Any(md => md.Id == selectedModelDetails.Id)); - if (dict != null) - { - sample = s; - break; - } + sample = s; + break; } } } - else - { - sample = null; - } + } + else + { + sample = null; + } - if (sample == null) - { - return; - } + if (sample == null) + { + return; + } - if ((sample.Model2Types == null && selectedModelDetails == null) || - (sample.Model2Types != null && (selectedModelDetails == null || selectedModelDetails2 == null))) - { - UpdatePlaceholderControl(); + if ((sample.Model2Types == null && selectedModelDetails == null) || + (sample.Model2Types != null && (selectedModelDetails == null || selectedModelDetails2 == null))) + { + UpdatePlaceholderControl(); - VisualStateManager.GoToState(this, "NoModelSelected", true); - return; - } - else + VisualStateManager.GoToState(this, "NoModelSelected", true); + return; + } + else + { + VisualStateManager.GoToState(this, "ModelSelected", true); + ModelDropDown2.Visibility = Visibility.Collapsed; + + ModelDropDown.Model = selectedModelDetails; + List models = [selectedModelDetails!]; + + if (sample.Model2Types != null) { - VisualStateManager.GoToState(this, "ModelSelected", true); - ModelDropDown2.Visibility = Visibility.Collapsed; + models.Add(selectedModelDetails2!); + ModelDropDown2.Model = selectedModelDetails2; + ModelDropDown2.Visibility = Visibility.Visible; + } - ModelDropDown.Model = selectedModelDetails; - List models = [selectedModelDetails!]; + await SampleContainer.LoadSampleAsync(sample, models); - if (sample.Model2Types != null) + await App.AppData.AddMru( + new MostRecentlyUsedItem() { - models.Add(selectedModelDetails2!); - ModelDropDown2.Model = selectedModelDetails2; - ModelDropDown2.Visibility = Visibility.Visible; - } - - await SampleContainer.LoadSampleAsync(sample, models); + Type = MostRecentlyUsedItemType.Scenario, + ItemId = scenario!.Id, + Icon = scenario.Icon, + Description = scenario.Description, + SubItemId = selectedModelDetails!.Id, + DisplayName = scenario.Name + }, + selectedModelDetails.Id); + } + } - await App.AppData.AddMru( - new MostRecentlyUsedItem() - { - Type = MostRecentlyUsedItemType.Scenario, - ItemId = scenario!.Id, - Icon = scenario.Icon, - Description = scenario.Description, - SubItemId = selectedModelDetails!.Id, - DisplayName = scenario.Name - }, - selectedModelDetails.Id); - } + private void UpdatePlaceholderControl() + { + if (sample == null || (sample.Model2Types == null && selectedModelDetails == null)) + { + PlaceholderControl.SetModels(modelSelectionControl.Models); } + else + { + PlaceholderControl.SetModels(modelSelectionControl2.Models); + } + } - private void UpdatePlaceholderControl() + private void CopyButton_Click(object sender, RoutedEventArgs e) + { + var dataPackage = new DataPackage(); + dataPackage.SetText("aidevgallery://scenarios/" + scenario!.Id); + Clipboard.SetContentWithOptions(dataPackage, null); + } + + private void CodeToggle_Click(object sender, RoutedEventArgs args) + { + if (sender is ToggleButton btn) { - if (sample == null || (sample.Model2Types == null && selectedModelDetails == null)) + if (sample != null) { - PlaceholderControl.SetModels(modelSelectionControl.Models); + ToggleCodeButtonEvent.Log(sample.Name ?? string.Empty, btn.IsChecked == true); + } + + if (btn.IsChecked == true) + { + SampleContainer.ShowCode(); } else { - PlaceholderControl.SetModels(modelSelectionControl2.Models); + SampleContainer.HideCode(); } } + } - private void CopyButton_Click(object sender, RoutedEventArgs e) + private async void ExportSampleToggle_Click(object sender, RoutedEventArgs e) + { + if (sender is not Button button || + sample == null || + selectedModelDetails == null || + (sample.Model2Types != null && selectedModelDetails2 == null)) { - var dataPackage = new DataPackage(); - dataPackage.SetText("aidevgallery://scenarios/" + scenario!.Id); - Clipboard.SetContentWithOptions(dataPackage, null); + return; } - private void CodeToggle_Click(object sender, RoutedEventArgs args) + Dictionary cachedModels = []; + + (string Id, string Path, string Url, long ModelSize, HardwareAccelerator HardwareAccelerator) cachedModel; + + if (selectedModelDetails.Size == 0) + { + cachedModel = (selectedModelDetails.Id, selectedModelDetails.Url, selectedModelDetails.Url, 0, selectedModelDetails.HardwareAccelerators.FirstOrDefault()); + } + else { - if (sender is ToggleButton btn) + var realCachedModel = App.ModelCache.GetCachedModel(selectedModelDetails.Url); + if (realCachedModel == null) { - if (sample != null) - { - ToggleCodeButtonEvent.Log(sample.Name ?? string.Empty, btn.IsChecked == true); - } - - if (btn.IsChecked == true) - { - SampleContainer.ShowCode(); - } - else - { - SampleContainer.HideCode(); - } + return; } + + cachedModel = (selectedModelDetails.Id, realCachedModel.Path, realCachedModel.Url, realCachedModel.ModelSize, selectedModelDetails.HardwareAccelerators.FirstOrDefault()); } - private async void ExportSampleToggle_Click(object sender, RoutedEventArgs e) + var cachedSampleItem = App.FindSampleItemById(cachedModel.Id); + + var model1Type = sample.Model1Types.Any(cachedSampleItem.Contains) + ? sample.Model1Types.First(cachedSampleItem.Contains) + : sample.Model1Types.First(); + cachedModels.Add(model1Type, cachedModel); + + if (sample.Model2Types != null) { - if (sender is not Button button || - sample == null || - selectedModelDetails == null || - (sample.Model2Types != null && selectedModelDetails2 == null)) + if (selectedModelDetails2 == null) { return; } - Dictionary cachedModels = []; - - (string Id, string Path, string Url, long ModelSize, HardwareAccelerator HardwareAccelerator) cachedModel; - - if (selectedModelDetails.Size == 0) + if (selectedModelDetails2.Size == 0) { - cachedModel = (selectedModelDetails.Id, selectedModelDetails.Url, selectedModelDetails.Url, 0, selectedModelDetails.HardwareAccelerators.FirstOrDefault()); + cachedModel = (selectedModelDetails2.Id, selectedModelDetails2.Url, selectedModelDetails2.Url, 0, selectedModelDetails2.HardwareAccelerators.FirstOrDefault()); } else { - var realCachedModel = App.ModelCache.GetCachedModel(selectedModelDetails.Url); + var realCachedModel = App.ModelCache.GetCachedModel(selectedModelDetails2.Url); if (realCachedModel == null) { return; } - cachedModel = (selectedModelDetails.Id, realCachedModel.Path, realCachedModel.Url, realCachedModel.ModelSize, selectedModelDetails.HardwareAccelerators.FirstOrDefault()); + cachedModel = (selectedModelDetails2.Id, realCachedModel.Path, realCachedModel.Url, realCachedModel.ModelSize, selectedModelDetails2.HardwareAccelerators.FirstOrDefault()); } - var cachedSampleItem = App.FindSampleItemById(cachedModel.Id); + var model2Type = sample.Model2Types.Any(cachedSampleItem.Contains) + ? sample.Model2Types.First(cachedSampleItem.Contains) + : sample.Model2Types.First(); - var model1Type = sample.Model1Types.Any(cachedSampleItem.Contains) - ? sample.Model1Types.First(cachedSampleItem.Contains) - : sample.Model1Types.First(); - cachedModels.Add(model1Type, cachedModel); + cachedModels.Add(model2Type, cachedModel); + } - if (sample.Model2Types != null) + ContentDialog? dialog = null; + try + { + var totalSize = cachedModels.Sum(cm => cm.Value.ModelSize); + if (totalSize == 0) { - if (selectedModelDetails2 == null) - { - return; - } - - if (selectedModelDetails2.Size == 0) - { - cachedModel = (selectedModelDetails2.Id, selectedModelDetails2.Url, selectedModelDetails2.Url, 0, selectedModelDetails2.HardwareAccelerators.FirstOrDefault()); - } - else - { - var realCachedModel = App.ModelCache.GetCachedModel(selectedModelDetails2.Url); - if (realCachedModel == null) - { - return; - } - - cachedModel = (selectedModelDetails2.Id, realCachedModel.Path, realCachedModel.Url, realCachedModel.ModelSize, selectedModelDetails2.HardwareAccelerators.FirstOrDefault()); - } - - var model2Type = sample.Model2Types.Any(cachedSampleItem.Contains) - ? sample.Model2Types.First(cachedSampleItem.Contains) - : sample.Model2Types.First(); - - cachedModels.Add(model2Type, cachedModel); + copyRadioButtons.Visibility = Visibility.Collapsed; } - - ContentDialog? dialog = null; - try + else { - var totalSize = cachedModels.Sum(cm => cm.Value.ModelSize); - if (totalSize == 0) - { - copyRadioButtons.Visibility = Visibility.Collapsed; - } - else - { - copyRadioButtons.Visibility = Visibility.Visible; - ModelExportSizeTxt.Text = AppUtils.FileSizeToString(totalSize); - } + copyRadioButtons.Visibility = Visibility.Visible; + ModelExportSizeTxt.Text = AppUtils.FileSizeToString(totalSize); + } - var output = await ExportDialog.ShowAsync(); + var output = await ExportDialog.ShowAsync(); - if (output == ContentDialogResult.Primary) + if (output == ContentDialogResult.Primary) + { + var hwnd = WinRT.Interop.WindowNative.GetWindowHandle(App.MainWindow); + var picker = new FolderPicker(); + picker.FileTypeFilter.Add("*"); + WinRT.Interop.InitializeWithWindow.Initialize(picker, hwnd); + var folder = await picker.PickSingleFolderAsync(); + if (folder != null) { - var hwnd = WinRT.Interop.WindowNative.GetWindowHandle(App.MainWindow); - var picker = new FolderPicker(); - picker.FileTypeFilter.Add("*"); - WinRT.Interop.InitializeWithWindow.Initialize(picker, hwnd); - var folder = await picker.PickSingleFolderAsync(); - if (folder != null) - { - var generator = new Generator(); + var generator = new Generator(); - dialog = new ContentDialog + dialog = new ContentDialog + { + XamlRoot = this.XamlRoot, + Title = "Creating Visual Studio project..", + Content = new ProgressRing { IsActive = true, Width = 48, Height = 48 } + }; + _ = dialog.ShowAsync(); + + Dictionary cachedModelsToGenerator = cachedModels + .Select(cm => (cm.Key, (cm.Value.Path, cm.Value.Url, cm.Value.HardwareAccelerator))) + .ToDictionary(x => x.Key, x => (x.Item2.Path, x.Item2.Url, x.Item2.HardwareAccelerator)); + + var projectPath = await generator.GenerateAsync( + sample, + cachedModelsToGenerator, + copyRadioButton.IsChecked == true && copyRadioButtons.Visibility == Visibility.Visible, + folder.Path, + CancellationToken.None); + + dialog.Closed += async (_, _) => + { + var confirmationDialog = new ContentDialog { XamlRoot = this.XamlRoot, - Title = "Creating Visual Studio project..", - Content = new ProgressRing { IsActive = true, Width = 48, Height = 48 } + Title = "Project exported", + Content = new TextBlock + { + Text = "The project has been successfully exported to the selected folder.", + TextWrapping = TextWrapping.WrapWholeWords + }, + PrimaryButtonText = "Open folder", + PrimaryButtonStyle = (Style)App.Current.Resources["AccentButtonStyle"], + CloseButtonText = "Close" }; - _ = dialog.ShowAsync(); - Dictionary cachedModelsToGenerator = cachedModels - .Select(cm => (cm.Key, (cm.Value.Path, cm.Value.Url, cm.Value.HardwareAccelerator))) - .ToDictionary(x => x.Key, x => (x.Item2.Path, x.Item2.Url, x.Item2.HardwareAccelerator)); - - var projectPath = await generator.GenerateAsync( - sample, - cachedModelsToGenerator, - copyRadioButton.IsChecked == true && copyRadioButtons.Visibility == Visibility.Visible, - folder.Path, - CancellationToken.None); - - dialog.Closed += async (_, _) => + var shouldOpenFolder = await confirmationDialog.ShowAsync(); + if (shouldOpenFolder == ContentDialogResult.Primary) { - var confirmationDialog = new ContentDialog - { - XamlRoot = this.XamlRoot, - Title = "Project exported", - Content = new TextBlock - { - Text = "The project has been successfully exported to the selected folder.", - TextWrapping = TextWrapping.WrapWholeWords - }, - PrimaryButtonText = "Open folder", - PrimaryButtonStyle = (Style)App.Current.Resources["AccentButtonStyle"], - CloseButtonText = "Close" - }; - - var shouldOpenFolder = await confirmationDialog.ShowAsync(); - if (shouldOpenFolder == ContentDialogResult.Primary) - { - await Windows.System.Launcher.LaunchFolderPathAsync(projectPath); - } - }; - dialog.Hide(); - dialog = null; - } + await Windows.System.Launcher.LaunchFolderPathAsync(projectPath); + } + }; + dialog.Hide(); + dialog = null; } } - catch (Exception ex) - { - Debug.WriteLine(ex); - dialog?.Hide(); + } + catch (Exception ex) + { + Debug.WriteLine(ex); + dialog?.Hide(); - var message = "Please try again, or report this issue."; - if (ex is IOException) - { - message = ex.Message; - } + var message = "Please try again, or report this issue."; + if (ex is IOException) + { + message = ex.Message; + } - var errorDialog = new ContentDialog - { - XamlRoot = this.XamlRoot, - Title = "Error while exporting project", - Content = new TextBlock - { - Text = $"An error occurred while exporting the project. {message}", - TextWrapping = TextWrapping.WrapWholeWords - }, - PrimaryButtonText = "Copy details", - CloseButtonText = "Close" - }; - - var result = await errorDialog.ShowAsync(); - if (result == ContentDialogResult.Primary) + var errorDialog = new ContentDialog + { + XamlRoot = this.XamlRoot, + Title = "Error while exporting project", + Content = new TextBlock { - var dataPackage = new DataPackage(); - dataPackage.SetText(ex.ToString()); - Clipboard.SetContentWithOptions(dataPackage, null); - } + Text = $"An error occurred while exporting the project. {message}", + TextWrapping = TextWrapping.WrapWholeWords + }, + PrimaryButtonText = "Copy details", + CloseButtonText = "Close" + }; + + var result = await errorDialog.ShowAsync(); + if (result == ContentDialogResult.Primary) + { + var dataPackage = new DataPackage(); + dataPackage.SetText(ex.ToString()); + Clipboard.SetContentWithOptions(dataPackage, null); } } + } + + private void ModelSelectionControl_ModelCollectionChanged(object sender) + { + PopulateModelControls(); + } - private void ModelSelectionControl_ModelCollectionChanged(object sender) + private void ActionButtonsGrid_SizeChanged(object sender, SizeChangedEventArgs e) + { + // Calculate if the modelselectors collide with the export/code buttons + if ((ModelPanel.ActualWidth + ButtonsPanel.ActualWidth) >= e.NewSize.Width) { - PopulateModelControls(); + VisualStateManager.GoToState(this, "NarrowLayout", true); } - - private void ActionButtonsGrid_SizeChanged(object sender, SizeChangedEventArgs e) + else { - // Calculate if the modelselectors collide with the export/code buttons - if ((ModelPanel.ActualWidth + ButtonsPanel.ActualWidth) >= e.NewSize.Width) - { - VisualStateManager.GoToState(this, "NarrowLayout", true); - } - else - { - VisualStateManager.GoToState(this, "WideLayout", true); - } + VisualStateManager.GoToState(this, "WideLayout", true); } } } \ No newline at end of file diff --git a/AIDevGallery/Pages/ScenarioSelectionPage.xaml.cs b/AIDevGallery/Pages/ScenarioSelectionPage.xaml.cs index 3a80c97..b790b89 100644 --- a/AIDevGallery/Pages/ScenarioSelectionPage.xaml.cs +++ b/AIDevGallery/Pages/ScenarioSelectionPage.xaml.cs @@ -11,207 +11,205 @@ using System.Collections.Generic; using System.Linq; -namespace AIDevGallery.Pages +namespace AIDevGallery.Pages; +internal sealed partial class ScenarioSelectionPage : Page { - internal sealed partial class ScenarioSelectionPage : Page - { - internal record FilterRecord(string? Tag, string Text); + internal record FilterRecord(string? Tag, string Text); - private readonly List filters = - [ - new(null, "All Scenarios" ), - new("npu", "NPU Scenarios" ), - new("gpu", "GPU Scenarios" ), - new("wcr-api", "WCR API Scenarios" ) - ]; + private readonly List filters = + [ + new(null, "All Scenarios" ), + new("npu", "NPU Scenarios" ), + new("gpu", "GPU Scenarios" ), + new("wcr-api", "WCR API Scenarios" ) + ]; - private static LastInternalNavigation? lastInternalNavigation; - private Scenario? selectedScenario; + private static LastInternalNavigation? lastInternalNavigation; + private Scenario? selectedScenario; - public ScenarioSelectionPage() - { - this.InitializeComponent(); - } + public ScenarioSelectionPage() + { + this.InitializeComponent(); + } - protected override void OnNavigatedTo(NavigationEventArgs e) - { - SetUpScenarios(); + protected override void OnNavigatedTo(NavigationEventArgs e) + { + SetUpScenarios(); + + NavigatedToPageEvent.Log(nameof(ScenarioSelectionPage)); - NavigatedToPageEvent.Log(nameof(ScenarioSelectionPage)); + this.NavView.Loaded += (sender, args) => + { + Scenario? scenario = null; + object? parameter = e.Parameter; - this.NavView.Loaded += (sender, args) => + if (e.Parameter == null && lastInternalNavigation?.Parameter != null) { - Scenario? scenario = null; - object? parameter = e.Parameter; + parameter = lastInternalNavigation.Parameter; + } - if (e.Parameter == null && lastInternalNavigation?.Parameter != null) + if (parameter is Scenario sc) + { + scenario = sc; + } + else if (parameter is MostRecentlyUsedItem mru) + { + scenario = App.FindScenarioById(mru.ItemId); + } + else if (parameter is Sample sample) + { + scenario = ScenarioCategoryHelpers.AllScenarioCategories.SelectMany(sc => sc.Scenarios).FirstOrDefault(s => s.ScenarioType == sample.Scenario); + } + else if (parameter is SampleNavigationArgs sampleArgs) + { + scenario = ScenarioCategoryHelpers.AllScenarioCategories.SelectMany(sc => sc.Scenarios).FirstOrDefault(s => s.ScenarioType == sampleArgs.Sample.Scenario); + if (scenario != null) { - parameter = lastInternalNavigation.Parameter; + NavigateToScenario(scenario, sampleArgs); } + } - if (parameter is Scenario sc) - { - scenario = sc; - } - else if (parameter is MostRecentlyUsedItem mru) - { - scenario = App.FindScenarioById(mru.ItemId); - } - else if (parameter is Sample sample) + if (scenario != null) + { + foreach (NavigationViewItem item in NavView.MenuItems) { - scenario = ScenarioCategoryHelpers.AllScenarioCategories.SelectMany(sc => sc.Scenarios).FirstOrDefault(s => s.ScenarioType == sample.Scenario); + SetSelectedScenarioInMenu(item, scenario); } - else if (parameter is SampleNavigationArgs sampleArgs) - { - scenario = ScenarioCategoryHelpers.AllScenarioCategories.SelectMany(sc => sc.Scenarios).FirstOrDefault(s => s.ScenarioType == sampleArgs.Sample.Scenario); - if (scenario != null) - { - NavigateToScenario(scenario, sampleArgs); - } - } - - if (scenario != null) + } + else + { + if (NavView.MenuItems[0] is NavigationViewItem item) { - foreach (NavigationViewItem item in NavView.MenuItems) + if (item.MenuItems != null && item.MenuItems.Count > 0) { - SetSelectedScenarioInMenu(item, scenario); + item.IsExpanded = true; + NavView.SelectedItem = item.MenuItems[0]; } - } - else - { - if (NavView.MenuItems[0] is NavigationViewItem item) + else { - if (item.MenuItems != null && item.MenuItems.Count > 0) - { - item.IsExpanded = true; - NavView.SelectedItem = item.MenuItems[0]; - } - else - { - NavView.SelectedItem = item; - } + NavView.SelectedItem = item; } } - }; - base.OnNavigatedTo(e); - } + } + }; + base.OnNavigatedTo(e); + } - private void SetUpScenarios(string? filter = null) + private void SetUpScenarios(string? filter = null) + { + NavView.MenuItems.Clear(); + + foreach (var scenarioCategory in ScenarioCategoryHelpers.AllScenarioCategories) { - NavView.MenuItems.Clear(); + var categoryMenu = new NavigationViewItem() { Content = scenarioCategory.Name, Icon = new FontIcon() { Glyph = scenarioCategory.Icon }, Tag = scenarioCategory }; + ToolTip categoryToolTip = new() { Content = scenarioCategory.Name }; + ToolTipService.SetToolTip(categoryMenu, categoryToolTip); - foreach (var scenarioCategory in ScenarioCategoryHelpers.AllScenarioCategories) + foreach (var scenario in scenarioCategory.Scenarios) { - var categoryMenu = new NavigationViewItem() { Content = scenarioCategory.Name, Icon = new FontIcon() { Glyph = scenarioCategory.Icon }, Tag = scenarioCategory }; - ToolTip categoryToolTip = new() { Content = scenarioCategory.Name }; - ToolTipService.SetToolTip(categoryMenu, categoryToolTip); - - foreach (var scenario in scenarioCategory.Scenarios) + if (filter != null) { - if (filter != null) + var models = GetModelsForScenario(scenario); + if (filter == "gpu" && !models.Any(m => m.HardwareAccelerators.Contains(HardwareAccelerator.DML))) { - var models = GetModelsForScenario(scenario); - if (filter == "gpu" && !models.Any(m => m.HardwareAccelerators.Contains(HardwareAccelerator.DML))) - { - continue; - } - - if (filter == "npu" && !models.Any(m => m.HardwareAccelerators.Contains(HardwareAccelerator.QNN) && !m.Url.StartsWith("file", System.StringComparison.InvariantCultureIgnoreCase))) - { - continue; - } - - if (filter == "wcr-api" && !models.Any(m => m.Url.StartsWith("file", System.StringComparison.InvariantCultureIgnoreCase))) - { - continue; - } + continue; } - NavigationViewItem currNavItem = new() { Content = scenario.Name, Tag = scenario }; - ToolTip secnarioToolTip = new() { Content = scenario.Name }; - ToolTipService.SetToolTip(currNavItem, secnarioToolTip); - categoryMenu.MenuItems.Add(currNavItem); - } - - categoryMenu.SelectsOnInvoked = false; + if (filter == "npu" && !models.Any(m => m.HardwareAccelerators.Contains(HardwareAccelerator.QNN) && !m.Url.StartsWith("file", System.StringComparison.InvariantCultureIgnoreCase))) + { + continue; + } - if (categoryMenu.MenuItems.Count > 0) - { - NavView.MenuItems.Add(categoryMenu); + if (filter == "wcr-api" && !models.Any(m => m.Url.StartsWith("file", System.StringComparison.InvariantCultureIgnoreCase))) + { + continue; + } } + + NavigationViewItem currNavItem = new() { Content = scenario.Name, Tag = scenario }; + ToolTip secnarioToolTip = new() { Content = scenario.Name }; + ToolTipService.SetToolTip(currNavItem, secnarioToolTip); + categoryMenu.MenuItems.Add(currNavItem); } - } - private List GetModelsForScenario(Scenario scenario) - { - var samples = SampleDetails.Samples.Where(sample => sample.Scenario == scenario.ScenarioType).ToList(); - List modelDetails = []; - foreach (var sample in samples) + categoryMenu.SelectsOnInvoked = false; + + if (categoryMenu.MenuItems.Count > 0) { - modelDetails.AddRange(ModelDetailsHelper.GetModelDetails(sample) - .SelectMany(d => d) - .GroupBy(kv => kv.Key) - .ToDictionary(g => g.Key, g => g.First().Value) - .Values - .SelectMany(v => v) - .ToList()); + NavView.MenuItems.Add(categoryMenu); } + } + } - return modelDetails; + private List GetModelsForScenario(Scenario scenario) + { + var samples = SampleDetails.Samples.Where(sample => sample.Scenario == scenario.ScenarioType).ToList(); + List modelDetails = []; + foreach (var sample in samples) + { + modelDetails.AddRange(ModelDetailsHelper.GetModelDetails(sample) + .SelectMany(d => d) + .GroupBy(kv => kv.Key) + .ToDictionary(g => g.Key, g => g.First().Value) + .Values + .SelectMany(v => v) + .ToList()); } - private void NavView_SelectionChanged(NavigationView sender, NavigationViewSelectionChangedEventArgs args) + return modelDetails; + } + + private void NavView_SelectionChanged(NavigationView sender, NavigationViewSelectionChangedEventArgs args) + { + if (args.SelectedItem is NavigationViewItem item && item.Tag is Scenario scenario && scenario != selectedScenario) { - if (args.SelectedItem is NavigationViewItem item && item.Tag is Scenario scenario && scenario != selectedScenario) - { - NavigateToScenario(scenario); - } + NavigateToScenario(scenario); } + } - private void NavigateToScenario(Scenario scenario, SampleNavigationArgs? sampleArgs = null) + private void NavigateToScenario(Scenario scenario, SampleNavigationArgs? sampleArgs = null) + { + selectedScenario = scenario; + lastInternalNavigation = new LastInternalNavigation(typeof(ScenarioPage), scenario); + if (sampleArgs != null) { - selectedScenario = scenario; - lastInternalNavigation = new LastInternalNavigation(typeof(ScenarioPage), scenario); - if (sampleArgs != null) - { - NavFrame.Navigate(typeof(ScenarioPage), sampleArgs); - } - else - { - NavFrame.Navigate(typeof(ScenarioPage), scenario); - } + NavFrame.Navigate(typeof(ScenarioPage), sampleArgs); } + else + { + NavFrame.Navigate(typeof(ScenarioPage), scenario); + } + } - private void SetSelectedScenarioInMenu(NavigationViewItem item, Scenario scenario) + private void SetSelectedScenarioInMenu(NavigationViewItem item, Scenario scenario) + { + foreach (var menuItem in item.MenuItems) { - foreach (var menuItem in item.MenuItems) + if (menuItem is NavigationViewItem navItem) { - if (menuItem is NavigationViewItem navItem) + if (navItem.Tag is Scenario modelSample && modelSample.Id.Equals(scenario.Id, System.StringComparison.OrdinalIgnoreCase)) { - if (navItem.Tag is Scenario modelSample && modelSample.Id.Equals(scenario.Id, System.StringComparison.OrdinalIgnoreCase)) - { - item.IsExpanded = true; - NavView.SelectedItem = navItem; - return; - } - else if (navItem.MenuItems.Count > 0) - { - SetSelectedScenarioInMenu(navItem, scenario); - } + item.IsExpanded = true; + NavView.SelectedItem = navItem; + return; + } + else if (navItem.MenuItems.Count > 0) + { + SetSelectedScenarioInMenu(navItem, scenario); } } } + } - private void ComboBox_SelectionChanged(object sender, SelectionChangedEventArgs e) + private void ComboBox_SelectionChanged(object sender, SelectionChangedEventArgs e) + { + var tag = (e.AddedItems[0] as FilterRecord)!.Tag; + SetUpScenarios(tag); + if (selectedScenario != null) { - var tag = (e.AddedItems[0] as FilterRecord)!.Tag; - SetUpScenarios(tag); - if (selectedScenario != null) + foreach (NavigationViewItem item in NavView.MenuItems) { - foreach (NavigationViewItem item in NavView.MenuItems) - { - SetSelectedScenarioInMenu(item, selectedScenario); - } + SetSelectedScenarioInMenu(item, selectedScenario); } } } diff --git a/AIDevGallery/Pages/SettingsPage.xaml.cs b/AIDevGallery/Pages/SettingsPage.xaml.cs index bd84677..7f2d374 100644 --- a/AIDevGallery/Pages/SettingsPage.xaml.cs +++ b/AIDevGallery/Pages/SettingsPage.xaml.cs @@ -16,110 +16,88 @@ using System.Threading; using Windows.Storage.Pickers; -namespace AIDevGallery.Pages -{ - internal sealed partial class SettingsPage : Page - { - private readonly ObservableCollection cachedModels = []; - private readonly RelayCommand endMoveCommand; - private string? cacheFolderPath; - private bool isMovingCache; +namespace AIDevGallery.Pages; - private CancellationTokenSource? _cts; +internal sealed partial class SettingsPage : Page +{ + private readonly ObservableCollection cachedModels = []; + private readonly RelayCommand endMoveCommand; + private string? cacheFolderPath; + private bool isMovingCache; - public SettingsPage() - { - this.InitializeComponent(); - endMoveCommand = new RelayCommand(() => _cts?.Cancel()); - } + private CancellationTokenSource? _cts; - protected override void OnNavigatedTo(NavigationEventArgs e) - { - base.OnNavigatedTo(e); + public SettingsPage() + { + this.InitializeComponent(); + endMoveCommand = new RelayCommand(() => _cts?.Cancel()); + } - NavigatedToPageEvent.Log(nameof(SettingsPage)); + protected override void OnNavigatedTo(NavigationEventArgs e) + { + base.OnNavigatedTo(e); - VersionTextRun.Text = AppUtils.GetAppVersion(); - GetStorageInfo(); + NavigatedToPageEvent.Log(nameof(SettingsPage)); - // DiagnosticDataToggleSwitch.IsOn = App.AppData.IsDiagnosticDataEnabled; - if (e.Parameter is string manageModels && manageModels == "ModelManagement") - { - ModelsExpander.IsExpanded = true; - } - } + VersionTextRun.Text = AppUtils.GetAppVersion(); + GetStorageInfo(); - protected override void OnNavigatingFrom(NavigatingCancelEventArgs e) + // DiagnosticDataToggleSwitch.IsOn = App.AppData.IsDiagnosticDataEnabled; + if (e.Parameter is string manageModels && manageModels == "ModelManagement") { - if (isMovingCache) - { - e.Cancel = true; - } - - base.OnNavigatingFrom(e); + ModelsExpander.IsExpanded = true; } + } - private void GetStorageInfo() + protected override void OnNavigatingFrom(NavigatingCancelEventArgs e) + { + if (isMovingCache) { - cachedModels.Clear(); - - cacheFolderPath = App.ModelCache.GetCacheFolder(); - FolderPathTxt.Content = cacheFolderPath; + e.Cancel = true; + } - long totalCacheSize = 0; + base.OnNavigatingFrom(e); + } - foreach (var cachedModel in App.ModelCache.Models.OrderBy(m => m.Details.Name)) - { - cachedModels.Add(cachedModel); - totalCacheSize += cachedModel.ModelSize; - } + private void GetStorageInfo() + { + cachedModels.Clear(); - if (App.ModelCache.Models.Count > 0) - { - ModelsExpander.IsExpanded = true; - } + cacheFolderPath = App.ModelCache.GetCacheFolder(); + FolderPathTxt.Content = cacheFolderPath; - TotalCacheTxt.Text = AppUtils.FileSizeToString(totalCacheSize); - } + long totalCacheSize = 0; - private void FolderPathTxt_Click(object sender, RoutedEventArgs e) + foreach (var cachedModel in App.ModelCache.Models.OrderBy(m => m.Details.Name)) { - if (cacheFolderPath != null) - { - Process.Start("explorer.exe", cacheFolderPath); - } + cachedModels.Add(cachedModel); + totalCacheSize += cachedModel.ModelSize; } - private async void DeleteModel_Click(object sender, RoutedEventArgs e) + if (App.ModelCache.Models.Count > 0) { - if (sender is Button button && button.Tag is CachedModel model) - { - ContentDialog deleteDialog = new() - { - Title = "Delete model", - Content = "Are you sure you want to delete this model?", - PrimaryButtonText = "Yes", - XamlRoot = this.Content.XamlRoot, - PrimaryButtonStyle = (Style)App.Current.Resources["AccentButtonStyle"], - CloseButtonText = "No" - }; + ModelsExpander.IsExpanded = true; + } - var result = await deleteDialog.ShowAsync(); + TotalCacheTxt.Text = AppUtils.FileSizeToString(totalCacheSize); + } - if (result == ContentDialogResult.Primary) - { - await App.ModelCache.DeleteModelFromCache(model); - GetStorageInfo(); - } - } + private void FolderPathTxt_Click(object sender, RoutedEventArgs e) + { + if (cacheFolderPath != null) + { + Process.Start("explorer.exe", cacheFolderPath); } + } - private async void ClearCache_Click(object sender, RoutedEventArgs e) + private async void DeleteModel_Click(object sender, RoutedEventArgs e) + { + if (sender is Button button && button.Tag is CachedModel model) { ContentDialog deleteDialog = new() { - Title = "Clear cache", - Content = "Are you sure you want to clear the entire cache? All downloaded models will be deleted.", + Title = "Delete model", + Content = "Are you sure you want to delete this model?", PrimaryButtonText = "Yes", XamlRoot = this.Content.XamlRoot, PrimaryButtonStyle = (Style)App.Current.Resources["AccentButtonStyle"], @@ -130,171 +108,192 @@ private async void ClearCache_Click(object sender, RoutedEventArgs e) if (result == ContentDialogResult.Primary) { - await App.ModelCache.ClearCache(); + await App.ModelCache.DeleteModelFromCache(model); GetStorageInfo(); } } + } - private void ModelFolder_Click(object sender, RoutedEventArgs e) + private async void ClearCache_Click(object sender, RoutedEventArgs e) + { + ContentDialog deleteDialog = new() { - if (sender is HyperlinkButton hyperlinkButton && hyperlinkButton.Tag is CachedModel model) - { - string? path = model.Path; + Title = "Clear cache", + Content = "Are you sure you want to clear the entire cache? All downloaded models will be deleted.", + PrimaryButtonText = "Yes", + XamlRoot = this.Content.XamlRoot, + PrimaryButtonStyle = (Style)App.Current.Resources["AccentButtonStyle"], + CloseButtonText = "No" + }; - if (model.IsFile) - { - path = Path.GetDirectoryName(path); - } + var result = await deleteDialog.ShowAsync(); - if (path != null) - { - Process.Start("explorer.exe", path); - } + if (result == ContentDialogResult.Primary) + { + await App.ModelCache.ClearCache(); + GetStorageInfo(); + } + } + + private void ModelFolder_Click(object sender, RoutedEventArgs e) + { + if (sender is HyperlinkButton hyperlinkButton && hyperlinkButton.Tag is CachedModel model) + { + string? path = model.Path; + + if (model.IsFile) + { + path = Path.GetDirectoryName(path); } + + if (path != null) + { + Process.Start("explorer.exe", path); + } + } + } + + // private async void DiagnosticDataToggleSwitch_Toggled(object sender, RoutedEventArgs e) + // { + // if (App.AppData.IsDiagnosticDataEnabled != DiagnosticDataToggleSwitch.IsOn) + // { + // App.AppData.IsDiagnosticDataEnabled = DiagnosticDataToggleSwitch.IsOn; + // await App.AppData.SaveAsync(); + // } + // } + private async void ChangeCacheFolder_Click(object sender, RoutedEventArgs e) + { + var downloadCount = App.ModelCache.DownloadQueue.GetDownloads().Count; + if (downloadCount > 0) + { + ContentDialog dialog = new() + { + Title = "Downloads in progress", + Content = $"There are currently {downloadCount} downloads in progress. Please cancel them or wait for them to complete before changing the cache path.", + XamlRoot = this.Content.XamlRoot, + CloseButtonText = "OK" + }; + await dialog.ShowAsync(); + return; } - // private async void DiagnosticDataToggleSwitch_Toggled(object sender, RoutedEventArgs e) - // { - // if (App.AppData.IsDiagnosticDataEnabled != DiagnosticDataToggleSwitch.IsOn) - // { - // App.AppData.IsDiagnosticDataEnabled = DiagnosticDataToggleSwitch.IsOn; - // await App.AppData.SaveAsync(); - // } - // } - private async void ChangeCacheFolder_Click(object sender, RoutedEventArgs e) + var hwnd = WinRT.Interop.WindowNative.GetWindowHandle(App.MainWindow); + var picker = new FolderPicker(); + WinRT.Interop.InitializeWithWindow.Initialize(picker, hwnd); + var folder = await picker.PickSingleFolderAsync(); + if (folder != null && folder.Path != App.ModelCache.GetCacheFolder()) { - var downloadCount = App.ModelCache.DownloadQueue.GetDownloads().Count; - if (downloadCount > 0) + if (Directory.GetFiles(folder.Path).Length > 0 || Directory.GetDirectories(folder.Path).Length > 0) { - ContentDialog dialog = new() + ContentDialog confirmFolderDialog = new() { - Title = "Downloads in progress", - Content = $"There are currently {downloadCount} downloads in progress. Please cancel them or wait for them to complete before changing the cache path.", + Title = "Folder not empty", + Content = @"The destination folder contains files. Please select an empty folder for the destination.", XamlRoot = this.Content.XamlRoot, CloseButtonText = "OK" }; - await dialog.ShowAsync(); + + await confirmFolderDialog.ShowAsync(); return; } - var hwnd = WinRT.Interop.WindowNative.GetWindowHandle(App.MainWindow); - var picker = new FolderPicker(); - WinRT.Interop.InitializeWithWindow.Initialize(picker, hwnd); - var folder = await picker.PickSingleFolderAsync(); - if (folder != null && folder.Path != App.ModelCache.GetCacheFolder()) - { - if (Directory.GetFiles(folder.Path).Length > 0 || Directory.GetDirectories(folder.Path).Length > 0) - { - ContentDialog confirmFolderDialog = new() - { - Title = "Folder not empty", - Content = @"The destination folder contains files. Please select an empty folder for the destination.", - XamlRoot = this.Content.XamlRoot, - CloseButtonText = "OK" - }; - - await confirmFolderDialog.ShowAsync(); - return; - } - - var cacheSize = App.ModelCache.Models.Sum(m => m.ModelSize); + var cacheSize = App.ModelCache.Models.Sum(m => m.ModelSize); - var sourceDrive = Path.GetPathRoot(App.ModelCache.GetCacheFolder()); - var destDrive = Path.GetPathRoot(folder.Path); + var sourceDrive = Path.GetPathRoot(App.ModelCache.GetCacheFolder()); + var destDrive = Path.GetPathRoot(folder.Path); - if (destDrive == null) - { - return; - } + if (destDrive == null) + { + return; + } - var driveInfo = new DriveInfo(destDrive); - var availableSpace = driveInfo.IsReady ? driveInfo.AvailableFreeSpace / 1024.0 / 1024.0 / 1024.0 : 0; + var driveInfo = new DriveInfo(destDrive); + var availableSpace = driveInfo.IsReady ? driveInfo.AvailableFreeSpace / 1024.0 / 1024.0 / 1024.0 : 0; - var cacheSizeInGb = cacheSize / 1024.0 / 1024.0 / 1024.0; + var cacheSizeInGb = cacheSize / 1024.0 / 1024.0 / 1024.0; - if (cacheSizeInGb > availableSpace && sourceDrive != destDrive) + if (cacheSizeInGb > availableSpace && sourceDrive != destDrive) + { + ContentDialog dialog = new() { - ContentDialog dialog = new() - { - Title = "Insufficient space", - Content = $@"You don't have enough space on drive {destDrive[0]}. + Title = "Insufficient space", + Content = $@"You don't have enough space on drive {destDrive[0]}. Required space {cacheSizeInGb:N1} GB Available space {availableSpace:N1} GB Please free up some space before moving the cache.", - XamlRoot = this.Content.XamlRoot, - CloseButtonText = "OK" - }; - await dialog.ShowAsync(); - return; - } + XamlRoot = this.Content.XamlRoot, + CloseButtonText = "OK" + }; + await dialog.ShowAsync(); + return; + } - var result = ContentDialogResult.Primary; + var result = ContentDialogResult.Primary; - if (cacheSizeInGb > 1 && sourceDrive != destDrive) + if (cacheSizeInGb > 1 && sourceDrive != destDrive) + { + ContentDialog confirmDialog = new() { - ContentDialog confirmDialog = new() - { - Title = "Confirm moving files", - Content = $@"You have {cacheSizeInGb:N1} GB to move, which may take a while. + Title = "Confirm moving files", + Content = $@"You have {cacheSizeInGb:N1} GB to move, which may take a while. You can speed things up by clearing the cache or deleting models from it first. Do you want to proceed with the move?", - PrimaryButtonText = "Confirm", - PrimaryButtonStyle = (Style)App.Current.Resources["AccentButtonStyle"], - XamlRoot = this.Content.XamlRoot, - CloseButtonText = "Cancel" - }; + PrimaryButtonText = "Confirm", + PrimaryButtonStyle = (Style)App.Current.Resources["AccentButtonStyle"], + XamlRoot = this.Content.XamlRoot, + CloseButtonText = "Cancel" + }; - result = await confirmDialog.ShowAsync(); - } + result = await confirmDialog.ShowAsync(); + } - if (result == ContentDialogResult.Primary) + if (result == ContentDialogResult.Primary) + { + try { - try + _cts = new CancellationTokenSource(); + StartMovingCache(); + await App.ModelCache.MoveCache(folder.Path, _cts.Token); + GetStorageInfo(); + EndMovingCache(); + } + catch (Exception ex) + { + EndMovingCache(); + if (ex is OperationCanceledException) { - _cts = new CancellationTokenSource(); - StartMovingCache(); - await App.ModelCache.MoveCache(folder.Path, _cts.Token); - GetStorageInfo(); - EndMovingCache(); + return; } - catch (Exception ex) + + ContentDialog errorDialog = new() { - EndMovingCache(); - if (ex is OperationCanceledException) - { - return; - } - - ContentDialog errorDialog = new() - { - Title = "Error moving files", - Content = $@"The cache folder could not be moved: + Title = "Error moving files", + Content = $@"The cache folder could not be moved: {ex.Message}", - XamlRoot = this.Content.XamlRoot, - CloseButtonText = "OK" - }; - await errorDialog.ShowAsync(); - } + XamlRoot = this.Content.XamlRoot, + CloseButtonText = "OK" + }; + await errorDialog.ShowAsync(); } } } + } - private void StartMovingCache() - { - isMovingCache = true; - _ = ProgressDialog.ShowAsync(); - } + private void StartMovingCache() + { + isMovingCache = true; + _ = ProgressDialog.ShowAsync(); + } - private void EndMovingCache() - { - isMovingCache = false; - ProgressDialog?.Hide(); - _cts?.Dispose(); - _cts = null; - } + private void EndMovingCache() + { + isMovingCache = false; + ProgressDialog?.Hide(); + _cts?.Dispose(); + _cts = null; } } \ No newline at end of file diff --git a/AIDevGallery/Program.cs b/AIDevGallery/Program.cs index 4f9f8fb..3901adf 100644 --- a/AIDevGallery/Program.cs +++ b/AIDevGallery/Program.cs @@ -12,84 +12,83 @@ using System.Threading; using System.Threading.Tasks; -namespace AIDevGallery +namespace AIDevGallery; + +/// +/// Program class +/// +public class Program { - /// - /// Program class - /// - public class Program + // Replaces the standard App.g.i.cs. + // Note: We can't declare Main to be async because in a WinUI app + // this prevents Narrator from reading XAML elements. + [STAThread] + private static void Main() { - // Replaces the standard App.g.i.cs. - // Note: We can't declare Main to be async because in a WinUI app - // this prevents Narrator from reading XAML elements. - [STAThread] - private static void Main() - { - WinRT.ComWrappersSupport.InitializeComWrappers(); - bool isRedirect = DecideRedirection(); + WinRT.ComWrappersSupport.InitializeComWrappers(); + bool isRedirect = DecideRedirection(); - using OgaHandle ogaHandle = new(); + using OgaHandle ogaHandle = new(); - if (!isRedirect) + if (!isRedirect) + { + Application.Start((p) => { - Application.Start((p) => - { - var context = new DispatcherQueueSynchronizationContext( - DispatcherQueue.GetForCurrentThread()); - SynchronizationContext.SetSynchronizationContext(context); - _ = new App(); - }); - } + var context = new DispatcherQueueSynchronizationContext( + DispatcherQueue.GetForCurrentThread()); + SynchronizationContext.SetSynchronizationContext(context); + _ = new App(); + }); } + } - private static bool DecideRedirection() - { - bool isRedirect = false; - AppActivationArguments args = AppInstance.GetCurrent().GetActivatedEventArgs(); - AppInstance keyInstance = AppInstance.FindOrRegisterForKey("AIDevGalleryApp"); - - if (keyInstance.IsCurrent) - { - keyInstance.Activated += OnActivated; - } - else - { - isRedirect = true; - RedirectActivationTo(args, keyInstance); - } + private static bool DecideRedirection() + { + bool isRedirect = false; + AppActivationArguments args = AppInstance.GetCurrent().GetActivatedEventArgs(); + AppInstance keyInstance = AppInstance.FindOrRegisterForKey("AIDevGalleryApp"); - return isRedirect; + if (keyInstance.IsCurrent) + { + keyInstance.Activated += OnActivated; } + else + { + isRedirect = true; + RedirectActivationTo(args, keyInstance); + } + + return isRedirect; + } - [DllImport("user32.dll")] - private static extern bool SetForegroundWindow(IntPtr hWnd); + [DllImport("user32.dll")] + private static extern bool SetForegroundWindow(IntPtr hWnd); - // Do the redirection on another thread, and use a non-blocking - // wait method to wait for the redirection to complete. - private static void RedirectActivationTo(AppActivationArguments args, AppInstance keyInstance) + // Do the redirection on another thread, and use a non-blocking + // wait method to wait for the redirection to complete. + private static void RedirectActivationTo(AppActivationArguments args, AppInstance keyInstance) + { + var redirectSemaphore = new Semaphore(0, 1); + Task.Run(() => { - var redirectSemaphore = new Semaphore(0, 1); - Task.Run(() => - { - keyInstance.RedirectActivationToAsync(args).AsTask().Wait(); - redirectSemaphore.Release(); - }); - redirectSemaphore.WaitOne(); - redirectSemaphore.Dispose(); + keyInstance.RedirectActivationToAsync(args).AsTask().Wait(); + redirectSemaphore.Release(); + }); + redirectSemaphore.WaitOne(); + redirectSemaphore.Dispose(); - // Bring the window to the foreground - Process process = Process.GetProcessById((int)keyInstance.ProcessId); + // Bring the window to the foreground + Process process = Process.GetProcessById((int)keyInstance.ProcessId); - SetForegroundWindow(process.MainWindowHandle); - } + SetForegroundWindow(process.MainWindowHandle); + } - private static void OnActivated(object? sender, AppActivationArguments args) + private static void OnActivated(object? sender, AppActivationArguments args) + { + var activationParam = ActivationHelper.GetActivationParam(args); + if (App.MainWindow is MainWindow mainWindow) { - var activationParam = ActivationHelper.GetActivationParam(args); - if (App.MainWindow is MainWindow mainWindow) - { - mainWindow.NavigateToPage(activationParam); - } + mainWindow.NavigateToPage(activationParam); } } } \ No newline at end of file diff --git a/AIDevGallery/ProjectGenerator/Generator.cs b/AIDevGallery/ProjectGenerator/Generator.cs index 68f02e9..fe8affb 100644 --- a/AIDevGallery/ProjectGenerator/Generator.cs +++ b/AIDevGallery/ProjectGenerator/Generator.cs @@ -16,656 +16,655 @@ using System.Threading.Tasks; using Windows.ApplicationModel; -namespace AIDevGallery.ProjectGenerator +namespace AIDevGallery.ProjectGenerator; + +internal partial class Generator { - internal partial class Generator + private readonly string templatePath = Path.Join(Package.Current.InstalledLocation.Path, "ProjectGenerator", "Template"); + + [GeneratedRegex(@"[^a-zA-Z0-9_]")] + private static partial Regex SafeNameRegex(); + + private static string ToSafeVariableName(string input) { - private readonly string templatePath = Path.Join(Package.Current.InstalledLocation.Path, "ProjectGenerator", "Template"); + // Replace invalid characters with an underscore + string safeName = SafeNameRegex().Replace(input, "_"); - [GeneratedRegex(@"[^a-zA-Z0-9_]")] - private static partial Regex SafeNameRegex(); + // Ensure the name does not start with a digit + if (safeName.Length > 0 && char.IsDigit(safeName[0])) + { + safeName = "_" + safeName; + } - private static string ToSafeVariableName(string input) + // If the name is empty or only contains invalid characters, return a default name + if (string.IsNullOrEmpty(safeName)) { - // Replace invalid characters with an underscore - string safeName = SafeNameRegex().Replace(input, "_"); + safeName = "MySampleApp"; + } - // Ensure the name does not start with a digit - if (safeName.Length > 0 && char.IsDigit(safeName[0])) - { - safeName = "_" + safeName; - } + return safeName; + } - // If the name is empty or only contains invalid characters, return a default name - if (string.IsNullOrEmpty(safeName)) - { - safeName = "MySampleApp"; - } + internal Task GenerateAsync(Sample sample, Dictionary models, bool copyModelLocally, string outputPath, CancellationToken cancellationToken) + { + var packageReferences = new List<(string PackageName, string? Version)> + { + ("Microsoft.WindowsAppSDK", null), + ("Microsoft.Windows.SDK.BuildTools", null), + }; - return safeName; + foreach (var nugetPackageReference in sample.NugetPackageReferences) + { + packageReferences.Add(new(nugetPackageReference, null)); } - internal Task GenerateAsync(Sample sample, Dictionary models, bool copyModelLocally, string outputPath, CancellationToken cancellationToken) - { - var packageReferences = new List<(string PackageName, string? Version)> - { - ("Microsoft.WindowsAppSDK", null), - ("Microsoft.Windows.SDK.BuildTools", null), - }; + return GenerateAsyncInternal(sample, models, copyModelLocally, packageReferences, outputPath, cancellationToken); + } - foreach (var nugetPackageReference in sample.NugetPackageReferences) - { - packageReferences.Add(new(nugetPackageReference, null)); - } + internal const string DotNetVersion = "net9.0"; - return GenerateAsyncInternal(sample, models, copyModelLocally, packageReferences, outputPath, cancellationToken); + private async Task GenerateAsyncInternal(Sample sample, Dictionary models, bool copyModelLocally, List<(string PackageName, string? Version)> packageReferences, string outputPath, CancellationToken cancellationToken) + { + var projectName = $"{sample.Name}Sample"; + string safeProjectName = ToSafeVariableName(projectName); + string guid9 = Guid.NewGuid().ToString(); + string xmlEscapedPublisher = "MyTestPublisher"; + string xmlEscapedPublisherDistinguishedName = $"CN={xmlEscapedPublisher}"; + + outputPath = Path.Join(outputPath, safeProjectName); + var dirIndexCount = 1; + while (Directory.Exists(outputPath)) + { + outputPath = Path.Join(Path.GetDirectoryName(outputPath), $"{safeProjectName}_{dirIndexCount}"); + dirIndexCount++; } - internal const string DotNetVersion = "net9.0"; + var modelTypes = sample.Model1Types.Concat(sample.Model2Types ?? Enumerable.Empty()) + .Where(models.ContainsKey); - private async Task GenerateAsyncInternal(Sample sample, Dictionary models, bool copyModelLocally, List<(string PackageName, string? Version)> packageReferences, string outputPath, CancellationToken cancellationToken) + if (copyModelLocally) { - var projectName = $"{sample.Name}Sample"; - string safeProjectName = ToSafeVariableName(projectName); - string guid9 = Guid.NewGuid().ToString(); - string xmlEscapedPublisher = "MyTestPublisher"; - string xmlEscapedPublisherDistinguishedName = $"CN={xmlEscapedPublisher}"; - - outputPath = Path.Join(outputPath, safeProjectName); - var dirIndexCount = 1; - while (Directory.Exists(outputPath)) - { - outputPath = Path.Join(Path.GetDirectoryName(outputPath), $"{safeProjectName}_{dirIndexCount}"); - dirIndexCount++; - } - - var modelTypes = sample.Model1Types.Concat(sample.Model2Types ?? Enumerable.Empty()) - .Where(models.ContainsKey); - - if (copyModelLocally) + long sumTotalSize = 0; + foreach (var modelType in modelTypes) { - long sumTotalSize = 0; - foreach (var modelType in modelTypes) + if (!models.TryGetValue(modelType, out var modelInfo)) { - if (!models.TryGetValue(modelType, out var modelInfo)) - { - throw new ArgumentException($"Model type {modelType} not found in the models dictionary", nameof(models)); - } + throw new ArgumentException($"Model type {modelType} not found in the models dictionary", nameof(models)); + } - if (modelInfo.CachedModelDirectoryPath.Contains("file://", StringComparison.OrdinalIgnoreCase)) - { - continue; - } + if (modelInfo.CachedModelDirectoryPath.Contains("file://", StringComparison.OrdinalIgnoreCase)) + { + continue; + } - var cachedModelDirectoryAttributes = File.GetAttributes(modelInfo.CachedModelDirectoryPath); + var cachedModelDirectoryAttributes = File.GetAttributes(modelInfo.CachedModelDirectoryPath); - if (cachedModelDirectoryAttributes.HasFlag(FileAttributes.Directory)) - { - sumTotalSize += Directory.GetFiles(modelInfo.CachedModelDirectoryPath, "*", SearchOption.AllDirectories).Sum(f => new FileInfo(f).Length); - } - else - { - sumTotalSize += new FileInfo(modelInfo.CachedModelDirectoryPath).Length; - } + if (cachedModelDirectoryAttributes.HasFlag(FileAttributes.Directory)) + { + sumTotalSize += Directory.GetFiles(modelInfo.CachedModelDirectoryPath, "*", SearchOption.AllDirectories).Sum(f => new FileInfo(f).Length); } - - var availableSpace = DriveInfo.GetDrives().First(d => d.RootDirectory.FullName == Path.GetPathRoot(outputPath)).AvailableFreeSpace; - if (sumTotalSize > availableSpace) + else { - throw new IOException("Not enough disk space to copy the model files."); + sumTotalSize += new FileInfo(modelInfo.CachedModelDirectoryPath).Length; } } - Directory.CreateDirectory(outputPath); + var availableSpace = DriveInfo.GetDrives().First(d => d.RootDirectory.FullName == Path.GetPathRoot(outputPath)).AvailableFreeSpace; + if (sumTotalSize > availableSpace) + { + throw new IOException("Not enough disk space to copy the model files."); + } + } + + Directory.CreateDirectory(outputPath); - bool addLllmTypes = false; - Dictionary modelInfos = []; - string model1Id = string.Empty; - string model2Id = string.Empty; - foreach (var modelType in modelTypes) + bool addLllmTypes = false; + Dictionary modelInfos = []; + string model1Id = string.Empty; + string model2Id = string.Empty; + foreach (var modelType in modelTypes) + { + if (!models.TryGetValue(modelType, out var modelInfo)) { - if (!models.TryGetValue(modelType, out var modelInfo)) - { - throw new ArgumentException($"Model type {modelType} not found in the models dictionary", nameof(models)); - } + throw new ArgumentException($"Model type {modelType} not found in the models dictionary", nameof(models)); + } - PromptTemplate? modelPromptTemplate = null; - string modelId = string.Empty; - bool isSingleFile = false; + PromptTemplate? modelPromptTemplate = null; + string modelId = string.Empty; + bool isSingleFile = false; - if (ModelTypeHelpers.ModelDetails.TryGetValue(modelType, out var modelDetails)) + if (ModelTypeHelpers.ModelDetails.TryGetValue(modelType, out var modelDetails)) + { + modelPromptTemplate = modelDetails.PromptTemplate; + modelId = modelDetails.Id; + } + else if (ModelTypeHelpers.ModelDetails.FirstOrDefault(mf => mf.Value.Url == modelInfo.ModelUrl) is var modelDetails2 && modelDetails2.Value != null) + { + modelPromptTemplate = modelDetails2.Value.PromptTemplate; + if (modelPromptTemplate != null) { - modelPromptTemplate = modelDetails.PromptTemplate; - modelId = modelDetails.Id; + addLllmTypes = true; } - else if (ModelTypeHelpers.ModelDetails.FirstOrDefault(mf => mf.Value.Url == modelInfo.ModelUrl) is var modelDetails2 && modelDetails2.Value != null) - { - modelPromptTemplate = modelDetails2.Value.PromptTemplate; - if (modelPromptTemplate != null) - { - addLllmTypes = true; - } - modelId = modelDetails2.Value.Id; - } - else if (ModelTypeHelpers.ApiDefinitionDetails.TryGetValue(modelType, out var apiDefinitionDetails)) - { - modelId = apiDefinitionDetails.Id; - } + modelId = modelDetails2.Value.Id; + } + else if (ModelTypeHelpers.ApiDefinitionDetails.TryGetValue(modelType, out var apiDefinitionDetails)) + { + modelId = apiDefinitionDetails.Id; + } - string modelPathStr; + string modelPathStr; - if (copyModelLocally && !modelInfo.CachedModelDirectoryPath.Contains("file://", StringComparison.OrdinalIgnoreCase)) - { - var modelPath = Path.GetFileName(modelInfo.CachedModelDirectoryPath); - var cachedModelDirectoryAttributes = File.GetAttributes(modelInfo.CachedModelDirectoryPath); + if (copyModelLocally && !modelInfo.CachedModelDirectoryPath.Contains("file://", StringComparison.OrdinalIgnoreCase)) + { + var modelPath = Path.GetFileName(modelInfo.CachedModelDirectoryPath); + var cachedModelDirectoryAttributes = File.GetAttributes(modelInfo.CachedModelDirectoryPath); - if (cachedModelDirectoryAttributes.HasFlag(FileAttributes.Directory)) + if (cachedModelDirectoryAttributes.HasFlag(FileAttributes.Directory)) + { + isSingleFile = false; + var modelDirectory = Directory.CreateDirectory(Path.Join(outputPath, "Models", modelPath)); + foreach (var file in Directory.GetFiles(modelInfo.CachedModelDirectoryPath, "*", SearchOption.AllDirectories)) { - isSingleFile = false; - var modelDirectory = Directory.CreateDirectory(Path.Join(outputPath, "Models", modelPath)); - foreach (var file in Directory.GetFiles(modelInfo.CachedModelDirectoryPath, "*", SearchOption.AllDirectories)) + cancellationToken.ThrowIfCancellationRequested(); + var filePath = Path.Join(modelDirectory.FullName, Path.GetRelativePath(modelInfo.CachedModelDirectoryPath, file)); + var directory = Path.GetDirectoryName(filePath); + if (directory != null && !Directory.Exists(directory)) { - cancellationToken.ThrowIfCancellationRequested(); - var filePath = Path.Join(modelDirectory.FullName, Path.GetRelativePath(modelInfo.CachedModelDirectoryPath, file)); - var directory = Path.GetDirectoryName(filePath); - if (directory != null && !Directory.Exists(directory)) - { - Directory.CreateDirectory(directory); - } - - await CopyFileAsync(file, filePath, cancellationToken).ConfigureAwait(false); + Directory.CreateDirectory(directory); } - } - else - { - isSingleFile = true; - var modelDirectory = Directory.CreateDirectory(Path.Join(outputPath, "Models")); - await CopyFileAsync(modelInfo.CachedModelDirectoryPath, Path.Join(modelDirectory.FullName, modelPath), cancellationToken).ConfigureAwait(false); - } - modelPathStr = $"System.IO.Path.Join(Windows.ApplicationModel.Package.Current.InstalledLocation.Path, \"Models\", @\"{modelPath}\")"; - modelInfo.CachedModelDirectoryPath = modelPath; + await CopyFileAsync(file, filePath, cancellationToken).ConfigureAwait(false); + } } else { - modelPathStr = $"@\"{modelInfo.CachedModelDirectoryPath}\""; + isSingleFile = true; + var modelDirectory = Directory.CreateDirectory(Path.Join(outputPath, "Models")); + await CopyFileAsync(modelInfo.CachedModelDirectoryPath, Path.Join(modelDirectory.FullName, modelPath), cancellationToken).ConfigureAwait(false); } - modelInfos.Add(modelType, new(modelInfo.CachedModelDirectoryPath, modelInfo.ModelUrl, isSingleFile, modelPathStr, modelInfo.HardwareAccelerator, modelPromptTemplate)); + modelPathStr = $"System.IO.Path.Join(Windows.ApplicationModel.Package.Current.InstalledLocation.Path, \"Models\", @\"{modelPath}\")"; + modelInfo.CachedModelDirectoryPath = modelPath; + } + else + { + modelPathStr = $"@\"{modelInfo.CachedModelDirectoryPath}\""; + } - if (modelTypes.First() == modelType) - { - model1Id = modelId; - } - else - { - model2Id = modelId; - } + modelInfos.Add(modelType, new(modelInfo.CachedModelDirectoryPath, modelInfo.ModelUrl, isSingleFile, modelPathStr, modelInfo.HardwareAccelerator, modelPromptTemplate)); + + if (modelTypes.First() == modelType) + { + model1Id = modelId; } + else + { + model2Id = modelId; + } + } - SampleProjectGeneratedEvent.Log(sample.Id, model1Id, model2Id, copyModelLocally); + SampleProjectGeneratedEvent.Log(sample.Id, model1Id, model2Id, copyModelLocally); - string[] extensions = [".manifest", ".xaml", ".cs", ".appxmanifest", ".csproj", ".ico", ".png", ".json", ".pubxml"]; + string[] extensions = [".manifest", ".xaml", ".cs", ".appxmanifest", ".csproj", ".ico", ".png", ".json", ".pubxml"]; - // Get all files from the template directory with the allowed extensions - var files = Directory.GetFiles(templatePath, "*.*", SearchOption.AllDirectories).Where(file => extensions.Any(file.EndsWith)); + // Get all files from the template directory with the allowed extensions + var files = Directory.GetFiles(templatePath, "*.*", SearchOption.AllDirectories).Where(file => extensions.Any(file.EndsWith)); - var renames = new Dictionary - { - { "Package-managed.appxmanifest", "Package.appxmanifest" }, - { "ProjectTemplate.csproj", $"{safeProjectName}.csproj" } - }; + var renames = new Dictionary + { + { "Package-managed.appxmanifest", "Package.appxmanifest" }, + { "ProjectTemplate.csproj", $"{safeProjectName}.csproj" } + }; - var className = await AddFilesFromSampleAsync(sample, packageReferences, safeProjectName, outputPath, addLllmTypes, modelInfos, cancellationToken); + var className = await AddFilesFromSampleAsync(sample, packageReferences, safeProjectName, outputPath, addLllmTypes, modelInfos, cancellationToken); - foreach (var file in files) + foreach (var file in files) + { + var relativePath = file[(templatePath.Length + 1)..]; + + var fileName = Path.GetFileName(file); + if (renames.TryGetValue(fileName, out var newName)) { - var relativePath = file[(templatePath.Length + 1)..]; + relativePath = relativePath.Replace(fileName, newName); + } - var fileName = Path.GetFileName(file); - if (renames.TryGetValue(fileName, out var newName)) - { - relativePath = relativePath.Replace(fileName, newName); - } + var outputPathFile = Path.Join(outputPath, relativePath); - var outputPathFile = Path.Join(outputPath, relativePath); + // Create the directory if it doesn't exist + var directory = Path.GetDirectoryName(outputPathFile); + if (directory != null && !Directory.Exists(directory)) + { + Directory.CreateDirectory(directory); + } - // Create the directory if it doesn't exist - var directory = Path.GetDirectoryName(outputPathFile); - if (directory != null && !Directory.Exists(directory)) - { - Directory.CreateDirectory(directory); - } + // if image file, just copy + if (Path.GetExtension(file) is ".ico" or ".png") + { + File.Copy(file, outputPathFile); + continue; + } + else + { + // Read the file + var content = await File.ReadAllTextAsync(file, cancellationToken); - // if image file, just copy - if (Path.GetExtension(file) is ".ico" or ".png") - { - File.Copy(file, outputPathFile); - continue; - } - else - { - // Read the file - var content = await File.ReadAllTextAsync(file, cancellationToken); - - // Replace the variables - content = content.Replace("$projectname$", projectName); - content = content.Replace("$safeprojectname$", safeProjectName); - content = content.Replace("$guid9$", guid9); - content = content.Replace("$XmlEscapedPublisherDistinguishedName$", xmlEscapedPublisherDistinguishedName); - content = content.Replace("$XmlEscapedPublisher$", xmlEscapedPublisher); - content = content.Replace("$DotNetVersion$", DotNetVersion); - content = content.Replace("$MainSamplePage$", className); - - // Write the file - await File.WriteAllTextAsync(outputPathFile, content, cancellationToken); - } + // Replace the variables + content = content.Replace("$projectname$", projectName); + content = content.Replace("$safeprojectname$", safeProjectName); + content = content.Replace("$guid9$", guid9); + content = content.Replace("$XmlEscapedPublisherDistinguishedName$", xmlEscapedPublisherDistinguishedName); + content = content.Replace("$XmlEscapedPublisher$", xmlEscapedPublisher); + content = content.Replace("$DotNetVersion$", DotNetVersion); + content = content.Replace("$MainSamplePage$", className); + + // Write the file + await File.WriteAllTextAsync(outputPathFile, content, cancellationToken); } + } - var csproj = Path.Join(outputPath, $"{safeProjectName}.csproj"); + var csproj = Path.Join(outputPath, $"{safeProjectName}.csproj"); - // Add NuGet references - if (packageReferences.Count > 0 || copyModelLocally) + // Add NuGet references + if (packageReferences.Count > 0 || copyModelLocally) + { + var project = ProjectRootElement.Open(csproj); + var itemGroup = project.AddItemGroup(); + + static void AddPackageReference(ProjectItemGroupElement itemGroup, string packageName, string? version) { - var project = ProjectRootElement.Open(csproj); - var itemGroup = project.AddItemGroup(); + var packageReferenceItem = itemGroup.AddItem("PackageReference", packageName); - static void AddPackageReference(ProjectItemGroupElement itemGroup, string packageName, string? version) + if (packageName == "Microsoft.Windows.CsWin32") { - var packageReferenceItem = itemGroup.AddItem("PackageReference", packageName); - - if (packageName == "Microsoft.Windows.CsWin32") - { - packageReferenceItem.AddMetadata("PrivateAssets", "all", true); - } - else if (packageName == "Microsoft.AI.DirectML" || - packageName == "Microsoft.ML.OnnxRuntime.DirectML" || - packageName == "Microsoft.ML.OnnxRuntimeGenAI.DirectML") - { - packageReferenceItem.Condition = "$(Platform) == 'x64'"; - } - else if (packageName == "Microsoft.ML.OnnxRuntime.Qnn" || - packageName == "Microsoft.ML.OnnxRuntimeGenAI" || - packageName == "Microsoft.ML.OnnxRuntimeGenAI.Managed") - { - packageReferenceItem.Condition = "$(Platform) == 'ARM64'"; - } - - var versionStr = version ?? PackageVersionHelpers.PackageVersions[packageName]; - packageReferenceItem.AddMetadata("Version", versionStr, true); - - if (packageName == "Microsoft.ML.OnnxRuntimeGenAI") - { - var noneItem = itemGroup.AddItem("None", "$(PKGMicrosoft_ML_OnnxRuntimeGenAI)\\runtimes\\win-arm64\\native\\onnxruntime-genai.dll"); - noneItem.Condition = "$(Platform) == 'ARM64'"; - noneItem.AddMetadata("Link", "onnxruntime-genai.dll", false); - noneItem.AddMetadata("CopyToOutputDirectory", "PreserveNewest", false); - noneItem.AddMetadata("Visible", "false", false); - - packageReferenceItem.AddMetadata("GeneratePathProperty", "true", true); - packageReferenceItem.AddMetadata("ExcludeAssets", "all", true); - } + packageReferenceItem.AddMetadata("PrivateAssets", "all", true); } - - foreach (var packageReference in packageReferences) + else if (packageName == "Microsoft.AI.DirectML" || + packageName == "Microsoft.ML.OnnxRuntime.DirectML" || + packageName == "Microsoft.ML.OnnxRuntimeGenAI.DirectML") { - var packageName = packageReference.PackageName; - var version = packageReference.Version; - if (packageName == "Microsoft.ML.OnnxRuntime.DirectML") - { - AddPackageReference(itemGroup, "Microsoft.ML.OnnxRuntime.Qnn", null); - } - else if (packageName == "Microsoft.ML.OnnxRuntimeGenAI.DirectML") - { - AddPackageReference(itemGroup, "Microsoft.ML.OnnxRuntimeGenAI", null); - AddPackageReference(itemGroup, "Microsoft.ML.OnnxRuntimeGenAI.Managed", null); - } - - AddPackageReference(itemGroup, packageName, version); + packageReferenceItem.Condition = "$(Platform) == 'x64'"; } - - if (copyModelLocally) + else if (packageName == "Microsoft.ML.OnnxRuntime.Qnn" || + packageName == "Microsoft.ML.OnnxRuntimeGenAI" || + packageName == "Microsoft.ML.OnnxRuntimeGenAI.Managed") { - var modelContentItemGroup = project.AddItemGroup(); - foreach (var modelInfo in modelInfos) - { - if (modelInfo.Value.CachedModelDirectoryPath.Contains("file://", StringComparison.OrdinalIgnoreCase)) - { - continue; - } - - if (modelInfo.Value.IsSingleFile) - { - modelContentItemGroup.AddItem("Content", @$"Models\{modelInfo.Value.CachedModelDirectoryPath}"); - } - else - { - modelContentItemGroup.AddItem("Content", @$"Models\{modelInfo.Value.CachedModelDirectoryPath}\**"); - } - } + packageReferenceItem.Condition = "$(Platform) == 'ARM64'"; } - project.Save(); + var versionStr = version ?? PackageVersionHelpers.PackageVersions[packageName]; + packageReferenceItem.AddMetadata("Version", versionStr, true); + + if (packageName == "Microsoft.ML.OnnxRuntimeGenAI") + { + var noneItem = itemGroup.AddItem("None", "$(PKGMicrosoft_ML_OnnxRuntimeGenAI)\\runtimes\\win-arm64\\native\\onnxruntime-genai.dll"); + noneItem.Condition = "$(Platform) == 'ARM64'"; + noneItem.AddMetadata("Link", "onnxruntime-genai.dll", false); + noneItem.AddMetadata("CopyToOutputDirectory", "PreserveNewest", false); + noneItem.AddMetadata("Visible", "false", false); + + packageReferenceItem.AddMetadata("GeneratePathProperty", "true", true); + packageReferenceItem.AddMetadata("ExcludeAssets", "all", true); + } } - // Fix PublishProfiles. This shouldn't be necessary once the templates are fixed - foreach (var file in Directory.GetFiles(outputPath, "*.pubxml", SearchOption.AllDirectories)) + foreach (var packageReference in packageReferences) { - var pubxml = ProjectRootElement.Open(file); - var firstPg = pubxml.PropertyGroups.FirstOrDefault(); - firstPg ??= pubxml.AddPropertyGroup(); - - if (!firstPg.Children.Any(p => p.ElementName == "RuntimeIdentifier")) + var packageName = packageReference.PackageName; + var version = packageReference.Version; + if (packageName == "Microsoft.ML.OnnxRuntime.DirectML") + { + AddPackageReference(itemGroup, "Microsoft.ML.OnnxRuntime.Qnn", null); + } + else if (packageName == "Microsoft.ML.OnnxRuntimeGenAI.DirectML") { - var runtimeIdentifier = Path.GetFileNameWithoutExtension(file).Split('-').Last(); - firstPg.AddProperty("RuntimeIdentifier", $"win-{runtimeIdentifier}"); + AddPackageReference(itemGroup, "Microsoft.ML.OnnxRuntimeGenAI", null); + AddPackageReference(itemGroup, "Microsoft.ML.OnnxRuntimeGenAI.Managed", null); } - pubxml.Save(); + AddPackageReference(itemGroup, packageName, version); } - // Styles - List styles = []; - foreach (var file in Directory.GetFiles(outputPath, "*.xaml", SearchOption.TopDirectoryOnly)) + if (copyModelLocally) { - var content = await File.ReadAllTextAsync(file, cancellationToken); - if (!content.StartsWith(" 0) + project.Save(); + } + + // Fix PublishProfiles. This shouldn't be necessary once the templates are fixed + foreach (var file in Directory.GetFiles(outputPath, "*.pubxml", SearchOption.AllDirectories)) + { + var pubxml = ProjectRootElement.Open(file); + var firstPg = pubxml.PropertyGroups.FirstOrDefault(); + firstPg ??= pubxml.AddPropertyGroup(); + + if (!firstPg.Children.Any(p => p.ElementName == "RuntimeIdentifier")) { - var appXamlPath = Path.Join(outputPath, "App.xaml"); - var appXaml = await File.ReadAllTextAsync(appXamlPath, cancellationToken); - appXaml = appXaml.Replace( - " ", - string.Join(Environment.NewLine, styles.Select(s => $" "))); - await File.WriteAllTextAsync(appXamlPath, appXaml, cancellationToken); + var runtimeIdentifier = Path.GetFileNameWithoutExtension(file).Split('-').Last(); + firstPg.AddProperty("RuntimeIdentifier", $"win-{runtimeIdentifier}"); } - return outputPath; + pubxml.Save(); } - private string GetChatClientLoaderString(Sample sample, string modelPath, string promptTemplate) + // Styles + List styles = []; + foreach (var file in Directory.GetFiles(outputPath, "*.xaml", SearchOption.TopDirectoryOnly)) { - if (!sample.SharedCode.Contains(SharedCodeEnum.GenAIModel)) + var content = await File.ReadAllTextAsync(file, cancellationToken); + if (!content.StartsWith(" 0) + { + var appXamlPath = Path.Join(outputPath, "App.xaml"); + var appXaml = await File.ReadAllTextAsync(appXamlPath, cancellationToken); + appXaml = appXaml.Replace( + " ", + string.Join(Environment.NewLine, styles.Select(s => $" "))); + await File.WriteAllTextAsync(appXamlPath, appXaml, cancellationToken); } - private static async Task CopyFileAsync(string sourceFile, string destinationFile, CancellationToken cancellationToken) + return outputPath; + } + + private string GetChatClientLoaderString(Sample sample, string modelPath, string promptTemplate) + { + if (!sample.SharedCode.Contains(SharedCodeEnum.GenAIModel)) { - using var sourceStream = new FileStream(sourceFile, FileMode.Open, FileAccess.Read, FileShare.Read, bufferSize: 4096, useAsync: true); - using var destinationStream = new FileStream(destinationFile, FileMode.CreateNew, FileAccess.Write, FileShare.None, bufferSize: 4096, useAsync: true); - await sourceStream.CopyToAsync(destinationStream, 81920, cancellationToken).ConfigureAwait(false); + return string.Empty; } - private static string EscapeNewLines(string str) + return $"GenAIModel.CreateAsync({modelPath}, {promptTemplate})"; + } + + private static async Task CopyFileAsync(string sourceFile, string destinationFile, CancellationToken cancellationToken) + { + using var sourceStream = new FileStream(sourceFile, FileMode.Open, FileAccess.Read, FileShare.Read, bufferSize: 4096, useAsync: true); + using var destinationStream = new FileStream(destinationFile, FileMode.CreateNew, FileAccess.Write, FileShare.None, bufferSize: 4096, useAsync: true); + await sourceStream.CopyToAsync(destinationStream, 81920, cancellationToken).ConfigureAwait(false); + } + + private static string EscapeNewLines(string str) + { + str = str + .Replace("\r", "\\r") + .Replace("\n", "\\n"); + return str; + } + + private string GetPromptTemplateString(PromptTemplate? promptTemplate, int spaceCount) + { + if (promptTemplate == null) { - str = str - .Replace("\r", "\\r") - .Replace("\n", "\\n"); - return str; + return "null"; } - private string GetPromptTemplateString(PromptTemplate? promptTemplate, int spaceCount) + StringBuilder modelPromptTemplateSb = new(); + var spaces = new string(' ', spaceCount); + modelPromptTemplateSb.AppendLine("new LlmPromptTemplate"); + modelPromptTemplateSb.Append(spaces); + modelPromptTemplateSb.AppendLine("{"); + if (!string.IsNullOrEmpty(promptTemplate.System)) { - if (promptTemplate == null) - { - return "null"; - } + modelPromptTemplateSb.Append(spaces); + modelPromptTemplateSb.AppendLine( + string.Format( + CultureInfo.InvariantCulture, + """ + System = "{0}", + """, + EscapeNewLines(promptTemplate.System))); + } - StringBuilder modelPromptTemplateSb = new(); - var spaces = new string(' ', spaceCount); - modelPromptTemplateSb.AppendLine("new LlmPromptTemplate"); + if (!string.IsNullOrEmpty(promptTemplate.User)) + { modelPromptTemplateSb.Append(spaces); - modelPromptTemplateSb.AppendLine("{"); - if (!string.IsNullOrEmpty(promptTemplate.System)) - { - modelPromptTemplateSb.Append(spaces); - modelPromptTemplateSb.AppendLine( - string.Format( - CultureInfo.InvariantCulture, - """ - System = "{0}", - """, - EscapeNewLines(promptTemplate.System))); - } + modelPromptTemplateSb.AppendLine(string.Format( + CultureInfo.InvariantCulture, + """ + User = "{0}", + """, + EscapeNewLines(promptTemplate.User))); + } - if (!string.IsNullOrEmpty(promptTemplate.User)) - { - modelPromptTemplateSb.Append(spaces); - modelPromptTemplateSb.AppendLine(string.Format( - CultureInfo.InvariantCulture, - """ - User = "{0}", - """, - EscapeNewLines(promptTemplate.User))); - } + if (!string.IsNullOrEmpty(promptTemplate.Assistant)) + { + modelPromptTemplateSb.Append(spaces); + modelPromptTemplateSb.AppendLine(string.Format( + CultureInfo.InvariantCulture, + """ + Assistant = "{0}", + """, + EscapeNewLines(promptTemplate.Assistant))); + } - if (!string.IsNullOrEmpty(promptTemplate.Assistant)) - { - modelPromptTemplateSb.Append(spaces); - modelPromptTemplateSb.AppendLine(string.Format( + if (promptTemplate.Stop != null && promptTemplate.Stop.Length > 0) + { + modelPromptTemplateSb.Append(spaces); + var stopStr = string.Join(", ", promptTemplate.Stop.Select(s => + string.Format( CultureInfo.InvariantCulture, """ - Assistant = "{0}", + "{0}" """, - EscapeNewLines(promptTemplate.Assistant))); - } + EscapeNewLines(s)))); + modelPromptTemplateSb.Append(" Stop = [ "); + modelPromptTemplateSb.Append(stopStr); + modelPromptTemplateSb.AppendLine("]"); + } - if (promptTemplate.Stop != null && promptTemplate.Stop.Length > 0) - { - modelPromptTemplateSb.Append(spaces); - var stopStr = string.Join(", ", promptTemplate.Stop.Select(s => - string.Format( - CultureInfo.InvariantCulture, - """ - "{0}" - """, - EscapeNewLines(s)))); - modelPromptTemplateSb.Append(" Stop = [ "); - modelPromptTemplateSb.Append(stopStr); - modelPromptTemplateSb.AppendLine("]"); - } + modelPromptTemplateSb.Append(spaces); + modelPromptTemplateSb.Append('}'); - modelPromptTemplateSb.Append(spaces); - modelPromptTemplateSb.Append('}'); + return modelPromptTemplateSb.ToString(); + } - return modelPromptTemplateSb.ToString(); + private async Task AddFilesFromSampleAsync( + Sample sample, + List<(string PackageName, string? Version)> packageReferences, + string safeProjectName, + string outputPath, + bool addLllmTypes, + Dictionary modelInfos, + CancellationToken cancellationToken) + { + var sharedCode = sample.SharedCode.ToList(); + if (!sharedCode.Contains(SharedCodeEnum.LlmPromptTemplate) && + (addLllmTypes || sample.SharedCode.Contains(SharedCodeEnum.GenAIModel))) + { + // Always used inside GenAIModel.cs + sharedCode.Add(SharedCodeEnum.LlmPromptTemplate); } - private async Task AddFilesFromSampleAsync( - Sample sample, - List<(string PackageName, string? Version)> packageReferences, - string safeProjectName, - string outputPath, - bool addLllmTypes, - Dictionary modelInfos, - CancellationToken cancellationToken) + if (sharedCode.Contains(SharedCodeEnum.DeviceUtils) && !sharedCode.Contains(SharedCodeEnum.NativeMethods)) { - var sharedCode = sample.SharedCode.ToList(); - if (!sharedCode.Contains(SharedCodeEnum.LlmPromptTemplate) && - (addLllmTypes || sample.SharedCode.Contains(SharedCodeEnum.GenAIModel))) + sharedCode.Add(SharedCodeEnum.NativeMethods); + var csWin32 = "Microsoft.Windows.CsWin32"; + if (!packageReferences.Any(packageReferences => packageReferences.PackageName == csWin32)) { - // Always used inside GenAIModel.cs - sharedCode.Add(SharedCodeEnum.LlmPromptTemplate); + packageReferences.Add((csWin32, null)); } + } - if (sharedCode.Contains(SharedCodeEnum.DeviceUtils) && !sharedCode.Contains(SharedCodeEnum.NativeMethods)) + foreach (var sharedCodeEnum in sharedCode) + { + var fileName = SharedCodeHelpers.GetName(sharedCodeEnum); + var source = SharedCodeHelpers.GetSource(sharedCodeEnum); + if (fileName.EndsWith(".xaml", StringComparison.OrdinalIgnoreCase)) { - sharedCode.Add(SharedCodeEnum.NativeMethods); - var csWin32 = "Microsoft.Windows.CsWin32"; - if (!packageReferences.Any(packageReferences => packageReferences.PackageName == csWin32)) - { - packageReferences.Add((csWin32, null)); - } + source = CleanXamlSource(source, $"{safeProjectName}.SharedCode", out _); } - - foreach (var sharedCodeEnum in sharedCode) + else { - var fileName = SharedCodeHelpers.GetName(sharedCodeEnum); - var source = SharedCodeHelpers.GetSource(sharedCodeEnum); - if (fileName.EndsWith(".xaml", StringComparison.OrdinalIgnoreCase)) - { - source = CleanXamlSource(source, $"{safeProjectName}.SharedCode", out _); - } - else - { - source = CleanCsSource(source, $"{safeProjectName}.SharedCode", false); - } - - await File.WriteAllTextAsync(Path.Join(outputPath, fileName), source, cancellationToken); + source = CleanCsSource(source, $"{safeProjectName}.SharedCode", false); } - string className = "Sample"; - if (!string.IsNullOrEmpty(sample.XAMLCode)) - { - var xamlSource = CleanXamlSource(sample.XAMLCode, safeProjectName, out className); - xamlSource = xamlSource.Replace($"{Environment.NewLine} xmlns:samples=\"using:AIDevGallery.Samples\"", string.Empty); - xamlSource = xamlSource.Replace("", ""); + await File.WriteAllTextAsync(Path.Join(outputPath, fileName), source, cancellationToken); + } - await File.WriteAllTextAsync(Path.Join(outputPath, $"{className}.xaml"), xamlSource, cancellationToken); - } + string className = "Sample"; + if (!string.IsNullOrEmpty(sample.XAMLCode)) + { + var xamlSource = CleanXamlSource(sample.XAMLCode, safeProjectName, out className); + xamlSource = xamlSource.Replace($"{Environment.NewLine} xmlns:samples=\"using:AIDevGallery.Samples\"", string.Empty); + xamlSource = xamlSource.Replace("", ""); - if (!string.IsNullOrEmpty(sample.CSCode)) - { - var cleanCsSource = CleanCsSource(sample.CSCode, safeProjectName, true); - cleanCsSource = cleanCsSource.Replace("sampleParams.NotifyCompletion();", "App.Window?.ModelLoaded();"); - cleanCsSource = cleanCsSource.Replace(": BaseSamplePage", ": Microsoft.UI.Xaml.Controls.Page"); - cleanCsSource = cleanCsSource.Replace( - "Task LoadModelAsync(SampleNavigationParameters sampleParams)", - "void OnNavigatedTo(Microsoft.UI.Xaml.Navigation.NavigationEventArgs e)"); - cleanCsSource = cleanCsSource.Replace( - "Task LoadModelAsync(MultiModelSampleNavigationParameters sampleParams)", - "void OnNavigatedTo(Microsoft.UI.Xaml.Navigation.NavigationEventArgs e)"); - cleanCsSource = RegexReturnTaskCompletedTask().Replace(cleanCsSource, string.Empty); - - string modelPath; - if (modelInfos.Count > 1) - { - int i = 0; - foreach (var modelInfo in modelInfos) - { - cleanCsSource = cleanCsSource.Replace($"sampleParams.HardwareAccelerators[{i}]", $"HardwareAccelerator.{modelInfo.Value.HardwareAccelerator}"); - cleanCsSource = cleanCsSource.Replace($"sampleParams.ModelPaths[{i}]", modelInfo.Value.ModelPathStr); - i++; - } + await File.WriteAllTextAsync(Path.Join(outputPath, $"{className}.xaml"), xamlSource, cancellationToken); + } - modelPath = modelInfos.First().Value.ModelPathStr; - } - else + if (!string.IsNullOrEmpty(sample.CSCode)) + { + var cleanCsSource = CleanCsSource(sample.CSCode, safeProjectName, true); + cleanCsSource = cleanCsSource.Replace("sampleParams.NotifyCompletion();", "App.Window?.ModelLoaded();"); + cleanCsSource = cleanCsSource.Replace(": BaseSamplePage", ": Microsoft.UI.Xaml.Controls.Page"); + cleanCsSource = cleanCsSource.Replace( + "Task LoadModelAsync(SampleNavigationParameters sampleParams)", + "void OnNavigatedTo(Microsoft.UI.Xaml.Navigation.NavigationEventArgs e)"); + cleanCsSource = cleanCsSource.Replace( + "Task LoadModelAsync(MultiModelSampleNavigationParameters sampleParams)", + "void OnNavigatedTo(Microsoft.UI.Xaml.Navigation.NavigationEventArgs e)"); + cleanCsSource = RegexReturnTaskCompletedTask().Replace(cleanCsSource, string.Empty); + + string modelPath; + if (modelInfos.Count > 1) + { + int i = 0; + foreach (var modelInfo in modelInfos) { - var modelInfo = modelInfos.Values.First(); - cleanCsSource = cleanCsSource.Replace("sampleParams.HardwareAccelerator", $"HardwareAccelerator.{modelInfo.HardwareAccelerator}"); - cleanCsSource = cleanCsSource.Replace("sampleParams.ModelPath", modelInfo.ModelPathStr); - modelPath = modelInfo.ModelPathStr; + cleanCsSource = cleanCsSource.Replace($"sampleParams.HardwareAccelerators[{i}]", $"HardwareAccelerator.{modelInfo.Value.HardwareAccelerator}"); + cleanCsSource = cleanCsSource.Replace($"sampleParams.ModelPaths[{i}]", modelInfo.Value.ModelPathStr); + i++; } - cleanCsSource = cleanCsSource.Replace("sampleParams.CancellationToken", "CancellationToken.None"); + modelPath = modelInfos.First().Value.ModelPathStr; + } + else + { + var modelInfo = modelInfos.Values.First(); + cleanCsSource = cleanCsSource.Replace("sampleParams.HardwareAccelerator", $"HardwareAccelerator.{modelInfo.HardwareAccelerator}"); + cleanCsSource = cleanCsSource.Replace("sampleParams.ModelPath", modelInfo.ModelPathStr); + modelPath = modelInfo.ModelPathStr; + } - var search = "sampleParams.GetIChatClientAsync()"; - int index = cleanCsSource.IndexOf(search, StringComparison.OrdinalIgnoreCase); - if (index > 0) - { - int newLineIndex = cleanCsSource[..index].LastIndexOf(Environment.NewLine, StringComparison.OrdinalIgnoreCase); - var subStr = cleanCsSource[(newLineIndex + Environment.NewLine.Length)..]; - var subStrWithoutSpaces = subStr.TrimStart(); - var spaceCount = subStr.Length - subStrWithoutSpaces.Length; - var promptTemplate = GetPromptTemplateString(modelInfos.Values.First().ModelPromptTemplate, spaceCount); - var chatClientLoader = GetChatClientLoaderString(sample, modelPath, promptTemplate); - if (chatClientLoader != null) - { - cleanCsSource = cleanCsSource.Replace(search, chatClientLoader); - } - } + cleanCsSource = cleanCsSource.Replace("sampleParams.CancellationToken", "CancellationToken.None"); - if (sample.SharedCode.Contains(SharedCodeEnum.GenAIModel)) + var search = "sampleParams.GetIChatClientAsync()"; + int index = cleanCsSource.IndexOf(search, StringComparison.OrdinalIgnoreCase); + if (index > 0) + { + int newLineIndex = cleanCsSource[..index].LastIndexOf(Environment.NewLine, StringComparison.OrdinalIgnoreCase); + var subStr = cleanCsSource[(newLineIndex + Environment.NewLine.Length)..]; + var subStrWithoutSpaces = subStr.TrimStart(); + var spaceCount = subStr.Length - subStrWithoutSpaces.Length; + var promptTemplate = GetPromptTemplateString(modelInfos.Values.First().ModelPromptTemplate, spaceCount); + var chatClientLoader = GetChatClientLoaderString(sample, modelPath, promptTemplate); + if (chatClientLoader != null) { - cleanCsSource = RegexInitializeComponent().Replace(cleanCsSource, $"$1this.InitializeComponent();$1GenAIModel.InitializeGenAI();"); + cleanCsSource = cleanCsSource.Replace(search, chatClientLoader); } + } - await File.WriteAllTextAsync(Path.Join(outputPath, $"{className}.xaml.cs"), cleanCsSource, cancellationToken); + if (sample.SharedCode.Contains(SharedCodeEnum.GenAIModel)) + { + cleanCsSource = RegexInitializeComponent().Replace(cleanCsSource, $"$1this.InitializeComponent();$1GenAIModel.InitializeGenAI();"); } - return className; + await File.WriteAllTextAsync(Path.Join(outputPath, $"{className}.xaml.cs"), cleanCsSource, cancellationToken); } - [GeneratedRegex(@"x:Class=""(@?[a-z_A-Z]\w+(?:\.@?[a-z_A-Z]\w+)*)""")] - private static partial Regex XClass(); + return className; + } + + [GeneratedRegex(@"x:Class=""(@?[a-z_A-Z]\w+(?:\.@?[a-z_A-Z]\w+)*)""")] + private static partial Regex XClass(); - [GeneratedRegex(@"xmlns:local=""using:(\w.+)""")] - private static partial Regex XamlLocalUsing(); + [GeneratedRegex(@"xmlns:local=""using:(\w.+)""")] + private static partial Regex XamlLocalUsing(); - [GeneratedRegex(@"[\r\n][\s]*return Task.CompletedTask;")] - private static partial Regex RegexReturnTaskCompletedTask(); + [GeneratedRegex(@"[\r\n][\s]*return Task.CompletedTask;")] + private static partial Regex RegexReturnTaskCompletedTask(); - [GeneratedRegex(@"(\s*)this.InitializeComponent\(\);")] - private static partial Regex RegexInitializeComponent(); + [GeneratedRegex(@"(\s*)this.InitializeComponent\(\);")] + private static partial Regex RegexInitializeComponent(); - private string CleanXamlSource(string xamlCode, string newNamespace, out string className) + private string CleanXamlSource(string xamlCode, string newNamespace, out string className) + { + var match = XClass().Match(xamlCode); + if (match.Success) { - var match = XClass().Match(xamlCode); - if (match.Success) - { - var oldClassFullName = match.Groups[1].Value; - _ = oldClassFullName[..oldClassFullName.LastIndexOf('.')]; - className = oldClassFullName[(oldClassFullName.LastIndexOf('.') + 1)..]; + var oldClassFullName = match.Groups[1].Value; + _ = oldClassFullName[..oldClassFullName.LastIndexOf('.')]; + className = oldClassFullName[(oldClassFullName.LastIndexOf('.') + 1)..]; - xamlCode = xamlCode.Replace(match.Value, @$"x:Class=""{newNamespace}.{className}"""); - } - else - { - className = "Sample"; - } + xamlCode = xamlCode.Replace(match.Value, @$"x:Class=""{newNamespace}.{className}"""); + } + else + { + className = "Sample"; + } - xamlCode = XamlLocalUsing().Replace(xamlCode, $"xmlns:local=\"using:{newNamespace}\""); + xamlCode = XamlLocalUsing().Replace(xamlCode, $"xmlns:local=\"using:{newNamespace}\""); - xamlCode = xamlCode.Replace("xmlns:shared=\"using:AIDevGallery.Samples.SharedCode\"", $"xmlns:shared=\"using:{newNamespace}.SharedCode\""); + xamlCode = xamlCode.Replace("xmlns:shared=\"using:AIDevGallery.Samples.SharedCode\"", $"xmlns:shared=\"using:{newNamespace}.SharedCode\""); - return xamlCode; - } + return xamlCode; + } - [GeneratedRegex(@"using AIDevGallery\S*;\r?\n", RegexOptions.Multiline)] - private static partial Regex UsingAIDevGalleryGNamespace(); + [GeneratedRegex(@"using AIDevGallery\S*;\r?\n", RegexOptions.Multiline)] + private static partial Regex UsingAIDevGalleryGNamespace(); - [GeneratedRegex(@"namespace AIDevGallery(?:[^;\r\n])*(;?)\r\n", RegexOptions.Multiline)] - private static partial Regex AIDevGalleryGNamespace(); + [GeneratedRegex(@"namespace AIDevGallery(?:[^;\r\n])*(;?)\r\n", RegexOptions.Multiline)] + private static partial Regex AIDevGalleryGNamespace(); - private static string CleanCsSource(string source, string newNamespace, bool addSharedSourceNamespace) - { - // Remove the using statements for the AIDevGallery.* namespaces - source = UsingAIDevGalleryGNamespace().Replace(source, string.Empty); + private static string CleanCsSource(string source, string newNamespace, bool addSharedSourceNamespace) + { + // Remove the using statements for the AIDevGallery.* namespaces + source = UsingAIDevGalleryGNamespace().Replace(source, string.Empty); - source = source.Replace("\r\r", "\r"); + source = source.Replace("\r\r", "\r"); - // Replace the AIDevGallery namespace with the namespace of the new project - // consider the 1st capture group to add the ; or not - var match = AIDevGalleryGNamespace().Match(source); - if (match.Success) - { - source = AIDevGalleryGNamespace().Replace(source, $"namespace {newNamespace}{match.Groups[1].Value}{Environment.NewLine}"); - } + // Replace the AIDevGallery namespace with the namespace of the new project + // consider the 1st capture group to add the ; or not + var match = AIDevGalleryGNamespace().Match(source); + if (match.Success) + { + source = AIDevGalleryGNamespace().Replace(source, $"namespace {newNamespace}{match.Groups[1].Value}{Environment.NewLine}"); + } - if (addSharedSourceNamespace) + if (addSharedSourceNamespace) + { + var namespaceLine = $"using {newNamespace}.SharedCode;"; + if (!source.Contains(namespaceLine)) { - var namespaceLine = $"using {newNamespace}.SharedCode;"; - if (!source.Contains(namespaceLine)) - { - source = namespaceLine + Environment.NewLine + source; - } + source = namespaceLine + Environment.NewLine + source; } - - return source; } + + return source; } } \ No newline at end of file diff --git a/AIDevGallery/ProjectGenerator/Template/App.xaml.cs b/AIDevGallery/ProjectGenerator/Template/App.xaml.cs index 4c78b38..c2ad9bb 100644 --- a/AIDevGallery/ProjectGenerator/Template/App.xaml.cs +++ b/AIDevGallery/ProjectGenerator/Template/App.xaml.cs @@ -19,32 +19,31 @@ // To learn more about WinUI, the WinUI project structure, // and more about our project templates, see: http://aka.ms/winui-project-info. -namespace $safeprojectname$ +namespace $safeprojectname$; + +/// +/// Provides application-specific behavior to supplement the default Application class. +/// +public partial class App : Application { /// - /// Provides application-specific behavior to supplement the default Application class. + /// Initializes the singleton application object. This is the first line of authored code + /// executed, and as such is the logical equivalent of main() or WinMain(). /// - public partial class App : Application + public App() { - /// - /// Initializes the singleton application object. This is the first line of authored code - /// executed, and as such is the logical equivalent of main() or WinMain(). - /// - public App() - { - this.InitializeComponent(); - } - - /// - /// Invoked when the application is launched. - /// - /// Details about the launch request and process. - protected override void OnLaunched(Microsoft.UI.Xaml.LaunchActivatedEventArgs args) - { - Window = new MainWindow(); - Window.Activate(); - } + this.InitializeComponent(); + } - internal static MainWindow? Window { get; private set; } + /// + /// Invoked when the application is launched. + /// + /// Details about the launch request and process. + protected override void OnLaunched(Microsoft.UI.Xaml.LaunchActivatedEventArgs args) + { + Window = new MainWindow(); + Window.Activate(); } -} + + internal static MainWindow? Window { get; private set; } +} \ No newline at end of file diff --git a/AIDevGallery/ProjectGenerator/Template/MainWindow.xaml.cs b/AIDevGallery/ProjectGenerator/Template/MainWindow.xaml.cs index ba7e49a..52f4cf0 100644 --- a/AIDevGallery/ProjectGenerator/Template/MainWindow.xaml.cs +++ b/AIDevGallery/ProjectGenerator/Template/MainWindow.xaml.cs @@ -1,21 +1,20 @@ using Microsoft.UI.Xaml; -namespace $safeprojectname$ +namespace $safeprojectname$; + +public sealed partial class MainWindow : Window { - public sealed partial class MainWindow : Window + public MainWindow() { - public MainWindow() + this.InitializeComponent(); + this.RootFrame.Loaded += (sender, args) => { - this.InitializeComponent(); - this.RootFrame.Loaded += (sender, args) => - { - RootFrame.Navigate(typeof($MainSamplePage$)); - }; - } + RootFrame.Navigate(typeof($MainSamplePage$)); + }; + } - internal void ModelLoaded() - { - ProgressRingGrid.Visibility = Visibility.Collapsed; - } + internal void ModelLoaded() + { + ProgressRingGrid.Visibility = Visibility.Collapsed; } -} +} \ No newline at end of file diff --git a/AIDevGallery/Samples/Open Source Models/Image Models/MultiHRNetPose/Multipose.xaml.cs b/AIDevGallery/Samples/Open Source Models/Image Models/MultiHRNetPose/Multipose.xaml.cs index 9149f70..86c6c94 100644 --- a/AIDevGallery/Samples/Open Source Models/Image Models/MultiHRNetPose/Multipose.xaml.cs +++ b/AIDevGallery/Samples/Open Source Models/Image Models/MultiHRNetPose/Multipose.xaml.cs @@ -18,291 +18,290 @@ using Windows.Storage.Pickers; using Windows.System; -namespace AIDevGallery.Samples.OpenSourceModels.MultiHRNetPose +namespace AIDevGallery.Samples.OpenSourceModels.MultiHRNetPose; + +[GallerySample( + Model1Types = [ModelType.HRNetPose], + Model2Types = [ModelType.YOLO], + Scenario = ScenarioType.ImageDetectPoses, + SharedCode = [ + SharedCodeEnum.Prediction, + SharedCodeEnum.BitmapFunctions, + SharedCodeEnum.DeviceUtils, + SharedCodeEnum.PoseHelper, + SharedCodeEnum.YOLOHelpers, + SharedCodeEnum.RCNNLabelMap + ], + NugetPackageReferences = [ + "System.Drawing.Common", + "Microsoft.ML.OnnxRuntime.DirectML", + "Microsoft.ML.OnnxRuntime.Extensions" + ], + Name = "Multiple Pose Detection", + Id = "9b74ccc0-f5f7-430f-bed0-71211c063508", + Icon = "\uE8B3")] +internal sealed partial class Multipose : BaseSamplePage { - [GallerySample( - Model1Types = [ModelType.HRNetPose], - Model2Types = [ModelType.YOLO], - Scenario = ScenarioType.ImageDetectPoses, - SharedCode = [ - SharedCodeEnum.Prediction, - SharedCodeEnum.BitmapFunctions, - SharedCodeEnum.DeviceUtils, - SharedCodeEnum.PoseHelper, - SharedCodeEnum.YOLOHelpers, - SharedCodeEnum.RCNNLabelMap - ], - NugetPackageReferences = [ - "System.Drawing.Common", - "Microsoft.ML.OnnxRuntime.DirectML", - "Microsoft.ML.OnnxRuntime.Extensions" - ], - Name = "Multiple Pose Detection", - Id = "9b74ccc0-f5f7-430f-bed0-71211c063508", - Icon = "\uE8B3")] - internal sealed partial class Multipose : BaseSamplePage - { - private InferenceSession? _detectionSession; - private InferenceSession? _poseSession; + private InferenceSession? _detectionSession; + private InferenceSession? _poseSession; - public Multipose() + public Multipose() + { + this.Unloaded += (s, e) => { - this.Unloaded += (s, e) => - { - _detectionSession?.Dispose(); - _poseSession?.Dispose(); - }; + _detectionSession?.Dispose(); + _poseSession?.Dispose(); + }; - this.Loaded += (s, e) => Page_Loaded(); // - this.InitializeComponent(); - } + this.Loaded += (s, e) => Page_Loaded(); // + this.InitializeComponent(); + } - // - private void Page_Loaded() - { - UploadButton.Focus(FocusState.Programmatic); - } + // + private void Page_Loaded() + { + UploadButton.Focus(FocusState.Programmatic); + } - // - protected override async Task LoadModelAsync(MultiModelSampleNavigationParameters sampleParams) - { - await InitModels(sampleParams.ModelPaths[0], sampleParams.HardwareAccelerators[0], sampleParams.ModelPaths[1], sampleParams.HardwareAccelerators[1]); - sampleParams.NotifyCompletion(); + // + protected override async Task LoadModelAsync(MultiModelSampleNavigationParameters sampleParams) + { + await InitModels(sampleParams.ModelPaths[0], sampleParams.HardwareAccelerators[0], sampleParams.ModelPaths[1], sampleParams.HardwareAccelerators[1]); + sampleParams.NotifyCompletion(); - await RunPipeline(Path.Join(Windows.ApplicationModel.Package.Current.InstalledLocation.Path, "Assets", "team.jpg")); - } + await RunPipeline(Path.Join(Windows.ApplicationModel.Package.Current.InstalledLocation.Path, "Assets", "team.jpg")); + } - private Task InitModels(string poseModelPath, HardwareAccelerator poseHardwareAccelerator, string detectionModelPath, HardwareAccelerator detectionHardwareAccelerator) + private Task InitModels(string poseModelPath, HardwareAccelerator poseHardwareAccelerator, string detectionModelPath, HardwareAccelerator detectionHardwareAccelerator) + { + return Task.Run(() => { - return Task.Run(() => + if (_poseSession != null) { - if (_poseSession != null) - { - return; - } + return; + } - SessionOptions poseOptions = new(); - poseOptions.RegisterOrtExtensions(); - if (poseHardwareAccelerator == HardwareAccelerator.DML) - { - poseOptions.AppendExecutionProvider_DML(DeviceUtils.GetBestDeviceId()); - } - else if (poseHardwareAccelerator == HardwareAccelerator.QNN) - { - Dictionary options = new() - { - { "backend_path", "QnnHtp.dll" }, - { "htp_performance_mode", "high_performance" }, - { "htp_graph_finalization_optimization_mode", "3" } - }; - poseOptions.AppendExecutionProvider("QNN", options); - } - - _poseSession = new InferenceSession(poseModelPath, poseOptions); - - if (_detectionSession != null) + SessionOptions poseOptions = new(); + poseOptions.RegisterOrtExtensions(); + if (poseHardwareAccelerator == HardwareAccelerator.DML) + { + poseOptions.AppendExecutionProvider_DML(DeviceUtils.GetBestDeviceId()); + } + else if (poseHardwareAccelerator == HardwareAccelerator.QNN) + { + Dictionary options = new() { - return; - } + { "backend_path", "QnnHtp.dll" }, + { "htp_performance_mode", "high_performance" }, + { "htp_graph_finalization_optimization_mode", "3" } + }; + poseOptions.AppendExecutionProvider("QNN", options); + } - SessionOptions detectionOptions = new(); - detectionOptions.RegisterOrtExtensions(); - if (detectionHardwareAccelerator == HardwareAccelerator.DML) - { - detectionOptions.AppendExecutionProvider_DML(DeviceUtils.GetBestDeviceId()); - } + _poseSession = new InferenceSession(poseModelPath, poseOptions); - _detectionSession = new InferenceSession(detectionModelPath, detectionOptions); - }); - } + if (_detectionSession != null) + { + return; + } - private async void UploadButton_Click(object sender, RoutedEventArgs e) - { - var window = new Window(); - var hwnd = WinRT.Interop.WindowNative.GetWindowHandle(window); + SessionOptions detectionOptions = new(); + detectionOptions.RegisterOrtExtensions(); + if (detectionHardwareAccelerator == HardwareAccelerator.DML) + { + detectionOptions.AppendExecutionProvider_DML(DeviceUtils.GetBestDeviceId()); + } - var picker = new FileOpenPicker(); - WinRT.Interop.InitializeWithWindow.Initialize(picker, hwnd); + _detectionSession = new InferenceSession(detectionModelPath, detectionOptions); + }); + } - picker.FileTypeFilter.Add(".png"); - picker.FileTypeFilter.Add(".jpeg"); - picker.FileTypeFilter.Add(".jpg"); + private async void UploadButton_Click(object sender, RoutedEventArgs e) + { + var window = new Window(); + var hwnd = WinRT.Interop.WindowNative.GetWindowHandle(window); - picker.ViewMode = PickerViewMode.Thumbnail; + var picker = new FileOpenPicker(); + WinRT.Interop.InitializeWithWindow.Initialize(picker, hwnd); - var file = await picker.PickSingleFileAsync(); - if (file != null) - { - // Call function to run inference and classify image - UploadButton.Focus(FocusState.Programmatic); - await RunPipeline(file.Path); - } + picker.FileTypeFilter.Add(".png"); + picker.FileTypeFilter.Add(".jpeg"); + picker.FileTypeFilter.Add(".jpg"); + + picker.ViewMode = PickerViewMode.Thumbnail; + + var file = await picker.PickSingleFileAsync(); + if (file != null) + { + // Call function to run inference and classify image + UploadButton.Focus(FocusState.Programmatic); + await RunPipeline(file.Path); } + } - private async Task RunPipeline(string filePath) + private async Task RunPipeline(string filePath) + { + if (!File.Exists(filePath)) { - if (!File.Exists(filePath)) - { - return; - } + return; + } - DispatcherQueue.TryEnqueue(() => - { - DefaultImage.Source = new BitmapImage(new Uri(filePath)); - Loader.IsActive = true; - Loader.Visibility = Visibility.Visible; - UploadButton.Visibility = Visibility.Collapsed; - }); + DispatcherQueue.TryEnqueue(() => + { + DefaultImage.Source = new BitmapImage(new Uri(filePath)); + Loader.IsActive = true; + Loader.Visibility = Visibility.Visible; + UploadButton.Visibility = Visibility.Collapsed; + }); - Bitmap originalImage = new(filePath); + Bitmap originalImage = new(filePath); - // Step 1: Detect where the "person" tag is found in the image - List predictions = await FindPeople(originalImage); - predictions = predictions.Where(x => x.Label == "person").ToList(); + // Step 1: Detect where the "person" tag is found in the image + List predictions = await FindPeople(originalImage); + predictions = predictions.Where(x => x.Label == "person").ToList(); - // Step 2: For each person detected, crop the region and run pose - foreach (var prediction in predictions) + // Step 2: For each person detected, crop the region and run pose + foreach (var prediction in predictions) + { + if (prediction.Box != null) { - if (prediction.Box != null) - { - using Bitmap croppedImage = BitmapFunctions.CropImage(originalImage, prediction.Box); + using Bitmap croppedImage = BitmapFunctions.CropImage(originalImage, prediction.Box); - using Bitmap poseOverlay = await DetectPose(croppedImage, originalImage); + using Bitmap poseOverlay = await DetectPose(croppedImage, originalImage); - originalImage = BitmapFunctions.OverlayImage(originalImage, poseOverlay, prediction.Box); - } + originalImage = BitmapFunctions.OverlayImage(originalImage, poseOverlay, prediction.Box); } + } - // Step 3: Convert the processed image back to BitmapImage - BitmapImage outputImage = BitmapFunctions.ConvertBitmapToBitmapImage(originalImage); + // Step 3: Convert the processed image back to BitmapImage + BitmapImage outputImage = BitmapFunctions.ConvertBitmapToBitmapImage(originalImage); - DispatcherQueue.TryEnqueue(() => - { - DefaultImage.Source = outputImage; - Loader.IsActive = false; - Loader.Visibility = Visibility.Collapsed; - UploadButton.Visibility = Visibility.Visible; - }); + DispatcherQueue.TryEnqueue(() => + { + DefaultImage.Source = outputImage; + Loader.IsActive = false; + Loader.Visibility = Visibility.Collapsed; + UploadButton.Visibility = Visibility.Visible; + }); - originalImage.Dispose(); - } + originalImage.Dispose(); + } - private async Task> FindPeople(Bitmap image) + private async Task> FindPeople(Bitmap image) + { + if (_detectionSession == null) { - if (_detectionSession == null) - { - return []; - } + return []; + } - int originalWidth = image.Width; - int originalHeight = image.Height; + int originalWidth = image.Width; + int originalHeight = image.Height; - var predictions = await Task.Run(() => - { - // Set up - var inputName = _detectionSession.InputNames[0]; - var inputDimensions = _detectionSession.InputMetadata[inputName].Dimensions; + var predictions = await Task.Run(() => + { + // Set up + var inputName = _detectionSession.InputNames[0]; + var inputDimensions = _detectionSession.InputMetadata[inputName].Dimensions; - // Set batch size - int batchSize = 1; - inputDimensions[0] = batchSize; + // Set batch size + int batchSize = 1; + inputDimensions[0] = batchSize; - // I know the input dimensions to be [batchSize, 416, 416, 3] - int inputWidth = inputDimensions[1]; - int inputHeight = inputDimensions[2]; + // I know the input dimensions to be [batchSize, 416, 416, 3] + int inputWidth = inputDimensions[1]; + int inputHeight = inputDimensions[2]; - using var resizedImage = BitmapFunctions.ResizeWithPadding(image, inputWidth, inputHeight); + using var resizedImage = BitmapFunctions.ResizeWithPadding(image, inputWidth, inputHeight); - // Preprocessing - Tensor input = new DenseTensor(inputDimensions); - input = BitmapFunctions.PreprocessBitmapForYOLO(resizedImage, input); + // Preprocessing + Tensor input = new DenseTensor(inputDimensions); + input = BitmapFunctions.PreprocessBitmapForYOLO(resizedImage, input); - // Setup inputs and outputs - var inputMetadataName = _detectionSession!.InputNames[0]; - var inputs = new List - { - NamedOnnxValue.CreateFromTensor(inputMetadataName, input) - }; + // Setup inputs and outputs + var inputMetadataName = _detectionSession!.InputNames[0]; + var inputs = new List + { + NamedOnnxValue.CreateFromTensor(inputMetadataName, input) + }; - // Run inference - using IDisposableReadOnlyCollection results = _detectionSession!.Run(inputs); + // Run inference + using IDisposableReadOnlyCollection results = _detectionSession!.Run(inputs); - // Extract tensors from inference results - var outputTensor1 = results[0].AsTensor(); - var outputTensor2 = results[1].AsTensor(); - var outputTensor3 = results[2].AsTensor(); + // Extract tensors from inference results + var outputTensor1 = results[0].AsTensor(); + var outputTensor2 = results[1].AsTensor(); + var outputTensor3 = results[2].AsTensor(); - // Define anchors (as per your model) - var anchors = new List<(float Width, float Height)> - { - (12, 16), (19, 36), (40, 28), // Small grid (52x52) - (36, 75), (76, 55), (72, 146), // Medium grid (26x26) - (142, 110), (192, 243), (459, 401) // Large grid (13x13) - }; + // Define anchors (as per your model) + var anchors = new List<(float Width, float Height)> + { + (12, 16), (19, 36), (40, 28), // Small grid (52x52) + (36, 75), (76, 55), (72, 146), // Medium grid (26x26) + (142, 110), (192, 243), (459, 401) // Large grid (13x13) + }; - // Combine tensors into a list for processing - var gridTensors = new List> { outputTensor1, outputTensor2, outputTensor3 }; + // Combine tensors into a list for processing + var gridTensors = new List> { outputTensor1, outputTensor2, outputTensor3 }; - // Postprocessing steps - var extractedPredictions = YOLOHelpers.ExtractPredictions(gridTensors, anchors, inputWidth, inputHeight, originalWidth, originalHeight); + // Postprocessing steps + var extractedPredictions = YOLOHelpers.ExtractPredictions(gridTensors, anchors, inputWidth, inputHeight, originalWidth, originalHeight); - // Extra step for filtering overlapping predictions - var filteredPredictions = YOLOHelpers.ApplyNms(extractedPredictions, .4f); + // Extra step for filtering overlapping predictions + var filteredPredictions = YOLOHelpers.ApplyNms(extractedPredictions, .4f); - // Return the final predictions - return filteredPredictions; - }); + // Return the final predictions + return filteredPredictions; + }); - return predictions; - } + return predictions; + } - private async Task DetectPose(Bitmap image, Bitmap baseImage) + private async Task DetectPose(Bitmap image, Bitmap baseImage) + { + if (image == null) { - if (image == null) - { - return new Bitmap(0, 0); - } + return new Bitmap(0, 0); + } - var inputName = _poseSession!.InputNames[0]; - var inputDimensions = _poseSession.InputMetadata[inputName].Dimensions; + var inputName = _poseSession!.InputNames[0]; + var inputDimensions = _poseSession.InputMetadata[inputName].Dimensions; - var originalImageWidth = image.Width; - var originalImageHeight = image.Height; + var originalImageWidth = image.Width; + var originalImageHeight = image.Height; - int modelInputWidth = inputDimensions[2]; - int modelInputHeight = inputDimensions[3]; + int modelInputWidth = inputDimensions[2]; + int modelInputHeight = inputDimensions[3]; - // Resize Bitmap - using Bitmap resizedImage = BitmapFunctions.ResizeBitmap(image, modelInputWidth, modelInputHeight); + // Resize Bitmap + using Bitmap resizedImage = BitmapFunctions.ResizeBitmap(image, modelInputWidth, modelInputHeight); - var predictions = await Task.Run(() => - { - // Preprocessing - Tensor input = new DenseTensor(inputDimensions); - input = BitmapFunctions.PreprocessBitmapWithStdDev(resizedImage, input); + var predictions = await Task.Run(() => + { + // Preprocessing + Tensor input = new DenseTensor(inputDimensions); + input = BitmapFunctions.PreprocessBitmapWithStdDev(resizedImage, input); - // Setup inputs - var inputs = new List - { - NamedOnnxValue.CreateFromTensor(inputName, input) - }; + // Setup inputs + var inputs = new List + { + NamedOnnxValue.CreateFromTensor(inputName, input) + }; - // Run inference - using IDisposableReadOnlyCollection results = _poseSession!.Run(inputs); - var heatmaps = results[0].AsTensor(); + // Run inference + using IDisposableReadOnlyCollection results = _poseSession!.Run(inputs); + var heatmaps = results[0].AsTensor(); - var outputName = _poseSession!.OutputNames[0]; - var outputDimensions = _poseSession!.OutputMetadata[outputName].Dimensions; + var outputName = _poseSession!.OutputNames[0]; + var outputDimensions = _poseSession!.OutputMetadata[outputName].Dimensions; - float outputWidth = outputDimensions[2]; - float outputHeight = outputDimensions[3]; + float outputWidth = outputDimensions[2]; + float outputHeight = outputDimensions[3]; - List<(float X, float Y)> keypointCoordinates = PoseHelper.PostProcessResults(heatmaps, originalImageWidth, originalImageHeight, outputWidth, outputHeight); - return keypointCoordinates; - }); + List<(float X, float Y)> keypointCoordinates = PoseHelper.PostProcessResults(heatmaps, originalImageWidth, originalImageHeight, outputWidth, outputHeight); + return keypointCoordinates; + }); - // Render predictions and create output bitmap - return PoseHelper.RenderPredictions(image, predictions, .015f, baseImage); - } + // Render predictions and create output bitmap + return PoseHelper.RenderPredictions(image, predictions, .015f, baseImage); } } \ No newline at end of file diff --git a/AIDevGallery/Samples/Open Source Models/Image Models/YOLOv4/YOLOObjectionDetection.xaml.cs b/AIDevGallery/Samples/Open Source Models/Image Models/YOLOv4/YOLOObjectionDetection.xaml.cs index dbc1b31..df1d1bf 100644 --- a/AIDevGallery/Samples/Open Source Models/Image Models/YOLOv4/YOLOObjectionDetection.xaml.cs +++ b/AIDevGallery/Samples/Open Source Models/Image Models/YOLOv4/YOLOObjectionDetection.xaml.cs @@ -15,189 +15,188 @@ using System.Threading.Tasks; using Windows.Storage.Pickers; -namespace AIDevGallery.Samples.OpenSourceModels.YOLOv4 +namespace AIDevGallery.Samples.OpenSourceModels.YOLOv4; + +[GallerySample( + Model1Types = [ModelType.YOLO], + Scenario = ScenarioType.ImageDetectObjects, + SharedCode = [ + SharedCodeEnum.Prediction, + SharedCodeEnum.BitmapFunctions, + SharedCodeEnum.RCNNLabelMap, + SharedCodeEnum.YOLOHelpers, + SharedCodeEnum.DeviceUtils + ], + NugetPackageReferences = [ + "System.Drawing.Common", + "Microsoft.ML.OnnxRuntime.DirectML", + "Microsoft.ML.OnnxRuntime.Extensions" + ], + Name = "YOLO Object Detection", + Id = "9b74ccc0-15f7-430f-bed0-7581fd163508", + Icon = "\uE8B3")] + +internal sealed partial class YOLOObjectionDetection : BaseSamplePage { - [GallerySample( - Model1Types = [ModelType.YOLO], - Scenario = ScenarioType.ImageDetectObjects, - SharedCode = [ - SharedCodeEnum.Prediction, - SharedCodeEnum.BitmapFunctions, - SharedCodeEnum.RCNNLabelMap, - SharedCodeEnum.YOLOHelpers, - SharedCodeEnum.DeviceUtils - ], - NugetPackageReferences = [ - "System.Drawing.Common", - "Microsoft.ML.OnnxRuntime.DirectML", - "Microsoft.ML.OnnxRuntime.Extensions" - ], - Name = "YOLO Object Detection", - Id = "9b74ccc0-15f7-430f-bed0-7581fd163508", - Icon = "\uE8B3")] - - internal sealed partial class YOLOObjectionDetection : BaseSamplePage - { - private InferenceSession? _inferenceSession; + private InferenceSession? _inferenceSession; - public YOLOObjectionDetection() - { - this.Unloaded += (s, e) => _inferenceSession?.Dispose(); + public YOLOObjectionDetection() + { + this.Unloaded += (s, e) => _inferenceSession?.Dispose(); - this.Loaded += (s, e) => Page_Loaded(); // - this.InitializeComponent(); - } + this.Loaded += (s, e) => Page_Loaded(); // + this.InitializeComponent(); + } - private void Page_Loaded() - { - UploadButton.Focus(FocusState.Programmatic); - } + private void Page_Loaded() + { + UploadButton.Focus(FocusState.Programmatic); + } - // - protected override async Task LoadModelAsync(SampleNavigationParameters sampleParams) - { - var hardwareAccelerator = sampleParams.HardwareAccelerator; - await InitModel(sampleParams.ModelPath, hardwareAccelerator); + // + protected override async Task LoadModelAsync(SampleNavigationParameters sampleParams) + { + var hardwareAccelerator = sampleParams.HardwareAccelerator; + await InitModel(sampleParams.ModelPath, hardwareAccelerator); - sampleParams.NotifyCompletion(); + sampleParams.NotifyCompletion(); - // Loads inference on default image - await DetectObjects(Windows.ApplicationModel.Package.Current.InstalledLocation.Path + "\\Assets\\team.jpg"); - } + // Loads inference on default image + await DetectObjects(Windows.ApplicationModel.Package.Current.InstalledLocation.Path + "\\Assets\\team.jpg"); + } - private Task InitModel(string modelPath, HardwareAccelerator hardwareAccelerator) + private Task InitModel(string modelPath, HardwareAccelerator hardwareAccelerator) + { + return Task.Run(() => { - return Task.Run(() => + if (_inferenceSession != null) { - if (_inferenceSession != null) - { - return; - } - - SessionOptions sessionOptions = new(); - sessionOptions.RegisterOrtExtensions(); - if (hardwareAccelerator == HardwareAccelerator.DML) - { - sessionOptions.AppendExecutionProvider_DML(DeviceUtils.GetBestDeviceId()); - } - - _inferenceSession = new InferenceSession(modelPath, sessionOptions); - }); - } + return; + } - private async void UploadButton_Click(object sender, RoutedEventArgs e) - { - var window = new Window(); - var hwnd = WinRT.Interop.WindowNative.GetWindowHandle(window); + SessionOptions sessionOptions = new(); + sessionOptions.RegisterOrtExtensions(); + if (hardwareAccelerator == HardwareAccelerator.DML) + { + sessionOptions.AppendExecutionProvider_DML(DeviceUtils.GetBestDeviceId()); + } - // Create a FileOpenPicker - var picker = new FileOpenPicker(); + _inferenceSession = new InferenceSession(modelPath, sessionOptions); + }); + } - WinRT.Interop.InitializeWithWindow.Initialize(picker, hwnd); + private async void UploadButton_Click(object sender, RoutedEventArgs e) + { + var window = new Window(); + var hwnd = WinRT.Interop.WindowNative.GetWindowHandle(window); - // Set the file type filter - picker.FileTypeFilter.Add(".png"); - picker.FileTypeFilter.Add(".jpeg"); - picker.FileTypeFilter.Add(".jpg"); - picker.FileTypeFilter.Add(".bmp"); + // Create a FileOpenPicker + var picker = new FileOpenPicker(); - picker.ViewMode = PickerViewMode.Thumbnail; + WinRT.Interop.InitializeWithWindow.Initialize(picker, hwnd); - // Pick a file - var file = await picker.PickSingleFileAsync(); - if (file != null) - { - // Call function to run inference and classify image - UploadButton.Focus(FocusState.Programmatic); - await DetectObjects(file.Path); - } + // Set the file type filter + picker.FileTypeFilter.Add(".png"); + picker.FileTypeFilter.Add(".jpeg"); + picker.FileTypeFilter.Add(".jpg"); + picker.FileTypeFilter.Add(".bmp"); + + picker.ViewMode = PickerViewMode.Thumbnail; + + // Pick a file + var file = await picker.PickSingleFileAsync(); + if (file != null) + { + // Call function to run inference and classify image + UploadButton.Focus(FocusState.Programmatic); + await DetectObjects(file.Path); } + } - private async Task DetectObjects(string filePath) + private async Task DetectObjects(string filePath) + { + if (_inferenceSession == null) { - if (_inferenceSession == null) - { - return; - } + return; + } - Loader.IsActive = true; - Loader.Visibility = Visibility.Visible; - UploadButton.Visibility = Visibility.Collapsed; + Loader.IsActive = true; + Loader.Visibility = Visibility.Visible; + UploadButton.Visibility = Visibility.Collapsed; - DefaultImage.Source = new BitmapImage(new Uri(filePath)); - NarratorHelper.AnnounceImageChanged(DefaultImage, "Image changed: new upload."); // + DefaultImage.Source = new BitmapImage(new Uri(filePath)); + NarratorHelper.AnnounceImageChanged(DefaultImage, "Image changed: new upload."); // - Bitmap image = new(filePath); + Bitmap image = new(filePath); - int originalWidth = image.Width; - int originalHeight = image.Height; + int originalWidth = image.Width; + int originalHeight = image.Height; - var predictions = await Task.Run(() => - { - // Set up - var inputName = _inferenceSession.InputNames[0]; - var inputDimensions = _inferenceSession.InputMetadata[inputName].Dimensions; + var predictions = await Task.Run(() => + { + // Set up + var inputName = _inferenceSession.InputNames[0]; + var inputDimensions = _inferenceSession.InputMetadata[inputName].Dimensions; - // Set batch size - int batchSize = 1; - inputDimensions[0] = batchSize; + // Set batch size + int batchSize = 1; + inputDimensions[0] = batchSize; - // I know the input dimensions to be [batchSize, 416, 416, 3] - int inputWidth = inputDimensions[1]; - int inputHeight = inputDimensions[2]; + // I know the input dimensions to be [batchSize, 416, 416, 3] + int inputWidth = inputDimensions[1]; + int inputHeight = inputDimensions[2]; - using var resizedImage = BitmapFunctions.ResizeWithPadding(image, inputWidth, inputHeight); + using var resizedImage = BitmapFunctions.ResizeWithPadding(image, inputWidth, inputHeight); - // Preprocessing - Tensor input = new DenseTensor(inputDimensions); - input = BitmapFunctions.PreprocessBitmapForYOLO(resizedImage, input); + // Preprocessing + Tensor input = new DenseTensor(inputDimensions); + input = BitmapFunctions.PreprocessBitmapForYOLO(resizedImage, input); - // Setup inputs and outputs - var inputMetadataName = _inferenceSession!.InputNames[0]; - var inputs = new List - { - NamedOnnxValue.CreateFromTensor(inputMetadataName, input) - }; + // Setup inputs and outputs + var inputMetadataName = _inferenceSession!.InputNames[0]; + var inputs = new List + { + NamedOnnxValue.CreateFromTensor(inputMetadataName, input) + }; - // Run inference - using IDisposableReadOnlyCollection results = _inferenceSession!.Run(inputs); + // Run inference + using IDisposableReadOnlyCollection results = _inferenceSession!.Run(inputs); - // Extract tensors from inference results - var outputTensor1 = results[0].AsTensor(); - var outputTensor2 = results[1].AsTensor(); - var outputTensor3 = results[2].AsTensor(); + // Extract tensors from inference results + var outputTensor1 = results[0].AsTensor(); + var outputTensor2 = results[1].AsTensor(); + var outputTensor3 = results[2].AsTensor(); - // Define anchors (as per your model) - var anchors = new List<(float Width, float Height)> - { - (12, 16), (19, 36), (40, 28), // Small grid (52x52) - (36, 75), (76, 55), (72, 146), // Medium grid (26x26) - (142, 110), (192, 243), (459, 401) // Large grid (13x13) - }; + // Define anchors (as per your model) + var anchors = new List<(float Width, float Height)> + { + (12, 16), (19, 36), (40, 28), // Small grid (52x52) + (36, 75), (76, 55), (72, 146), // Medium grid (26x26) + (142, 110), (192, 243), (459, 401) // Large grid (13x13) + }; - // Combine tensors into a list for processing - var gridTensors = new List> { outputTensor1, outputTensor2, outputTensor3 }; + // Combine tensors into a list for processing + var gridTensors = new List> { outputTensor1, outputTensor2, outputTensor3 }; - // Postprocessing steps - var extractedPredictions = YOLOHelpers.ExtractPredictions(gridTensors, anchors, inputWidth, inputHeight, originalWidth, originalHeight); - var filteredPredictions = YOLOHelpers.ApplyNms(extractedPredictions, .4f); + // Postprocessing steps + var extractedPredictions = YOLOHelpers.ExtractPredictions(gridTensors, anchors, inputWidth, inputHeight, originalWidth, originalHeight); + var filteredPredictions = YOLOHelpers.ApplyNms(extractedPredictions, .4f); - // Return the final predictions - return filteredPredictions; - }); + // Return the final predictions + return filteredPredictions; + }); - BitmapImage outputImage = BitmapFunctions.RenderPredictions(image, predictions); + BitmapImage outputImage = BitmapFunctions.RenderPredictions(image, predictions); - DispatcherQueue.TryEnqueue(() => - { - DefaultImage.Source = outputImage; - Loader.IsActive = false; - Loader.Visibility = Visibility.Collapsed; - UploadButton.Visibility = Visibility.Visible; - }); - - NarratorHelper.AnnounceImageChanged(DefaultImage, "Image changed: objects detected."); // - image.Dispose(); - } + DispatcherQueue.TryEnqueue(() => + { + DefaultImage.Source = outputImage; + Loader.IsActive = false; + Loader.Visibility = Visibility.Collapsed; + UploadButton.Visibility = Visibility.Visible; + }); + + NarratorHelper.AnnounceImageChanged(DefaultImage, "Image changed: objects detected."); // + image.Dispose(); } } \ No newline at end of file diff --git a/AIDevGallery/Samples/SharedCode/PoseHelper.cs b/AIDevGallery/Samples/SharedCode/PoseHelper.cs index a534212..f5a27e6 100644 --- a/AIDevGallery/Samples/SharedCode/PoseHelper.cs +++ b/AIDevGallery/Samples/SharedCode/PoseHelper.cs @@ -5,94 +5,93 @@ using System.Collections.Generic; using System.Drawing; -namespace AIDevGallery.Samples.SharedCode +namespace AIDevGallery.Samples.SharedCode; + +internal class PoseHelper { - internal class PoseHelper + public static List<(float X, float Y)> PostProcessResults(Tensor heatmaps, float originalWidth, float originalHeight, float outputWidth, float outputHeight) { - public static List<(float X, float Y)> PostProcessResults(Tensor heatmaps, float originalWidth, float originalHeight, float outputWidth, float outputHeight) - { - List<(float X, float Y)> keypointCoordinates = []; + List<(float X, float Y)> keypointCoordinates = []; - // Scaling factors from heatmap (64x48) directly to original image size - float scale_x = originalWidth / outputWidth; - float scale_y = originalHeight / outputHeight; + // Scaling factors from heatmap (64x48) directly to original image size + float scale_x = originalWidth / outputWidth; + float scale_y = originalHeight / outputHeight; - int numKeypoints = heatmaps.Dimensions[1]; - int heatmapWidth = heatmaps.Dimensions[2]; - int heatmapHeight = heatmaps.Dimensions[3]; + int numKeypoints = heatmaps.Dimensions[1]; + int heatmapWidth = heatmaps.Dimensions[2]; + int heatmapHeight = heatmaps.Dimensions[3]; - for (int i = 0; i < numKeypoints; i++) - { - float maxVal = float.MinValue; - int maxX = 0, maxY = 0; + for (int i = 0; i < numKeypoints; i++) + { + float maxVal = float.MinValue; + int maxX = 0, maxY = 0; - for (int x = 0; x < heatmapWidth; x++) + for (int x = 0; x < heatmapWidth; x++) + { + for (int y = 0; y < heatmapHeight; y++) { - for (int y = 0; y < heatmapHeight; y++) + float value = heatmaps[0, i, y, x]; + if (value > maxVal) { - float value = heatmaps[0, i, y, x]; - if (value > maxVal) - { - maxVal = value; - maxX = x; - maxY = y; - } + maxVal = value; + maxX = x; + maxY = y; } } - - float scaledX = maxX * scale_x; - float scaledY = maxY * scale_y; - - keypointCoordinates.Add((scaledX, scaledY)); } - return keypointCoordinates; + float scaledX = maxX * scale_x; + float scaledY = maxY * scale_y; + + keypointCoordinates.Add((scaledX, scaledY)); } - public static Bitmap RenderPredictions(Bitmap originalImage, List<(float X, float Y)> keypoints, float markerRatio, Bitmap? baseImage = null) - { - Bitmap outputImage = new(originalImage); + return keypointCoordinates; + } - using (Graphics g = Graphics.FromImage(outputImage)) - { - // If refernce is multipose, use base image not cropped image for scaling - // If reference is one person pose, use original image as base image isn't used. - var imageValue = baseImage != null ? baseImage.Width + baseImage.Height : originalImage.Width + originalImage.Height; - int markerSize = (int)(imageValue * markerRatio / 2); - Brush brush = Brushes.Red; - - using Pen linePen = new(Color.Blue, markerSize / 2); - List<(int StartIdx, int EndIdx)> connections = - [ - (5, 6), // Left shoulder to right shoulder - (5, 7), // Left shoulder to left elbow - (7, 9), // Left elbow to left wrist - (6, 8), // Right shoulder to right elbow - (8, 10), // Right elbow to right wrist - (11, 12), // Left hip to right hip - (5, 11), // Left shoulder to left hip - (6, 12), // Right shoulder to right hip - (11, 13), // Left hip to left knee - (13, 15), // Left knee to left ankle - (12, 14), // Right hip to right knee - (14, 16) // Right knee to right ankle - ]; - - foreach (var (startIdx, endIdx) in connections) - { - var (startPointX, startPointY) = keypoints[startIdx]; - var (endPointX, endPointY) = keypoints[endIdx]; + public static Bitmap RenderPredictions(Bitmap originalImage, List<(float X, float Y)> keypoints, float markerRatio, Bitmap? baseImage = null) + { + Bitmap outputImage = new(originalImage); - g.DrawLine(linePen, startPointX, startPointY, endPointX, endPointY); - } + using (Graphics g = Graphics.FromImage(outputImage)) + { + // If refernce is multipose, use base image not cropped image for scaling + // If reference is one person pose, use original image as base image isn't used. + var imageValue = baseImage != null ? baseImage.Width + baseImage.Height : originalImage.Width + originalImage.Height; + int markerSize = (int)(imageValue * markerRatio / 2); + Brush brush = Brushes.Red; + + using Pen linePen = new(Color.Blue, markerSize / 2); + List<(int StartIdx, int EndIdx)> connections = + [ + (5, 6), // Left shoulder to right shoulder + (5, 7), // Left shoulder to left elbow + (7, 9), // Left elbow to left wrist + (6, 8), // Right shoulder to right elbow + (8, 10), // Right elbow to right wrist + (11, 12), // Left hip to right hip + (5, 11), // Left shoulder to left hip + (6, 12), // Right shoulder to right hip + (11, 13), // Left hip to left knee + (13, 15), // Left knee to left ankle + (12, 14), // Right hip to right knee + (14, 16) // Right knee to right ankle + ]; + + foreach (var (startIdx, endIdx) in connections) + { + var (startPointX, startPointY) = keypoints[startIdx]; + var (endPointX, endPointY) = keypoints[endIdx]; - foreach (var (x, y) in keypoints) - { - g.FillEllipse(brush, x - markerSize / 2, y - markerSize / 2, markerSize, markerSize); - } + g.DrawLine(linePen, startPointX, startPointY, endPointX, endPointY); } - return outputImage; + foreach (var (x, y) in keypoints) + { + g.FillEllipse(brush, x - markerSize / 2, y - markerSize / 2, markerSize, markerSize); + } } + + return outputImage; } } \ No newline at end of file diff --git a/AIDevGallery/Samples/SharedCode/YOLOHelpers.cs b/AIDevGallery/Samples/SharedCode/YOLOHelpers.cs index db065f4..629c26a 100644 --- a/AIDevGallery/Samples/SharedCode/YOLOHelpers.cs +++ b/AIDevGallery/Samples/SharedCode/YOLOHelpers.cs @@ -6,144 +6,143 @@ using System.Collections.Generic; using System.Linq; -namespace AIDevGallery.Samples.SharedCode +namespace AIDevGallery.Samples.SharedCode; + +internal class YOLOHelpers { - internal class YOLOHelpers + public static List ExtractPredictions(List> gridTensors, List<(float Width, float Height)> anchors, int inputWidth, int inputHeight, int originalWidth, int originalHeight) { - public static List ExtractPredictions(List> gridTensors, List<(float Width, float Height)> anchors, int inputWidth, int inputHeight, int originalWidth, int originalHeight) - { - var predictions = new List(); - int anchorCounter = 0; - float confidenceThreshold = .5f; + var predictions = new List(); + int anchorCounter = 0; + float confidenceThreshold = .5f; - foreach (var tensor in gridTensors) - { - var gridSize = tensor.Dimensions[2]; + foreach (var tensor in gridTensors) + { + var gridSize = tensor.Dimensions[2]; - int gridX = gridSize; - int gridY = gridSize; - int numAnchors = tensor.Dimensions[3]; + int gridX = gridSize; + int gridY = gridSize; + int numAnchors = tensor.Dimensions[3]; - for (int anchor = 0; anchor < numAnchors; anchor++) + for (int anchor = 0; anchor < numAnchors; anchor++) + { + for (int i = 0; i < gridX; i++) { - for (int i = 0; i < gridX; i++) + for (int j = 0; j < gridY; j++) { - for (int j = 0; j < gridY; j++) + // Access prediction vector + var predictionVector = new List(); + for (int k = 0; k < tensor.Dimensions[^1]; k++) { - // Access prediction vector - var predictionVector = new List(); - for (int k = 0; k < tensor.Dimensions[^1]; k++) - { - predictionVector.Add(tensor[0, i, j, anchor, k]); - } - - // Extract bounding box and confidence - float bx = Sigmoid(predictionVector[0]); // x offset - float by = Sigmoid(predictionVector[1]); // y offset - float bw = (float)Math.Exp(predictionVector[2]) * anchors[anchorCounter + anchor].Width; - float bh = (float)Math.Exp(predictionVector[3]) * anchors[anchorCounter + anchor].Height; - float confidence = Sigmoid(predictionVector[4]); - - // Skip low-confidence predictions - if (confidence < confidenceThreshold) - { - continue; - } - - // Get class probabilities - var classProbs = predictionVector.Skip(5).Select(Sigmoid).ToArray(); - float maxProb = classProbs.Max(); - int classIndex = Array.IndexOf(classProbs, maxProb); - - // Skip if class probability is low - if (maxProb * confidence < confidenceThreshold) - { - continue; - } - - // Adjust bounding box to image dimensions - bx = (bx + j) * (inputWidth / gridX); // Convert to absolute x - by = (by + i) * (inputHeight / gridY); // Convert to absolute y - bw *= inputWidth / 416; // Normalize to input width - bh *= inputHeight / 416; // Normalize to input height - - float scale = Math.Min((float)inputWidth / originalWidth, (float)inputHeight / originalHeight); - int offsetX = (inputWidth - (int)(originalWidth * scale)) / 2; - int offsetY = (inputHeight - (int)(originalHeight * scale)) / 2; - - float xmin = (bx - bw / 2 - offsetX) / scale; - float ymin = (by - bh / 2 - offsetY) / scale; - float xmax = (bx + bw / 2 - offsetX) / scale; - float ymax = (by + bh / 2 - offsetY) / scale; - - // Define your class labels (replace with your model's labels) - string[] labels = RCNNLabelMap.Labels.Skip(1).ToArray(); - - // Add prediction - predictions.Add(new Prediction - { - Box = new Box(xmin, ymin, xmax, ymax), - Label = labels[classIndex], // Use label from the provided labels array - Confidence = confidence * maxProb - }); + predictionVector.Add(tensor[0, i, j, anchor, k]); } - } - } - // Increment anchorCounter for the next grid level - anchorCounter += numAnchors; - } + // Extract bounding box and confidence + float bx = Sigmoid(predictionVector[0]); // x offset + float by = Sigmoid(predictionVector[1]); // y offset + float bw = (float)Math.Exp(predictionVector[2]) * anchors[anchorCounter + anchor].Width; + float bh = (float)Math.Exp(predictionVector[3]) * anchors[anchorCounter + anchor].Height; + float confidence = Sigmoid(predictionVector[4]); - return predictions; - } + // Skip low-confidence predictions + if (confidence < confidenceThreshold) + { + continue; + } - private static float Sigmoid(float x) - { - return 1f / (1f + (float)Math.Exp(-x)); - } + // Get class probabilities + var classProbs = predictionVector.Skip(5).Select(Sigmoid).ToArray(); + float maxProb = classProbs.Max(); + int classIndex = Array.IndexOf(classProbs, maxProb); - public static List ApplyNms(List predictions, float nmsThreshold) - { - var filteredPredictions = new List(); + // Skip if class probability is low + if (maxProb * confidence < confidenceThreshold) + { + continue; + } - // Group predictions by class - var groupedPredictions = predictions.GroupBy(p => p.Label); + // Adjust bounding box to image dimensions + bx = (bx + j) * (inputWidth / gridX); // Convert to absolute x + by = (by + i) * (inputHeight / gridY); // Convert to absolute y + bw *= inputWidth / 416; // Normalize to input width + bh *= inputHeight / 416; // Normalize to input height - foreach (var group in groupedPredictions) - { - var sortedGroup = group.OrderByDescending(p => p.Confidence).ToList(); + float scale = Math.Min((float)inputWidth / originalWidth, (float)inputHeight / originalHeight); + int offsetX = (inputWidth - (int)(originalWidth * scale)) / 2; + int offsetY = (inputHeight - (int)(originalHeight * scale)) / 2; - while (sortedGroup.Count > 0) - { - // Take the highest confidence prediction - var bestPrediction = sortedGroup[0]; - filteredPredictions.Add(bestPrediction); - sortedGroup.RemoveAt(0); - - // Remove overlapping predictions - sortedGroup = sortedGroup - .Where(p => IoU(bestPrediction.Box!, p.Box!) < nmsThreshold) - .ToList(); + float xmin = (bx - bw / 2 - offsetX) / scale; + float ymin = (by - bh / 2 - offsetY) / scale; + float xmax = (bx + bw / 2 - offsetX) / scale; + float ymax = (by + bh / 2 - offsetY) / scale; + + // Define your class labels (replace with your model's labels) + string[] labels = RCNNLabelMap.Labels.Skip(1).ToArray(); + + // Add prediction + predictions.Add(new Prediction + { + Box = new Box(xmin, ymin, xmax, ymax), + Label = labels[classIndex], // Use label from the provided labels array + Confidence = confidence * maxProb + }); + } } } - return filteredPredictions; + // Increment anchorCounter for the next grid level + anchorCounter += numAnchors; } - // Function to compute Intersection Over Union (IoU) - private static float IoU(Box boxA, Box boxB) - { - float x1 = Math.Max(boxA.Xmin, boxB.Xmin); - float y1 = Math.Max(boxA.Ymin, boxB.Ymin); - float x2 = Math.Min(boxA.Xmax, boxB.Xmax); - float y2 = Math.Min(boxA.Ymax, boxB.Ymax); + return predictions; + } + + private static float Sigmoid(float x) + { + return 1f / (1f + (float)Math.Exp(-x)); + } - float intersection = Math.Max(0, x2 - x1) * Math.Max(0, y2 - y1); - float union = (boxA.Xmax - boxA.Xmin) * (boxA.Ymax - boxA.Ymin) + - (boxB.Xmax - boxB.Xmin) * (boxB.Ymax - boxB.Ymin) - - intersection; + public static List ApplyNms(List predictions, float nmsThreshold) + { + var filteredPredictions = new List(); - return intersection / union; + // Group predictions by class + var groupedPredictions = predictions.GroupBy(p => p.Label); + + foreach (var group in groupedPredictions) + { + var sortedGroup = group.OrderByDescending(p => p.Confidence).ToList(); + + while (sortedGroup.Count > 0) + { + // Take the highest confidence prediction + var bestPrediction = sortedGroup[0]; + filteredPredictions.Add(bestPrediction); + sortedGroup.RemoveAt(0); + + // Remove overlapping predictions + sortedGroup = sortedGroup + .Where(p => IoU(bestPrediction.Box!, p.Box!) < nmsThreshold) + .ToList(); + } } + + return filteredPredictions; + } + + // Function to compute Intersection Over Union (IoU) + private static float IoU(Box boxA, Box boxB) + { + float x1 = Math.Max(boxA.Xmin, boxB.Xmin); + float y1 = Math.Max(boxA.Ymin, boxB.Ymin); + float x2 = Math.Min(boxA.Xmax, boxB.Xmax); + float y2 = Math.Min(boxA.Ymax, boxB.Ymax); + + float intersection = Math.Max(0, x2 - x1) * Math.Max(0, y2 - y1); + float union = (boxA.Xmax - boxA.Xmin) * (boxA.Ymax - boxA.Ymin) + + (boxB.Xmax - boxB.Xmin) * (boxB.Ymax - boxB.Ymin) - + intersection; + + return intersection / union; } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/ButtonClickedEvent.cs b/AIDevGallery/Telemetry/Events/ButtonClickedEvent.cs index d918ee5..a3e7a4a 100644 --- a/AIDevGallery/Telemetry/Events/ButtonClickedEvent.cs +++ b/AIDevGallery/Telemetry/Events/ButtonClickedEvent.cs @@ -6,31 +6,30 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class ButtonClickedEvent : EventBase { - [EventData] - internal class ButtonClickedEvent : EventBase - { - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public string ButtonName - { - get; - } + public string ButtonName + { + get; + } - private ButtonClickedEvent(string buttonName) - { - ButtonName = buttonName; - } + private ButtonClickedEvent(string buttonName) + { + ButtonName = buttonName; + } - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - // No sensitive strings to replace. - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + // No sensitive strings to replace. + } - public static void Log(string buttonName) - { - TelemetryFactory.Get().Log("ButtonClicked_Event", LogLevel.Measure, new ButtonClickedEvent(buttonName)); - } + public static void Log(string buttonName) + { + TelemetryFactory.Get().Log("ButtonClicked_Event", LogLevel.Measure, new ButtonClickedEvent(buttonName)); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/DeepLinkActivatedEvent.cs b/AIDevGallery/Telemetry/Events/DeepLinkActivatedEvent.cs index fc7040f..229c661 100644 --- a/AIDevGallery/Telemetry/Events/DeepLinkActivatedEvent.cs +++ b/AIDevGallery/Telemetry/Events/DeepLinkActivatedEvent.cs @@ -6,31 +6,30 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class DeepLinkActivatedEvent : EventBase { - [EventData] - internal class DeepLinkActivatedEvent : EventBase - { - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public string Uri - { - get; - } + public string Uri + { + get; + } - private DeepLinkActivatedEvent(string uri) - { - Uri = uri; - } + private DeepLinkActivatedEvent(string uri) + { + Uri = uri; + } - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - // No sensitive strings to replace. - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + // No sensitive strings to replace. + } - public static void Log(string uri) - { - TelemetryFactory.Get().Log("DeepLinkActivated_Event", LogLevel.Measure, new DeepLinkActivatedEvent(uri)); - } + public static void Log(string uri) + { + TelemetryFactory.Get().Log("DeepLinkActivated_Event", LogLevel.Measure, new DeepLinkActivatedEvent(uri)); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/DownloadSearchedModelEvent.cs b/AIDevGallery/Telemetry/Events/DownloadSearchedModelEvent.cs index 576761f..290baa0 100644 --- a/AIDevGallery/Telemetry/Events/DownloadSearchedModelEvent.cs +++ b/AIDevGallery/Telemetry/Events/DownloadSearchedModelEvent.cs @@ -6,28 +6,27 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class DownloadSearchedModelEvent : EventBase { - [EventData] - internal class DownloadSearchedModelEvent : EventBase - { - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public string ModelName { get; } + public string ModelName { get; } - private DownloadSearchedModelEvent(string modelName) - { - ModelName = modelName; - } + private DownloadSearchedModelEvent(string modelName) + { + ModelName = modelName; + } - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - // No sensitive strings to replace. - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + // No sensitive strings to replace. + } - public static void Log(string modelName) - { - TelemetryFactory.Get().Log("DownloadSearchedModel_Event", LogLevel.Measure, new DownloadSearchedModelEvent(modelName)); - } + public static void Log(string modelName) + { + TelemetryFactory.Get().Log("DownloadSearchedModel_Event", LogLevel.Measure, new DownloadSearchedModelEvent(modelName)); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/EmptyEvent.cs b/AIDevGallery/Telemetry/Events/EmptyEvent.cs index fc9ba9d..a8947b1 100644 --- a/AIDevGallery/Telemetry/Events/EmptyEvent.cs +++ b/AIDevGallery/Telemetry/Events/EmptyEvent.cs @@ -5,21 +5,20 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry +namespace AIDevGallery.Telemetry; + +[EventData] +internal class EmptyEvent : EventBase { - [EventData] - internal class EmptyEvent : EventBase - { - public override PartA_PrivTags PartA_PrivTags { get; } + public override PartA_PrivTags PartA_PrivTags { get; } - public EmptyEvent(PartA_PrivTags tags) - { - PartA_PrivTags = tags; - } + public EmptyEvent(PartA_PrivTags tags) + { + PartA_PrivTags = tags; + } - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - // No sensitive string - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + // No sensitive string } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/EventBase.cs b/AIDevGallery/Telemetry/Events/EventBase.cs index 5fc22b1..557624e 100644 --- a/AIDevGallery/Telemetry/Events/EventBase.cs +++ b/AIDevGallery/Telemetry/Events/EventBase.cs @@ -5,39 +5,38 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry +namespace AIDevGallery.Telemetry; + +/// +/// Base class for all telemetry events to ensure they are properly tagged. +/// +/// +/// The public properties of each event are logged in the telemetry. +/// We should not change an event's properties, as that could break the processing of that event's data. +/// +[EventData] +public abstract class EventBase { /// - /// Base class for all telemetry events to ensure they are properly tagged. + /// Gets the privacy datatype tag for the telemetry event. /// - /// - /// The public properties of each event are logged in the telemetry. - /// We should not change an event's properties, as that could break the processing of that event's data. - /// - [EventData] - public abstract class EventBase - { - /// - /// Gets the privacy datatype tag for the telemetry event. - /// #pragma warning disable CA1707 // Identifiers should not contain underscores - public abstract PartA_PrivTags PartA_PrivTags + public abstract PartA_PrivTags PartA_PrivTags #pragma warning restore CA1707 // Identifiers should not contain underscores - { - get; - } - - /// - /// Replaces all the strings in this event that may contain PII using the provided function. - /// - /// - /// This is called by before logging the event. - /// It is the responsibility of each event to ensure we replace all strings with possible PII; - /// we ensure we at least consider this by forcing to implement this. - /// - /// - /// A function that replaces all the sensitive strings in a given string with tokens - /// - public abstract void ReplaceSensitiveStrings(Func replaceSensitiveStrings); + { + get; } + + /// + /// Replaces all the strings in this event that may contain PII using the provided function. + /// + /// + /// This is called by before logging the event. + /// It is the responsibility of each event to ensure we replace all strings with possible PII; + /// we ensure we at least consider this by forcing to implement this. + /// + /// + /// A function that replaces all the sensitive strings in a given string with tokens + /// + public abstract void ReplaceSensitiveStrings(Func replaceSensitiveStrings); } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/ModelCacheDeletedEvent.cs b/AIDevGallery/Telemetry/Events/ModelCacheDeletedEvent.cs index 650ce06..22a816c 100644 --- a/AIDevGallery/Telemetry/Events/ModelCacheDeletedEvent.cs +++ b/AIDevGallery/Telemetry/Events/ModelCacheDeletedEvent.cs @@ -1,13 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +internal static class ModelCacheDeletedEvent { - internal static class ModelCacheDeletedEvent + public static void Log() { - public static void Log() - { - TelemetryFactory.Get().LogCritical("ModelCacheDeleted_Event"); - } + TelemetryFactory.Get().LogCritical("ModelCacheDeleted_Event"); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/ModelCacheMovedEvent.cs b/AIDevGallery/Telemetry/Events/ModelCacheMovedEvent.cs index 7f1997f..243e5b4 100644 --- a/AIDevGallery/Telemetry/Events/ModelCacheMovedEvent.cs +++ b/AIDevGallery/Telemetry/Events/ModelCacheMovedEvent.cs @@ -4,29 +4,28 @@ using Microsoft.Diagnostics.Telemetry.Internal; using System; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +internal class ModelCacheMovedEvent : EventBase { - internal class ModelCacheMovedEvent : EventBase - { #pragma warning disable IDE0052 // Remove unread private members - private readonly string newPath; + private readonly string newPath; #pragma warning restore IDE0052 // Remove unread private members - private ModelCacheMovedEvent(string newPath) - { - this.newPath = newPath; - } + private ModelCacheMovedEvent(string newPath) + { + this.newPath = newPath; + } - public override PartA_PrivTags PartA_PrivTags => throw new NotImplementedException(); + public override PartA_PrivTags PartA_PrivTags => throw new NotImplementedException(); - public static void Log(string newPath) - { - TelemetryFactory.Get().Log("ModelCacheMoved_Event", LogLevel.Measure, new ModelCacheMovedEvent(newPath)); - } + public static void Log(string newPath) + { + TelemetryFactory.Get().Log("ModelCacheMoved_Event", LogLevel.Measure, new ModelCacheMovedEvent(newPath)); + } - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - // No sensitive strings to replace. - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + // No sensitive strings to replace. } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/ModelDeletedEvent.cs b/AIDevGallery/Telemetry/Events/ModelDeletedEvent.cs index de6862a..44f6fef 100644 --- a/AIDevGallery/Telemetry/Events/ModelDeletedEvent.cs +++ b/AIDevGallery/Telemetry/Events/ModelDeletedEvent.cs @@ -6,28 +6,27 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class ModelDeletedEvent : EventBase { - [EventData] - internal class ModelDeletedEvent : EventBase - { - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public string ModelName { get; } + public string ModelName { get; } - private ModelDeletedEvent(string modelName) - { - ModelName = modelName; - } + private ModelDeletedEvent(string modelName) + { + ModelName = modelName; + } - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - // No sensitive strings to replace. - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + // No sensitive strings to replace. + } - public static void Log(string modelName) - { - TelemetryFactory.Get().Log("ModelDeleted_Event", LogLevel.Measure, new ModelDeletedEvent(modelName)); - } + public static void Log(string modelName) + { + TelemetryFactory.Get().Log("ModelDeleted_Event", LogLevel.Measure, new ModelDeletedEvent(modelName)); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/ModelDetailsLinkClickedEvent.cs b/AIDevGallery/Telemetry/Events/ModelDetailsLinkClickedEvent.cs index fe2769c..67d271a 100644 --- a/AIDevGallery/Telemetry/Events/ModelDetailsLinkClickedEvent.cs +++ b/AIDevGallery/Telemetry/Events/ModelDetailsLinkClickedEvent.cs @@ -6,31 +6,30 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class ModelDetailsLinkClickedEvent : EventBase { - [EventData] - internal class ModelDetailsLinkClickedEvent : EventBase - { - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public string Link - { - get; - } + public string Link + { + get; + } - private ModelDetailsLinkClickedEvent(string link) - { - Link = link; - } + private ModelDetailsLinkClickedEvent(string link) + { + Link = link; + } - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - // No sensitive strings to replace. - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + // No sensitive strings to replace. + } - public static void Log(string link) - { - TelemetryFactory.Get().Log("ModelDetailsLinkClicked_Event", LogLevel.Measure, new ModelDetailsLinkClickedEvent(link)); - } + public static void Log(string link) + { + TelemetryFactory.Get().Log("ModelDetailsLinkClicked_Event", LogLevel.Measure, new ModelDetailsLinkClickedEvent(link)); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/ModelDownloadCancelEvent.cs b/AIDevGallery/Telemetry/Events/ModelDownloadCancelEvent.cs index 0524324..1315cc4 100644 --- a/AIDevGallery/Telemetry/Events/ModelDownloadCancelEvent.cs +++ b/AIDevGallery/Telemetry/Events/ModelDownloadCancelEvent.cs @@ -6,30 +6,29 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class ModelDownloadCancelEvent : EventBase { - [EventData] - internal class ModelDownloadCancelEvent : EventBase + internal ModelDownloadCancelEvent(string modelUrl, DateTime canceledTime) { - internal ModelDownloadCancelEvent(string modelUrl, DateTime canceledTime) - { - ModelUrl = modelUrl; - CanceledTime = canceledTime; - } + ModelUrl = modelUrl; + CanceledTime = canceledTime; + } - public string ModelUrl { get; private set; } + public string ModelUrl { get; private set; } - public DateTime CanceledTime { get; private set; } + public DateTime CanceledTime { get; private set; } - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + } - public static void Log(string modelUrl) - { - TelemetryFactory.Get().Log("ModelDownloadCancel_Event", LogLevel.Measure, new ModelDownloadCancelEvent(modelUrl, DateTime.Now)); - } + public static void Log(string modelUrl) + { + TelemetryFactory.Get().Log("ModelDownloadCancel_Event", LogLevel.Measure, new ModelDownloadCancelEvent(modelUrl, DateTime.Now)); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/ModelDownloadCompleteEvent.cs b/AIDevGallery/Telemetry/Events/ModelDownloadCompleteEvent.cs index 386d464..c8288d2 100644 --- a/AIDevGallery/Telemetry/Events/ModelDownloadCompleteEvent.cs +++ b/AIDevGallery/Telemetry/Events/ModelDownloadCompleteEvent.cs @@ -6,30 +6,29 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class ModelDownloadCompleteEvent : EventBase { - [EventData] - internal class ModelDownloadCompleteEvent : EventBase + internal ModelDownloadCompleteEvent(string modelUrl, DateTime completeTime) { - internal ModelDownloadCompleteEvent(string modelUrl, DateTime completeTime) - { - ModelUrl = modelUrl; - CompleteTime = completeTime; - } + ModelUrl = modelUrl; + CompleteTime = completeTime; + } - public string ModelUrl { get; private set; } + public string ModelUrl { get; private set; } - public DateTime CompleteTime { get; private set; } + public DateTime CompleteTime { get; private set; } - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + } - public static void Log(string modelUrl) - { - TelemetryFactory.Get().Log("ModelDownloadComplete_Event", LogLevel.Measure, new ModelDownloadCompleteEvent(modelUrl, DateTime.Now)); - } + public static void Log(string modelUrl) + { + TelemetryFactory.Get().Log("ModelDownloadComplete_Event", LogLevel.Measure, new ModelDownloadCompleteEvent(modelUrl, DateTime.Now)); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/ModelDownloadEnqueueEvent.cs b/AIDevGallery/Telemetry/Events/ModelDownloadEnqueueEvent.cs index 7511186..7cbb437 100644 --- a/AIDevGallery/Telemetry/Events/ModelDownloadEnqueueEvent.cs +++ b/AIDevGallery/Telemetry/Events/ModelDownloadEnqueueEvent.cs @@ -6,30 +6,29 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class ModelDownloadEnqueueEvent : EventBase { - [EventData] - internal class ModelDownloadEnqueueEvent : EventBase + internal ModelDownloadEnqueueEvent(string modelUrl, DateTime enqueuedTime) { - internal ModelDownloadEnqueueEvent(string modelUrl, DateTime enqueuedTime) - { - ModelUrl = modelUrl; - EnqueuedTime = enqueuedTime; - } + ModelUrl = modelUrl; + EnqueuedTime = enqueuedTime; + } - public string ModelUrl { get; private set; } + public string ModelUrl { get; private set; } - public DateTime EnqueuedTime { get; private set; } + public DateTime EnqueuedTime { get; private set; } - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + } - public static void Log(string modelUrl) - { - TelemetryFactory.Get().Log("ModelDownloadEnqueue_Event", LogLevel.Measure, new ModelDownloadEnqueueEvent(modelUrl, DateTime.Now)); - } + public static void Log(string modelUrl) + { + TelemetryFactory.Get().Log("ModelDownloadEnqueue_Event", LogLevel.Measure, new ModelDownloadEnqueueEvent(modelUrl, DateTime.Now)); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/ModelDownloadFailedEvent.cs b/AIDevGallery/Telemetry/Events/ModelDownloadFailedEvent.cs index 1d94046..5529dc6 100644 --- a/AIDevGallery/Telemetry/Events/ModelDownloadFailedEvent.cs +++ b/AIDevGallery/Telemetry/Events/ModelDownloadFailedEvent.cs @@ -6,32 +6,31 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class ModelDownloadFailedEvent : EventBase { - [EventData] - internal class ModelDownloadFailedEvent : EventBase + internal ModelDownloadFailedEvent(string modelUrl, DateTime errorTime) { - internal ModelDownloadFailedEvent(string modelUrl, DateTime errorTime) - { - ModelUrl = modelUrl; - ErrorTime = errorTime; - } + ModelUrl = modelUrl; + ErrorTime = errorTime; + } - public string ModelUrl { get; private set; } + public string ModelUrl { get; private set; } - public DateTime ErrorTime { get; private set; } + public DateTime ErrorTime { get; private set; } - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + } - public static void Log(string modelUrl, Exception ex) - { - var relatedActivityId = Guid.NewGuid(); - TelemetryFactory.Get().LogError("ModelDownloadFailed_Event", LogLevel.Critical, new ModelDownloadFailedEvent(modelUrl, DateTime.Now), relatedActivityId); - TelemetryFactory.Get().LogException("ModelDownloadFailed_Event", ex, relatedActivityId); - } + public static void Log(string modelUrl, Exception ex) + { + var relatedActivityId = Guid.NewGuid(); + TelemetryFactory.Get().LogError("ModelDownloadFailed_Event", LogLevel.Critical, new ModelDownloadFailedEvent(modelUrl, DateTime.Now), relatedActivityId); + TelemetryFactory.Get().LogException("ModelDownloadFailed_Event", ex, relatedActivityId); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/ModelDownloadStartEvent.cs b/AIDevGallery/Telemetry/Events/ModelDownloadStartEvent.cs index ef6017a..e7e62cb 100644 --- a/AIDevGallery/Telemetry/Events/ModelDownloadStartEvent.cs +++ b/AIDevGallery/Telemetry/Events/ModelDownloadStartEvent.cs @@ -6,30 +6,29 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class ModelDownloadStartEvent : EventBase { - [EventData] - internal class ModelDownloadStartEvent : EventBase + internal ModelDownloadStartEvent(string modelUrl, DateTime startTime) { - internal ModelDownloadStartEvent(string modelUrl, DateTime startTime) - { - ModelUrl = modelUrl; - StartTime = startTime; - } + ModelUrl = modelUrl; + StartTime = startTime; + } - public string ModelUrl { get; private set; } + public string ModelUrl { get; private set; } - public DateTime StartTime { get; private set; } + public DateTime StartTime { get; private set; } - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + } - public static void Log(string modelUrl) - { - TelemetryFactory.Get().Log("ModelDownloadStart_Event", LogLevel.Measure, new ModelDownloadStartEvent(modelUrl, DateTime.Now)); - } + public static void Log(string modelUrl) + { + TelemetryFactory.Get().Log("ModelDownloadStart_Event", LogLevel.Measure, new ModelDownloadStartEvent(modelUrl, DateTime.Now)); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/NavigatedToPageEvent.cs b/AIDevGallery/Telemetry/Events/NavigatedToPageEvent.cs index 193a2f7..6d096c6 100644 --- a/AIDevGallery/Telemetry/Events/NavigatedToPageEvent.cs +++ b/AIDevGallery/Telemetry/Events/NavigatedToPageEvent.cs @@ -6,31 +6,30 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class NavigatedToPageEvent : EventBase { - [EventData] - internal class NavigatedToPageEvent : EventBase - { - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public string PageName - { - get; - } + public string PageName + { + get; + } - private NavigatedToPageEvent(string pageName) - { - PageName = pageName; - } + private NavigatedToPageEvent(string pageName) + { + PageName = pageName; + } - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - // No sensitive strings to replace. - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + // No sensitive strings to replace. + } - public static void Log(string pageName) - { - TelemetryFactory.Get().Log("NavigatedToPage_Event", LogLevel.Measure, new NavigatedToPageEvent(pageName)); - } + public static void Log(string pageName) + { + TelemetryFactory.Get().Log("NavigatedToPage_Event", LogLevel.Measure, new NavigatedToPageEvent(pageName)); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/NavigatedToSampleEvent.cs b/AIDevGallery/Telemetry/Events/NavigatedToSampleEvent.cs index 728301a..76ab2c5 100644 --- a/AIDevGallery/Telemetry/Events/NavigatedToSampleEvent.cs +++ b/AIDevGallery/Telemetry/Events/NavigatedToSampleEvent.cs @@ -6,31 +6,30 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class NavigatedToSampleEvent : EventBase { - [EventData] - internal class NavigatedToSampleEvent : EventBase - { - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public string Name { get; } + public string Name { get; } - public DateTime StartTime { get; } + public DateTime StartTime { get; } - private NavigatedToSampleEvent(string name, DateTime startTime) - { - Name = name; - StartTime = startTime; - } + private NavigatedToSampleEvent(string name, DateTime startTime) + { + Name = name; + StartTime = startTime; + } - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - // No sensitive strings to replace. - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + // No sensitive strings to replace. + } - public static void Log(string name) - { - TelemetryFactory.Get().Log("NavigatedToSample_Event", LogLevel.Measure, new NavigatedToSampleEvent(name, DateTime.Now)); - } + public static void Log(string name) + { + TelemetryFactory.Get().Log("NavigatedToSample_Event", LogLevel.Measure, new NavigatedToSampleEvent(name, DateTime.Now)); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/NavigatedToSampleLoadedEvent.cs b/AIDevGallery/Telemetry/Events/NavigatedToSampleLoadedEvent.cs index d999ef3..0da256e 100644 --- a/AIDevGallery/Telemetry/Events/NavigatedToSampleLoadedEvent.cs +++ b/AIDevGallery/Telemetry/Events/NavigatedToSampleLoadedEvent.cs @@ -6,31 +6,30 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class NavigatedToSampleLoadedEvent : EventBase { - [EventData] - internal class NavigatedToSampleLoadedEvent : EventBase - { - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public string Name { get; private set; } + public string Name { get; private set; } - public DateTime CompleteTime { get; private set; } + public DateTime CompleteTime { get; private set; } - private NavigatedToSampleLoadedEvent(string name, DateTime completeTime) - { - Name = name; - CompleteTime = completeTime; - } + private NavigatedToSampleLoadedEvent(string name, DateTime completeTime) + { + Name = name; + CompleteTime = completeTime; + } - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - // No sensitive strings to replace. - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + // No sensitive strings to replace. + } - public static void Log(string name) - { - TelemetryFactory.Get().Log("NavigatedToSampleLoaded_Event", LogLevel.Measure, new NavigatedToSampleLoadedEvent(name, DateTime.Now)); - } + public static void Log(string name) + { + TelemetryFactory.Get().Log("NavigatedToSampleLoaded_Event", LogLevel.Measure, new NavigatedToSampleLoadedEvent(name, DateTime.Now)); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/OpenModelFolderEvent.cs b/AIDevGallery/Telemetry/Events/OpenModelFolderEvent.cs index 3868233..910fcc4 100644 --- a/AIDevGallery/Telemetry/Events/OpenModelFolderEvent.cs +++ b/AIDevGallery/Telemetry/Events/OpenModelFolderEvent.cs @@ -6,31 +6,30 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class OpenModelFolderEvent : EventBase { - [EventData] - internal class OpenModelFolderEvent : EventBase - { - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public string ModelUrl - { - get; - } + public string ModelUrl + { + get; + } - private OpenModelFolderEvent(string modelUrl) - { - ModelUrl = modelUrl; - } + private OpenModelFolderEvent(string modelUrl) + { + ModelUrl = modelUrl; + } - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - // No sensitive strings to replace. - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + // No sensitive strings to replace. + } - public static void Log(string modelUrl) - { - TelemetryFactory.Get().Log("OpenModelFolder_Event", LogLevel.Measure, new OpenModelFolderEvent(modelUrl)); - } + public static void Log(string modelUrl) + { + TelemetryFactory.Get().Log("OpenModelFolder_Event", LogLevel.Measure, new OpenModelFolderEvent(modelUrl)); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/SampleProjectGeneratedEvent.cs b/AIDevGallery/Telemetry/Events/SampleProjectGeneratedEvent.cs index 2247da5..1a15068 100644 --- a/AIDevGallery/Telemetry/Events/SampleProjectGeneratedEvent.cs +++ b/AIDevGallery/Telemetry/Events/SampleProjectGeneratedEvent.cs @@ -6,34 +6,33 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class SampleProjectGeneratedEvent : EventBase { - [EventData] - internal class SampleProjectGeneratedEvent : EventBase - { - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public string SampleId { get; } - public string Model1Id { get; } - public string Model2Id { get; } - public bool CopyModelLocally { get; } + public string SampleId { get; } + public string Model1Id { get; } + public string Model2Id { get; } + public bool CopyModelLocally { get; } - private SampleProjectGeneratedEvent(string sampleId, string model1Id, string model2Id, bool copyModelLocally) - { - SampleId = sampleId; - Model1Id = model1Id; - Model2Id = model2Id; - CopyModelLocally = copyModelLocally; - } + private SampleProjectGeneratedEvent(string sampleId, string model1Id, string model2Id, bool copyModelLocally) + { + SampleId = sampleId; + Model1Id = model1Id; + Model2Id = model2Id; + CopyModelLocally = copyModelLocally; + } - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - // No sensitive strings to replace. - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + // No sensitive strings to replace. + } - public static void Log(string sampleId, string model1Id, string model2Id, bool copyModelLocally) - { - TelemetryFactory.Get().Log("SampleProjectGenerated_Event", LogLevel.Measure, new SampleProjectGeneratedEvent(sampleId, model1Id, model2Id, copyModelLocally)); - } + public static void Log(string sampleId, string model1Id, string model2Id, bool copyModelLocally) + { + TelemetryFactory.Get().Log("SampleProjectGenerated_Event", LogLevel.Measure, new SampleProjectGeneratedEvent(sampleId, model1Id, model2Id, copyModelLocally)); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/SearchModelEvent.cs b/AIDevGallery/Telemetry/Events/SearchModelEvent.cs index 711d4ed..180f889 100644 --- a/AIDevGallery/Telemetry/Events/SearchModelEvent.cs +++ b/AIDevGallery/Telemetry/Events/SearchModelEvent.cs @@ -6,28 +6,27 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class SearchModelEvent : EventBase { - [EventData] - internal class SearchModelEvent : EventBase - { - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public string Query { get; } + public string Query { get; } - private SearchModelEvent(string query) - { - Query = query; - } + private SearchModelEvent(string query) + { + Query = query; + } - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - // No sensitive strings to replace. - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + // No sensitive strings to replace. + } - public static void Log(string query) - { - TelemetryFactory.Get().Log("SearchModel_Event", LogLevel.Measure, new SearchModelEvent(query)); - } + public static void Log(string query) + { + TelemetryFactory.Get().Log("SearchModel_Event", LogLevel.Measure, new SearchModelEvent(query)); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Events/ToggleCodeButtonEvent.cs b/AIDevGallery/Telemetry/Events/ToggleCodeButtonEvent.cs index 30e7147..2428ab3 100644 --- a/AIDevGallery/Telemetry/Events/ToggleCodeButtonEvent.cs +++ b/AIDevGallery/Telemetry/Events/ToggleCodeButtonEvent.cs @@ -6,30 +6,29 @@ using System; using System.Diagnostics.Tracing; -namespace AIDevGallery.Telemetry.Events +namespace AIDevGallery.Telemetry.Events; + +[EventData] +internal class ToggleCodeButtonEvent : EventBase { - [EventData] - internal class ToggleCodeButtonEvent : EventBase - { - public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; + public override PartA_PrivTags PartA_PrivTags => PrivTags.ProductAndServiceUsage; - public string Name { get; } - public bool IsChecked { get; } + public string Name { get; } + public bool IsChecked { get; } - private ToggleCodeButtonEvent(string name, bool isChecked) - { - Name = name; - IsChecked = isChecked; - } + private ToggleCodeButtonEvent(string name, bool isChecked) + { + Name = name; + IsChecked = isChecked; + } - public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) - { - // No sensitive strings to replace. - } + public override void ReplaceSensitiveStrings(Func replaceSensitiveStrings) + { + // No sensitive strings to replace. + } - public static void Log(string name, bool isChecked) - { - TelemetryFactory.Get().Log("ToggleCodeButton_Event", LogLevel.Measure, new ToggleCodeButtonEvent(name, isChecked)); - } + public static void Log(string name, bool isChecked) + { + TelemetryFactory.Get().Log("ToggleCodeButton_Event", LogLevel.Measure, new ToggleCodeButtonEvent(name, isChecked)); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/ITelemetry.cs b/AIDevGallery/Telemetry/ITelemetry.cs index d691097..ff073ab 100644 --- a/AIDevGallery/Telemetry/ITelemetry.cs +++ b/AIDevGallery/Telemetry/ITelemetry.cs @@ -4,78 +4,77 @@ using System; using System.Diagnostics.CodeAnalysis; -namespace AIDevGallery.Telemetry +namespace AIDevGallery.Telemetry; + +internal interface ITelemetry { - internal interface ITelemetry - { - /// - /// Add a string that we should try stripping out of some of our telemetry for sensitivity reasons (ex. VM name, etc.). - /// We can never be 100% sure we can remove every string, but this should greatly reduce us collecting PII. - /// Note that the order in which AddSensitive is called matters, as later when we call ReplaceSensitiveStrings, it will try - /// finding and replacing the earlier strings first. This can be helpful, since we can target specific - /// strings (like username) first, which should help preserve more information helpful for diagnosis. - /// - /// Sensitive string to add (ex. "c:\xyz") - /// string to replace it with (ex. "-path-") - public void AddSensitiveString(string name, string replaceWith); + /// + /// Add a string that we should try stripping out of some of our telemetry for sensitivity reasons (ex. VM name, etc.). + /// We can never be 100% sure we can remove every string, but this should greatly reduce us collecting PII. + /// Note that the order in which AddSensitive is called matters, as later when we call ReplaceSensitiveStrings, it will try + /// finding and replacing the earlier strings first. This can be helpful, since we can target specific + /// strings (like username) first, which should help preserve more information helpful for diagnosis. + /// + /// Sensitive string to add (ex. "c:\xyz") + /// string to replace it with (ex. "-path-") + public void AddSensitiveString(string name, string replaceWith); - /// - /// Gets a value indicating whether telemetry is on - /// For future use if we add a registry key or some other setting to check if telemetry is turned on. - /// - public bool IsTelemetryOn { get; } + /// + /// Gets a value indicating whether telemetry is on + /// For future use if we add a registry key or some other setting to check if telemetry is turned on. + /// + public bool IsTelemetryOn { get; } - /// - /// Gets or sets a value indicating whether diagnostic telemetry is on. - /// - public bool IsDiagnosticTelemetryOn { get; set; } + /// + /// Gets or sets a value indicating whether diagnostic telemetry is on. + /// + public bool IsDiagnosticTelemetryOn { get; set; } - /// - /// Logs an exception at Measure level. To log at Critical level, the event name needs approval. - /// - /// What we trying to do when the exception occurred. - /// Exception object - /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and correlate them - public void LogException(string action, Exception e, Guid? relatedActivityId = null); + /// + /// Logs an exception at Measure level. To log at Critical level, the event name needs approval. + /// + /// What we trying to do when the exception occurred. + /// Exception object + /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and correlate them + public void LogException(string action, Exception e, Guid? relatedActivityId = null); - /// - /// Log the time an action took (ex. time spent on a tool). - /// - /// The measurement we're performing (ex. "DeployTime"). - /// How long the action took in milliseconds. - /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and correlate them - public void LogTimeTaken(string eventName, uint timeTakenMilliseconds, Guid? relatedActivityId = null); + /// + /// Log the time an action took (ex. time spent on a tool). + /// + /// The measurement we're performing (ex. "DeployTime"). + /// How long the action took in milliseconds. + /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and correlate them + public void LogTimeTaken(string eventName, uint timeTakenMilliseconds, Guid? relatedActivityId = null); - /// - /// Log an event with no additional data. - /// - /// The name of the event to log - /// Set to true if an error condition raised this event. - /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and correlate them - public void LogCritical(string eventName, bool isError = false, Guid? relatedActivityId = null); + /// + /// Log an event with no additional data. + /// + /// The name of the event to log + /// Set to true if an error condition raised this event. + /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and correlate them + public void LogCritical(string eventName, bool isError = false, Guid? relatedActivityId = null); - /// - /// Log an informational event. Typically used for just a single event that's only called one place in the code. - /// If you are logging the same event multiple times, it's best to add a helper method in Telemetry - /// - /// Name of the error event - /// Determines whether to upload the data to our servers, and on how many machines. - /// Values to send to the telemetry system. - /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and correlate them - /// Anonymous type. - public void Log<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] T>(string eventName, LogLevel level, T data, Guid? relatedActivityId = null) - where T : EventBase; + /// + /// Log an informational event. Typically used for just a single event that's only called one place in the code. + /// If you are logging the same event multiple times, it's best to add a helper method in Telemetry + /// + /// Name of the error event + /// Determines whether to upload the data to our servers, and on how many machines. + /// Values to send to the telemetry system. + /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and correlate them + /// Anonymous type. + public void Log<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] T>(string eventName, LogLevel level, T data, Guid? relatedActivityId = null) + where T : EventBase; - /// - /// Log an error event. Typically used for just a single event that's only called one place in the code. - /// If you are logging the same event multiple times, it's best to add a helper method in Telemetry - /// - /// Name of the error event - /// Determines whether to upload the data to our servers, and on how many machines. - /// Values to send to the telemetry system. - /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and correlate them - /// Anonymous type. - public void LogError<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] T>(string eventName, LogLevel level, T data, Guid? relatedActivityId = null) - where T : EventBase; - } + /// + /// Log an error event. Typically used for just a single event that's only called one place in the code. + /// If you are logging the same event multiple times, it's best to add a helper method in Telemetry + /// + /// Name of the error event + /// Determines whether to upload the data to our servers, and on how many machines. + /// Values to send to the telemetry system. + /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and correlate them + /// Anonymous type. + public void LogError<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] T>(string eventName, LogLevel level, T data, Guid? relatedActivityId = null) + where T : EventBase; } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/LogLevel.cs b/AIDevGallery/Telemetry/LogLevel.cs index a9c9df5..26bed73 100644 --- a/AIDevGallery/Telemetry/LogLevel.cs +++ b/AIDevGallery/Telemetry/LogLevel.cs @@ -1,40 +1,39 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -namespace AIDevGallery.Telemetry +namespace AIDevGallery.Telemetry; + +/// +/// Telemetry Levels. +/// These levels are defined by our telemetry system, so it's possible the sampling +/// could change in the future. +/// There aren't any convenient enums we can consume, so create our own. +/// +public enum LogLevel { /// - /// Telemetry Levels. - /// These levels are defined by our telemetry system, so it's possible the sampling - /// could change in the future. - /// There aren't any convenient enums we can consume, so create our own. + /// Local. + /// Only log telemetry locally on the machine (similar to an ETW event). /// - public enum LogLevel - { - /// - /// Local. - /// Only log telemetry locally on the machine (similar to an ETW event). - /// - Local, + Local, - /// - /// Info. - /// Send telemetry from internal and flighted machines, but no external retail machines. - /// - Info, + /// + /// Info. + /// Send telemetry from internal and flighted machines, but no external retail machines. + /// + Info, - /// - /// Measure. - /// Send telemetry from internal and flighted machines, plus a small, sample % of retail machines. - /// Should only be used for telemetry we use to derive measures from. - /// - Measure, + /// + /// Measure. + /// Send telemetry from internal and flighted machines, plus a small, sample % of retail machines. + /// Should only be used for telemetry we use to derive measures from. + /// + Measure, - /// - /// Critical. - /// Send telemetry from all devices sampled at 100%. - /// Should only be used for approved events. - /// - Critical, - } + /// + /// Critical. + /// Send telemetry from all devices sampled at 100%. + /// Should only be used for approved events. + /// + Critical, } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/PrivacyConsentHelpers.cs b/AIDevGallery/Telemetry/PrivacyConsentHelpers.cs index 7ccf746..07a50a4 100644 --- a/AIDevGallery/Telemetry/PrivacyConsentHelpers.cs +++ b/AIDevGallery/Telemetry/PrivacyConsentHelpers.cs @@ -5,54 +5,53 @@ using System.Linq; using Windows.Globalization; -namespace AIDevGallery.Telemetry +namespace AIDevGallery.Telemetry; + +internal static class PrivacyConsentHelpers { - internal static class PrivacyConsentHelpers - { - private static readonly string[] PrivacySensitiveRegions = - [ - "AUT", - "BEL", - "BGR", - "BRA", - "CAN", - "HRV", - "CYP", - "CZE", - "DNK", - "EST", - "FIN", - "FRA", - "DEU", - "GRC", - "HUN", - "ISL", - "IRL", - "ITA", - "KOR", // Double Check - "LVA", - "LIE", - "LTU", - "LUX", - "MLT", - "NLD", - "NOR", - "POL", - "PRT", - "ROU", - "SVK", - "SVN", - "ESP", - "SWE", - "CHE", - "GBR", - ]; + private static readonly string[] PrivacySensitiveRegions = + [ + "AUT", + "BEL", + "BGR", + "BRA", + "CAN", + "HRV", + "CYP", + "CZE", + "DNK", + "EST", + "FIN", + "FRA", + "DEU", + "GRC", + "HUN", + "ISL", + "IRL", + "ITA", + "KOR", // Double Check + "LVA", + "LIE", + "LTU", + "LUX", + "MLT", + "NLD", + "NOR", + "POL", + "PRT", + "ROU", + "SVK", + "SVN", + "ESP", + "SWE", + "CHE", + "GBR", + ]; - public static bool IsPrivacySensitiveRegion() - { - var geographicRegion = new GeographicRegion(); + public static bool IsPrivacySensitiveRegion() + { + var geographicRegion = new GeographicRegion(); - return PrivacySensitiveRegions.Contains(geographicRegion.CodeThreeLetter, StringComparer.OrdinalIgnoreCase); - } + return PrivacySensitiveRegions.Contains(geographicRegion.CodeThreeLetter, StringComparer.OrdinalIgnoreCase); } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/Telemetry.cs b/AIDevGallery/Telemetry/Telemetry.cs index 0a965b6..c9ccd61 100644 --- a/AIDevGallery/Telemetry/Telemetry.cs +++ b/AIDevGallery/Telemetry/Telemetry.cs @@ -10,341 +10,340 @@ using System.Linq; using System.Text; -namespace AIDevGallery.Telemetry +namespace AIDevGallery.Telemetry; + +/// +/// To create an instance call TelemetryFactory.Get<ITelemetry>() +/// +internal sealed class Telemetry : ITelemetry { + private const string ProviderName = "Microsoft.Windows.AIDevGallery"; + /// - /// To create an instance call TelemetryFactory.Get<ITelemetry>() + /// Time Taken Event Name /// - internal sealed class Telemetry : ITelemetry - { - private const string ProviderName = "Microsoft.Windows.AIDevGallery"; - - /// - /// Time Taken Event Name - /// - private const string TimeTakenEventName = "TimeTaken"; - - /// - /// Exception Thrown Event Name - /// - private const string ExceptionThrownEventName = "ExceptionThrown"; - - private static readonly Guid DefaultRelatedActivityId = Guid.Empty; - - /// - /// Can only have one EventSource alive per process, so just create one statically. - /// - private static readonly EventSource TelemetryEventSourceInstance = new TelemetryEventSource(ProviderName); - - /// - /// Logs telemetry locally, but shouldn't upload it. Similar to an ETW event. - /// Should be the same as EventSourceOptions(), as Verbose is the default level. - /// - private static readonly EventSourceOptions LocalOption = new() { Level = EventLevel.Verbose }; - - /// - /// Logs error telemetry locally, but shouldn't upload it. Similar to an ETW event. - /// - private static readonly EventSourceOptions LocalErrorOption = new() { Level = EventLevel.Error }; - - /// - /// Logs telemetry. - /// Currently this is at 0% sampling for both internal and external retail devices. - /// - private static readonly EventSourceOptions InfoOption = new() { Keywords = TelemetryEventSource.TelemetryKeyword }; - - /// - /// Logs error telemetry. - /// Currently this is at 0% sampling for both internal and external retail devices. - /// - private static readonly EventSourceOptions InfoErrorOption = new() { Level = EventLevel.Error, Keywords = TelemetryEventSource.TelemetryKeyword }; - - /// - /// Logs measure telemetry. - /// This should be sent back on internal devices, and a small, sampled % of external retail devices. - /// - private static readonly EventSourceOptions MeasureOption = new() { Keywords = TelemetryEventSource.MeasuresKeyword }; - - /// - /// Logs measure error telemetry. - /// This should be sent back on internal devices, and a small, sampled % of external retail devices. - /// - private static readonly EventSourceOptions MeasureErrorOption = new() { Level = EventLevel.Error, Keywords = TelemetryEventSource.MeasuresKeyword }; - - /// - /// Logs critical telemetry. - /// This should be sent back on all devices sampled at 100%. - /// - private static readonly EventSourceOptions CriticalDataOption = new() { Keywords = TelemetryEventSource.CriticalDataKeyword }; - - /// - /// Logs critical error telemetry. - /// This should be sent back on all devices sampled at 100%. - /// - private static readonly EventSourceOptions CriticalDataErrorOption = new() { Level = EventLevel.Error, Keywords = TelemetryEventSource.CriticalDataKeyword }; - - /// - /// ActivityId so we can correlate all events in the same run - /// - private static Guid activityId = Guid.NewGuid(); - - /// - /// List of strings we should try removing for sensitivity reasons. - /// - private readonly List> sensitiveStrings = []; - - /// - /// Initializes a new instance of the class. - /// Prevents a default instance of the Telemetry class from being created. - /// - internal Telemetry() - { - } + private const string TimeTakenEventName = "TimeTaken"; - /// - /// Gets a value indicating whether telemetry is on - /// For future use if we add a registry key or some other setting to check if telemetry is turned on. - /// - public bool IsTelemetryOn { get; } = true; - - /// - /// Gets or sets a value indicating whether diagnostic telemetry is on. - /// - public bool IsDiagnosticTelemetryOn { get; set; } - - /// - /// Add a string that we should try stripping out of some of our telemetry for sensitivity reasons (ex. VM name, etc.). - /// We can never be 100% sure we can remove every string, but this should greatly reduce us collecting PII. - /// Note that the order in which AddSensitive is called matters, as later when we call ReplaceSensitiveStrings, it will try - /// finding and replacing the earlier strings first. This can be helpful, since we can target specific - /// strings (like username) first, which should help preserve more information helpful for diagnosis. - /// - /// Sensitive string to add (ex. "c:\xyz") - /// string to replace it with (ex. "-path-") - public void AddSensitiveString(string name, string replaceWith) - { - // Make sure the name isn't blank, hasn't already been added, and is greater than three characters. - // Otherwise they could name their VM "a", and then we would end up replacing every "a" with another string. - if (!string.IsNullOrWhiteSpace(name) && name.Length > 3 && !this.sensitiveStrings.Exists(item => name.Equals(item.Key, StringComparison.Ordinal))) - { - this.sensitiveStrings.Add(new KeyValuePair(name, replaceWith ?? string.Empty)); - } - } + /// + /// Exception Thrown Event Name + /// + private const string ExceptionThrownEventName = "ExceptionThrown"; - /// - /// Logs an exception at Measure level. To log at Critical level, the event name needs approval. - /// - /// What we trying to do when the exception occurred. - /// Exception object - /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and corelate them - public void LogException(string action, Exception e, Guid? relatedActivityId = null) - { - var innerMessage = this.ReplaceSensitiveStrings(e.InnerException?.Message); - StringBuilder innerStackTrace = new(); - Exception? innerException = e.InnerException; - while (innerException != null) - { - innerStackTrace.Append(innerException.StackTrace); + private static readonly Guid DefaultRelatedActivityId = Guid.Empty; - // Separating by 2 new lines to distinguish between different exceptions. - innerStackTrace.AppendLine(); - innerStackTrace.AppendLine(); - innerException = innerException.InnerException; - } + /// + /// Can only have one EventSource alive per process, so just create one statically. + /// + private static readonly EventSource TelemetryEventSourceInstance = new TelemetryEventSource(ProviderName); - this.LogInternal( - ExceptionThrownEventName, - LogLevel.Critical, - new - { - action, - name = e.GetType().Name, - stackTrace = e.StackTrace, - innerName = e.InnerException?.GetType().Name, - innerMessage, - innerStackTrace = innerStackTrace.ToString(), - message = this.ReplaceSensitiveStrings(e.Message), - PartA_PrivTags = PartA_PrivTags.ProductAndServicePerformance, - }, - relatedActivityId, - isError: true); - } + /// + /// Logs telemetry locally, but shouldn't upload it. Similar to an ETW event. + /// Should be the same as EventSourceOptions(), as Verbose is the default level. + /// + private static readonly EventSourceOptions LocalOption = new() { Level = EventLevel.Verbose }; - /// - /// Log the time an action took (ex. deploy time). - /// - /// The measurement we're performing (ex. "DeployTime"). - /// How long the action took in milliseconds. - /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and corelate them - public void LogTimeTaken(string eventName, uint timeTakenMilliseconds, Guid? relatedActivityId = null) - { - this.LogInternal( - TimeTakenEventName, - LogLevel.Critical, - new - { - eventName, - timeTakenMilliseconds, - PartA_PrivTags = PartA_PrivTags.ProductAndServicePerformance, - }, - relatedActivityId, - isError: false); - } + /// + /// Logs error telemetry locally, but shouldn't upload it. Similar to an ETW event. + /// + private static readonly EventSourceOptions LocalErrorOption = new() { Level = EventLevel.Error }; - /// - /// Log an informal event with no additional data at log level measure. - /// - /// The name of the event to log - /// Set to true if an error condition raised this event. - /// GUID to correlate activities. - public void LogCritical(string eventName, bool isError = false, Guid? relatedActivityId = null) - { - this.LogInternal(eventName, LogLevel.Critical, new EmptyEvent(PartA_PrivTags.ProductAndServiceUsage), relatedActivityId, isError); - } + /// + /// Logs telemetry. + /// Currently this is at 0% sampling for both internal and external retail devices. + /// + private static readonly EventSourceOptions InfoOption = new() { Keywords = TelemetryEventSource.TelemetryKeyword }; - /// - /// Log an informational event. Typically used for just a single event that's only called one place in the code. - /// - /// Name of the error event - /// Determines whether to upload the data to our servers, and on how many machines. - /// Values to send to the telemetry system. - /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and correlate them - /// Anonymous type. - public void Log<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] T>(string eventName, LogLevel level, T data, Guid? relatedActivityId = null) - where T : EventBase - { - data.ReplaceSensitiveStrings(this.ReplaceSensitiveStrings); - this.LogInternal(eventName, level, data, relatedActivityId, isError: false); - } + /// + /// Logs error telemetry. + /// Currently this is at 0% sampling for both internal and external retail devices. + /// + private static readonly EventSourceOptions InfoErrorOption = new() { Level = EventLevel.Error, Keywords = TelemetryEventSource.TelemetryKeyword }; - /// - /// Log an error event. Typically used for just a single event that's only called one place in the code. - /// - /// Name of the error event - /// Determines whether to upload the data to our servers, and on how many machines. - /// Values to send to the telemetry system. - /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and correlate them - /// Anonymous type. - public void LogError<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] T>(string eventName, LogLevel level, T data, Guid? relatedActivityId = null) - where T : EventBase + /// + /// Logs measure telemetry. + /// This should be sent back on internal devices, and a small, sampled % of external retail devices. + /// + private static readonly EventSourceOptions MeasureOption = new() { Keywords = TelemetryEventSource.MeasuresKeyword }; + + /// + /// Logs measure error telemetry. + /// This should be sent back on internal devices, and a small, sampled % of external retail devices. + /// + private static readonly EventSourceOptions MeasureErrorOption = new() { Level = EventLevel.Error, Keywords = TelemetryEventSource.MeasuresKeyword }; + + /// + /// Logs critical telemetry. + /// This should be sent back on all devices sampled at 100%. + /// + private static readonly EventSourceOptions CriticalDataOption = new() { Keywords = TelemetryEventSource.CriticalDataKeyword }; + + /// + /// Logs critical error telemetry. + /// This should be sent back on all devices sampled at 100%. + /// + private static readonly EventSourceOptions CriticalDataErrorOption = new() { Level = EventLevel.Error, Keywords = TelemetryEventSource.CriticalDataKeyword }; + + /// + /// ActivityId so we can correlate all events in the same run + /// + private static Guid activityId = Guid.NewGuid(); + + /// + /// List of strings we should try removing for sensitivity reasons. + /// + private readonly List> sensitiveStrings = []; + + /// + /// Initializes a new instance of the class. + /// Prevents a default instance of the Telemetry class from being created. + /// + internal Telemetry() + { + } + + /// + /// Gets a value indicating whether telemetry is on + /// For future use if we add a registry key or some other setting to check if telemetry is turned on. + /// + public bool IsTelemetryOn { get; } = true; + + /// + /// Gets or sets a value indicating whether diagnostic telemetry is on. + /// + public bool IsDiagnosticTelemetryOn { get; set; } + + /// + /// Add a string that we should try stripping out of some of our telemetry for sensitivity reasons (ex. VM name, etc.). + /// We can never be 100% sure we can remove every string, but this should greatly reduce us collecting PII. + /// Note that the order in which AddSensitive is called matters, as later when we call ReplaceSensitiveStrings, it will try + /// finding and replacing the earlier strings first. This can be helpful, since we can target specific + /// strings (like username) first, which should help preserve more information helpful for diagnosis. + /// + /// Sensitive string to add (ex. "c:\xyz") + /// string to replace it with (ex. "-path-") + public void AddSensitiveString(string name, string replaceWith) + { + // Make sure the name isn't blank, hasn't already been added, and is greater than three characters. + // Otherwise they could name their VM "a", and then we would end up replacing every "a" with another string. + if (!string.IsNullOrWhiteSpace(name) && name.Length > 3 && !this.sensitiveStrings.Exists(item => name.Equals(item.Key, StringComparison.Ordinal))) { - data.ReplaceSensitiveStrings(this.ReplaceSensitiveStrings); - this.LogInternal(eventName, level, data, relatedActivityId, isError: true); + this.sensitiveStrings.Add(new KeyValuePair(name, replaceWith ?? string.Empty)); } + } - private void LogInternal<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] T>(string eventName, LogLevel level, T data, Guid? relatedActivityId, bool isError) + /// + /// Logs an exception at Measure level. To log at Critical level, the event name needs approval. + /// + /// What we trying to do when the exception occurred. + /// Exception object + /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and corelate them + public void LogException(string action, Exception e, Guid? relatedActivityId = null) + { + var innerMessage = this.ReplaceSensitiveStrings(e.InnerException?.Message); + StringBuilder innerStackTrace = new(); + Exception? innerException = e.InnerException; + while (innerException != null) { - this.WriteTelemetryEvent(eventName, level, relatedActivityId ?? DefaultRelatedActivityId, isError, data); + innerStackTrace.Append(innerException.StackTrace); + + // Separating by 2 new lines to distinguish between different exceptions. + innerStackTrace.AppendLine(); + innerStackTrace.AppendLine(); + innerException = innerException.InnerException; } - /// - /// Replaces sensitive strings in a string with non sensitive strings. - /// - /// Before, unstripped string. - /// After, stripped string - private string? ReplaceSensitiveStrings(string? message) + this.LogInternal( + ExceptionThrownEventName, + LogLevel.Critical, + new + { + action, + name = e.GetType().Name, + stackTrace = e.StackTrace, + innerName = e.InnerException?.GetType().Name, + innerMessage, + innerStackTrace = innerStackTrace.ToString(), + message = this.ReplaceSensitiveStrings(e.Message), + PartA_PrivTags = PartA_PrivTags.ProductAndServicePerformance, + }, + relatedActivityId, + isError: true); + } + + /// + /// Log the time an action took (ex. deploy time). + /// + /// The measurement we're performing (ex. "DeployTime"). + /// How long the action took in milliseconds. + /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and corelate them + public void LogTimeTaken(string eventName, uint timeTakenMilliseconds, Guid? relatedActivityId = null) + { + this.LogInternal( + TimeTakenEventName, + LogLevel.Critical, + new + { + eventName, + timeTakenMilliseconds, + PartA_PrivTags = PartA_PrivTags.ProductAndServicePerformance, + }, + relatedActivityId, + isError: false); + } + + /// + /// Log an informal event with no additional data at log level measure. + /// + /// The name of the event to log + /// Set to true if an error condition raised this event. + /// GUID to correlate activities. + public void LogCritical(string eventName, bool isError = false, Guid? relatedActivityId = null) + { + this.LogInternal(eventName, LogLevel.Critical, new EmptyEvent(PartA_PrivTags.ProductAndServiceUsage), relatedActivityId, isError); + } + + /// + /// Log an informational event. Typically used for just a single event that's only called one place in the code. + /// + /// Name of the error event + /// Determines whether to upload the data to our servers, and on how many machines. + /// Values to send to the telemetry system. + /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and correlate them + /// Anonymous type. + public void Log<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] T>(string eventName, LogLevel level, T data, Guid? relatedActivityId = null) + where T : EventBase + { + data.ReplaceSensitiveStrings(this.ReplaceSensitiveStrings); + this.LogInternal(eventName, level, data, relatedActivityId, isError: false); + } + + /// + /// Log an error event. Typically used for just a single event that's only called one place in the code. + /// + /// Name of the error event + /// Determines whether to upload the data to our servers, and on how many machines. + /// Values to send to the telemetry system. + /// Optional relatedActivityId which will allow to correlate this telemetry with other telemetry in the same action/activity or thread and correlate them + /// Anonymous type. + public void LogError<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] T>(string eventName, LogLevel level, T data, Guid? relatedActivityId = null) + where T : EventBase + { + data.ReplaceSensitiveStrings(this.ReplaceSensitiveStrings); + this.LogInternal(eventName, level, data, relatedActivityId, isError: true); + } + + private void LogInternal<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] T>(string eventName, LogLevel level, T data, Guid? relatedActivityId, bool isError) + { + this.WriteTelemetryEvent(eventName, level, relatedActivityId ?? DefaultRelatedActivityId, isError, data); + } + + /// + /// Replaces sensitive strings in a string with non sensitive strings. + /// + /// Before, unstripped string. + /// After, stripped string + private string? ReplaceSensitiveStrings(string? message) + { + if (message != null) { - if (message != null) + foreach (KeyValuePair pair in this.sensitiveStrings) { - foreach (KeyValuePair pair in this.sensitiveStrings) + // There's no String.Replace() with case insensitivity. + // We could use Regular Expressions here for searching for case-insensitive string matches, + // but it's not easy to specify the RegEx timeout value for .net 4.0. And we were worried + // about rare cases where the user could accidentally lock us up with RegEx, since we're using strings + // provided by the user, so just use a simple non-RegEx replacement algorithm instead. + var sb = new StringBuilder(); + var i = 0; + while (true) { - // There's no String.Replace() with case insensitivity. - // We could use Regular Expressions here for searching for case-insensitive string matches, - // but it's not easy to specify the RegEx timeout value for .net 4.0. And we were worried - // about rare cases where the user could accidentally lock us up with RegEx, since we're using strings - // provided by the user, so just use a simple non-RegEx replacement algorithm instead. - var sb = new StringBuilder(); - var i = 0; - while (true) + // Find the string to strip out. + var foundPosition = message.IndexOf(pair.Key, i, StringComparison.OrdinalIgnoreCase); + if (foundPosition < 0) { - // Find the string to strip out. - var foundPosition = message.IndexOf(pair.Key, i, StringComparison.OrdinalIgnoreCase); - if (foundPosition < 0) - { - sb.Append(message, i, message.Length - i); - message = sb.ToString(); - break; - } - - // Replace the string. - sb.Append(message, i, foundPosition - i); - sb.Append(pair.Value); - i = foundPosition + pair.Key.Length; + sb.Append(message, i, message.Length - i); + message = sb.ToString(); + break; } + + // Replace the string. + sb.Append(message, i, foundPosition - i); + sb.Append(pair.Value); + i = foundPosition + pair.Key.Length; } } - - return message; } - /// - /// Writes the telemetry event info using the TraceLogging API. - /// - /// Anonymous type. - /// Name of the event. - /// Determines whether to upload the data to our servers, and the sample set of host machines. - /// GUID to correlate activities. - /// Set to true if an error condition raised this event. - /// Values to send to the telemetry system. - private void WriteTelemetryEvent<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] T>(string eventName, LogLevel level, Guid relatedActivityId, bool isError, T data) + return message; + } + + /// + /// Writes the telemetry event info using the TraceLogging API. + /// + /// Anonymous type. + /// Name of the event. + /// Determines whether to upload the data to our servers, and the sample set of host machines. + /// GUID to correlate activities. + /// Set to true if an error condition raised this event. + /// Values to send to the telemetry system. + private void WriteTelemetryEvent<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] T>(string eventName, LogLevel level, Guid relatedActivityId, bool isError, T data) + { + EventSourceOptions telemetryOptions; + if (this.IsTelemetryOn) { - EventSourceOptions telemetryOptions; - if (this.IsTelemetryOn) + if (!IsDiagnosticTelemetryOn) { - if (!IsDiagnosticTelemetryOn) - { - if (!isError && (level == LogLevel.Measure || level == LogLevel.Info)) - { - level = LogLevel.Local; - } - } - - switch (level) + if (!isError && (level == LogLevel.Measure || level == LogLevel.Info)) { - case LogLevel.Critical: - telemetryOptions = isError ? Telemetry.CriticalDataErrorOption : Telemetry.CriticalDataOption; - break; - case LogLevel.Measure: - telemetryOptions = isError ? Telemetry.MeasureErrorOption : Telemetry.MeasureOption; - break; - case LogLevel.Info: - telemetryOptions = isError ? Telemetry.InfoErrorOption : Telemetry.InfoOption; - break; - case LogLevel.Local: - default: - telemetryOptions = isError ? Telemetry.LocalErrorOption : Telemetry.LocalOption; - break; + level = LogLevel.Local; } } - else + + switch (level) { - // The telemetry is not turned on, downgrade to local telemetry - telemetryOptions = isError ? Telemetry.LocalErrorOption : Telemetry.LocalOption; + case LogLevel.Critical: + telemetryOptions = isError ? Telemetry.CriticalDataErrorOption : Telemetry.CriticalDataOption; + break; + case LogLevel.Measure: + telemetryOptions = isError ? Telemetry.MeasureErrorOption : Telemetry.MeasureOption; + break; + case LogLevel.Info: + telemetryOptions = isError ? Telemetry.InfoErrorOption : Telemetry.InfoOption; + break; + case LogLevel.Local: + default: + telemetryOptions = isError ? Telemetry.LocalErrorOption : Telemetry.LocalOption; + break; } + } + else + { + // The telemetry is not turned on, downgrade to local telemetry + telemetryOptions = isError ? Telemetry.LocalErrorOption : Telemetry.LocalOption; + } #pragma warning disable IL2026 - TelemetryEventSourceInstance.Write(eventName, ref telemetryOptions, ref activityId, ref relatedActivityId, ref data); + TelemetryEventSourceInstance.Write(eventName, ref telemetryOptions, ref activityId, ref relatedActivityId, ref data); #pragma warning restore IL2026 - } + } - internal void AddWellKnownSensitiveStrings() + internal void AddWellKnownSensitiveStrings() + { + try { - try - { - // This should convert "c:\users\johndoe" to "". - var userDirectory = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); - this.AddSensitiveString(userDirectory.ToString(), ""); - - // Include both these names, since they should cover the logged on user, and the user who is running the tools built on top of these API's - // These names should almost always be the same, but technically could be different. - this.AddSensitiveString(Environment.UserName, ""); - var currentUserName = System.Security.Principal.WindowsIdentity.GetCurrent().Name.Split('\\').Last(); - this.AddSensitiveString(currentUserName, ""); - } - catch (Exception e) - { - // Catch and log exception - this.LogException("AddSensitiveStrings", e); - } + // This should convert "c:\users\johndoe" to "". + var userDirectory = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); + this.AddSensitiveString(userDirectory.ToString(), ""); + + // Include both these names, since they should cover the logged on user, and the user who is running the tools built on top of these API's + // These names should almost always be the same, but technically could be different. + this.AddSensitiveString(Environment.UserName, ""); + var currentUserName = System.Security.Principal.WindowsIdentity.GetCurrent().Name.Split('\\').Last(); + this.AddSensitiveString(currentUserName, ""); + } + catch (Exception e) + { + // Catch and log exception + this.LogException("AddSensitiveStrings", e); } } } \ No newline at end of file diff --git a/AIDevGallery/Telemetry/TelemetryFactory.cs b/AIDevGallery/Telemetry/TelemetryFactory.cs index 8ff9458..20f6d0e 100644 --- a/AIDevGallery/Telemetry/TelemetryFactory.cs +++ b/AIDevGallery/Telemetry/TelemetryFactory.cs @@ -3,42 +3,41 @@ using System.Threading; -namespace AIDevGallery.Telemetry +namespace AIDevGallery.Telemetry; + +/// +/// Creates instance of Telemetry +/// This would be useful for the future when interfaces have been updated for logger like ITelemetry2, ITelemetry3 and so on +/// +internal class TelemetryFactory { - /// - /// Creates instance of Telemetry - /// This would be useful for the future when interfaces have been updated for logger like ITelemetry2, ITelemetry3 and so on - /// - internal class TelemetryFactory - { - private static readonly Lock LockObj = new(); + private static readonly Lock LockObj = new(); - private static Telemetry? telemetryInstance; + private static Telemetry? telemetryInstance; - private static Telemetry GetTelemetryInstance() + private static Telemetry GetTelemetryInstance() + { + if (telemetryInstance == null) { - if (telemetryInstance == null) + lock (LockObj) { - lock (LockObj) - { - telemetryInstance ??= new Telemetry(); - telemetryInstance.AddWellKnownSensitiveStrings(); - } + telemetryInstance ??= new Telemetry(); + telemetryInstance.AddWellKnownSensitiveStrings(); } - - return telemetryInstance; } - /// - /// Gets a singleton instance of Telemetry - /// This would be useful for the future when interfaces have been updated for logger like ITelemetry2, ITelemetry3 and so on - /// - /// The type of telemetry interface. - /// A singleton instance of the specified telemetry interface. - public static T Get() - where T : ITelemetry - { - return (T)(object)GetTelemetryInstance(); - } + return telemetryInstance; + } + + /// + /// Gets a singleton instance of Telemetry + /// This would be useful for the future when interfaces have been updated for logger like ITelemetry2, ITelemetry3 and so on + /// + /// The type of telemetry interface. + /// A singleton instance of the specified telemetry interface. + public static T Get() + where T : ITelemetry + { + return (T)(object)GetTelemetryInstance(); } } \ No newline at end of file diff --git a/AIDevGallery/Utils/AppDataSourceGenerationContext.cs b/AIDevGallery/Utils/AppDataSourceGenerationContext.cs index db3aaff..3ba9a4e 100644 --- a/AIDevGallery/Utils/AppDataSourceGenerationContext.cs +++ b/AIDevGallery/Utils/AppDataSourceGenerationContext.cs @@ -5,12 +5,11 @@ using System.Collections.Generic; using System.Text.Json.Serialization; -namespace AIDevGallery.Utils +namespace AIDevGallery.Utils; + +[JsonSourceGenerationOptions(WriteIndented = true, AllowTrailingCommas = true)] +[JsonSerializable(typeof(AppData))] +[JsonSerializable(typeof(List))] +internal partial class AppDataSourceGenerationContext : JsonSerializerContext { - [JsonSourceGenerationOptions(WriteIndented = true, AllowTrailingCommas = true)] - [JsonSerializable(typeof(AppData))] - [JsonSerializable(typeof(List))] - internal partial class AppDataSourceGenerationContext : JsonSerializerContext - { - } } \ No newline at end of file diff --git a/AIDevGallery/Utils/HttpClientExtensions.cs b/AIDevGallery/Utils/HttpClientExtensions.cs index 76fad75..4f9a144 100644 --- a/AIDevGallery/Utils/HttpClientExtensions.cs +++ b/AIDevGallery/Utils/HttpClientExtensions.cs @@ -8,80 +8,79 @@ using System.Threading; using System.Threading.Tasks; -namespace AIDevGallery.Utils +namespace AIDevGallery.Utils; + +internal static class HttpClientExtensions { - internal static class HttpClientExtensions + public static async Task DownloadAsync(this HttpClient client, string requestUri, Stream destination, IProgress? progress = null, IProgress? progressBytesRead = null, CancellationToken cancellationToken = default) { - public static async Task DownloadAsync(this HttpClient client, string requestUri, Stream destination, IProgress? progress = null, IProgress? progressBytesRead = null, CancellationToken cancellationToken = default) + // Get the http headers first to examine the content length + using (var response = await client.GetAsync(requestUri, HttpCompletionOption.ResponseHeadersRead, cancellationToken)) { - // Get the http headers first to examine the content length - using (var response = await client.GetAsync(requestUri, HttpCompletionOption.ResponseHeadersRead, cancellationToken)) - { - var contentLength = response.Content.Headers.ContentLength; + var contentLength = response.Content.Headers.ContentLength; - using (var download = await response.Content.ReadAsStreamAsync(cancellationToken)) + using (var download = await response.Content.ReadAsStreamAsync(cancellationToken)) + { + // Ignore progress reporting when no progress reporter was + // passed or when the content length is unknown + if (!contentLength.HasValue) { - // Ignore progress reporting when no progress reporter was - // passed or when the content length is unknown - if (!contentLength.HasValue) - { - await download.CopyToAsync(destination, cancellationToken); - return; - } + await download.CopyToAsync(destination, cancellationToken); + return; + } - // Convert absolute progress (bytes downloaded) into relative progress (0% - 100%) - var relativeProgress = new Progress(totalBytes => - { - progress?.Report((float)totalBytes / contentLength.Value); - progressBytesRead?.Report(totalBytes); - }); + // Convert absolute progress (bytes downloaded) into relative progress (0% - 100%) + var relativeProgress = new Progress(totalBytes => + { + progress?.Report((float)totalBytes / contentLength.Value); + progressBytesRead?.Report(totalBytes); + }); - // Use extension method to report progress while downloading - await download.CopyToAsync(destination, 81920, relativeProgress, cancellationToken); - progress?.Report(1); - } + // Use extension method to report progress while downloading + await download.CopyToAsync(destination, 81920, relativeProgress, cancellationToken); + progress?.Report(1); } } + } - public static async Task CopyToAsync(this Stream source, Stream destination, int bufferSize, IProgress? progress = null, CancellationToken cancellationToken = default) - { - ArgumentNullException.ThrowIfNull(source); + public static async Task CopyToAsync(this Stream source, Stream destination, int bufferSize, IProgress? progress = null, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(source); - if (!source.CanRead) - { - throw new ArgumentException("Has to be readable", nameof(source)); - } + if (!source.CanRead) + { + throw new ArgumentException("Has to be readable", nameof(source)); + } - ArgumentNullException.ThrowIfNull(destination); + ArgumentNullException.ThrowIfNull(destination); - if (!destination.CanWrite) - { - throw new ArgumentException("Has to be writable", nameof(destination)); - } + if (!destination.CanWrite) + { + throw new ArgumentException("Has to be writable", nameof(destination)); + } - ArgumentOutOfRangeException.ThrowIfNegative(bufferSize); + ArgumentOutOfRangeException.ThrowIfNegative(bufferSize); - var buffer = new byte[bufferSize]; - long totalBytesRead = 0; - int bytesRead; - try + var buffer = new byte[bufferSize]; + long totalBytesRead = 0; + int bytesRead; + try + { + while ((bytesRead = await source.ReadAsync(buffer, cancellationToken).ConfigureAwait(false)) != 0) { - while ((bytesRead = await source.ReadAsync(buffer, cancellationToken).ConfigureAwait(false)) != 0) + if (cancellationToken.IsCancellationRequested) { - if (cancellationToken.IsCancellationRequested) - { - return; - } - - await destination.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken).ConfigureAwait(false); - totalBytesRead += bytesRead; - progress?.Report(totalBytesRead); + return; } + + await destination.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken).ConfigureAwait(false); + totalBytesRead += bytesRead; + progress?.Report(totalBytesRead); } - catch (TaskCanceledException) - { - Debug.WriteLine("Download cancelled"); - } + } + catch (TaskCanceledException) + { + Debug.WriteLine("Download cancelled"); } } } \ No newline at end of file diff --git a/AIDevGallery/Utils/ModelCache.cs b/AIDevGallery/Utils/ModelCache.cs index bee1298..1effbd6 100644 --- a/AIDevGallery/Utils/ModelCache.cs +++ b/AIDevGallery/Utils/ModelCache.cs @@ -9,195 +9,194 @@ using System.Threading; using System.Threading.Tasks; -namespace AIDevGallery.Utils +namespace AIDevGallery.Utils; + +internal class ModelCache { - internal class ModelCache - { - private readonly AppData _appData; + private readonly AppData _appData; - /* private long _movedSize; */ + /* private long _movedSize; */ - public ModelDownloadQueue DownloadQueue { get; } - public ModelCacheStore CacheStore { get; private set; } - public IReadOnlyList Models => CacheStore.Models; + public ModelDownloadQueue DownloadQueue { get; } + public ModelCacheStore CacheStore { get; private set; } + public IReadOnlyList Models => CacheStore.Models; - private ModelCache(AppData appData, ModelDownloadQueue modelDownloadQueue, ModelCacheStore modelCacheStore) - { - _appData = appData; - DownloadQueue = modelDownloadQueue; - CacheStore = modelCacheStore; - } + private ModelCache(AppData appData, ModelDownloadQueue modelDownloadQueue, ModelCacheStore modelCacheStore) + { + _appData = appData; + DownloadQueue = modelDownloadQueue; + CacheStore = modelCacheStore; + } + + public static async Task CreateForApp(AppData appData) + { + var downloadQueue = new ModelDownloadQueue(appData.ModelCachePath); - public static async Task CreateForApp(AppData appData) + var modelCacheStore = await ModelCacheStore.CreateForApp(appData.ModelCachePath); + var instance = new ModelCache(appData, downloadQueue, modelCacheStore) { - var downloadQueue = new ModelDownloadQueue(appData.ModelCachePath); + CacheStore = modelCacheStore + }; + instance.DownloadQueue.ModelDownloadCompleted += instance.ModelDownloadQueue_ModelDownloadCompleted; - var modelCacheStore = await ModelCacheStore.CreateForApp(appData.ModelCachePath); - var instance = new ModelCache(appData, downloadQueue, modelCacheStore) - { - CacheStore = modelCacheStore - }; - instance.DownloadQueue.ModelDownloadCompleted += instance.ModelDownloadQueue_ModelDownloadCompleted; + return instance; + } - return instance; - } + private async void ModelDownloadQueue_ModelDownloadCompleted(object? sender, ModelDownloadCompletedEventArgs e) + { + var cachedModel = e.CachedModel; + await CacheStore.AddModel(cachedModel); + } - private async void ModelDownloadQueue_ModelDownloadCompleted(object? sender, ModelDownloadCompletedEventArgs e) - { - var cachedModel = e.CachedModel; - await CacheStore.AddModel(cachedModel); - } + public string GetCacheFolder() + { + return _appData.ModelCachePath; + } - public string GetCacheFolder() - { - return _appData.ModelCachePath; - } + public async Task SetCacheFolderPath(string newPath, List? models = null) + { + _appData.ModelCachePath = newPath; + await _appData.SaveAsync(); - public async Task SetCacheFolderPath(string newPath, List? models = null) - { - _appData.ModelCachePath = newPath; - await _appData.SaveAsync(); + CacheStore = await ModelCacheStore.CreateForApp(newPath, models); - CacheStore = await ModelCacheStore.CreateForApp(newPath, models); + // cancel existing downloads + DownloadQueue.CacheDir = newPath; + DownloadQueue.GetDownloads().ToList().ForEach(DownloadQueue.CancelModelDownload); + } - // cancel existing downloads - DownloadQueue.CacheDir = newPath; - DownloadQueue.GetDownloads().ToList().ForEach(DownloadQueue.CancelModelDownload); + public ModelDownload? AddModelToDownloadQueue(ModelDetails modelDetails) + { + if (IsModelCached(modelDetails.Url)) + { + return null; } - public ModelDownload? AddModelToDownloadQueue(ModelDetails modelDetails) + var existingDownload = DownloadQueue.GetDownload(modelDetails.Url); + if (existingDownload != null) { - if (IsModelCached(modelDetails.Url)) - { - return null; - } + return existingDownload; + } - var existingDownload = DownloadQueue.GetDownload(modelDetails.Url); - if (existingDownload != null) - { - return existingDownload; - } + var download = DownloadQueue.EnqueueModelDownload(modelDetails); + return download; + } - var download = DownloadQueue.EnqueueModelDownload(modelDetails); - return download; - } + public CachedModel? GetCachedModel(string url) + { + url = UrlHelpers.GetFullUrl(url); + return CacheStore.Models.FirstOrDefault(m => m.Url == url); + } - public CachedModel? GetCachedModel(string url) - { - url = UrlHelpers.GetFullUrl(url); - return CacheStore.Models.FirstOrDefault(m => m.Url == url); - } + public bool IsModelCached(string url) + { + url = UrlHelpers.GetFullUrl(url); + return CacheStore.Models.Any(m => m.Url == url); + } - public bool IsModelCached(string url) + public async Task DeleteModelFromCache(string url) + { + if (IsModelCached(url)) { url = UrlHelpers.GetFullUrl(url); - return CacheStore.Models.Any(m => m.Url == url); + await DeleteModelFromCache(CacheStore.Models.First(m => m.Url == url)); } + } - public async Task DeleteModelFromCache(string url) + public async Task DeleteModelFromCache(CachedModel model) + { + ModelDeletedEvent.Log(model.Url); + await CacheStore.RemoveModel(model); + if (model.IsFile && File.Exists(model.Path)) { - if (IsModelCached(url)) - { - url = UrlHelpers.GetFullUrl(url); - await DeleteModelFromCache(CacheStore.Models.First(m => m.Url == url)); - } + File.Delete(model.Path); } - - public async Task DeleteModelFromCache(CachedModel model) + else if (Directory.Exists(model.Path)) { - ModelDeletedEvent.Log(model.Url); - await CacheStore.RemoveModel(model); - if (model.IsFile && File.Exists(model.Path)) - { - File.Delete(model.Path); - } - else if (Directory.Exists(model.Path)) - { - Directory.Delete(model.Path, true); - } + Directory.Delete(model.Path, true); } + } - public async Task ClearCache() - { - ModelCacheDeletedEvent.Log(); + public async Task ClearCache() + { + ModelCacheDeletedEvent.Log(); - var cacheDir = GetCacheFolder(); - Directory.Delete(cacheDir, true); - await CacheStore.ClearAsync(); - } + var cacheDir = GetCacheFolder(); + Directory.Delete(cacheDir, true); + await CacheStore.ClearAsync(); + } - public async Task MoveCache(string path, CancellationToken ct) - { - ModelCacheMovedEvent.Log(path); - var sourceFolder = GetCacheFolder(); - /* _movedSize = 0; */ + public async Task MoveCache(string path, CancellationToken ct) + { + ModelCacheMovedEvent.Log(path); + var sourceFolder = GetCacheFolder(); + /* _movedSize = 0; */ - await Task.Run( - () => + await Task.Run( + () => + { + if (Directory.Exists(sourceFolder)) { - if (Directory.Exists(sourceFolder)) + if (Directory.Exists(path)) { - if (Directory.Exists(path)) - { - Directory.Delete(path, true); - } - - MoveFolder(sourceFolder, path, ct); - if (!ct.IsCancellationRequested && Directory.Exists(sourceFolder)) - { - Directory.Delete(sourceFolder, true); - } + Directory.Delete(path, true); } - }, - ct); - var newModels = CacheStore.Models.Select(m => new CachedModel(m.Details, m.Path.Replace(sourceFolder, path), m.IsFile, m.ModelSize)); - await SetCacheFolderPath(path, newModels.ToList()); - } + MoveFolder(sourceFolder, path, ct); + if (!ct.IsCancellationRequested && Directory.Exists(sourceFolder)) + { + Directory.Delete(sourceFolder, true); + } + } + }, + ct); + + var newModels = CacheStore.Models.Select(m => new CachedModel(m.Details, m.Path.Replace(sourceFolder, path), m.IsFile, m.ModelSize)); + await SetCacheFolderPath(path, newModels.ToList()); + } - private void MoveFolder(string sourcePath, string destinationPath, CancellationToken ct) + private void MoveFolder(string sourcePath, string destinationPath, CancellationToken ct) + { + if (Path.GetPathRoot(sourcePath) != Path.GetPathRoot(destinationPath)) { - if (Path.GetPathRoot(sourcePath) != Path.GetPathRoot(destinationPath)) - { - CopyFolder(sourcePath, destinationPath, ct); - } - else - { - Directory.Move(sourcePath, destinationPath); - } + CopyFolder(sourcePath, destinationPath, ct); + } + else + { + Directory.Move(sourcePath, destinationPath); } + } - private void CopyFolder(string sourceFolder, string destFolder, CancellationToken ct) + private void CopyFolder(string sourceFolder, string destFolder, CancellationToken ct) + { + if (!Directory.Exists(destFolder)) { - if (!Directory.Exists(destFolder)) - { - Directory.CreateDirectory(destFolder); - } + Directory.CreateDirectory(destFolder); + } - ct.ThrowIfCancellationRequested(); + ct.ThrowIfCancellationRequested(); - string[] files = Directory.GetFiles(sourceFolder); - ct.ThrowIfCancellationRequested(); + string[] files = Directory.GetFiles(sourceFolder); + ct.ThrowIfCancellationRequested(); - foreach (string file in files) - { - string name = Path.GetFileName(file); - string dest = Path.Combine(destFolder, name); - File.Copy(file, dest); - /* _movedSize += new FileInfo(dest).Length; */ - ct.ThrowIfCancellationRequested(); - } - - string[] folders = Directory.GetDirectories(sourceFolder); + foreach (string file in files) + { + string name = Path.GetFileName(file); + string dest = Path.Combine(destFolder, name); + File.Copy(file, dest); + /* _movedSize += new FileInfo(dest).Length; */ ct.ThrowIfCancellationRequested(); + } - foreach (string folder in folders) - { - string name = Path.GetFileName(folder); - string dest = Path.Combine(destFolder, name); - MoveFolder(folder, dest, ct); - ct.ThrowIfCancellationRequested(); - } + string[] folders = Directory.GetDirectories(sourceFolder); + ct.ThrowIfCancellationRequested(); + + foreach (string folder in folders) + { + string name = Path.GetFileName(folder); + string dest = Path.Combine(destFolder, name); + MoveFolder(folder, dest, ct); + ct.ThrowIfCancellationRequested(); } } } \ No newline at end of file diff --git a/AIDevGallery/Utils/ModelCacheStore.cs b/AIDevGallery/Utils/ModelCacheStore.cs index 18c3b88..3dc272a 100644 --- a/AIDevGallery/Utils/ModelCacheStore.cs +++ b/AIDevGallery/Utils/ModelCacheStore.cs @@ -9,112 +9,111 @@ using System.Text.Json; using System.Threading.Tasks; -namespace AIDevGallery.Utils +namespace AIDevGallery.Utils; + +internal class ModelCacheStore { - internal class ModelCacheStore - { - public IReadOnlyList Models => _models.AsReadOnly(); + public IReadOnlyList Models => _models.AsReadOnly(); - public delegate void ModelsChangedHandler(ModelCacheStore sender); - public event ModelsChangedHandler? ModelsChanged; + public delegate void ModelsChangedHandler(ModelCacheStore sender); + public event ModelsChangedHandler? ModelsChanged; - private readonly List _models = []; + private readonly List _models = []; - public string CacheDir { get; init; } = null!; + public string CacheDir { get; init; } = null!; - private ModelCacheStore(string cacheDir, List? models) - { - CacheDir = cacheDir; - _models = models ?? []; - } + private ModelCacheStore(string cacheDir, List? models) + { + CacheDir = cacheDir; + _models = models ?? []; + } - public static async Task CreateForApp(string cacheDir, List? models = null) - { - ModelCacheStore? modelCacheStore = null; + public static async Task CreateForApp(string cacheDir, List? models = null) + { + ModelCacheStore? modelCacheStore = null; - try + try + { + if (models == null) { - if (models == null) + var cacheFile = Path.Combine(cacheDir, "cache.json"); + if (File.Exists(cacheFile)) { - var cacheFile = Path.Combine(cacheDir, "cache.json"); - if (File.Exists(cacheFile)) - { - var json = await File.ReadAllTextAsync(cacheFile); + var json = await File.ReadAllTextAsync(cacheFile); - modelCacheStore = new ModelCacheStore(cacheDir, JsonSerializer.Deserialize(json, AppDataSourceGenerationContext.Default.ListCachedModel)); - } - } - else - { - modelCacheStore = new(cacheDir, models); + modelCacheStore = new ModelCacheStore(cacheDir, JsonSerializer.Deserialize(json, AppDataSourceGenerationContext.Default.ListCachedModel)); } } - catch + else { + modelCacheStore = new(cacheDir, models); } - - modelCacheStore ??= new ModelCacheStore(cacheDir, null); - await modelCacheStore.ValidateAndSaveAsync(); - - return modelCacheStore; } - - private async Task SaveAsync() + catch { - var cacheFile = Path.Combine(CacheDir, "cache.json"); - - var str = JsonSerializer.Serialize(_models, AppDataSourceGenerationContext.Default.ListCachedModel); - - if (!Path.Exists(CacheDir)) - { - Directory.CreateDirectory(CacheDir); - } - - await File.WriteAllTextAsync(cacheFile, str); } - public async Task AddModel(CachedModel model) - { - var existingModel = _models.Where(m => m.Url == model.Url).ToList(); - foreach (var cachedModel in existingModel) - { - _models.Remove(cachedModel); - } + modelCacheStore ??= new ModelCacheStore(cacheDir, null); + await modelCacheStore.ValidateAndSaveAsync(); - _models.Add(model); + return modelCacheStore; + } - ModelsChanged?.Invoke(this); + private async Task SaveAsync() + { + var cacheFile = Path.Combine(CacheDir, "cache.json"); - await SaveAsync(); - } + var str = JsonSerializer.Serialize(_models, AppDataSourceGenerationContext.Default.ListCachedModel); - public async Task RemoveModel(CachedModel model) + if (!Path.Exists(CacheDir)) { - _models.Remove(model); - ModelsChanged?.Invoke(this); - await SaveAsync(); + Directory.CreateDirectory(CacheDir); } - public async Task ClearAsync() + await File.WriteAllTextAsync(cacheFile, str); + } + + public async Task AddModel(CachedModel model) + { + var existingModel = _models.Where(m => m.Url == model.Url).ToList(); + foreach (var cachedModel in existingModel) { - _models.Clear(); - ModelsChanged?.Invoke(this); - await SaveAsync(); + _models.Remove(cachedModel); } - private async Task ValidateAndSaveAsync() - { - List models = [.. _models]; + _models.Add(model); + + ModelsChanged?.Invoke(this); + + await SaveAsync(); + } + + public async Task RemoveModel(CachedModel model) + { + _models.Remove(model); + ModelsChanged?.Invoke(this); + await SaveAsync(); + } - foreach (var cachedModel in models) + public async Task ClearAsync() + { + _models.Clear(); + ModelsChanged?.Invoke(this); + await SaveAsync(); + } + + private async Task ValidateAndSaveAsync() + { + List models = [.. _models]; + + foreach (var cachedModel in models) + { + if (!Path.Exists(cachedModel.Path)) { - if (!Path.Exists(cachedModel.Path)) - { - _models.Remove(cachedModel); - } + _models.Remove(cachedModel); } - - await SaveAsync(); } + + await SaveAsync(); } } \ No newline at end of file diff --git a/AIDevGallery/Utils/ModelDownload.cs b/AIDevGallery/Utils/ModelDownload.cs index 2b58749..9502bc6 100644 --- a/AIDevGallery/Utils/ModelDownload.cs +++ b/AIDevGallery/Utils/ModelDownload.cs @@ -6,43 +6,42 @@ using System.Text.Json.Serialization; using System.Threading; -namespace AIDevGallery.Utils +namespace AIDevGallery.Utils; + +internal class ModelDownload : IDisposable { - internal class ModelDownload : IDisposable - { - public ModelDetails Details { get; set; } - public DownloadStatus DownloadStatus { get; set; } = DownloadStatus.Waiting; - public float DownloadProgress { get; set; } + public ModelDetails Details { get; set; } + public DownloadStatus DownloadStatus { get; set; } = DownloadStatus.Waiting; + public float DownloadProgress { get; set; } + + public CancellationTokenSource CancellationTokenSource { get; } = new CancellationTokenSource(); - public CancellationTokenSource CancellationTokenSource { get; } = new CancellationTokenSource(); + public ModelUrl ModelUrl { get; set; } - public ModelUrl ModelUrl { get; set; } + public void Dispose() + { + CancellationTokenSource.Dispose(); + } - public void Dispose() + public ModelDownload(ModelDetails details) + { + Details = details; + if (details.Url.StartsWith("https://github.com", StringComparison.InvariantCulture)) { - CancellationTokenSource.Dispose(); + ModelUrl = new GitHubUrl(details.Url); } - - public ModelDownload(ModelDetails details) + else { - Details = details; - if (details.Url.StartsWith("https://github.com", StringComparison.InvariantCulture)) - { - ModelUrl = new GitHubUrl(details.Url); - } - else - { - ModelUrl = new HuggingFaceUrl(details.Url); - } + ModelUrl = new HuggingFaceUrl(details.Url); } } +} - [JsonConverter(typeof(JsonStringEnumConverter))] - internal enum DownloadStatus - { - Waiting, - InProgress, - Completed, - Canceled - } +[JsonConverter(typeof(JsonStringEnumConverter))] +internal enum DownloadStatus +{ + Waiting, + InProgress, + Completed, + Canceled } \ No newline at end of file diff --git a/AIDevGallery/Utils/ModelDownloadQueue.cs b/AIDevGallery/Utils/ModelDownloadQueue.cs index ae5db96..75ffc7f 100644 --- a/AIDevGallery/Utils/ModelDownloadQueue.cs +++ b/AIDevGallery/Utils/ModelDownloadQueue.cs @@ -13,136 +13,124 @@ using System.Threading; using System.Threading.Tasks; -namespace AIDevGallery.Utils +namespace AIDevGallery.Utils; + +internal class ModelDownloadQueue(string cacheDir) { - internal class ModelDownloadQueue(string cacheDir) - { - private readonly List _queue = []; - public event EventHandler? ModelDownloadProgressChanged; - public event EventHandler? ModelDownloadCompleted; + private readonly List _queue = []; + public event EventHandler? ModelDownloadProgressChanged; + public event EventHandler? ModelDownloadCompleted; - public delegate void ModelsChangedHandler(ModelDownloadQueue sender); - public event ModelsChangedHandler? ModelsChanged; + public delegate void ModelsChangedHandler(ModelDownloadQueue sender); + public event ModelsChangedHandler? ModelsChanged; - public string CacheDir { get; set; } = cacheDir; + public string CacheDir { get; set; } = cacheDir; - private Task? processingTask; + private Task? processingTask; - public ModelDownload EnqueueModelDownload(ModelDetails modelDetails) - { - var url = UrlHelpers.GetFullUrl(modelDetails.Url); + public ModelDownload EnqueueModelDownload(ModelDetails modelDetails) + { + var url = UrlHelpers.GetFullUrl(modelDetails.Url); - var modelDownload = new ModelDownload(modelDetails); + var modelDownload = new ModelDownload(modelDetails); - _queue.Add(modelDownload); - ModelDownloadEnqueueEvent.Log(modelDetails.Url); - ModelsChanged?.Invoke(this); + _queue.Add(modelDownload); + ModelDownloadEnqueueEvent.Log(modelDetails.Url); + ModelsChanged?.Invoke(this); - lock (this) - { - if (processingTask == null || processingTask.IsFaulted) - { - processingTask = Task.Run(ProcessDownloads); - } - } - - return modelDownload; - } - - public void CancelModelDownload(string url) + lock (this) { - var download = GetDownload(url); - if (download != null) + if (processingTask == null || processingTask.IsFaulted) { - CancelModelDownload(download); + processingTask = Task.Run(ProcessDownloads); } } - public void CancelModelDownload(ModelDownload download) - { - if (download.DownloadStatus != DownloadStatus.Canceled) - { - download.CancellationTokenSource.Cancel(); - download.DownloadStatus = DownloadStatus.Canceled; - } - - ModelDownloadCancelEvent.Log(download.Details.Url); - _queue.Remove(download); - ModelsChanged?.Invoke(this); - download.Dispose(); - OnModelDownloadProgressChanged(download, 0, DownloadStatus.Canceled); - } + return modelDownload; + } - public IReadOnlyList GetDownloads() + public void CancelModelDownload(string url) + { + var download = GetDownload(url); + if (download != null) { - return _queue.AsReadOnly(); + CancelModelDownload(download); } + } - public ModelDownload? GetDownload(string url) + public void CancelModelDownload(ModelDownload download) + { + if (download.DownloadStatus != DownloadStatus.Canceled) { - url = UrlHelpers.GetFullUrl(url); - return _queue.FirstOrDefault(d => UrlHelpers.GetFullUrl(d.Details.Url) == url); + download.CancellationTokenSource.Cancel(); + download.DownloadStatus = DownloadStatus.Canceled; } - private async Task ProcessDownloads() + ModelDownloadCancelEvent.Log(download.Details.Url); + _queue.Remove(download); + ModelsChanged?.Invoke(this); + download.Dispose(); + OnModelDownloadProgressChanged(download, 0, DownloadStatus.Canceled); + } + + public IReadOnlyList GetDownloads() + { + return _queue.AsReadOnly(); + } + + public ModelDownload? GetDownload(string url) + { + url = UrlHelpers.GetFullUrl(url); + return _queue.FirstOrDefault(d => UrlHelpers.GetFullUrl(d.Details.Url) == url); + } + + private async Task ProcessDownloads() + { + while (_queue.Count > 0) { - while (_queue.Count > 0) + var download = _queue[0]; + TaskCompletionSource tcs = new(); + App.MainWindow.DispatcherQueue.TryEnqueue(async () => { - var download = _queue[0]; - TaskCompletionSource tcs = new(); - App.MainWindow.DispatcherQueue.TryEnqueue(async () => + try { - try - { - await Download(download, download.CancellationTokenSource.Token); - _queue.Remove(download); - ModelsChanged?.Invoke(this); - download.Dispose(); - tcs.SetResult(true); - } - catch (TaskCanceledException) - { - } - catch (Exception e) - { - tcs.SetException(e); - } - }); - - await tcs.Task; - } + await Download(download, download.CancellationTokenSource.Token); + _queue.Remove(download); + ModelsChanged?.Invoke(this); + download.Dispose(); + tcs.SetResult(true); + } + catch (TaskCanceledException) + { + } + catch (Exception e) + { + tcs.SetException(e); + } + }); - processingTask = null; + await tcs.Task; } - private async Task Download(ModelDownload modelDownload, CancellationToken cancellationToken) - { - modelDownload.DownloadStatus = DownloadStatus.InProgress; + processingTask = null; + } - Progress progress = new(p => - { - modelDownload.DownloadProgress = p; - OnModelDownloadProgressChanged(modelDownload, p, DownloadStatus.InProgress); - }); - ModelDownloadStartEvent.Log(modelDownload.Details.Url); - CachedModel cachedModel; - try - { - cachedModel = await DownloadModel(modelDownload.Details, CacheDir, progress, cancellationToken); + private async Task Download(ModelDownload modelDownload, CancellationToken cancellationToken) + { + modelDownload.DownloadStatus = DownloadStatus.InProgress; - if (cancellationToken.IsCancellationRequested) - { - modelDownload.DownloadStatus = DownloadStatus.Canceled; - var localPath = modelDownload.ModelUrl.GetLocalPath(CacheDir); - if (Directory.Exists(localPath)) - { - Directory.Delete(localPath, true); - } - - return; - } - } - catch (Exception e) + Progress progress = new(p => + { + modelDownload.DownloadProgress = p; + OnModelDownloadProgressChanged(modelDownload, p, DownloadStatus.InProgress); + }); + ModelDownloadStartEvent.Log(modelDownload.Details.Url); + CachedModel cachedModel; + try + { + cachedModel = await DownloadModel(modelDownload.Details, CacheDir, progress, cancellationToken); + + if (cancellationToken.IsCancellationRequested) { modelDownload.DownloadStatus = DownloadStatus.Canceled; var localPath = modelDownload.ModelUrl.GetLocalPath(CacheDir); @@ -151,131 +139,142 @@ private async Task Download(ModelDownload modelDownload, CancellationToken cance Directory.Delete(localPath, true); } - ModelDownloadFailedEvent.Log(modelDownload.Details.Url, e); return; } - - modelDownload.DownloadStatus = DownloadStatus.Completed; - - ModelDownloadCompleteEvent.Log(cachedModel.Url); - ModelDownloadCompleted?.Invoke(this, new ModelDownloadCompletedEventArgs - { - CachedModel = cachedModel - }); - OnModelDownloadProgressChanged(modelDownload, 1, DownloadStatus.Completed); - SendNotification(modelDownload.Details); } - - private void OnModelDownloadProgressChanged(ModelDownload modelDownload, float p, DownloadStatus downloadStatus) + catch (Exception e) { - ModelDownloadProgressChanged?.Invoke(this, new ModelDownloadProgressEventArgs + modelDownload.DownloadStatus = DownloadStatus.Canceled; + var localPath = modelDownload.ModelUrl.GetLocalPath(CacheDir); + if (Directory.Exists(localPath)) { - ModelUrl = modelDownload.Details.Url, - Progress = p, - Status = downloadStatus - }); + Directory.Delete(localPath, true); + } + + ModelDownloadFailedEvent.Log(modelDownload.Details.Url, e); + return; } - public static async Task DownloadModel(ModelDetails model, string cacheDir, IProgress? progress = null, CancellationToken cancellationToken = default) + modelDownload.DownloadStatus = DownloadStatus.Completed; + + ModelDownloadCompleteEvent.Log(cachedModel.Url); + ModelDownloadCompleted?.Invoke(this, new ModelDownloadCompletedEventArgs { - ModelUrl url; - List filesToDownload; - if (model.Url.StartsWith("https://github.com", StringComparison.InvariantCulture)) - { - var ghUrl = new GitHubUrl(model.Url); - filesToDownload = await ModelInformationHelper.GetDownloadFilesFromGitHub(ghUrl, cancellationToken); - url = ghUrl; - } - else + CachedModel = cachedModel + }); + OnModelDownloadProgressChanged(modelDownload, 1, DownloadStatus.Completed); + SendNotification(modelDownload.Details); + } + + private void OnModelDownloadProgressChanged(ModelDownload modelDownload, float p, DownloadStatus downloadStatus) + { + ModelDownloadProgressChanged?.Invoke(this, new ModelDownloadProgressEventArgs + { + ModelUrl = modelDownload.Details.Url, + Progress = p, + Status = downloadStatus + }); + } + + public static async Task DownloadModel(ModelDetails model, string cacheDir, IProgress? progress = null, CancellationToken cancellationToken = default) + { + ModelUrl url; + List filesToDownload; + if (model.Url.StartsWith("https://github.com", StringComparison.InvariantCulture)) + { + var ghUrl = new GitHubUrl(model.Url); + filesToDownload = await ModelInformationHelper.GetDownloadFilesFromGitHub(ghUrl, cancellationToken); + url = ghUrl; + } + else + { + var hfUrl = new HuggingFaceUrl(model.Url); + using var socketsHttpHandler = new SocketsHttpHandler { - var hfUrl = new HuggingFaceUrl(model.Url); - using var socketsHttpHandler = new SocketsHttpHandler - { - MaxConnectionsPerServer = 4 - }; - filesToDownload = await ModelInformationHelper.GetDownloadFilesFromHuggingFace(hfUrl, socketsHttpHandler, cancellationToken); - url = hfUrl; - } + MaxConnectionsPerServer = 4 + }; + filesToDownload = await ModelInformationHelper.GetDownloadFilesFromHuggingFace(hfUrl, socketsHttpHandler, cancellationToken); + url = hfUrl; + } - var localFolderPath = $"{cacheDir}\\{url.Organization}--{url.Repo}\\{url.Ref}"; - Directory.CreateDirectory(localFolderPath); + var localFolderPath = $"{cacheDir}\\{url.Organization}--{url.Repo}\\{url.Ref}"; + Directory.CreateDirectory(localFolderPath); - var existingFiles = Directory.GetFiles(localFolderPath, "*", SearchOption.AllDirectories); + var existingFiles = Directory.GetFiles(localFolderPath, "*", SearchOption.AllDirectories); - filesToDownload = ModelInformationHelper.FilterFiles(filesToDownload, model.FileFilters); + filesToDownload = ModelInformationHelper.FilterFiles(filesToDownload, model.FileFilters); - long modelSize = filesToDownload.Sum(f => f.Size); - long bytesDownloaded = 0; + long modelSize = filesToDownload.Sum(f => f.Size); + long bytesDownloaded = 0; - var internalProgress = new Progress(p => - { - var percentage = (float)(bytesDownloaded + p) / (float)modelSize; - progress?.Report(percentage); - }); + var internalProgress = new Progress(p => + { + var percentage = (float)(bytesDownloaded + p) / (float)modelSize; + progress?.Report(percentage); + }); - using var client = new HttpClient(); + using var client = new HttpClient(); - foreach (var downloadableFile in filesToDownload) + foreach (var downloadableFile in filesToDownload) + { + if (downloadableFile.DownloadUrl == null) { - if (downloadableFile.DownloadUrl == null) - { - continue; - } + continue; + } - var filePath = Path.Combine(localFolderPath, downloadableFile.Path!.Replace("/", "\\")); + var filePath = Path.Combine(localFolderPath, downloadableFile.Path!.Replace("/", "\\")); - var existingFile = existingFiles.Where(f => f == filePath).FirstOrDefault(); - if (existingFile != null) - { - // check if the file is the same size as the one on the server - var existingFileInfo = new FileInfo(existingFile); - if (existingFileInfo.Length == downloadableFile.Size) - { - continue; - } - } - - Directory.CreateDirectory(Path.GetDirectoryName(filePath)!); - using (FileStream file = new(filePath, FileMode.Create, FileAccess.Write, FileShare.None)) + var existingFile = existingFiles.Where(f => f == filePath).FirstOrDefault(); + if (existingFile != null) + { + // check if the file is the same size as the one on the server + var existingFileInfo = new FileInfo(existingFile); + if (existingFileInfo.Length == downloadableFile.Size) { - await client.DownloadAsync(downloadableFile.DownloadUrl, file, null, internalProgress, cancellationToken); - file.Close(); + continue; } + } - var fileInfo = new FileInfo(filePath); - if (fileInfo.Length != downloadableFile.Size) - { - // file did not download properly, should retry - } + Directory.CreateDirectory(Path.GetDirectoryName(filePath)!); + using (FileStream file = new(filePath, FileMode.Create, FileAccess.Write, FileShare.None)) + { + await client.DownloadAsync(downloadableFile.DownloadUrl, file, null, internalProgress, cancellationToken); + file.Close(); + } - bytesDownloaded += downloadableFile.Size; + var fileInfo = new FileInfo(filePath); + if (fileInfo.Length != downloadableFile.Size) + { + // file did not download properly, should retry } - var modelDirectory = url.GetLocalPath(cacheDir); - return new CachedModel(model, url.IsFile ? $"{modelDirectory}\\{filesToDownload.First().Name}" : modelDirectory, url.IsFile, modelSize); + bytesDownloaded += downloadableFile.Size; } - private static void SendNotification(ModelDetails model) - { - var builder = new AppNotificationBuilder() - .AddText(model.Name + " is ready to use.") - .AddButton(new AppNotificationButton("Try it out") - .AddArgument("model", model.Id)); - - var notificationManager = AppNotificationManager.Default; - notificationManager.Show(builder.BuildNotification()); - } + var modelDirectory = url.GetLocalPath(cacheDir); + return new CachedModel(model, url.IsFile ? $"{modelDirectory}\\{filesToDownload.First().Name}" : modelDirectory, url.IsFile, modelSize); } - internal class ModelDownloadProgressEventArgs + private static void SendNotification(ModelDetails model) { - public required string ModelUrl { get; init; } - public required float Progress { get; init; } - public required DownloadStatus Status { get; init; } - } + var builder = new AppNotificationBuilder() + .AddText(model.Name + " is ready to use.") + .AddButton(new AppNotificationButton("Try it out") + .AddArgument("model", model.Id)); - internal class ModelDownloadCompletedEventArgs - { - public required CachedModel CachedModel { get; init; } + var notificationManager = AppNotificationManager.Default; + notificationManager.Show(builder.BuildNotification()); } +} + +internal class ModelDownloadProgressEventArgs +{ + public required string ModelUrl { get; init; } + public required float Progress { get; init; } + public required DownloadStatus Status { get; init; } +} + +internal class ModelDownloadCompletedEventArgs +{ + public required CachedModel CachedModel { get; init; } } \ No newline at end of file diff --git a/AIDevGallery/Utils/SourceGenerationContext.cs b/AIDevGallery/Utils/SourceGenerationContext.cs index 6248e59..6386d30 100644 --- a/AIDevGallery/Utils/SourceGenerationContext.cs +++ b/AIDevGallery/Utils/SourceGenerationContext.cs @@ -5,11 +5,10 @@ using System.Collections.Generic; using System.Text.Json.Serialization; -namespace AIDevGallery.Utils +namespace AIDevGallery.Utils; + +[JsonSerializable(typeof(List))] +[JsonSerializable(typeof(GenAIConfig))] +internal partial class SourceGenerationContext : JsonSerializerContext { - [JsonSerializable(typeof(List))] - [JsonSerializable(typeof(GenAIConfig))] - internal partial class SourceGenerationContext : JsonSerializerContext - { - } } \ No newline at end of file