//===- AArch64SLSHardening.cpp - Harden Straight Line Missspeculation -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains a pass to insert code to mitigate against side channel
// vulnerabilities that may happen under straight line miss-speculation.
//
//===----------------------------------------------------------------------===//

#include "AArch64InstrInfo.h"
#include "AArch64Subtarget.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/CodeGen/IndirectThunks.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/RegisterScavenging.h"
#include "llvm/IR/DebugLoc.h"
#include "llvm/Pass.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Target/TargetMachine.h"
#include <cassert>
#include <climits>
#include <tuple>

using namespace llvm;

#define DEBUG_TYPE "aarch64-sls-hardening"

#define AARCH64_SLS_HARDENING_NAME "AArch64 sls hardening pass"

// Common name prefix of all thunks generated by this pass.
//
// The generic form is
// __llvm_slsblr_thunk_xN            for BLR thunks
// __llvm_slsblr_thunk_(aaz|abz)_xN  for BLRAAZ and BLRABZ thunks
// __llvm_slsblr_thunk_(aa|ab)_xN_xM for BLRAA and BLRAB thunks
static constexpr StringRef CommonNamePrefix = "__llvm_slsblr_thunk_";

namespace {

struct ThunkKind {
  enum ThunkKindId {
    ThunkBR,
    ThunkBRAA,
    ThunkBRAB,
    ThunkBRAAZ,
    ThunkBRABZ,
  };

  ThunkKindId Id;
  StringRef NameInfix;
  bool HasXmOperand;
  bool NeedsPAuth;

  // Opcode to perform indirect jump from inside the thunk.
  unsigned BROpcode;

  static const ThunkKind BR;
  static const ThunkKind BRAA;
  static const ThunkKind BRAB;
  static const ThunkKind BRAAZ;
  static const ThunkKind BRABZ;
};

// Set of inserted thunks.
class ThunksSet {
public:
  static constexpr unsigned NumXRegisters = 32;

  // Given Xn register, returns n.
  static unsigned indexOfXReg(Register Xn);
  // Given n, returns Xn register.
  static Register xRegByIndex(unsigned N);

  ThunksSet &operator|=(const ThunksSet &Other) {
    BLRThunks |= Other.BLRThunks;
    BLRAAZThunks |= Other.BLRAAZThunks;
    BLRABZThunks |= Other.BLRABZThunks;
    for (unsigned I = 0; I < NumXRegisters; ++I)
      BLRAAThunks[I] |= Other.BLRAAThunks[I];
    for (unsigned I = 0; I < NumXRegisters; ++I)
      BLRABThunks[I] |= Other.BLRABThunks[I];

    return *this;
  }

  bool get(ThunkKind::ThunkKindId Kind, Register Xn, Register Xm) {
    reg_bitmask_t XnBit = reg_bitmask_t(1) << indexOfXReg(Xn);
    return getBitmask(Kind, Xm) & XnBit;
  }

  void set(ThunkKind::ThunkKindId Kind, Register Xn, Register Xm) {
    reg_bitmask_t XnBit = reg_bitmask_t(1) << indexOfXReg(Xn);
    getBitmask(Kind, Xm) |= XnBit;
  }

private:
  typedef uint32_t reg_bitmask_t;
  static_assert(NumXRegisters <= sizeof(reg_bitmask_t) * CHAR_BIT,
                "Bitmask is not wide enough to hold all Xn registers");

  // Bitmasks representing operands used, with n-th bit corresponding to Xn
  // register operand. If the instruction has a second operand (Xm), an array
  // of bitmasks is used, indexed by m.
  // Indexes corresponding to the forbidden x16, x17 and x30 registers are
  // always unset, for simplicity there are no holes.
  reg_bitmask_t BLRThunks = 0;
  reg_bitmask_t BLRAAZThunks = 0;
  reg_bitmask_t BLRABZThunks = 0;
  reg_bitmask_t BLRAAThunks[NumXRegisters] = {};
  reg_bitmask_t BLRABThunks[NumXRegisters] = {};

  reg_bitmask_t &getBitmask(ThunkKind::ThunkKindId Kind, Register Xm) {
    switch (Kind) {
    case ThunkKind::ThunkBR:
      return BLRThunks;
    case ThunkKind::ThunkBRAAZ:
      return BLRAAZThunks;
    case ThunkKind::ThunkBRABZ:
      return BLRABZThunks;
    case ThunkKind::ThunkBRAA:
      return BLRAAThunks[indexOfXReg(Xm)];
    case ThunkKind::ThunkBRAB:
      return BLRABThunks[indexOfXReg(Xm)];
    }
    llvm_unreachable("Unknown ThunkKindId enum");
  }
};

struct SLSHardeningInserter : ThunkInserter<SLSHardeningInserter, ThunksSet> {
public:
  const char *getThunkPrefix() { return CommonNamePrefix.data(); }
  bool mayUseThunk(const MachineFunction &MF) {
    ComdatThunks &= !MF.getSubtarget<AArch64Subtarget>().hardenSlsNoComdat();
    // We are inserting barriers aside from thunk calls, so
    // check hardenSlsRetBr() as well.
    return MF.getSubtarget<AArch64Subtarget>().hardenSlsBlr() ||
           MF.getSubtarget<AArch64Subtarget>().hardenSlsRetBr();
  }
  ThunksSet insertThunks(MachineModuleInfo &MMI, MachineFunction &MF,
                         ThunksSet ExistingThunks);
  void populateThunk(MachineFunction &MF);

private:
  bool ComdatThunks = true;

  bool hardenReturnsAndBRs(MachineModuleInfo &MMI, MachineBasicBlock &MBB);
  bool hardenBLRs(MachineModuleInfo &MMI, MachineBasicBlock &MBB,
                  ThunksSet &Thunks);

  void convertBLRToBL(MachineModuleInfo &MMI, MachineBasicBlock &MBB,
                      MachineBasicBlock::instr_iterator MBBI,
                      ThunksSet &Thunks);
};

} // end anonymous namespace

const ThunkKind ThunkKind::BR = {ThunkBR, "", /*HasXmOperand=*/false,
                                 /*NeedsPAuth=*/false, AArch64::BR};
const ThunkKind ThunkKind::BRAA = {ThunkBRAA, "aa_", /*HasXmOperand=*/true,
                                   /*NeedsPAuth=*/true, AArch64::BRAA};
const ThunkKind ThunkKind::BRAB = {ThunkBRAB, "ab_", /*HasXmOperand=*/true,
                                   /*NeedsPAuth=*/true, AArch64::BRAB};
const ThunkKind ThunkKind::BRAAZ = {ThunkBRAAZ, "aaz_", /*HasXmOperand=*/false,
                                    /*NeedsPAuth=*/true, AArch64::BRAAZ};
const ThunkKind ThunkKind::BRABZ = {ThunkBRABZ, "abz_", /*HasXmOperand=*/false,
                                    /*NeedsPAuth=*/true, AArch64::BRABZ};

// Returns thunk kind to emit, or nullptr if not a BLR* instruction.
static const ThunkKind *getThunkKind(unsigned OriginalOpcode) {
  switch (OriginalOpcode) {
  case AArch64::BLR:
  case AArch64::BLRNoIP:
    return &ThunkKind::BR;
  case AArch64::BLRAA:
    return &ThunkKind::BRAA;
  case AArch64::BLRAB:
    return &ThunkKind::BRAB;
  case AArch64::BLRAAZ:
    return &ThunkKind::BRAAZ;
  case AArch64::BLRABZ:
    return &ThunkKind::BRABZ;
  }
  return nullptr;
}

static bool isBLR(const MachineInstr &MI) {
  return getThunkKind(MI.getOpcode()) != nullptr;
}

unsigned ThunksSet::indexOfXReg(Register Reg) {
  assert(AArch64::GPR64RegClass.contains(Reg));
  assert(Reg != AArch64::X16 && Reg != AArch64::X17 && Reg != AArch64::LR);

  // Most Xn registers have consecutive ids, except for FP and XZR.
  unsigned Result = (unsigned)Reg - (unsigned)AArch64::X0;
  if (Reg == AArch64::FP)
    Result = 29;
  else if (Reg == AArch64::XZR)
    Result = 31;

  assert(Result < NumXRegisters && "Internal register numbering changed");
  assert(AArch64::GPR64RegClass.getRegister(Result).id() == Reg &&
         "Internal register numbering changed");

  return Result;
}

Register ThunksSet::xRegByIndex(unsigned N) {
  return AArch64::GPR64RegClass.getRegister(N);
}

