之前運行在服務器上的程序 , 注冊 IoC, Optons 都是通過掃描目錄下的DLL, 然后遍歷包含特定 Attribute 的類, 在跟據 Attribute 里的參數去注冊.
AOT不支持動態加載, 同時會打包成一個文件, 所以不可能在去掃描DLL, 所以之前的方法, 基本行不通.
這里使用了 IIncrementalGenerator
, 而不是 ISourceGenerator
, 是因為這兩個東西的支持文檔都不友好 ,研究一個都耗費了好長時間.
[Generator(LanguageNames.CSharp)]
public class RegistServiceGenerator : IIncrementalGenerator
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
//Debugger.Launch();
var provider = context.CompilationProvider.Select((compilation, cts) =>
{
var ns1 = new List<INamespaceSymbol>() {
compilation.Assembly.GlobalNamespace
};
var ns2 = compilation.SourceModule
.ReferencedAssemblySymbols
.Where(s => s.Name.Contains("CNB.PubTools"))
.Select(s => s.GlobalNamespace);
var ns = ns1.Concat(ns2);
var symbols = ns.SelectMany(s => Helper.GetAllTypeSymbol(s));
return Helper.Transform(symbols.ToArray());
});
context.RegisterSourceOutput(provider, (spc, datas) =>
{
if (datas.Any())
{
var code = Helper.GetCode(datas);
spc.AddSource("Regist.g.cs", code);
}
});
}
}
上面這段代碼, 主要是:
1, 從 當前 及 當前引用的所有 的 Assembly
里找到所有的 GlobalNamespace
。
2, 遍歷這些命名空間下的 Symbol
。
3, 把這些 Symbol
轉換為中間數據。
4, 把中間數據通過 SourceOutput 生成代碼。
遍歷所有命名空間, 找到所有的 Symbol
:
public static IEnumerable<INamedTypeSymbol> GetAllTypeSymbol(INamespaceSymbol namespaceSymbol)
{
var typeMemberList = namespaceSymbol.GetTypeMembers();
foreach (var typeSymbol in typeMemberList)
{
yield return typeSymbol;
}
foreach (var namespaceMember in namespaceSymbol.GetNamespaceMembers())
{
foreach (var typeSymbol in GetAllTypeSymbol(namespaceMember))
{
yield return typeSymbol;
}
}
}
轉換為中間數據:
private static readonly string REGISTATTRIBUTE_NAME = "CNB.PubTools.Common.Attributes.RegistAttribute";
...
public static IEnumerable<RegistDescriptor> Transform(params INamedTypeSymbol[] symbols)
{
if (symbols?.Any() != true)
return Enumerable.Empty<RegistDescriptor>();
var typeSymbols = symbols
.Select(s => new
{
ClassName = s.ToDisplayString(),
Attributes = s.GetAttributes().Where(a => a.AttributeClass.ToDisplayString() == REGISTATTRIBUTE_NAME || a.AttributeClass.BaseType.ToDisplayString() == REGISTATTRIBUTE_NAME)
})
.Where(s => s.Attributes?.Any() == true);
var descriptors = new List<RegistDescriptor>();
foreach (var typeSymbol in typeSymbols)
{
var tmp = typeSymbol.Attributes.Select(a =>
{
var name = a.AttributeClass.Name;
return name switch
{
"RegistAttribute" => a.AttributeClass.IsGenericType
? new RegistDescriptor()
{
Lifetime = (int)a.ConstructorArguments[0].Value,
ForType = a.AttributeClass.TypeArguments.First().ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat),
TargetType = typeSymbol.ClassName,
Tag = a.AttributeClass.Name
}
: new RegistDescriptor()
{
Lifetime = (int)a.ConstructorArguments[0].Value,
ForType = ((ITypeSymbol)a.ConstructorArguments[1].Value)?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) ?? typeSymbol.ClassName,
TargetType = typeSymbol.ClassName,
Tag = a.AttributeClass.Name
},
"RegistInitializerAttribute" => new RegistDescriptor()
{
TargetType = typeSymbol.ClassName,
Tag = a.AttributeClass.Name
},
"RegistViewAttribute" => new RegistDescriptor()
{
TargetType = typeSymbol.ClassName,
Tag = a.AttributeClass.Name,
Name = typeSymbol.ClassName
},
_ => new RegistDescriptor(),
};
});
descriptors.AddRange(tmp);
}
return descriptors;
}
.Where(a => a.AttributeClass.ToDisplayString() == REGISTATTRIBUTE_NAME || a.AttributeClass.BaseType.ToDisplayString() == REGISTATTRIBUTE_NAME)
這一段是判斷當前 Symbol
是否包含特定的 Attribute, 或者 或父類是否是特定的 Attribute.
因為 RegistAttribute
會有很多變形,如果完全解析的話, 難度和時間成本太大, 所以這里采取了一個折衷的辦法, 即 switch 那一段, 跟據不同的名稱,生成不同的數據。
除了 ConstructorArguments
之外, 還有 NamedArguments
, 但是這個東西的變形太大, 不方便在這里使用, 所以要求 RegistAttribute
要足夠簡單。
[AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = false)]
public class RegistAttribute : Attribute/*, INamed*/
{
/// <summary>
/// 模式
/// </summary>
public RegistMode Mode { get; }
/// <summary>
/// 注冊為哪個類型
/// </summary>
public Type? ForType { get; }
public RegistAttribute(RegistMode mode, Type? forType = null)
{
this.Mode = mode;
this.ForType = forType;
}
}
[AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = false)]
public class RegistAttribute<T> : RegistAttribute
{
public RegistAttribute(RegistMode mode) : base(mode, typeof(T))
{
}
}
[AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = false)]
public class RegistInitializerAttribute : RegistAttribute
{
/// <summary>
///
/// </summary>
public RegistInitializerAttribute() : base(RegistMode.Singleton, typeof(IInitializer))
{
}
}
生成代碼:
public static string GetCode(RegistDescriptor descriptor)
{
if (string.IsNullOrWhiteSpace(descriptor.ForType))
return $"RegistService(sc, typeof({descriptor.TargetType}), \"{descriptor.Tag}\", \"{descriptor.Name}\");";
else
return $"RegistService(sc, typeof({descriptor.ForType}), typeof({descriptor.TargetType}), {descriptor.Lifetime});";
}
public static string GetCode(IEnumerable<RegistDescriptor> ds)
{
if (ds?.Any() != true)
ds = Enumerable.Empty<RegistDescriptor>();
var arr = ds.Select(GetCode).Distinct().OrderBy(s => s);
var str = string.Join("\r\n ", arr);
var code = $$"""
using CNB.PubTools.Common;
using Microsoft.Extensions.DependencyInjection;
using System;
using System.Diagnostics.CodeAnalysis;
namespace CNB.PubTools
{
internal partial class Regist
{
public static partial void RegistService(IServiceCollection sc)
{
{{str}}
}
}
}
""";
return code;
}
結合上面的 SourceGenerator , 需要另外定義一個 partial 的類:
/// <summary>
///
/// </summary>
internal partial class Regist
{
/// <summary>
///
/// </summary>
/// <param name="sc"></param>
/// <param name="serviceType"></param>
/// <param name="targetType"></param>
/// <param name="mode"></param>
private static void RegistService(IServiceCollection sc, Type serviceType, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, int mode)
{
switch ((RegistMode)mode)
{
case RegistMode.Scoped:
sc.AddScoped(serviceType, targetType);
break;
case RegistMode.Singleton:
sc.AddSingleton(serviceType, targetType);
break;
case RegistMode.PreRequest:
sc.AddTransient(serviceType, targetType);
break;
}
}
private static void RegistService(IServiceCollection sc, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, string tag, string? name)
{
switch (tag)
{
case "RegistInitializerAttribute":
sc.AddSingleton(typeof(IInitializer), targetType);
break;
case "RegistViewAttribute":
sc.TryAddKeyedTransient(typeof(IView), name, targetType);
break;
}
}
public static partial void RegistService(IServiceCollection sc);
}
然后,在主項目里引用這個 SourceGenerator
<ProjectReference Include="..\CNB.PubTools.SourceGenerator\CNB.PubTools.SourceGenerator.csproj" OutputItemType="Analyzer" ReferenceOutputAssembly="false" />
分析器項目, 不引用輸出的 dll, OutputItemType="Analyzer"
ReferenceOutputAssembly="false"
最后, 你可以在項目文件中添加:
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
把生成的代碼輸出到 obj 目錄下面。
或者在如下的地方找到 SourceGenerator 生成的文件: