JsonRpc/CodeGenerator/JsonRPCCodeGenerator.cs

507 lines
18 KiB
C#
Raw Permalink Normal View History

2025-10-14 21:05:08 +08:00
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Emit;
using Mono.Cecil;
using Mono.Cecil.Cil;
using Mono.Cecil.Rocks;
using System.Text;
namespace CodeGenerator
{
public class JsonRPCCodeGenerator
{
private readonly Action<string> logger;
private ModuleDefinition moduleDefinition = null!;
private List<TypeDefinition> types = null!;
private string filePath = null!;
private MethodDefinition moduleeStaticConstructor = null;
public JsonRPCCodeGenerator(Action<string> logFn)
{
this.logger = logFn;
}
public void LoadAssembly(string filePath)
{
this.filePath = filePath;
var readerParameters = new ReaderParameters
{
InMemory = true
};
moduleDefinition = ModuleDefinition.ReadModule(filePath, readerParameters);
}
public void WriteModule()
{
ArgumentNullException.ThrowIfNull(filePath, $"WriteModule call but {filePath} is null.");
moduleDefinition.Write(filePath);
}
private void AddModuleStaticConstructor(TypeDefinition type)
{
var sctor = type.GetStaticConstructor();
if (sctor == null)
{
sctor = new MethodDefinition(".cctor", MethodAttributes.Static | MethodAttributes.Private | MethodAttributes.HideBySig, moduleDefinition.ImportReference(typeof(void)));
sctor.IsRuntimeSpecialName = true;
sctor.IsSpecialName = true;
//var cw = moduleDefinition.ImportReference(typeof(Console).GetMethods().Where(m => m.GetParameters().Length == 1 && m.Name == "WriteLine" && m.GetParameters()[0].ParameterType == typeof(string)).First());
// sctor.Body.Instructions.Add(Instruction.Create(OpCodes.Nop));
// sctor.Body.Instructions.Add(Instruction.Create(OpCodes.Ldstr, "AddModuleStaticConstructor"));
// sctor.Body.Instructions.Add(Instruction.Create(OpCodes.Call, cw));
sctor.Body.Instructions.Add(Instruction.Create(OpCodes.Nop));
sctor.Body.Instructions.Add(Instruction.Create(OpCodes.Ret));
type.Methods.Add(sctor);
}
moduleeStaticConstructor = sctor;
}
private void AddModuleInitializerMethod(MethodReference methodReference)
{
ArgumentNullException.ThrowIfNull(moduleeStaticConstructor, $"moduleeStaticConstructor is null.");
if (moduleeStaticConstructor.Body.Instructions[0].OpCode.Code == Code.Nop)
{
moduleeStaticConstructor.Body.Instructions.Insert(1, Instruction.Create(OpCodes.Call, methodReference));
}
else
{
moduleeStaticConstructor.Body.Instructions.Insert(0, Instruction.Create(OpCodes.Call, methodReference));
}
}
public void ProcessAssembly()
{
ArgumentNullException.ThrowIfNull(moduleDefinition, $"moduleDefinition is null.");
types = moduleDefinition.GetTypes().ToList();
foreach (var type in types)
{
if (type.Name.Equals("<Module>"))
{
AddModuleStaticConstructor(type);
break;
}
}
if (moduleDefinition.Assembly.ContainsTargetAttribute())
{
foreach (var type in types)
{
if (type.IsCompilerGenerated())
{
continue;
}
foreach (var method in type.ConcreteMethods())
{
ProcessMethod(type, method);
}
}
return;
}
foreach (var type in types)
{
if (type.IsCompilerGenerated())
{
continue;
}
if (type.ContainsTargetAttribute())
{
foreach (var method in type.ConcreteMethods())
{
ProcessMethod(type, method);
}
continue;
}
foreach (var method in type.ConcreteMethods().Where(_ => _.ContainsTargetAttribute()))
{
ProcessMethod(type, method);
}
}
RemoveAttributes();
}
private MethodDefinition CompileFunc(MemoryStream ms, string funcBody, string returnType)
{
returnType = returnType.EndsWith("Void") ? "void" : returnType;
string code = @$"
using System;
public class DynamicClass
{{
public static {returnType} HelloWorld {funcBody}
}}";
// 创建语法树
SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(code);
// 引用必要的程序集
var references = AppDomain.CurrentDomain.GetAssemblies()
.Where(a => !a.IsDynamic)
.Select(a => MetadataReference.CreateFromFile(a.Location))
.Cast<MetadataReference>()
.ToList();
// 创建编译选项
CSharpCompilationOptions options = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary, allowUnsafe: true);
// 创建编译对象
CSharpCompilation compilation = CSharpCompilation.Create(
"DynamicAssembly",
new[] { syntaxTree },
references,
options
);
// 编译代码
EmitResult result = compilation.Emit(ms);
// 检查编译错误
if (!result.Success)
{
foreach (var diagnostic in result.Diagnostics)
{
Console.WriteLine(diagnostic.ToString());
}
return null;
}
// 加载编译后的程序集
ms.Seek(0, SeekOrigin.Begin);
var md = ModuleDefinition.ReadModule(ms);
var m = md.GetType("DynamicClass").Method("HelloWorld");
return m;
}
private string GetMethodSignature(MethodDefinition method)
{
StringBuilder sb = new StringBuilder();
sb.Append(method.DeclaringType.FullName.Replace('.', '_'));
sb.Append("__");
sb.Append(method.Name);
sb.Append('_');
for (int i = 0; i < method.Parameters.Count; i++)
{
var param = method.Parameters[i];
sb.Append('_');
sb.Append(param.ParameterType.Name);
}
return sb.ToString();
}
private MethodDefinition AddRPCRealMethod(TypeDefinition type, MethodDefinition method)
{
var methodSignature = GetMethodSignature(method);
MethodBody body = method.Body;
MethodDefinition realMethod = new MethodDefinition(methodSignature, MethodAttributes.Public | MethodAttributes.Static, method.ReturnType);
realMethod.Body.InitLocals = body.InitLocals;
realMethod.Body.Instructions.Clear();
realMethod.Body.Variables.Clear();
for (int i = 0; i < body.Instructions.Count; i++)
{
var ins = body.Instructions[i];
realMethod.Body.Instructions.Add(ins);
}
for (int i = 0; i < body.Variables.Count; i++)
{
realMethod.Body.Variables.Add(body.Variables[i]);
}
for (int i = 0; i < method.Parameters.Count; i++)
{
var param = method.Parameters[i];
realMethod.Parameters.Add(param);
}
type.Methods.Add(realMethod);
return realMethod;
}
private void AddRPCHandleMethod(TypeDefinition type, MethodDefinition method)
{
var name = method.Name +"_Handle";
var handleMethod = new MethodDefinition(name, MethodAttributes.Static | MethodAttributes.Public, moduleDefinition.ImportReference(typeof(JsonRPC.Protocol.JsonRPCResponse)));
handleMethod.Parameters.Add(new("req", ParameterAttributes.In, moduleDefinition.ImportReference(typeof(JsonRPC.Protocol.JsonRPCRequest))));
using var ms = new MemoryStream();
StringBuilder bodyBuilder = new();
StringBuilder parameterBuilder = new();
type.Methods.Add(handleMethod);
bodyBuilder.AppendLine(@$"
(JsonRPC.Protocol.JsonRPCRequest req)
{{
if(req.Params != null && req.Params.Count != {method.Parameters.Count})
{{
return new JsonRPC.Protocol.JsonRPCResponse()
{{
Id = req.Id,
Error = new JsonRPC.Protocol.JsonRPCError()
{{
Code = (int)JsonRPC.Protocol.EErrorCode.InvalidParam,
Message = ""req Parameters {method.Parameters.Count}"",
}}
}};
}}
else
{{
");
bool returnVoid = method.ReturnType.Name.EndsWith("Void");
if (!returnVoid)
{
bodyBuilder.Append("var ret =");
}
bodyBuilder.Append($"{method.Name}(");
for (int i = 0; i < method.Parameters.Count; i++)
{
var param = method.Parameters[i];
bodyBuilder.Append($"({param.ParameterType.FullName})req.Params[{i}]");
parameterBuilder.Append(param.ParameterType.FullName);
parameterBuilder.Append(' ');
parameterBuilder.Append(param.Name);
if (param != method.Parameters[^1])
{
parameterBuilder.Append(',');
bodyBuilder.Append(',');
}
}
bodyBuilder.Append(");");
if (returnVoid)
{
bodyBuilder.AppendLine(@"
return new JsonRPC.Protocol.JsonRPCResponse()
{
Id = req.Id,
};
");
}
else
{
bodyBuilder.AppendLine(@$"
return new JsonRPC.Protocol.JsonRPCResponse<{method.ReturnType.FullName}>()
{{
Id = req.Id,
Result = ret,
}};
");
}
bodyBuilder.AppendLine($@"
}}
}}
public static {(returnVoid ? "void" : method.ReturnType.FullName)} {method.Name}({parameterBuilder.ToString()}){{ return {(returnVoid ? "" : "default")};}}
");
var tempFunc = CompileFunc(ms, bodyBuilder.ToString(), "JsonRPC.Protocol.JsonRPCResponse");
ReplaceMethodBody(tempFunc, handleMethod);
for (int i = 0; i < handleMethod.Body.Instructions.Count; i++)
{
var ins = handleMethod.Body.Instructions[i];
if (ins.OpCode.Code == Code.Call && ins.Operand is MethodReference mr)
{
if (mr.Name.Equals(method.Name))
{
ins.Operand = moduleDefinition.ImportReference(method);
}
}
}
AddRPCRegMethod(type, handleMethod);
}
private void AddRPCRegMethod(TypeDefinition type, MethodDefinition method)
{
var name = method.Name +"_Reg";
var regMethod = new MethodDefinition(name, MethodAttributes.Static | MethodAttributes.Public, moduleDefinition.ImportReference(typeof(void)));
type.Methods.Add(regMethod);
var body = regMethod.Body;
using var ms = new MemoryStream();
StringBuilder bodyBuilder = new();
bodyBuilder.AppendLine(@$"
()
{{
var fn = new Func<JsonRPC.Protocol.JsonRPCRequest, JsonRPC.Protocol.JsonRPCResponse>({method.Name});
JsonRPC.RPC.Receiver.Instance.RegHandler(""{method.Name.Replace("_Handle", "")}"", fn);
}}
public static JsonRPC.Protocol.JsonRPCResponse {method.Name}(JsonRPC.Protocol.JsonRPCRequest req){{ return null;}}
");
var tempFunc = CompileFunc(ms, bodyBuilder.ToString(), "Void");
ReplaceMethodBody(tempFunc, regMethod);
for (int i = 0; i < regMethod.Body.Instructions.Count; i++)
{
var ins = regMethod.Body.Instructions[i];
if (ins.OpCode.Code == Code.Ldftn && ins.Operand is MethodReference mr)
{
if (mr.Name.Equals(method.Name))
{
ins.Operand = moduleDefinition.ImportReference(method);
}
}
}
AddModuleInitializerMethod(regMethod);
}
private void ReplaceMethodBody(MethodDefinition srcMethod, MethodDefinition dstMethod)
{
var body = dstMethod.Body;
body.Instructions.Clear();
body.Variables.Clear();
srcMethod.Body.SimplifyMacros();
body.MaxStackSize = srcMethod.Body.MaxStackSize;
body.InitLocals = srcMethod.Body.InitLocals;
body.LocalVarToken = srcMethod.Body.LocalVarToken;
foreach (var item in srcMethod.Body.Instructions)
{
if (item.OpCode.Code == Code.Castclass && item.Operand is TypeReference tr)
{
var instruction = Instruction.Create(item.OpCode, moduleDefinition.ImportReference(tr));
instruction.Offset = item.Offset;
body.Instructions.Add(instruction);
}
else if (item.Operand is MethodReference mr)
{
var newMr = moduleDefinition.ImportReference(mr);
var instruction = Instruction.Create(item.OpCode, newMr);
instruction.Offset = item.Offset;
body.Instructions.Add(instruction);
}
else
{
body.Instructions.Add(item);
}
}
for (int i = 0; i < srcMethod.Body.Variables.Count; i++)
{
var localVar = srcMethod.Body.Variables[i];
localVar.VariableType = moduleDefinition.ImportReference(localVar.VariableType);
body.Variables.Add(localVar);
}
body.Optimize();
}
private void ProcessMethod(TypeDefinition type, MethodDefinition method)
{
bool hasTargetAttribute = method.ContainsTargetAttribute();
if (method.IsYield())
{
if (hasTargetAttribute)
{
logger.Invoke("Could not process '" + method.FullName + "' since methods that yield are currently not supported. Please remove the [Time] attribute from that method.");
return;
}
logger.Invoke("Skipping '" + method.FullName + "' since methods that yield are not supported.");
return;
}
if (!method.IsStatic)
{
logger.Invoke("Skipping '" + method.FullName + "' non static methods are not supported.");
return;
}
if (method.IsAsync() && hasTargetAttribute)
{
logger.Invoke("Could not process '" + method.FullName + "' async methods are not supported.");
return;
}
var realMethod = AddRPCRealMethod(type, method);
AddRPCHandleMethod(type, realMethod);
StringBuilder parameterBuilder = new();
StringBuilder bodyBuilder = new();
bool returnVoid = method.ReturnType.Name.EndsWith("Void");
parameterBuilder.Append('(');
bodyBuilder.AppendLine("{");
bodyBuilder.AppendLine($"JsonRPC.Protocol.JsonRPCRequest req = new(\"{realMethod.Name}\", JsonRPC.RPC.Sender.Instance.GetId());");
foreach (var item in method.Parameters)
{
parameterBuilder.Append(item.ParameterType.FullName);
parameterBuilder.Append(' ');
parameterBuilder.Append(item.Name);
if (item != method.Parameters[^1])
{
parameterBuilder.Append(",");
}
bodyBuilder.AppendLine($"req.AddParam({item.Name});");
}
parameterBuilder.Append(')');
if(returnVoid)
{
bodyBuilder.AppendLine(@$"
var json = JsonRPC.RPC.Sender.Instance.Send(req);
var res = JsonRPC.RPC.Sender.Instance.JsonDeserialize<JsonRPC.Protocol.JsonRPCResponse>(json);
if(!JsonRPC.RPC.Sender.Instance.HnadleResponseError(res))
{{
}}
return;
}}");
}
else
{
bodyBuilder.AppendLine(@$"
var json = JsonRPC.RPC.Sender.Instance.Send(req);
var res = JsonRPC.RPC.Sender.Instance.JsonDeserialize<JsonRPC.Protocol.JsonRPCResponse<{method.ReturnType.FullName}>>(json);
if(!JsonRPC.RPC.Sender.Instance.HnadleResponseError(res))
{{
return default;
}}
return res.Result;
}}");
}
using var ms = new MemoryStream();
var sigFunc = CompileFunc(ms, parameterBuilder.AppendLine(bodyBuilder.ToString()).ToString(), method.ReturnType.FullName);
ReplaceMethodBody(sigFunc, method);
}
void RemoveAttributes()
{
moduleDefinition.Assembly.RemoveTargetAttribute();
foreach (var typeDefinition in types)
{
typeDefinition.RemoveTargetAttribute();
foreach (var method in typeDefinition.Methods)
{
method.RemoveTargetAttribute();
}
}
}
}
}