static void insertSpeculationBarrier(const AArch64Subtarget *ST,
                                     MachineBasicBlock &MBB,
                                     MachineBasicBlock::iterator MBBI,
                                     DebugLoc DL,
                                     bool AlwaysUseISBDSB = false) {
  assert(MBBI != MBB.begin() &&
         "Must not insert SpeculationBarrierEndBB as only instruction in MBB.");
  assert(std::prev(MBBI)->isBarrier() &&
         "SpeculationBarrierEndBB must only follow unconditional control flow "
         "instructions.");
  assert(std::prev(MBBI)->isTerminator() &&
         "SpeculationBarrierEndBB must only follow terminators.");
  const TargetInstrInfo *TII = ST->getInstrInfo();
  unsigned BarrierOpc = ST->hasSB() && !AlwaysUseISBDSB
                            ? AArch64::SpeculationBarrierSBEndBB
                            : AArch64::SpeculationBarrierISBDSBEndBB;
  if (MBBI == MBB.end() ||
      (MBBI->getOpcode() != AArch64::SpeculationBarrierSBEndBB &&
       MBBI->getOpcode() != AArch64::SpeculationBarrierISBDSBEndBB))
    BuildMI(MBB, MBBI, DL, TII->get(BarrierOpc));
}

ThunksSet SLSHardeningInserter::insertThunks(MachineModuleInfo &MMI,
                                             MachineFunction &MF,
                                             ThunksSet ExistingThunks) {
  const AArch64Subtarget *ST = &MF.getSubtarget<AArch64Subtarget>();

  for (auto &MBB : MF) {
    if (ST->hardenSlsRetBr())
      hardenReturnsAndBRs(MMI, MBB);
    if (ST->hardenSlsBlr())
      hardenBLRs(MMI, MBB, ExistingThunks);
  }
  return ExistingThunks;
}

bool SLSHardeningInserter::hardenReturnsAndBRs(MachineModuleInfo &MMI,
                                               MachineBasicBlock &MBB) {
  const AArch64Subtarget *ST =
      &MBB.getParent()->getSubtarget<AArch64Subtarget>();
  bool Modified = false;
  MachineBasicBlock::iterator MBBI = MBB.getFirstTerminator(), E = MBB.end();
  MachineBasicBlock::iterator NextMBBI;
  for (; MBBI != E; MBBI = NextMBBI) {
    MachineInstr &MI = *MBBI;
    NextMBBI = std::next(MBBI);
    if (MI.isReturn() || isIndirectBranchOpcode(MI.getOpcode())) {
      assert(MI.isTerminator());
      insertSpeculationBarrier(ST, MBB, std::next(MBBI), MI.getDebugLoc());
      Modified = true;
    }
  }
  return Modified;
}

// Currently, the longest possible thunk name is
//   __llvm_slsblr_thunk_aa_xNN_xMM
// which is 31 characters (without the '\0' character).
static SmallString<32> createThunkName(const ThunkKind &Kind, Register Xn,
                                       Register Xm) {
  unsigned N = ThunksSet::indexOfXReg(Xn);
  if (!Kind.HasXmOperand)
    return formatv("{0}{1}x{2}", CommonNamePrefix, Kind.NameInfix, N);

  unsigned M = ThunksSet::indexOfXReg(Xm);
  return formatv("{0}{1}x{2}_x{3}", CommonNamePrefix, Kind.NameInfix, N, M);
}

static std::tuple<const ThunkKind &, Register, Register>
parseThunkName(StringRef ThunkName) {
  assert(ThunkName.starts_with(CommonNamePrefix) &&
         "Should be filtered out by ThunkInserter");
  // Thunk name suffix, such as "x1" or "aa_x2_x3".
  StringRef NameSuffix = ThunkName.drop_front(CommonNamePrefix.size());

  // Parse thunk kind based on thunk name infix.
  const ThunkKind &Kind = *StringSwitch<const ThunkKind *>(NameSuffix)
                               .StartsWith("aa_", &ThunkKind::BRAA)
                               .StartsWith("ab_", &ThunkKind::BRAB)
                               .StartsWith("aaz_", &ThunkKind::BRAAZ)
                               .StartsWith("abz_", &ThunkKind::BRABZ)
                               .Default(&ThunkKind::BR);

  auto ParseRegName = [](StringRef Name) {
    unsigned N;

    assert(Name.starts_with("x") && "xN register name expected");
    bool Fail = Name.drop_front(1).getAsInteger(/*Radix=*/10, N);
    assert(!Fail && N < ThunksSet::NumXRegisters && "Unexpected register");
    (void)Fail;

    return ThunksSet::xRegByIndex(N);
  };

  // For example, "x1" or "x2_x3".
  StringRef RegsStr = NameSuffix.drop_front(Kind.NameInfix.size());
  StringRef XnStr, XmStr;
  std::tie(XnStr, XmStr) = RegsStr.split('_');

  // Parse register operands.
  Register Xn = ParseRegName(XnStr);
  Register Xm = Kind.HasXmOperand ? ParseRegName(XmStr) : AArch64::NoRegister;

  return std::make_tuple(std::ref(Kind), Xn, Xm);
}

