用 SourceGenerator 生成 注冊 IoC / 配置的代碼

之前運行在服務器上的程序 , 注冊 IoC, Optons 都是通過掃描目錄下的DLL, 然后遍歷包含特定 Attribute 的類, 在跟據 Attribute 里的參數去注冊.

AOT不支持動態加載, 同時會打包成一個文件, 所以不可能在去掃描DLL, 所以之前的方法, 基本行不通.

這里使用了 IIncrementalGenerator, 而不是 ISourceGenerator , 是因為這兩個東西的支持文檔都不友好 ,研究一個都耗費了好長時間.

    [Generator(LanguageNames.CSharp)]
    public class RegistServiceGenerator : IIncrementalGenerator
    {

        public void Initialize(IncrementalGeneratorInitializationContext context)
        {

            //Debugger.Launch();

            var provider = context.CompilationProvider.Select((compilation, cts) =>
            {

                var ns1 = new List<INamespaceSymbol>() {
                    compilation.Assembly.GlobalNamespace
                };

                var ns2 = compilation.SourceModule
                            .ReferencedAssemblySymbols
                            .Where(s => s.Name.Contains("CNB.PubTools"))
                            .Select(s => s.GlobalNamespace);

                var ns = ns1.Concat(ns2);

                var symbols = ns.SelectMany(s => Helper.GetAllTypeSymbol(s));

                return Helper.Transform(symbols.ToArray());
            });

            context.RegisterSourceOutput(provider, (spc, datas) =>
            {
                if (datas.Any())
                {
                    var code = Helper.GetCode(datas);
                    spc.AddSource("Regist.g.cs", code);
                }
            });
        }
    }

上面這段代碼, 主要是:
1, 從 當前 及 當前引用的所有 的 Assembly 里找到所有的 GlobalNamespace
2, 遍歷這些命名空間下的 Symbol。
3, 把這些 Symbol 轉換為中間數據。
4, 把中間數據通過 SourceOutput 生成代碼。

遍歷所有命名空間, 找到所有的 Symbol:

public static IEnumerable<INamedTypeSymbol> GetAllTypeSymbol(INamespaceSymbol namespaceSymbol)
{
    var typeMemberList = namespaceSymbol.GetTypeMembers();

    foreach (var typeSymbol in typeMemberList)
    {
        yield return typeSymbol;
    }

    foreach (var namespaceMember in namespaceSymbol.GetNamespaceMembers())
    {
        foreach (var typeSymbol in GetAllTypeSymbol(namespaceMember))
        {
            yield return typeSymbol;
        }
    }
}

轉換為中間數據:

private static readonly string REGISTATTRIBUTE_NAME = "CNB.PubTools.Common.Attributes.RegistAttribute";
...
public static IEnumerable<RegistDescriptor> Transform(params INamedTypeSymbol[] symbols)
{
    if (symbols?.Any() != true)
        return Enumerable.Empty<RegistDescriptor>();

    var typeSymbols = symbols
                .Select(s => new
                {
                    ClassName = s.ToDisplayString(),
                    Attributes = s.GetAttributes().Where(a => a.AttributeClass.ToDisplayString() == REGISTATTRIBUTE_NAME || a.AttributeClass.BaseType.ToDisplayString() == REGISTATTRIBUTE_NAME)
                })
                .Where(s => s.Attributes?.Any() == true);

    var descriptors = new List<RegistDescriptor>();
    foreach (var typeSymbol in typeSymbols)
    {
        var tmp = typeSymbol.Attributes.Select(a =>
        {
            var name = a.AttributeClass.Name;
            return name switch
            {
                "RegistAttribute" => a.AttributeClass.IsGenericType
                    ? new RegistDescriptor()
                    {
                        Lifetime = (int)a.ConstructorArguments[0].Value,
                        ForType = a.AttributeClass.TypeArguments.First().ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
                        TargetType = typeSymbol.ClassName,
                        Tag = a.AttributeClass.Name
                    }
                    : new RegistDescriptor()
                    {
                        Lifetime = (int)a.ConstructorArguments[0].Value,
                        ForType = ((ITypeSymbol)a.ConstructorArguments[1].Value)?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) ?? typeSymbol.ClassName,
                        TargetType = typeSymbol.ClassName,
                        Tag = a.AttributeClass.Name
                    },
                "RegistInitializerAttribute" => new RegistDescriptor()
                {
                    TargetType = typeSymbol.ClassName,
                    Tag = a.AttributeClass.Name
                },
                "RegistViewAttribute" => new RegistDescriptor()
                {
                    TargetType = typeSymbol.ClassName,
                    Tag = a.AttributeClass.Name,
                    Name = typeSymbol.ClassName
                },
                _ => new RegistDescriptor(),
            };
        });
        descriptors.AddRange(tmp);
    }

    return descriptors;
}

