Newer
Older
Import / research / reflection / source / vm.cpp
#include "vm.hpp"


const int cNumberOfRegisters = 32;
const int cNumberOfFunctions = 32;
const int cMemorySize = 0x1000000;  // 16 MB


#if USE_STL
#include <iostream>
#include <fstream>
using istream = std::istream;
using ifstream = std::ifstream;
using OSFunction = std::function<void(uint32_t m_registers[cNumberOfRegisters])>;
#else
using istream = min::istream;
using ifstream = min::ifstream;
using OSFunction = void(*)(uint32_t m_registers[cNumberOfRegisters]);
#endif


struct Flags
{
  uint32_t       m_equal       : 1; // 1 = zero/equal
  uint32_t       m_sign        : 1; // 0 = less, 1 = greater or equal

  uint32_t       m_overflow    : 1;
  uint32_t       m_carry       : 1;
  uint32_t       m_parity      : 1;
  uint32_t       m_trap        : 1;
  uint32_t       m_interrupts  : 1;
  uint32_t       m_direction   : 1;
  uint32_t       m_cpuid       : 8;  // 8-bits of CPU identification

  uint32_t       m_privilege   : 1;
  uint32_t       m_nested      : 1;
  uint32_t       m_resume      : 1;
  uint32_t       m_virtual     : 1;

  uint32_t       m_instructionMask : 12; // Mask of supported instructions
};
static_assert(sizeof(Flags) == 4, "Bad flags size");


struct MachineState
{
  union {
    Flags                 m_flags;
    uint32_t              m_flagsValue;
  };

  // special regs
  uint32_t                m_ip;
  uint32_t                m_sp;

  // general regs
  uint32_t                m_registers[cNumberOfRegisters];
  float                   m_fpRegisters[cNumberOfRegisters];

  // elf sections / memory segments
  vector<uint32_t>        m_program;   // ro  .text  / code
  vector<uint32_t>        m_constants; // ro  .rodata  / immediates,consts,strings

                                       // rw  bss   / zeroed data (block started by symbol)
  vector<uint32_t>        m_memory;    // rw  heap  / data
  vector<uint32_t>        m_stack;     // rw  stack / stack

  // int vector table
  OSFunction              m_ioctls[cNumberOfFunctions];
};