void SLSHardeningInserter::populateThunk(MachineFunction &MF) {
  assert(MF.getFunction().hasComdat() == ComdatThunks &&
         "ComdatThunks value changed since MF creation");
  Register Xn, Xm;
  auto KindAndRegs = parseThunkName(MF.getName());
  const ThunkKind &Kind = std::get<0>(KindAndRegs);
  std::tie(std::ignore, Xn, Xm) = KindAndRegs;

  const TargetInstrInfo *TII =
      MF.getSubtarget<AArch64Subtarget>().getInstrInfo();

  // Depending on whether this pass is in the same FunctionPassManager as the
  // IR->MIR conversion, the thunk may be completely empty, or contain a single
  // basic block with a single return instruction. Normalise it to contain a
  // single empty basic block.
  if (MF.size() == 1) {
    assert(MF.front().size() == 1);
    assert(MF.front().front().getOpcode() == AArch64::RET);
    MF.front().erase(MF.front().begin());
  } else {
    assert(MF.size() == 0);
    MF.push_back(MF.CreateMachineBasicBlock());
  }

  MachineBasicBlock *Entry = &MF.front();
  Entry->clear();

  //  These thunks need to consist of the following instructions:
  //  __llvm_slsblr_thunk_...:
  //      MOV x16, xN     ; BR* instructions are not compatible with "BTI c"
  //                      ; branch target unless xN is x16 or x17.
  //      BR* ...         ; One of: BR        x16
  //                      ;         BRA(A|B)  x16, xM
  //                      ;         BRA(A|B)Z x16
  //      barrierInsts
  Entry->addLiveIn(Xn);
  // MOV X16, Reg == ORR X16, XZR, Reg, LSL #0
  BuildMI(Entry, DebugLoc(), TII->get(AArch64::ORRXrs), AArch64::X16)
      .addReg(AArch64::XZR)
      .addReg(Xn)
      .addImm(0);
  MachineInstrBuilder Builder =
      BuildMI(Entry, DebugLoc(), TII->get(Kind.BROpcode)).addReg(AArch64::X16);
  if (Xm != AArch64::NoRegister) {
    Entry->addLiveIn(Xm);
    Builder.addReg(Xm);
  }

  // Make sure the thunks do not make use of the SB extension in case there is
  // a function somewhere that will call to it that for some reason disabled
  // the SB extension locally on that function, even though it's enabled for
  // the module otherwise. Therefore set AlwaysUseISBSDB to true.
  insertSpeculationBarrier(&MF.getSubtarget<AArch64Subtarget>(), *Entry,
                           Entry->end(), DebugLoc(), true /*AlwaysUseISBDSB*/);
}

