Newer
Older
Import / research / reflection / source / asm.cpp
#include "vm.hpp"
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <map>


#define TEST  0


enum class SectionType
{
  NONE,
  TEXT,
  RODATA,
  BSS,
  DATA,
  STACK,
};


enum class ArgType
{
  Register,
  Label,
  Immediate,
  Dereference,
  Define,
};


struct Fixup
{
  uint32_t            m_position;
  uint32_t            m_position_x86;
  OpCode              m_opCode;
  uint32_t            m_regA;
  uint32_t            m_regB;
  std::string         m_forwardLabel; // address to be looked-up
};


struct Arg
{
  ArgType      m_type;
  std::string  m_label;
  uint32_t     m_register;
  uint32_t     m_immediate;
};


struct AssemblerState
{
  SectionType                           m_currentSection = SectionType::NONE;
  Program                               m_program;

  std::map<std::string, OpCode>         m_opCodeMap;
  std::map<std::string,size_t>          m_textLabelOffsetMap;
  std::map<std::string,size_t>          m_textLabelOffsetMap_x86;
  std::map<std::string,size_t>          m_dataLabelOffsetMap;
  std::vector<Fixup>                    m_fixups;
};


AssemblerState InitializedAssembler()
{
  AssemblerState state;

  // Zero args
  state.m_opCodeMap["nop"] = OpCode::Nop;
  state.m_opCodeMap["ret"] = OpCode::Ret;

  // One arg
  state.m_opCodeMap["pushR"] = OpCode::Push;
  state.m_opCodeMap["popR"]  = OpCode::Pop;
  state.m_opCodeMap["callL"] = OpCode::Call;
  state.m_opCodeMap["jmpL"]  = OpCode::Jmp;
  state.m_opCodeMap["jeL"]   = OpCode::Je;
  state.m_opCodeMap["jneL"]  = OpCode::Jne;
  state.m_opCodeMap["ioctlS"]= OpCode::Ioctl;

  // Two args
  state.m_opCodeMap["addRR"] = OpCode::Add;
  state.m_opCodeMap["subRR"] = OpCode::Sub;
  state.m_opCodeMap["shlRR"] = OpCode::Shl;
  state.m_opCodeMap["shrRR"] = OpCode::Shr;
  state.m_opCodeMap["mulRR"] = OpCode::Mul;
  state.m_opCodeMap["divRR"] = OpCode::Div;
  state.m_opCodeMap["modRR"] = OpCode::Mod;
  state.m_opCodeMap["notRR"] = OpCode::Not;
  state.m_opCodeMap["orRR"]  = OpCode::Or;
  state.m_opCodeMap["xorRR"] = OpCode::Xor;
  state.m_opCodeMap["andRR"] = OpCode::And;
  state.m_opCodeMap["cmpRR"] = OpCode::Cmp;

  state.m_opCodeMap["movRR"] = OpCode::MovRR;
  state.m_opCodeMap["movRL"] = OpCode::MovIR;
  state.m_opCodeMap["movRI"] = OpCode::MovIR;
  state.m_opCodeMap["movRD"] = OpCode::MovCR;

  /*
  MovRR,  // reg         -> reg
  MovMR,  // mem[r]      -> reg
  MovI0,  // imm25       -> reg0
  MovM0,  // mem[addr]   -> reg0
  MovC0,  // const[addr] -> reg0
  MovIM,  // imm20       -> mem[r]
  MovRM,  // reg         -> mem[r]
  Mov0M,  // reg0        -> mem[addr]
  MovMM,  // mem[r]      -> mem[r]
  MovCM,  // const[r]    -> mem[r]
  */

  return state;
}


uint32_t CreateInstruction(OpCode a_instruction, uint32_t a_regA = 0, uint32_t a_regB = 0, uint32_t a_address = 0)
{
  uint32_t instruction = 0;
  instruction |= static_cast<uint32_t>(a_instruction) << 25;
  instruction |= (a_regA & 0x1F) << 20;
  instruction |= (a_regB & 0x1F) << 15;
  instruction |= a_address & 0x1FFFFFF;
  return instruction;
}


std::map<int, std::vector<uint8_t>> opcodes;


void BuildX86CodeMap()
{
#define OpCodeSequence
//#define OpCodeSequence(len,bytes)     OpCodeSeq{ len, std::vector<uint8_t>(bytes) } 
#define MakeOp(op,r1,r2,addr)         CreateInstruction(op,r1,r2,addr)
#include "opcodes.h"
}


