LLVMのswitchを解析してみたかった。

rutilicus.hatenablog.com
これの続きです。改めて解析対象のソースコードなど。

これから出力したLLVM IRも。

で、これを解析しようとしたコードがこれです。

#include <iostream>
#include <map>
#include <string>
#include <list>
#include <set>
#include <algorithm>

#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Instruction.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IRReader/IRReader.h>
#include <llvm/Support/SourceMgr.h>

int main(int argc, char** argv) {
    llvm::LLVMContext context;
    llvm::SMDiagnostic err;
    std::unique_ptr<llvm::Module> module = llvm::parseIRFile("aft.ll", err, context);

    if (!module) {
        err.print(argv[0], llvm::errs());
        return 1;
    }

    for (llvm::Function &function: module->getFunctionList()) {
        if (!function.isDeclaration()) {
            // Condition Id -> Condition String
            std::map<int, std::string> condIdStrMap;
            // Condition Id
            int condId = 0;
            // BasicBlock Name -> Block Condition Ids
            std::map<std::string, std::set<int>> blockCondsMap;
            // Condition sets derived from same origin Conditional Syntax
            std::list<std::set<int>> condSetSameOrigin;

            for (llvm::BasicBlock &block: function) {
                std::string currentBlockName = block.getName().data();

                // remove unnecessary block conditions
                if (blockCondsMap.find(currentBlockName) == blockCondsMap.end()) {
                    blockCondsMap[currentBlockName] = std::set<int>();
                }
                std::list<int> allBlockConditions;
                for (int i: blockCondsMap[currentBlockName]) {
                    allBlockConditions.push_back(i);
                }
                for (int i: allBlockConditions) {
                    for (std::set<int> s: condSetSameOrigin) {
                        if (s.find(i) != s.end() &&
                            std::all_of(s.begin(), s.end(),
                                        [&](int x) { return blockCondsMap[currentBlockName].find(x) !=
                                                     blockCondsMap[currentBlockName].end(); }))  {
                            std::for_each(s.begin(), s.end(), 
                                          [&](int x) { blockCondsMap[currentBlockName].erase(x); });
                        }
                    }
                }
                std::string blockConditionStr;
                if (blockCondsMap[currentBlockName].empty()) {
                    blockConditionStr = "N/A";
                } else {
                    for (int i: blockCondsMap[currentBlockName]) {
                        blockConditionStr += condIdStrMap[i] + "OR";
                    }
                    blockConditionStr.resize(blockConditionStr.length() - 2);
                }

                for (llvm::Instruction &instr: block) {
                    switch (instr.getOpcode()) {
                        case llvm::Instruction::Switch: {
                                llvm::SwitchInst &switchInst = 
                                    static_cast<llvm::SwitchInst &>(instr);
                                std::list<int> condIds;
                                for (auto &llvmCase: switchInst.cases()) {
                                    // cases
                                    std::string condStr = "(";
                                    condStr += switchInst.getOperand(0)->getName().data();
                                    condStr += "=";
                                    condStr += llvmCase.getCaseValue()->getZExtValue();
                                    condStr += + ")";
                                    condIdStrMap[condId] = condStr;
                                    condIds.push_back(condId);
                                    std::string blockName = llvmCase.getCaseSuccessor()->getName().data();
                                    if (blockCondsMap.find(blockName) != blockCondsMap.end()) {
                                        blockCondsMap[blockName].insert(condId);
                                    } else {
                                        blockCondsMap[blockName] = std::set<int>{condId};
                                    }
                                    condId++;
                                }
                                {
                                    // default
                                    std::string condStr = "(NOT(";
                                    for (int i: condIds) {
                                        condStr += condIdStrMap[i];
                                        condStr += "OR";
                                    }
                                    condStr.resize(condStr.length() - 2);
                                    condStr += "))";
                                    condIdStrMap[condId] = condStr;
                                    condIds.push_back(condId);
                                    std::string blockName = switchInst.getDefaultDest()->getName().data();
                                    if (blockCondsMap.find(blockName) != blockCondsMap.end()) {
                                        blockCondsMap[blockName].insert(condId);
                                    } else {
                                        blockCondsMap[blockName] = std::set<int>{condId};
                                    }
                                    condId++;
                                }
                            }
                            break;
                        case llvm::Instruction::Call: {
                                llvm::CallBase &callBase =
                                    static_cast<llvm::CallBase &>(instr);
                                std::cout << blockConditionStr << "\t" <<
                                    callBase.getCalledFunction()->getName().data() << 
                                    std::endl;
                            }
                            break;
                        default:
                            break;
                    }
                }
            }
        }
    }

    return 0;
}

ネストしたswitchには対応していません。ちょっといじれば対応できる作りにはしてありますが。
で、実行結果はこんな感じです。

N/A	getParam
(=x01)	hoge
(=x02)	newHoge
(NOT((=x01)OR(=x02)))	piyo
(NOT((=x01)OR(=x02)))	poyo
N/A	foo
N/A	bar

作りは正直かなり雑です。switchで分岐する各BasicBlockごとに条件を割り振っているのですが、その合流後の消し方も同一制御構文から発生した条件がすべてそろっている場合は消す、という方法をとっています。条件の書き方もデータ構造を用意するのではなくstringでベタベタに書いています。そして、見て分かりますが判断に使用したオペランドの名前を取れていません。これはどうやらそういうもののようです。(LLVM get operand and lvalue name of an instruction - Stack Overflow)

今後についてはこれを継続するかどうか悩んでいます。これ以上進むとなるとまぁ、全Instructionの分析が必要になりますよね……私の実業務で困ったことではあるものの、そこまで労力をつぎこんでいいものかどうか……

  • 参考にしました

tomo-wait-for-it-yuki.hatenablog.com