[WRAPPER HELPER] Improve Record Parsing (#880)

This commit is contained in:
wannacu 2023-07-12 15:43:08 +08:00 committed by GitHub
parent a6b231ce56
commit b1c09acb0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 75 additions and 29 deletions

View File

@ -25,11 +25,15 @@
#include "gen.h" #include "gen.h"
#include "utils.h" #include "utils.h"
static void ParseParameter(clang::ASTContext* AST, WrapperGenerator* Gen, clang::ParmVarDecl* Decl, FuncInfo* Func) { static void ParseParameter(clang::ASTContext* AST, WrapperGenerator* Gen, clang::QualType ParmType, FuncInfo* Func) {
using namespace clang; using namespace clang;
(void)AST; (void)Func; (void)AST; (void)Func;
auto ParmType = Decl->getType(); if (ParmType->isFunctionPointerType()) {
if (ParmType->isPointerType()) { auto ProtoType = ParmType->getPointeeType()->getAs<FunctionProtoType>();
for (unsigned i = 0; i < ProtoType->getNumParams(); i++) {
ParseParameter(AST, Gen, ProtoType->getParamType(i), Func);
}
} else if (ParmType->isPointerType()) {
auto PointeeType = ParmType->getPointeeType(); auto PointeeType = ParmType->getPointeeType();
if (PointeeType->isRecordType()) { if (PointeeType->isRecordType()) {
if (Gen->records.find(StripTypedef(PointeeType)) == Gen->records.end()) { if (Gen->records.find(StripTypedef(PointeeType)) == Gen->records.end()) {
@ -91,7 +95,7 @@ static void ParseFunction(clang::ASTContext* AST, WrapperGenerator* Gen, clang::
} else { } else {
FuncInfo->callback_args[i] = nullptr; FuncInfo->callback_args[i] = nullptr;
} }
ParseParameter(AST, Gen, ParmDecl, FuncInfo); ParseParameter(AST, Gen, ParmDecl->getType(), FuncInfo);
} }
} }

View File

@ -289,6 +289,7 @@ std::string WrapperGenerator::GenDeclare(ASTContext *Ctx,
const RecordInfo &Record) { const RecordInfo &Record) {
(void)Ctx; (void)Ctx;
std::string RecordStr; std::string RecordStr;
std::string PreDecl;
RecordStr += "\ntypedef "; RecordStr += "\ntypedef ";
RecordStr += RecordStr +=
(Record.is_union ? "union " : "struct ") + Record.type_name + " {\n"; (Record.is_union ? "union " : "struct ") + Record.type_name + " {\n";
@ -327,10 +328,10 @@ std::string WrapperGenerator::GenDeclare(ASTContext *Ctx,
FieldStr += Name; FieldStr += Name;
RecordStr += FieldStr; RecordStr += FieldStr;
} else { } else {
RecordStr += TypeStringify(StripTypedef(Type), Field, nullptr); RecordStr += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
} }
} else { } else {
RecordStr += TypeStringify(StripTypedef(Type), Field, nullptr); RecordStr += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
} }
RecordStr += ";\n"; RecordStr += ";\n";
} }
@ -547,6 +548,7 @@ std::string WrapperGenerator::GenDeclareDiffTriple(
(void)Ctx; (void)Ctx;
std::string GuestRecord; std::string GuestRecord;
std::string HostRecord; std::string HostRecord;
std::string PreDecl;
std::vector<uint64_t> GuestFieldOff; std::vector<uint64_t> GuestFieldOff;
std::vector<uint64_t> HostFieldOff; std::vector<uint64_t> HostFieldOff;
GuestRecord += "typedef "; GuestRecord += "typedef ";
@ -599,7 +601,7 @@ std::string WrapperGenerator::GenDeclareDiffTriple(
std::cout << "Err: unknown type size " << typeSize << std::endl; std::cout << "Err: unknown type size " << typeSize << std::endl;
break; break;
} }
HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr); HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
} else if (Type->isFunctionPointerType()) { } else if (Type->isFunctionPointerType()) {
auto FuncType = StripTypedef(Type->getPointeeType()); auto FuncType = StripTypedef(Type->getPointeeType());
if (callbacks.count(FuncType)) { if (callbacks.count(FuncType)) {
@ -634,12 +636,12 @@ std::string WrapperGenerator::GenDeclareDiffTriple(
GuestRecord += FieldStr; GuestRecord += FieldStr;
HostRecord += "host_" + FieldStr; HostRecord += "host_" + FieldStr;
} else { } else {
GuestRecord += TypeStringify(StripTypedef(Type), Field, nullptr); GuestRecord += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr); HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
} }
} else { } else {
HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr); HostRecord += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
GuestRecord += TypeStringify(StripTypedef(Type), Field, nullptr); GuestRecord += TypeStringify(StripTypedef(Type), Field, nullptr, PreDecl);
} }
GuestRecord += ";\n"; GuestRecord += ";\n";
HostRecord += ";\n"; HostRecord += ";\n";
@ -685,12 +687,12 @@ std::string WrapperGenerator::GenRecordConvert(const RecordInfo &Record) {
if (!AlignDiffFields.size()) { if (!AlignDiffFields.size()) {
return res; return res;
} }
res += "void g2h_" + Record.type_name + "(" + "struct host_" + Record.type_name + "*d, struct" + Record.type_name + "*s) {\n"; res += "void g2h_" + Record.type_name + "(" + "struct host_" + Record.type_name + " *d, struct " + Record.type_name + " *s) {\n";
std::string body = " memcpy(d, s, offsetof(struct " + Record.type_name + std::string body = " memcpy(d, s, offsetof(struct " + Record.type_name +
", " + AlignDiffFields[0]->getNameAsString() + "));\n"; ", " + AlignDiffFields[0]->getNameAsString() + "));\n";
std::string offstr = "offsetof(struct " + Record.type_name + ", " + std::string offstr = "offsetof(struct " + Record.type_name + ", " +
AlignDiffFields[0]->getNameAsString() + ")"; AlignDiffFields[0]->getNameAsString() + ")";
for (size_t i = 1; i < AlignDiffFields.size() - 1; i++) { for (size_t i = 0; i < AlignDiffFields.size() - 1; i++) {
body += " memcpy(d->" + AlignDiffFields[i]->getNameAsString() + ", " + body += " memcpy(d->" + AlignDiffFields[i]->getNameAsString() + ", " +
"s->" + AlignDiffFields[i]->getNameAsString() + ", " + "s->" + AlignDiffFields[i]->getNameAsString() + ", " +
"offsetof(struct " + Record.type_name + ", " + "offsetof(struct " + Record.type_name + ", " +
@ -707,7 +709,7 @@ std::string WrapperGenerator::GenRecordConvert(const RecordInfo &Record) {
" - " + offstr + ");\n"; " - " + offstr + ");\n";
res += body + "}\n"; res += body + "}\n";
res += "void h2g_" + Record.type_name + "(struct" + Record.type_name + "*d, " + "struct host_" + Record.type_name + "*s) {\n"; res += "void h2g_" + Record.type_name + "(struct " + Record.type_name + " *d, " + "struct host_" + Record.type_name + " *s) {\n";
res += body; res += body;
res += "}\n"; res += "}\n";
} }
@ -775,6 +777,7 @@ void WrapperGenerator::ParseRecordRecursive(
std::string WrapperGenerator::TypeStringify(const Type *Type, std::string WrapperGenerator::TypeStringify(const Type *Type,
FieldDecl *FieldDecl, FieldDecl *FieldDecl,
ParmVarDecl *ParmDecl, ParmVarDecl *ParmDecl,
std::string& PreDecl,
std::string indent, std::string indent,
std::string Name) { std::string Name) {
std::string res; std::string res;
@ -807,7 +810,7 @@ std::string WrapperGenerator::TypeStringify(const Type *Type,
res += records[StripTypedef(Type->getCanonicalTypeInternal())].type_name; res += records[StripTypedef(Type->getCanonicalTypeInternal())].type_name;
res += " "; res += " ";
} else { } else {
res += AnonRecordDecl(Type->getAs<RecordType>(), indent); res += AnonRecordDecl(Type->getAs<RecordType>(), PreDecl, indent + " ");
} }
res += name; res += name;
} else if (Type->isConstantArrayType()) { } else if (Type->isConstantArrayType()) {
@ -816,6 +819,21 @@ std::string WrapperGenerator::TypeStringify(const Type *Type,
int EleSize = ArrayType->getSize().getZExtValue(); int EleSize = ArrayType->getSize().getZExtValue();
if (ArrayType->getElementType()->isPointerType()) { if (ArrayType->getElementType()->isPointerType()) {
res += "void *"; res += "void *";
} else if (ArrayType->getElementType()->isEnumeralType()) {
res += "int ";
} else if (ArrayType->getElementType()->isRecordType()) {
auto RecordType = ArrayType->getElementType()->getAs<clang::RecordType>();
auto RecordDecl = RecordType->getDecl();
if (RecordDecl->isCompleteDefinition()) {
auto& Ctx = RecordDecl->getDeclContext()->getParentASTContext();
PreDecl += "#include \"";
PreDecl += GetDeclHeaderFile(Ctx, RecordDecl);
PreDecl += "\"";
PreDecl += "\n";
}
res += StripTypedef(ArrayType->getElementType())
->getCanonicalTypeInternal()
.getAsString();
} else { } else {
res += StripTypedef(ArrayType->getElementType()) res += StripTypedef(ArrayType->getElementType())
->getCanonicalTypeInternal() ->getCanonicalTypeInternal()
@ -887,13 +905,13 @@ std::string WrapperGenerator::SimpleTypeStringify(const Type *Type,
return indent + res; return indent + res;
} }
std::string WrapperGenerator::AnonRecordDecl(const RecordType *Type, std::string indent) { std::string WrapperGenerator::AnonRecordDecl(const RecordType *Type, std::string& PreDecl, std::string indent) {
auto RecordDecl = Type->getDecl(); auto RecordDecl = Type->getDecl();
std::string res; std::string res;
res += Type->isUnionType() ? "union {\n" : "struct {\n"; res += Type->isUnionType() ? "union {\n" : "struct {\n";
for (const auto &field : RecordDecl->fields()) { for (const auto &field : RecordDecl->fields()) {
auto FieldType = field->getType(); auto FieldType = field->getType();
res += TypeStringify(StripTypedef(FieldType), field, nullptr, indent + " "); res += TypeStringify(StripTypedef(FieldType), field, nullptr, PreDecl, indent + " ");
res += ";\n"; res += ";\n";
} }
res += indent + "} "; res += indent + "} ";
@ -907,7 +925,7 @@ WrapperGenerator::SimpleAnonRecordDecl(const RecordType *Type, std::string inden
res += Type->isUnionType() ? "union {\n" : "struct {\n"; res += Type->isUnionType() ? "union {\n" : "struct {\n";
for (const auto &field : RecordDecl->fields()) { for (const auto &field : RecordDecl->fields()) {
auto FieldType = field->getType(); auto FieldType = field->getType();
res += SimpleTypeStringify(StripTypedef(FieldType), field, nullptr, indent + " "); res += SimpleTypeStringify(StripTypedef(FieldType), field, nullptr, indent + " ");
res += ";\n"; res += ";\n";
} }
res += indent + "} "; res += indent + "} ";
@ -917,15 +935,16 @@ WrapperGenerator::SimpleAnonRecordDecl(const RecordType *Type, std::string inden
// Get func info from FunctionType // Get func info from FunctionType
FuncDefinition WrapperGenerator::GetFuncDefinition(const Type *Type) { FuncDefinition WrapperGenerator::GetFuncDefinition(const Type *Type) {
FuncDefinition res; FuncDefinition res;
std::string PreDecl;
auto ProtoType = Type->getAs<FunctionProtoType>(); auto ProtoType = Type->getAs<FunctionProtoType>();
res.ret = StripTypedef(ProtoType->getReturnType()); res.ret = StripTypedef(ProtoType->getReturnType());
res.ret_str = res.ret_str =
TypeStringify(StripTypedef(ProtoType->getReturnType()), nullptr, nullptr); TypeStringify(StripTypedef(ProtoType->getReturnType()), nullptr, nullptr, PreDecl);
for (unsigned i = 0; i < ProtoType->getNumParams(); i++) { for (unsigned i = 0; i < ProtoType->getNumParams(); i++) {
auto ParamType = ProtoType->getParamType(i); auto ParamType = ProtoType->getParamType(i);
res.arg_types.push_back(StripTypedef(ParamType)); res.arg_types.push_back(StripTypedef(ParamType));
res.arg_types_str.push_back( res.arg_types_str.push_back(
TypeStringify(StripTypedef(ParamType), nullptr, nullptr)); TypeStringify(StripTypedef(ParamType), nullptr, nullptr, PreDecl));
res.arg_names.push_back(std::string("a") + std::to_string(i)); res.arg_names.push_back(std::string("a") + std::to_string(i));
} }
if (ProtoType->isVariadic()) { if (ProtoType->isVariadic()) {
@ -938,15 +957,16 @@ FuncDefinition WrapperGenerator::GetFuncDefinition(const Type *Type) {
// Get funcdecl info from FunctionDecl // Get funcdecl info from FunctionDecl
FuncDefinition WrapperGenerator::GetFuncDefinition(FunctionDecl *Decl) { FuncDefinition WrapperGenerator::GetFuncDefinition(FunctionDecl *Decl) {
FuncDefinition res; FuncDefinition res;
std::string PreDecl;
auto RetType = Decl->getReturnType(); auto RetType = Decl->getReturnType();
res.ret = RetType.getTypePtr(); res.ret = RetType.getTypePtr();
res.ret_str = TypeStringify(StripTypedef(RetType), nullptr, nullptr); res.ret_str = TypeStringify(StripTypedef(RetType), nullptr, nullptr, PreDecl);
for (unsigned i = 0; i < Decl->getNumParams(); i++) { for (unsigned i = 0; i < Decl->getNumParams(); i++) {
auto ParamDecl = Decl->getParamDecl(i); auto ParamDecl = Decl->getParamDecl(i);
auto ParamType = ParamDecl->getType(); auto ParamType = ParamDecl->getType();
res.arg_types.push_back(ParamType.getTypePtr()); res.arg_types.push_back(ParamType.getTypePtr());
res.arg_types_str.push_back( res.arg_types_str.push_back(
TypeStringify(StripTypedef(ParamType), nullptr, nullptr)); TypeStringify(StripTypedef(ParamType), nullptr, nullptr, PreDecl));
res.arg_names.push_back(ParamDecl->getNameAsString()); res.arg_names.push_back(ParamDecl->getNameAsString());
} }
if (Decl->isVariadic()) { if (Decl->isVariadic()) {
@ -960,7 +980,8 @@ std::vector<uint64_t> WrapperGenerator::GetRecordFieldOffDiff(
const Type *Type, const std::string &GuestTriple, const Type *Type, const std::string &GuestTriple,
const std::string &HostTriple, std::vector<uint64_t> &GuestFieldOff, const std::string &HostTriple, std::vector<uint64_t> &GuestFieldOff,
std::vector<uint64_t> &HostFieldOff) { std::vector<uint64_t> &HostFieldOff) {
std::string Code = TypeStringify(Type, nullptr, nullptr, "", "dummy;"); std::string PreDecl;
std::string Code = TypeStringify(Type, nullptr, nullptr, PreDecl, "", "dummy;");
std::vector<uint64_t> OffsetDiff; std::vector<uint64_t> OffsetDiff;
GuestFieldOff = GetRecordFieldOff(Code, GuestTriple); GuestFieldOff = GetRecordFieldOff(Code, GuestTriple);
HostFieldOff = GetRecordFieldOff(Code, HostTriple); HostFieldOff = GetRecordFieldOff(Code, HostTriple);
@ -978,15 +999,18 @@ std::vector<uint64_t> WrapperGenerator::GetRecordFieldOffDiff(
// Get the size under a specific triple // Get the size under a specific triple
uint64_t WrapperGenerator::GetRecordSize(const Type *Type, uint64_t WrapperGenerator::GetRecordSize(const Type *Type,
const std::string &Triple) { const std::string &Triple) {
std::string Code = TypeStringify(Type, nullptr, nullptr, "", "dummy;"); std::string PreDecl;
return ::GetRecordSize(Code, Triple); std::string Code = TypeStringify(Type, nullptr, nullptr, PreDecl, "", "dummy;");
auto Size = ::GetRecordSize(PreDecl + Code, Triple);
return Size;
} }
// Get the align under a specific triple // Get the align under a specific triple
CharUnits::QuantityType WrapperGenerator::GetRecordAlign(const Type *Type, CharUnits::QuantityType WrapperGenerator::GetRecordAlign(const Type *Type,
const std::string &Triple) { const std::string &Triple) {
std::string Code = TypeStringify(Type, nullptr, nullptr, "", "dummy;"); std::string PreDecl{};
return ::GetRecordAlign(Code, Triple); std::string Code = TypeStringify(Type, nullptr, nullptr, PreDecl, "", "dummy;");
return ::GetRecordAlign(PreDecl + Code, Triple);
} }
// Generate the func sig by type, used for export func // Generate the func sig by type, used for export func

View File

@ -129,9 +129,9 @@ private:
std::string GenCallbackWrap(clang::ASTContext* Ctx, const RecordInfo& Struct); std::string GenCallbackWrap(clang::ASTContext* Ctx, const RecordInfo& Struct);
void ParseRecordRecursive(clang::ASTContext* Ctx, const clang::Type* Type, bool& Special, std::set<const clang::Type*>& Visited); void ParseRecordRecursive(clang::ASTContext* Ctx, const clang::Type* Type, bool& Special, std::set<const clang::Type*>& Visited);
std::string TypeStringify(const clang::Type* Type, clang::FieldDecl* FieldDecl, clang::ParmVarDecl* ParmDecl, std::string indent = "", std::string Name = ""); std::string TypeStringify(const clang::Type* Type, clang::FieldDecl* FieldDecl, clang::ParmVarDecl* ParmDecl, std::string& PreDecl, std::string indent = "", std::string Name = "");
std::string SimpleTypeStringify(const clang::Type* Type, clang::FieldDecl* FieldDecl, clang::ParmVarDecl* ParmDecl, std::string indent = "", std::string Name = ""); std::string SimpleTypeStringify(const clang::Type* Type, clang::FieldDecl* FieldDecl, clang::ParmVarDecl* ParmDecl, std::string indent = "", std::string Name = "");
std::string AnonRecordDecl(const clang::RecordType* Type, std::string indent); std::string AnonRecordDecl(const clang::RecordType* Type, std::string& PreDecl, std::string indent);
std::string SimpleAnonRecordDecl(const clang::RecordType* Type, std::string indent); std::string SimpleAnonRecordDecl(const clang::RecordType* Type, std::string indent);
FuncDefinition GetFuncDefinition(const clang::Type* Type); FuncDefinition GetFuncDefinition(const clang::Type* Type);
FuncDefinition GetFuncDefinition(clang::FunctionDecl* Decl); FuncDefinition GetFuncDefinition(clang::FunctionDecl* Decl);

View File

@ -59,5 +59,11 @@ int main(int argc, const char* argv[]) {
std::string err; std::string err;
auto compile_db = clang::tooling::FixedCompilationDatabase::loadFromCommandLine(argc, argv, err); auto compile_db = clang::tooling::FixedCompilationDatabase::loadFromCommandLine(argc, argv, err);
clang::tooling::ClangTool Tool(*compile_db, {argv[1]}); clang::tooling::ClangTool Tool(*compile_db, {argv[1]});
Tool.appendArgumentsAdjuster([&guest_triple](const clang::tooling::CommandLineArguments &args, clang::StringRef) {
clang::tooling::CommandLineArguments adjusted_args = args;
adjusted_args.push_back(std::string{"-target"});
adjusted_args.push_back(guest_triple);
return adjusted_args;
});
return Tool.run(std::make_unique<MyFrontendActionFactory>(libname, host_triple, guest_triple).get()); return Tool.run(std::make_unique<MyFrontendActionFactory>(libname, host_triple, guest_triple).get());
} }

View File

@ -1,4 +1,5 @@
#pragma once #pragma once
#include <clang/AST/ASTContext.h>
#include <clang/AST/Decl.h> #include <clang/AST/Decl.h>
#include <clang/AST/Type.h> #include <clang/AST/Type.h>
#include <clang/Tooling/Tooling.h> #include <clang/Tooling/Tooling.h>
@ -19,3 +20,14 @@ static const clang::Type* StripTypedef(clang::QualType type) {
return type.getTypePtr(); return type.getTypePtr();
} }
} }
// FIXME: Need to support other triple except default target triple
static std::string GetDeclHeaderFile(clang::ASTContext& Ctx, clang::Decl* Decl) {
const auto& SourceManager = Ctx.getSourceManager();
const clang::FileID FileID = SourceManager.getFileID(Decl->getBeginLoc());
const clang::FileEntry *FileEntry = SourceManager.getFileEntryForID(FileID);
if (FileEntry) {
return FileEntry->getName().str();
}
return "";
}