bool ExecuteInstruction(MachineState& a_machine, OpCode a_instruction, uint8_t a_regA, uint8_t a_regB, uint32_t a_address)
{
  switch (a_instruction)
  {
    case Nop:   break;
    case Add:   a_machine.m_registers[a_regA] +=  a_machine.m_registers[a_regB]; break;
    case Sub:   a_machine.m_registers[a_regA] -=  a_machine.m_registers[a_regB]; break;
    case Shl:   a_machine.m_registers[a_regA] <<= a_machine.m_registers[a_regB]; break;
    case Shr:   a_machine.m_registers[a_regA] >>= a_machine.m_registers[a_regB]; break;
    case Mul:   a_machine.m_registers[a_regA] *=  a_machine.m_registers[a_regB]; break;
    case Div:   a_machine.m_registers[a_regA] /=  a_machine.m_registers[a_regB]; break;
    case Mod:   a_machine.m_registers[a_regA] %=  a_machine.m_registers[a_regB]; break;
    case Not:   a_machine.m_registers[a_regA] =  ~a_machine.m_registers[a_regB]; break;
    case Or:    a_machine.m_registers[a_regA] |=  a_machine.m_registers[a_regB]; break;
    case Xor:   a_machine.m_registers[a_regA] ^=  a_machine.m_registers[a_regB]; break;
    case And:   a_machine.m_registers[a_regA] &=  a_machine.m_registers[a_regB]; break;

    case MovIR: a_machine.m_registers[a_regA]  =  a_address & 0xFFFFF;                                  break;
    case MovRR: a_machine.m_registers[a_regA]  =  a_machine.m_registers[a_regB];                        break;
    //   Mov0R: a_machine.m_registers[a_regA]  =  a_machine.m_registers[0];                             break;
    case MovMR: a_machine.m_registers[a_regA]  =  a_machine.m_memory[a_machine.m_registers[a_regB]];    break;
    case MovCR: a_machine.m_registers[a_regA]  =  a_machine.m_constants[a_machine.m_registers[a_regB]]; break;

    case MovI0: a_machine.m_registers[0]       =  a_address & 0x1FFFFFF;                                break;
    //   MovR0: a_machine.m_registers[0]       =  a_machine.m_registers[a_regA];                        break;
    //   Mov00: a_machine.m_registers[0]       =  a_machine.m_registers[0];                             break;
    case MovM0: a_machine.m_registers[0]       =  a_machine.m_memory[a_address];                        break;
    case MovC0: a_machine.m_registers[0]       =  a_machine.m_constants[a_address];                     break;

    case MovIM: a_machine.m_memory[a_machine.m_registers[a_regA]] = a_address & 0xFFFFF;                                  break;
    case MovRM: a_machine.m_memory[a_machine.m_registers[a_regA]] = a_machine.m_registers[a_regB];                        break;
    case Mov0M: a_machine.m_memory[a_address]                     = a_machine.m_registers[0];                             break;
    case MovMM: a_machine.m_memory[a_machine.m_registers[a_regA]] = a_machine.m_memory[a_machine.m_registers[a_regB]];    break;
    case MovCM: a_machine.m_memory[a_machine.m_registers[a_regA]] = a_machine.m_constants[a_machine.m_registers[a_regB]]; break;

    case Cmp:
      a_machine.m_flags.m_equal = a_machine.m_registers[a_regA] == a_machine.m_registers[a_regB];
      a_machine.m_flags.m_sign  = a_machine.m_registers[a_regA] >= a_machine.m_registers[a_regB];
      break;
    case Ioctl:
      a_machine.m_ioctls[a_address](a_machine.m_registers);
      break;
    case Call:
      a_machine.m_stack[a_machine.m_sp] = a_machine.m_ip;
      ++a_machine.m_sp;
      a_machine.m_ip = a_address - 1;
      break;
    case Jmp:
      a_machine.m_ip = a_address - 1;
      break;
    case Je:
      if (a_machine.m_flags.m_equal)
        a_machine.m_ip = a_address - 1;
      break;
    case Jne:
      if (!a_machine.m_flags.m_equal)
        a_machine.m_ip = a_address - 1;
      break;
    case Push:
      a_machine.m_stack[a_machine.m_sp] = a_machine.m_registers[a_regA];
      ++a_machine.m_sp;
      break;
    case Pop:
      --a_machine.m_sp;
      a_machine.m_registers[a_regA] = a_machine.m_stack[a_machine.m_sp];
      break;
    case Ret:
      --a_machine.m_sp;
      a_machine.m_ip = a_machine.m_stack[a_machine.m_sp];
      if (a_machine.m_ip == -1)
        return true;
      break;
  }
  return false;
}


void os_exit(uint32_t m_registers[cNumberOfRegisters])
{
  //printf("exit\n");
  exit(m_registers[0]);
}


void os_putc(uint32_t m_registers[cNumberOfRegisters])
{
  putchar(m_registers[0]);
}


void InitializeMachine(MachineState& a_machine)
{
  a_machine.m_flagsValue = 0;
  a_machine.m_ip = 0;
  a_machine.m_sp = 0;
  for (int i = 0; i < cNumberOfRegisters; ++i)
    a_machine.m_registers[i] = 0;
  a_machine.m_memory.resize(cMemorySize);
  a_machine.m_stack.resize(cMemorySize);
  a_machine.m_ioctls[Syscall_exit] = os_exit;
  a_machine.m_ioctls[Syscall_putc] = os_putc;
}


void LoadProgram(MachineState& a_machine, const Program& a_program)
{
  a_machine.m_program   = a_program.m_text;
  a_machine.m_constants = a_program.m_rodata;
}


bool Old_RunProgram(MachineState& a_machine)
{
  a_machine.m_ip = 0;
  a_machine.m_stack[0] = -1;
  a_machine.m_sp = 1;
  bool finished = false;
  while (!finished)
  {
    uint32_t currentInstruction = a_machine.m_program[a_machine.m_ip];
    uint8_t op = currentInstruction >> 25;
    uint8_t dst = (currentInstruction >> 20) & 0x1F;
    uint8_t src = (currentInstruction >> 15) & 0x1F;
    uint32_t addr = currentInstruction & 0x1FFFFFF;
    finished = ExecuteInstruction(a_machine, static_cast<OpCode>(op), dst, src, addr);
    ++a_machine.m_ip;
  }
  return true;
}