std::vector<uint8_t> MapInstruction_x86(OpCode a_instruction, uint32_t a_regA = 0, uint32_t a_regB = 0, uint32_t a_address = 0) // uint32_t a_vmInstruction)
{
  bool addressSubstitutionNeeded = (a_address != 0);
  uint32_t vmInstruction = CreateInstruction(a_instruction, a_regA, a_regB, a_address ? 0xb4b5b6b7 : 0);
	if (!opcodes.count(vmInstruction))
  {
    if (a_address == 0)
    {
      // search again with address substitution - might be either and address of 0, or an immediate value of 0
      vmInstruction = CreateInstruction(a_instruction, a_regA, a_regB, 0xb4b5b6b7);
      addressSubstitutionNeeded = true;
    }
    if (!opcodes.count(vmInstruction))
    {
      printf("error, x86 mapping of opcode not found: %i  =>  %i,%i,%i,%i \n", vmInstruction, a_instruction, a_regA, a_regB, a_address);
      exit(-1);
    }
  }
	std::vector<uint8_t> bytes = opcodes[vmInstruction];
  if (addressSubstitutionNeeded)
  {
    bool found = false;
    int i = 0;
    for (uint8_t b : bytes)
    {
      if (b == 0xb7)
      {
        found = true;
        break;
      }
      i++;
    }
    assert(found);
    assert(bytes[i] == 0xb7);
    bytes[i] = (a_address >>  0) & 0xff;
    i++;
    assert(bytes[i] == 0xb6);
    bytes[i] = (a_address >>  8) & 0xff;
    i++;
    assert(bytes[i] == 0xb5);
    bytes[i] = (a_address >> 16) & 0xff;
    i++;
    assert(bytes[i] == 0xb4);
    bytes[i] = (a_address >> 24) & 0xff;
  }
	return bytes;
}


size_t AddInstruction(Program& a_program, OpCode a_instruction, uint32_t a_regA = 0, uint32_t a_regB = 0, uint32_t a_address = 0)
{
  uint32_t inst = CreateInstruction(a_instruction, a_regA, a_regB, a_address);
  a_program.m_text.emplace_back(inst);
  
  auto bytes = MapInstruction_x86(a_instruction, a_regA, a_regB, a_address);//  inst);
  for (auto byte : bytes)
    a_program.m_text_x86.emplace_back(byte);
  
  return a_program.m_text.size() - 1; // return the address of the instruction
}


size_t AddConstantString(Program& a_program, const char* a_string)
{
  size_t pos = a_program.m_rodata.size(); // the address of the start of added bytes
  for (int i = 0; a_string[i]; ++i)
  {
    a_program.m_rodata.emplace_back(a_string[i]);
  }
  a_program.m_rodata.emplace_back(0);
  return pos;
}


#if TEST
void AssembleProgramTest(Program& a_program)
{
  // constant section
uint32_t str1 =
  AddConstantString(a_program, "hello world\n");
uint32_t str2 =
  AddConstantString(a_program, "blah blah\n");

  // a_program code
  AddInstruction(a_program, MovIR, 2, 0, str1);
uint32_t patch0 =
  AddInstruction(a_program, Nop);               // place-holder
  AddInstruction(a_program, MovIR, 2, 0, str2);
uint32_t patch1 =
  AddInstruction(a_program, Nop);               // place-holder
  AddInstruction(a_program, Ioctl, 0, 0, 0);    // exit

uint32_t print =
  AddInstruction(a_program, MovIR, 3, 0, 1);    // reg3 = 1
  AddInstruction(a_program, MovCR, 0, 2, 0);    // reg0 = c[reg2]    = str[0]
uint32_t loop =
  AddInstruction(a_program, Ioctl, 0, 0, 1);    // ioctl putc
  AddInstruction(a_program, Add,   2, 3, 0);    // reg2 += reg3
  AddInstruction(a_program, MovCR, 0, 2, 0);    // reg0 = c[reg2]
  AddInstruction(a_program, Cmp,   0, 4, 0);    // reg0 == reg4
  AddInstruction(a_program, Jne,   0, 0, loop); // loop
  AddInstruction(a_program, Ret);

  // patching up forward references
  a_program.m_text[patch0] = CreateInstruction(Call, 0, 0, print);
  a_program.m_text[patch1] = CreateInstruction(Call, 0, 0, print);
}
#endif