.Where(a => a.AttributeClass.ToDisplayString() == REGISTATTRIBUTE_NAME || a.AttributeClass.BaseType.ToDisplayString() == REGISTATTRIBUTE_NAME)

這一段是判斷當前 Symbol 是否包含特定的 Attribute, 或者 或父類是否是特定的 Attribute.

因為 RegistAttribute 會有很多變形,如果完全解析的話, 難度和時間成本太大, 所以這里采取了一個折衷的辦法, 即 switch 那一段, 跟據不同的名稱,生成不同的數據。

除了 ConstructorArguments 之外, 還有 NamedArguments, 但是這個東西的變形太大, 不方便在這里使用, 所以要求 RegistAttribute 要足夠簡單。

[AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = false)]
public class RegistAttribute : Attribute/*, INamed*/
{

    /// <summary>
    /// 模式
    /// </summary>
    public RegistMode Mode { get; }

    /// <summary>
    /// 注冊為哪個類型
    /// </summary>
    public Type? ForType { get; }

    public RegistAttribute(RegistMode mode, Type? forType = null)
    {
        this.Mode = mode;
        this.ForType = forType;
    }
}

[AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = false)]
public class RegistAttribute<T> : RegistAttribute
{
    public RegistAttribute(RegistMode mode) : base(mode, typeof(T))
    {
    }
}

[AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = false)]
public class RegistInitializerAttribute : RegistAttribute
{
    /// <summary>
    /// 
    /// </summary>
    public RegistInitializerAttribute() : base(RegistMode.Singleton, typeof(IInitializer))
    {
    }
}

生成代碼:

public static string GetCode(RegistDescriptor descriptor)
{
    if (string.IsNullOrWhiteSpace(descriptor.ForType))
        return $"RegistService(sc, typeof({descriptor.TargetType}), \"{descriptor.Tag}\", \"{descriptor.Name}\");";
    else
        return $"RegistService(sc, typeof({descriptor.ForType}), typeof({descriptor.TargetType}), {descriptor.Lifetime});";
}

public static string GetCode(IEnumerable<RegistDescriptor> ds)
{
    if (ds?.Any() != true)
        ds = Enumerable.Empty<RegistDescriptor>();

    var arr = ds.Select(GetCode).Distinct().OrderBy(s => s);
    var str = string.Join("\r\n            ", arr);

    var code = $$"""
            using CNB.PubTools.Common;
            using Microsoft.Extensions.DependencyInjection;
            using System;
            using System.Diagnostics.CodeAnalysis;

            namespace CNB.PubTools
            {
                internal partial class Regist
                {

                    public static partial void RegistService(IServiceCollection sc)
                    {
                        {{str}}
                    }
                }
            }
            """;

    return code;
}

結合上面的 SourceGenerator , 需要另外定義一個 partial 的類:

    /// <summary>
    /// 
    /// </summary>
    internal partial class Regist
    {

        /// <summary>
        /// 
        /// </summary>
        /// <param name="sc"></param>
        /// <param name="serviceType"></param>
        /// <param name="targetType"></param>
        /// <param name="mode"></param>
        private static void RegistService(IServiceCollection sc, Type serviceType, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, int mode)
        {
            switch ((RegistMode)mode)
            {
                case RegistMode.Scoped:
                    sc.AddScoped(serviceType, targetType);
                    break;
                case RegistMode.Singleton:
                    sc.AddSingleton(serviceType, targetType);
                    break;
                case RegistMode.PreRequest:
                    sc.AddTransient(serviceType, targetType);
                    break;
            }
        }

        private static void RegistService(IServiceCollection sc, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, string tag, string? name)
        {
            switch (tag)
            {
                case "RegistInitializerAttribute":
                    sc.AddSingleton(typeof(IInitializer), targetType);
                    break;
                case "RegistViewAttribute":
                    sc.TryAddKeyedTransient(typeof(IView), name, targetType);
                    break;
            }
        }

        public static partial void RegistService(IServiceCollection sc);
    }

然后,在主項目里引用這個 SourceGenerator

<ProjectReference Include="..\CNB.PubTools.SourceGenerator\CNB.PubTools.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />

分析器項目, 不引用輸出的 dll, OutputItemType="Analyzer" ReferenceOutputAssembly="false"

最后, 你可以在項目文件中添加:

<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>

把生成的代碼輸出到 obj 目錄下面。
或者在如下的地方找到 SourceGenerator 生成的文件:


image.png
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容