// Optimized
//
// Next level speed up might be to try to un-roll / expand the opcodes in to cases for each reg/reg combo
// so that can then map vm regs to h/w regs better. At the moment the regs are in an array, so may be going
// via memory / cache.
//
// Next-next level would be to implement a JIT compiler as part of the VM
bool RunProgram(MachineState& a_machine)
{
  a_machine.m_ip = 0;
  a_machine.m_sp = 0;
  uint32_t* code_start = &a_machine.m_program[0];
  uint32_t* ip = &a_machine.m_program[a_machine.m_ip];
  uint32_t* sp = &a_machine.m_stack[a_machine.m_sp];
  uint32_t r0 = 0, r1 = 0, r2 = 0, r3 = 0, r4 = 0, r5 = 0, r6 = 0, r7 = 0, r8 = 0, r9 = 0, r10 = 0, r11 = 0, r12 = 0, r13 = 0, r14 = 0, r15 = 0;
  uint32_t* regs[] = { &r0, &r1, &r2, &r3, &r4, &r5, &r6, &r7, &r8, &r9, &r10, &r11, &r12, &r13, &r14, &r15 };

  bool equal = false;
  bool sign = false;

  *sp = -1;
  ++sp;

  uint32_t nextInstruction = *ip;
  while (true)
  {
    uint32_t currentInstruction = nextInstruction; // *ip;
    nextInstruction = *(ip + 1);
    ++ip;
    uint8_t a_regA = (currentInstruction >> 20) & 0x1F;   // Very HOT
    uint8_t a_regB = (currentInstruction >> 15) & 0x1F;   // HOT
    uint32_t* rA = regs[a_regA]; // even though these 2 lines are unused, it somehow adds to performance
    uint32_t* rB = regs[a_regB]; // removing these lines makes it run slower!
    uint32_t* rC = regs[a_regA]; // removing these lines makes it run slower!
    uint32_t* rD = regs[a_regB]; // removing these lines makes it run slower!
    switch (static_cast<OpCode>(currentInstruction >> 25)) // Very HOT
    {
      case Nop:   break;
      case Add:   a_machine.m_registers[a_regA] +=  a_machine.m_registers[a_regB]; break;  // HOT
      case Sub:   a_machine.m_registers[a_regA] -=  a_machine.m_registers[a_regB]; break;
      case Shl:   a_machine.m_registers[a_regA] <<= a_machine.m_registers[a_regB]; break;
      case Shr:   a_machine.m_registers[a_regA] >>= a_machine.m_registers[a_regB]; break;
      case Mul:   a_machine.m_registers[a_regA] *=  a_machine.m_registers[a_regB]; break;
      case Div:   a_machine.m_registers[a_regA] /=  a_machine.m_registers[a_regB]; break;
      case Mod:   a_machine.m_registers[a_regA] %=  a_machine.m_registers[a_regB]; break;
      case Not:   a_machine.m_registers[a_regA] =  ~a_machine.m_registers[a_regB]; break;
      case Or:    a_machine.m_registers[a_regA] |=  a_machine.m_registers[a_regB]; break;
      case Xor:   a_machine.m_registers[a_regA] ^=  a_machine.m_registers[a_regB]; break;
      case And:   a_machine.m_registers[a_regA] &=  a_machine.m_registers[a_regB]; break;

      case MovIR: a_machine.m_registers[a_regA]  =  currentInstruction & 0xFFFFF;                          break;
      case MovRR: a_machine.m_registers[a_regA]  =  a_machine.m_registers[a_regB];                         break;  // HOT
      //   Mov0R: a_machine.m_registers[a_regA]  =  a_machine.m_registers[0];                              break;
      case MovMR: a_machine.m_registers[a_regA]  =  a_machine.m_memory[a_machine.m_registers[a_regB]];     break;
      case MovCR: a_machine.m_registers[a_regA]  =  a_machine.m_constants[a_machine.m_registers[a_regB]];  break;

      case MovI0: a_machine.m_registers[0]       =  currentInstruction & 0x1FFFFFF;                        break;
      //   MovR0: a_machine.m_registers[0]       =  a_machine.m_registers[a_regA];                         break;
      //   Mov00: a_machine.m_registers[0]       =  a_machine.m_registers[0];                              break;
      case MovM0: a_machine.m_registers[0]       =  a_machine.m_memory[currentInstruction & 0x1FFFFFF];    break;
      case MovC0: a_machine.m_registers[0]       =  a_machine.m_constants[currentInstruction & 0x1FFFFFF]; break;

      case MovIM: a_machine.m_memory[a_machine.m_registers[a_regA]]  = currentInstruction & 0xFFFFF;                         break;
      case MovRM: a_machine.m_memory[a_machine.m_registers[a_regA]]  = a_machine.m_registers[a_regB];                        break;
      case Mov0M: a_machine.m_memory[currentInstruction & 0x1FFFFFF] = a_machine.m_registers[0];                             break;
      case MovMM: a_machine.m_memory[a_machine.m_registers[a_regA]]  = a_machine.m_memory[a_machine.m_registers[a_regB]];    break;
      case MovCM: a_machine.m_memory[a_machine.m_registers[a_regA]]  = a_machine.m_constants[a_machine.m_registers[a_regB]]; break;

      case Cmp:
                  equal = a_machine.m_registers[a_regA] == a_machine.m_registers[a_regB];
                  sign  = a_machine.m_registers[a_regA] >= a_machine.m_registers[a_regB];
                  break;
      case Ioctl:
                  a_machine.m_ioctls[currentInstruction & 0x1FFFFFF](a_machine.m_registers);
                  break;
      case Call:
                  *sp = static_cast<uint32_t>(ip - code_start - 1);
                  ++sp;
                  ip = code_start + (currentInstruction & 0x1FFFFFF);// - 1;
                  nextInstruction = *ip;
                  break;
      case Jmp:
                  ip = code_start + (currentInstruction & 0x1FFFFFF);// - 1;
                  nextInstruction = *ip;
                  break;
      case Je:
                  if (equal) {
                    ip = code_start + (currentInstruction & 0x1FFFFFF);// - 1;
                    nextInstruction = *ip;
                  }
                  break;
      case Jne:
                  if (!equal) {
                    ip = code_start + (currentInstruction & 0x1FFFFFF);// - 1;
                    nextInstruction = *ip;
                  }
                  break;
      case Push:
                  *sp = a_machine.m_registers[a_regA];
                  ++sp;
                  break;
      case Pop:
                  a_machine.m_registers[a_regA] = *(sp - 1);
                  --sp;
                  break;
      case Ret:
                  ip = code_start + *(sp - 1) + 1;
                  --sp;
                  if (*sp == -1)
                    return true;
                  nextInstruction = *ip;
                  break;
    }
  }
  return true;
}