void UpdateProgramHeader(Program& a_program)
{
  a_program.m_header.m_headerSize   = sizeof(ProgramHeader);
  a_program.m_header.m_textOffset   = sizeof(ProgramHeader);
  a_program.m_header.m_textSize     = a_program.m_text.size();
  a_program.m_header.m_textSize_x86 = a_program.m_text_x86.size();
  a_program.m_header.m_rodataOffset = sizeof(ProgramHeader) + a_program.m_text.size() * sizeof(uint32_t);
  a_program.m_header.m_rodataSize   = a_program.m_rodata.size();
}


void WriteProgram(Program& a_program, std::ostream& a_output)
{
  // write out the program
  a_output.write((char*)&a_program.m_header,         sizeof(a_program.m_header));
  a_output.write((char*)a_program.m_text.data(),     a_program.m_text.size() * sizeof(uint32_t));
  a_output.write((char*)a_program.m_text_x86.data(), a_program.m_text_x86.size() * sizeof(uint32_t));
  a_output.write((char*)a_program.m_rodata.data(),   a_program.m_rodata.size() * sizeof(uint32_t));
}


// Need to parse something like this:
/*
;
; Comments using semi-colon
;

.text

  _main:
        mov    %r2, @str1
        call   @_print
        mov    %r2, @str2
        call   @_print
        ioctl  Syscall_exit

  _print:
        mov    %r3, #1
        mov    %r0, [%r2]
    loop:
        ioctl  Syscall_putc
        add    %r2, %r3
        mov    %r0, [%r2]
        cmp    %r0, %r4
        jne    @loop

.data

  str1:
        db  "hello world\n"
  str2:
        db  "blah blah\n"

*/

ArgType ArgTypeFromString(const std::string& str)
{
  if (str[0] == '%')
    return ArgType::Register;
  if (str[0] == '@')
    return ArgType::Label;
  if (str[0] == '#')
    return ArgType::Immediate;
  if (str[0] == '[')
    return ArgType::Dereference;
  return ArgType::Define;
}


Arg ArgFromString(const std::string& str)
{
  // TODO: defines are something to add to asm parsing
  std::map<std::string, uint32_t> definesMap;
  definesMap["Syscall_exit"] = 0;
  definesMap["Syscall_putc"] = 1;

  Arg arg;
  arg.m_type = ArgTypeFromString(str);
  std::string remainder = std::string(str.begin() + 1, str.end());
  switch (arg.m_type)
  {
    case ArgType::Register:    arg.m_register = remainder[1] - '0';       break;  // TODO: validate, also only handles 0-9
    case ArgType::Label:       arg.m_label = remainder;                   break;
    case ArgType::Immediate:   arg.m_immediate = std::stoul(remainder);   break;
    case ArgType::Dereference: arg.m_register = remainder[2] - '0';       break;  // TODO: validate, also only handles 0-9
    case ArgType::Define:      arg.m_immediate = definesMap[str];         break;  // TODO: validate
  }
  return arg;
}


// https://stackoverflow.com/questions/236129/how-do-i-iterate-over-the-words-of-a-string
std::vector<std::string> tokenize(const std::string& str,
                                  const std::string& delimiters = " ",
                                  bool               trimEmpty = false)
{
  std::vector<std::string> tokens;
  std::size_t start = 0, end, length = str.length();
  while (length && start < length + 1)
  {
    end = str.find_first_of(delimiters, start);
    if (end == std::string::npos)
    {
      end = length;
    }
    if (end != start || !trimEmpty)
      tokens.push_back(std::string(str.data() + start, end - start));
    start = end + 1;
  }
  return tokens;
}


std::string join(const std::vector<std::string> v, const char& delimiter)
{
  if (!v.empty())
  {
    std::stringstream ss;
    std::string str(1, delimiter);
    auto it = v.cbegin();
    while (true)
    {
      ss << *it++;
      if (it != v.cend())
      {
        if (delimiter)
          ss << delimiter;
      }
      else
        return ss.str();
    }
  }
  return "";
}