void SLSHardeningInserter::convertBLRToBL(
    MachineModuleInfo &MMI, MachineBasicBlock &MBB,
    MachineBasicBlock::instr_iterator MBBI, ThunksSet &Thunks) {
  // Transform a BLR* instruction (one of BLR, BLRAA/BLRAB or BLRAAZ/BLRABZ) to
  // a BL to the thunk containing BR, BRAA/BRAB or BRAAZ/BRABZ, respectively.
  //
  // Before:
  //   |-----------------------------|
  //   |      ...                    |
  //   |  instI                      |
  //   |  BLR* xN or BLR* xN, xM     |
  //   |  instJ                      |
  //   |      ...                    |
  //   |-----------------------------|
  //
  // After:
  //   |-----------------------------|
  //   |      ...                    |
  //   |  instI                      |
  //   |  BL __llvm_slsblr_thunk_... |
  //   |  instJ                      |
  //   |      ...                    |
  //   |-----------------------------|
  //
  //   __llvm_slsblr_thunk_...:
  //   |-----------------------------|
  //   |  MOV x16, xN                |
  //   |  BR* x16 or BR* x16, xM     |
  //   |  barrierInsts               |
  //   |-----------------------------|
  //
  // This function needs to transform BLR* instruction into BL with the correct
  // thunk name and lazily create the thunk if it does not exist yet.
  //
  // Since linkers are allowed to clobber X16 and X17 on function calls, the
  // above mitigation only works if the original BLR* instruction had neither
  // X16 nor X17 as one of its operands. Code generation before must make sure
  // that no such BLR* instruction was produced if the mitigation is enabled.

  MachineInstr &BLR = *MBBI;
  assert(isBLR(BLR));
  const ThunkKind &Kind = *getThunkKind(BLR.getOpcode());

  unsigned NumRegOperands = Kind.HasXmOperand ? 2 : 1;
  assert(BLR.getNumExplicitOperands() == NumRegOperands &&
         "Expected one or two register inputs");
  Register Xn = BLR.getOperand(0).getReg();
  Register Xm =
      Kind.HasXmOperand ? BLR.getOperand(1).getReg() : AArch64::NoRegister;

  DebugLoc DL = BLR.getDebugLoc();

  MachineFunction &MF = *MBBI->getMF();
  MCContext &Context = MBB.getParent()->getContext();
  const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();

  auto ThunkName = createThunkName(Kind, Xn, Xm);
  MCSymbol *Sym = Context.getOrCreateSymbol(ThunkName);

  if (!Thunks.get(Kind.Id, Xn, Xm)) {
    StringRef TargetAttrs = Kind.NeedsPAuth ? "+pauth" : "";
    Thunks.set(Kind.Id, Xn, Xm);
    createThunkFunction(MMI, ThunkName, ComdatThunks, TargetAttrs);
  }

  MachineInstr *BL = BuildMI(MBB, MBBI, DL, TII->get(AArch64::BL)).addSym(Sym);

  // Now copy the implicit operands from BLR to BL and copy other necessary
  // info.
  // However, both BLR and BL instructions implictly use SP and implicitly
  // define LR. Blindly copying implicit operands would result in SP and LR
  // operands to be present multiple times. While this may not be too much of
  // an issue, let's avoid that for cleanliness, by removing those implicit
  // operands from the BL created above before we copy over all implicit
  // operands from the BLR.
  int ImpLROpIdx = -1;
  int ImpSPOpIdx = -1;
  for (unsigned OpIdx = BL->getNumExplicitOperands();
       OpIdx < BL->getNumOperands(); OpIdx++) {
    MachineOperand Op = BL->getOperand(OpIdx);
    if (!Op.isReg())
      continue;
    if (Op.getReg() == AArch64::LR && Op.isDef())
      ImpLROpIdx = OpIdx;
    if (Op.getReg() == AArch64::SP && !Op.isDef())
      ImpSPOpIdx = OpIdx;
  }
  assert(ImpLROpIdx != -1);
  assert(ImpSPOpIdx != -1);
  int FirstOpIdxToRemove = std::max(ImpLROpIdx, ImpSPOpIdx);
  int SecondOpIdxToRemove = std::min(ImpLROpIdx, ImpSPOpIdx);
  BL->removeOperand(FirstOpIdxToRemove);
  BL->removeOperand(SecondOpIdxToRemove);
  // Now copy over the implicit operands from the original BLR
  BL->copyImplicitOps(MF, BLR);
  MF.moveCallSiteInfo(&BLR, BL);
  // Also add the register operands of the original BLR* instruction
  // as being used in the called thunk.
  for (unsigned OpIdx = 0; OpIdx < NumRegOperands; ++OpIdx) {
    MachineOperand &Op = BLR.getOperand(OpIdx);
    BL->addOperand(MachineOperand::CreateReg(Op.getReg(), /*isDef=*/false,
                                             /*isImp=*/true, Op.isKill()));
  }
  // Remove BLR instruction
  MBB.erase(MBBI);
}

bool SLSHardeningInserter::hardenBLRs(MachineModuleInfo &MMI,
                                      MachineBasicBlock &MBB,
                                      ThunksSet &Thunks) {
  bool Modified = false;
  MachineBasicBlock::instr_iterator MBBI = MBB.instr_begin(),
                                    E = MBB.instr_end();
  MachineBasicBlock::instr_iterator NextMBBI;
  for (; MBBI != E; MBBI = NextMBBI) {
    MachineInstr &MI = *MBBI;
    NextMBBI = std::next(MBBI);
    if (isBLR(MI)) {
      convertBLRToBL(MMI, MBB, MBBI, Thunks);
      Modified = true;
    }
  }
  return Modified;
}

namespace {
class AArch64SLSHardening : public ThunkInserterPass<SLSHardeningInserter> {
public:
  static char ID;

  AArch64SLSHardening() : ThunkInserterPass(ID) {}

  StringRef getPassName() const override { return AARCH64_SLS_HARDENING_NAME; }
};

} // end anonymous namespace

char AArch64SLSHardening::ID = 0;

INITIALIZE_PASS(AArch64SLSHardening, "aarch64-sls-hardening",
                AARCH64_SLS_HARDENING_NAME, false, false)

FunctionPass *llvm::createAArch64SLSHardeningPass() {
  return new AArch64SLSHardening();
}