void CheckInput(istream& a_input)
{
  if (a_input.eof())
  {
    printf("unexpected end of file\n");
    exit(-1);
  }
}


void ReadProgram(Program& a_program, istream& a_input)
{
  CheckInput(a_input);
  a_input.read((char*)&a_program.m_header, sizeof(ProgramHeader));
  a_program.m_text.resize(a_program.m_header.m_textSize);
  a_program.m_text_x86.resize(a_program.m_header.m_textSize_x86);
  a_program.m_rodata.resize(a_program.m_header.m_rodataSize);
  CheckInput(a_input);
  a_input.read((char*)a_program.m_text.data(), a_program.m_header.m_textSize * sizeof(uint32_t));
  CheckInput(a_input);
  a_input.read((char*)a_program.m_text_x86.data(), a_program.m_header.m_textSize_x86 * sizeof(uint32_t));
  CheckInput(a_input);
  a_input.read((char*)a_program.m_rodata.data(), a_program.m_header.m_rodataSize * sizeof(uint32_t));
  // printf("sizes: %u  %u  %u\n", a_program.m_header.m_headerSize, a_program.m_header.m_textSize, a_program.m_header.m_rodataSize);
}


int main(int argc, char* argv[])
{
  if (argc < 2)
  {
    printf("bad number of arguments\n");
    return -1;
  }

  ifstream input(argv[1]);
  if (!input.is_open())
  {
    printf("couldn't open input file\n");
    return -1;
  }

  Program program;
  ReadProgram(program, input);

  MachineState machine;
  InitializeMachine(machine);
  LoadProgram(machine, program);
  for (int i = 0; i < 100; i++)
    RunProgram(machine);
  return 0;
}