std::string expandEscapes(const std::string& str)
{
  std::string expandedStr;
  bool inEscape = false;
  for (char ch : str)
  {
    if (ch == '\\')
    {
      inEscape = true;
    }
    else
    {
      if (inEscape)
      {
        switch (ch)
        {
          case 'a':  expandedStr += '\a'; break;
          case 'b':  expandedStr += '\b'; break;
          case 'f':  expandedStr += '\f'; break;
          case 'n':  expandedStr += '\n'; break;
          case 'r':  expandedStr += '\r'; break;
          case 't':  expandedStr += '\t'; break;
          case 'v':  expandedStr += '\v'; break;
          case '\'': expandedStr += '\''; break;
          case '\\': expandedStr += '\\'; break;
           // \c      Ignore rest of string
           // \num    Write a byte whose value is the 1-, 2-, or 3-digit octal number num.  Multibyte characters can be constructed using multiple \num sequences.
        }
      }
      else
      {
        expandedStr += ch;
      }
      inEscape = false;
    }
  }
  return expandedStr;
}


/*
   Maps from instruction like these to op-codes:
   
   movRI
   movRD
   movRD
*/
OpCode OpCodeFromMangledInstruction(const AssemblerState& a_state, const std::string& str)
{
  if (!a_state.m_opCodeMap.count(str))
  {
    printf("opcode not found for %s\n", str.c_str());
    return OpCode::Nop;
  }
  return a_state.m_opCodeMap.at(str);
}


void ProcessInput(AssemblerState& a_state, std::istream& a_input)
{
  while (!a_input.eof())
  {
    std::string line;
    std::getline(a_input, line);
    std::istringstream iss(line);
    std::string uncommentedLine;
    if (std::getline(iss, uncommentedLine, ';'))
    {
      std::vector<std::string> tokens = tokenize(uncommentedLine, " \t\v", true);
      if (tokens.size())
      {
        std::string firstStr = tokens[0];
        std::string rest;
        if (tokens.size() > 1)
        {
          std::vector<std::string> r(tokens.begin()+1, tokens.end());
          rest = join(r, 0);
        }

        if (firstStr[0] == '.')
        {
          if (firstStr == ".text")
          {
            a_state.m_currentSection = SectionType::TEXT;
          }
          else if (firstStr == ".data")
          {
            a_state.m_currentSection = SectionType::RODATA;
          }
          else
          {
            printf("invalid/unhandled section type. Should be .text or .data\n");
          }
          // section
          //printf("s");
        }
        else if (*(firstStr.end()-1) == ':')
        {
          std::string label(firstStr.begin(), firstStr.end() - 1);
          switch (a_state.m_currentSection)
          {
            case SectionType::TEXT:
              a_state.m_textLabelOffsetMap[label] = a_state.m_program.m_text.size();
              a_state.m_textLabelOffsetMap_x86[label] = a_state.m_program.m_text_x86.size();
              break;
            case SectionType::RODATA:
              a_state.m_dataLabelOffsetMap[label] = a_state.m_program.m_rodata.size();
              break;
            case SectionType::NONE:
            case SectionType::DATA:
            case SectionType::BSS:
            case SectionType::STACK:
              printf("Invalid/unhandled section type to contain label. Put in .text or .data section\n"); 
              break;
          }
          // label
          //printf("l");
        }
        else if (firstStr == "db")
        {
          if (a_state.m_currentSection != SectionType::RODATA)
          {
              printf("Invalid/unhandled section type to contain data. Put db in .data section\n"); 
          }
          
          std::vector<std::string> strs = tokenize(line + " ", "\"", false);
          if (strs.size() == 3) // It is a db of a string (rather than list of bytes)
          {
            std::vector<std::string> r2(strs.begin() + 1, strs.end() - 1);
            rest = join(r2, 0);
            rest = expandEscapes(rest);
          }
          else
          {
            printf("lists of db bytes not supported, only strings currently supported. Strings with double quotes also not supported.\n");
          }

          AddConstantString(a_state.m_program, rest.c_str());
          // data
          //printf("d");
        }
        else
        {
          if (a_state.m_currentSection != SectionType::TEXT)
          {
              printf("Invalid/unhandled section type to contain data. Put db in .data section\n"); 
          }

          // instruction
          //printf("i");
          //printf("--%s-%s--\n", firstStr.c_str(), rest.c_str());
          
          std::vector<std::string> argStrs = tokenize(rest, ",", false);
          std::vector<Arg> args;
          std::string instruction = firstStr;
          bool containsLabel = false;
          for (std::string s : argStrs)
          {
            Arg arg = ArgFromString(s);
            args.emplace_back(arg);

            switch (arg.m_type)
            {
              case ArgType::Register:    instruction += 'R'; break;
              case ArgType::Label:       instruction += 'L'; containsLabel = true; break;
              case ArgType::Immediate:   instruction += 'I'; break;
              case ArgType::Dereference: instruction += 'D'; break;
              case ArgType::Define:      instruction += 'S'; break;
            }
          }

          OpCode opCode = OpCodeFromMangledInstruction(a_state, instruction);
          int regIdx = 0;
          uint32_t regs[2] = { 0, 0 };
          uint32_t addr = 0;
          std::string label = "";
          for (auto arg : args)
          {
            switch (arg.m_type)
            {
              case ArgType::Register:    regs[regIdx] = arg.m_register; regIdx++; break;
              case ArgType::Label:       label = arg.m_label; break;
              case ArgType::Immediate:   addr = arg.m_immediate; break;
              case ArgType::Dereference: regs[regIdx] = arg.m_register; regIdx++; break;
              case ArgType::Define:      addr = arg.m_immediate; break;
            }
          }

          if (containsLabel)
          {
            uint32_t pos_x86 = a_state.m_program.m_text_x86.size();
            // Add dummy placeholder instruction until we do the fixup
            // uint32_t pos = AddInstruction(a_state.m_program, OpCode::Nop);
            // Because the x86 instructions are not all a fixed length, we need to put in a placeholder instruction of the same length:
            uint32_t pos = AddInstruction(a_state.m_program, opCode, regs[0], regs[1], 0); // location/address of zero for now until fixup is done

            // save the fixup to be done with all required information to apply it (when we find the forward declared label's actual code position)
            Fixup fixup{ pos, pos_x86, opCode, regs[0], regs[1], label };
            a_state.m_fixups.emplace_back(fixup);
          }
          else
          {
            AddInstruction(a_state.m_program, opCode, regs[0], regs[1], addr);
          }
          //printf("--%s-%s--\n", instruction.c_str(), rest.c_str());
        }
      }
    }
  }

  // apply fixups
  for (const Fixup& fixup : a_state.m_fixups)
  {
    uint32_t location = 0;

    // Do fixup for the VM text
    if (a_state.m_textLabelOffsetMap.count(fixup.m_forwardLabel))
      location = a_state.m_textLabelOffsetMap[fixup.m_forwardLabel];
    else if (a_state.m_dataLabelOffsetMap.count(fixup.m_forwardLabel))
      location = a_state.m_dataLabelOffsetMap[fixup.m_forwardLabel];
    else
      printf("unable to find label: %s\n", fixup.m_forwardLabel.c_str());
    a_state.m_program.m_text[fixup.m_position] = CreateInstruction(fixup.m_opCode, fixup.m_regA, fixup.m_regB, location);

    // Do fixup again for the x86 text
    if (a_state.m_textLabelOffsetMap_x86.count(fixup.m_forwardLabel))
      location = a_state.m_textLabelOffsetMap[fixup.m_forwardLabel];
    else if (a_state.m_dataLabelOffsetMap.count(fixup.m_forwardLabel))
      location = a_state.m_dataLabelOffsetMap[fixup.m_forwardLabel];
    else
      printf("unable to find label: %s\n", fixup.m_forwardLabel.c_str());

    auto bytes = MapInstruction_x86(fixup.m_opCode, fixup.m_regA, fixup.m_regB, location);
    size_t pos = fixup.m_position_x86;
    for (auto byte : bytes)
    {
      a_state.m_program.m_text_x86[pos] = byte;
      ++pos;
    }

  }
}



int main(int argc, char* argv[])
{
  if (argc < 3)
  {
    printf("bad number of arguments\n");
    return -1;
  }
  else
  {
    std::ifstream input(argv[2]);
    if (input.is_open())
    {
      printf("output file already exists\n");
      return -1;
    }
  }

  std::ifstream input(argv[1]);
  if (!input.is_open())
  {
    printf("couldn't open input file\n");
    return -1;
  }

  std::ofstream output(argv[2]);
  if (!output.is_open())
  {
    printf("output file couldn't be created\n");
    return -1;
  }

  AssemblerState state = InitializedAssembler();
  BuildX86CodeMap();
#if TEST
  AssembleProgramTest(state.m_program);
#else
  ProcessInput(state, input);
#endif
  UpdateProgramHeader(state.m_program);
  WriteProgram(state.m_program, output);
  return 0;
}