#include "vm.hpp"
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <map>
#define TEST 0
enum class SectionType
{
NONE,
TEXT,
RODATA,
BSS,
DATA,
STACK,
};
enum class ArgType
{
Register,
Label,
Immediate,
Dereference,
Define,
};
struct Fixup
{
uint32_t m_position;
uint32_t m_position_x86;
OpCode m_opCode;
uint32_t m_regA;
uint32_t m_regB;
std::string m_forwardLabel; // address to be looked-up
};
struct Arg
{
ArgType m_type;
std::string m_label;
uint32_t m_register;
uint32_t m_immediate;
};
struct AssemblerState
{
SectionType m_currentSection = SectionType::NONE;
Program m_program;
std::map<std::string, OpCode> m_opCodeMap;
std::map<std::string,size_t> m_textLabelOffsetMap;
std::map<std::string,size_t> m_textLabelOffsetMap_x86;
std::map<std::string,size_t> m_dataLabelOffsetMap;
std::vector<Fixup> m_fixups;
};
AssemblerState InitializedAssembler()
{
AssemblerState state;
// Zero args
state.m_opCodeMap["nop"] = OpCode::Nop;
state.m_opCodeMap["ret"] = OpCode::Ret;
// One arg
state.m_opCodeMap["pushR"] = OpCode::Push;
state.m_opCodeMap["popR"] = OpCode::Pop;
state.m_opCodeMap["callL"] = OpCode::Call;
state.m_opCodeMap["jmpL"] = OpCode::Jmp;
state.m_opCodeMap["jeL"] = OpCode::Je;
state.m_opCodeMap["jneL"] = OpCode::Jne;
state.m_opCodeMap["ioctlS"]= OpCode::Ioctl;
// Two args
state.m_opCodeMap["addRR"] = OpCode::Add;
state.m_opCodeMap["subRR"] = OpCode::Sub;
state.m_opCodeMap["shlRR"] = OpCode::Shl;
state.m_opCodeMap["shrRR"] = OpCode::Shr;
state.m_opCodeMap["mulRR"] = OpCode::Mul;
state.m_opCodeMap["divRR"] = OpCode::Div;
state.m_opCodeMap["modRR"] = OpCode::Mod;
state.m_opCodeMap["notRR"] = OpCode::Not;
state.m_opCodeMap["orRR"] = OpCode::Or;
state.m_opCodeMap["xorRR"] = OpCode::Xor;
state.m_opCodeMap["andRR"] = OpCode::And;
state.m_opCodeMap["cmpRR"] = OpCode::Cmp;
state.m_opCodeMap["movRR"] = OpCode::MovRR;
state.m_opCodeMap["movRL"] = OpCode::MovIR;
state.m_opCodeMap["movRI"] = OpCode::MovIR;
state.m_opCodeMap["movRD"] = OpCode::MovCR;
/*
MovRR, // reg -> reg
MovMR, // mem[r] -> reg
MovI0, // imm25 -> reg0
MovM0, // mem[addr] -> reg0
MovC0, // const[addr] -> reg0
MovIM, // imm20 -> mem[r]
MovRM, // reg -> mem[r]
Mov0M, // reg0 -> mem[addr]
MovMM, // mem[r] -> mem[r]
MovCM, // const[r] -> mem[r]
*/
return state;
}
uint32_t CreateInstruction(OpCode a_instruction, uint32_t a_regA = 0, uint32_t a_regB = 0, uint32_t a_address = 0)
{
uint32_t instruction = 0;
instruction |= static_cast<uint32_t>(a_instruction) << 25;
instruction |= (a_regA & 0x1F) << 20;
instruction |= (a_regB & 0x1F) << 15;
instruction |= a_address & 0x1FFFFFF;
return instruction;
}
std::map<int, std::vector<uint8_t>> opcodes;
void BuildX86CodeMap()
{
#define OpCodeSequence
//#define OpCodeSequence(len,bytes) OpCodeSeq{ len, std::vector<uint8_t>(bytes) }
#define MakeOp(op,r1,r2,addr) CreateInstruction(op,r1,r2,addr)
#include "opcodes.h"
}
std::vector<uint8_t> MapInstruction_x86(OpCode a_instruction, uint32_t a_regA = 0, uint32_t a_regB = 0, uint32_t a_address = 0) // uint32_t a_vmInstruction)
{
bool addressSubstitutionNeeded = (a_address != 0);
uint32_t vmInstruction = CreateInstruction(a_instruction, a_regA, a_regB, a_address ? 0xb4b5b6b7 : 0);
if (!opcodes.count(vmInstruction))
{
if (a_address == 0)
{
// search again with address substitution - might be either and address of 0, or an immediate value of 0
vmInstruction = CreateInstruction(a_instruction, a_regA, a_regB, 0xb4b5b6b7);
addressSubstitutionNeeded = true;
}
if (!opcodes.count(vmInstruction))
{
printf("error, x86 mapping of opcode not found: %i => %i,%i,%i,%i \n", vmInstruction, a_instruction, a_regA, a_regB, a_address);
exit(-1);
}
}
std::vector<uint8_t> bytes = opcodes[vmInstruction];
if (addressSubstitutionNeeded)
{
bool found = false;
int i = 0;
for (uint8_t b : bytes)
{
if (b == 0xb7)
{
found = true;
break;
}
i++;
}
assert(found);
assert(bytes[i] == 0xb7);
bytes[i] = (a_address >> 0) & 0xff;
i++;
assert(bytes[i] == 0xb6);
bytes[i] = (a_address >> 8) & 0xff;
i++;
assert(bytes[i] == 0xb5);
bytes[i] = (a_address >> 16) & 0xff;
i++;
assert(bytes[i] == 0xb4);
bytes[i] = (a_address >> 24) & 0xff;
}
return bytes;
}
size_t AddInstruction(Program& a_program, OpCode a_instruction, uint32_t a_regA = 0, uint32_t a_regB = 0, uint32_t a_address = 0)
{
uint32_t inst = CreateInstruction(a_instruction, a_regA, a_regB, a_address);
a_program.m_text.emplace_back(inst);
auto bytes = MapInstruction_x86(a_instruction, a_regA, a_regB, a_address);// inst);
for (auto byte : bytes)
a_program.m_text_x86.emplace_back(byte);
return a_program.m_text.size() - 1; // return the address of the instruction
}
size_t AddConstantString(Program& a_program, const char* a_string)
{
size_t pos = a_program.m_rodata.size(); // the address of the start of added bytes
for (int i = 0; a_string[i]; ++i)
{
a_program.m_rodata.emplace_back(a_string[i]);
}
a_program.m_rodata.emplace_back(0);
return pos;
}
#if TEST
void AssembleProgramTest(Program& a_program)
{
// constant section
uint32_t str1 =
AddConstantString(a_program, "hello world\n");
uint32_t str2 =
AddConstantString(a_program, "blah blah\n");
// a_program code
AddInstruction(a_program, MovIR, 2, 0, str1);
uint32_t patch0 =
AddInstruction(a_program, Nop); // place-holder
AddInstruction(a_program, MovIR, 2, 0, str2);
uint32_t patch1 =
AddInstruction(a_program, Nop); // place-holder
AddInstruction(a_program, Ioctl, 0, 0, 0); // exit
uint32_t print =
AddInstruction(a_program, MovIR, 3, 0, 1); // reg3 = 1
AddInstruction(a_program, MovCR, 0, 2, 0); // reg0 = c[reg2] = str[0]
uint32_t loop =
AddInstruction(a_program, Ioctl, 0, 0, 1); // ioctl putc
AddInstruction(a_program, Add, 2, 3, 0); // reg2 += reg3
AddInstruction(a_program, MovCR, 0, 2, 0); // reg0 = c[reg2]
AddInstruction(a_program, Cmp, 0, 4, 0); // reg0 == reg4
AddInstruction(a_program, Jne, 0, 0, loop); // loop
AddInstruction(a_program, Ret);
// patching up forward references
a_program.m_text[patch0] = CreateInstruction(Call, 0, 0, print);
a_program.m_text[patch1] = CreateInstruction(Call, 0, 0, print);
}
#endif
void UpdateProgramHeader(Program& a_program)
{
a_program.m_header.m_headerSize = sizeof(ProgramHeader);
a_program.m_header.m_textOffset = sizeof(ProgramHeader);
a_program.m_header.m_textSize = a_program.m_text.size();
a_program.m_header.m_textSize_x86 = a_program.m_text_x86.size();
a_program.m_header.m_rodataOffset = sizeof(ProgramHeader) + a_program.m_text.size() * sizeof(uint32_t);
a_program.m_header.m_rodataSize = a_program.m_rodata.size();
}
void WriteProgram(Program& a_program, std::ostream& a_output)
{
// write out the program
a_output.write((char*)&a_program.m_header, sizeof(a_program.m_header));
a_output.write((char*)a_program.m_text.data(), a_program.m_text.size() * sizeof(uint32_t));
a_output.write((char*)a_program.m_text_x86.data(), a_program.m_text_x86.size() * sizeof(uint32_t));
a_output.write((char*)a_program.m_rodata.data(), a_program.m_rodata.size() * sizeof(uint32_t));
}
// Need to parse something like this:
/*
;
; Comments using semi-colon
;
.text
_main:
mov %r2, @str1
call @_print
mov %r2, @str2
call @_print
ioctl Syscall_exit
_print:
mov %r3, #1
mov %r0, [%r2]
loop:
ioctl Syscall_putc
add %r2, %r3
mov %r0, [%r2]
cmp %r0, %r4
jne @loop
.data
str1:
db "hello world\n"
str2:
db "blah blah\n"
*/
ArgType ArgTypeFromString(const std::string& str)
{
if (str[0] == '%')
return ArgType::Register;
if (str[0] == '@')
return ArgType::Label;
if (str[0] == '#')
return ArgType::Immediate;
if (str[0] == '[')
return ArgType::Dereference;
return ArgType::Define;
}
Arg ArgFromString(const std::string& str)
{
// TODO: defines are something to add to asm parsing
std::map<std::string, uint32_t> definesMap;
definesMap["Syscall_exit"] = 0;
definesMap["Syscall_putc"] = 1;
Arg arg;
arg.m_type = ArgTypeFromString(str);
std::string remainder = std::string(str.begin() + 1, str.end());
switch (arg.m_type)
{
case ArgType::Register: arg.m_register = remainder[1] - '0'; break; // TODO: validate, also only handles 0-9
case ArgType::Label: arg.m_label = remainder; break;
case ArgType::Immediate: arg.m_immediate = std::stoul(remainder); break;
case ArgType::Dereference: arg.m_register = remainder[2] - '0'; break; // TODO: validate, also only handles 0-9
case ArgType::Define: arg.m_immediate = definesMap[str]; break; // TODO: validate
}
return arg;
}
// https://stackoverflow.com/questions/236129/how-do-i-iterate-over-the-words-of-a-string
std::vector<std::string> tokenize(const std::string& str,
const std::string& delimiters = " ",
bool trimEmpty = false)
{
std::vector<std::string> tokens;
std::size_t start = 0, end, length = str.length();
while (length && start < length + 1)
{
end = str.find_first_of(delimiters, start);
if (end == std::string::npos)
{
end = length;
}
if (end != start || !trimEmpty)
tokens.push_back(std::string(str.data() + start, end - start));
start = end + 1;
}
return tokens;
}
std::string join(const std::vector<std::string> v, const char& delimiter)
{
if (!v.empty())
{
std::stringstream ss;
std::string str(1, delimiter);
auto it = v.cbegin();
while (true)
{
ss << *it++;
if (it != v.cend())
{
if (delimiter)
ss << delimiter;
}
else
return ss.str();
}
}
return "";
}
std::string expandEscapes(const std::string& str)
{
std::string expandedStr;
bool inEscape = false;
for (char ch : str)
{
if (ch == '\\')
{
inEscape = true;
}
else
{
if (inEscape)
{
switch (ch)
{
case 'a': expandedStr += '\a'; break;
case 'b': expandedStr += '\b'; break;
case 'f': expandedStr += '\f'; break;
case 'n': expandedStr += '\n'; break;
case 'r': expandedStr += '\r'; break;
case 't': expandedStr += '\t'; break;
case 'v': expandedStr += '\v'; break;
case '\'': expandedStr += '\''; break;
case '\\': expandedStr += '\\'; break;
// \c Ignore rest of string
// \num Write a byte whose value is the 1-, 2-, or 3-digit octal number num. Multibyte characters can be constructed using multiple \num sequences.
}
}
else
{
expandedStr += ch;
}
inEscape = false;
}
}
return expandedStr;
}
/*
Maps from instruction like these to op-codes:
movRI
movRD
movRD
*/
OpCode OpCodeFromMangledInstruction(const AssemblerState& a_state, const std::string& str)
{
if (!a_state.m_opCodeMap.count(str))
{
printf("opcode not found for %s\n", str.c_str());
return OpCode::Nop;
}
return a_state.m_opCodeMap.at(str);
}
void ProcessInput(AssemblerState& a_state, std::istream& a_input)
{
while (!a_input.eof())
{
std::string line;
std::getline(a_input, line);
std::istringstream iss(line);
std::string uncommentedLine;
if (std::getline(iss, uncommentedLine, ';'))
{
std::vector<std::string> tokens = tokenize(uncommentedLine, " \t\v", true);
if (tokens.size())
{
std::string firstStr = tokens[0];
std::string rest;
if (tokens.size() > 1)
{
std::vector<std::string> r(tokens.begin()+1, tokens.end());
rest = join(r, 0);
}
if (firstStr[0] == '.')
{
if (firstStr == ".text")
{
a_state.m_currentSection = SectionType::TEXT;
}
else if (firstStr == ".data")
{
a_state.m_currentSection = SectionType::RODATA;
}
else
{
printf("invalid/unhandled section type. Should be .text or .data\n");
}
// section
//printf("s");
}
else if (*(firstStr.end()-1) == ':')
{
std::string label(firstStr.begin(), firstStr.end() - 1);
switch (a_state.m_currentSection)
{
case SectionType::TEXT:
a_state.m_textLabelOffsetMap[label] = a_state.m_program.m_text.size();
a_state.m_textLabelOffsetMap_x86[label] = a_state.m_program.m_text_x86.size();
break;
case SectionType::RODATA:
a_state.m_dataLabelOffsetMap[label] = a_state.m_program.m_rodata.size();
break;
case SectionType::NONE:
case SectionType::DATA:
case SectionType::BSS:
case SectionType::STACK:
printf("Invalid/unhandled section type to contain label. Put in .text or .data section\n");
break;
}
// label
//printf("l");
}
else if (firstStr == "db")
{
if (a_state.m_currentSection != SectionType::RODATA)
{
printf("Invalid/unhandled section type to contain data. Put db in .data section\n");
}
std::vector<std::string> strs = tokenize(line + " ", "\"", false);
if (strs.size() == 3) // It is a db of a string (rather than list of bytes)
{
std::vector<std::string> r2(strs.begin() + 1, strs.end() - 1);
rest = join(r2, 0);
rest = expandEscapes(rest);
}
else
{
printf("lists of db bytes not supported, only strings currently supported. Strings with double quotes also not supported.\n");
}
AddConstantString(a_state.m_program, rest.c_str());
// data
//printf("d");
}
else
{
if (a_state.m_currentSection != SectionType::TEXT)
{
printf("Invalid/unhandled section type to contain data. Put db in .data section\n");
}
// instruction
//printf("i");
//printf("--%s-%s--\n", firstStr.c_str(), rest.c_str());
std::vector<std::string> argStrs = tokenize(rest, ",", false);
std::vector<Arg> args;
std::string instruction = firstStr;
bool containsLabel = false;
for (std::string s : argStrs)
{
Arg arg = ArgFromString(s);
args.emplace_back(arg);
switch (arg.m_type)
{
case ArgType::Register: instruction += 'R'; break;
case ArgType::Label: instruction += 'L'; containsLabel = true; break;
case ArgType::Immediate: instruction += 'I'; break;
case ArgType::Dereference: instruction += 'D'; break;
case ArgType::Define: instruction += 'S'; break;
}
}
OpCode opCode = OpCodeFromMangledInstruction(a_state, instruction);
int regIdx = 0;
uint32_t regs[2] = { 0, 0 };
uint32_t addr = 0;
std::string label = "";
for (auto arg : args)
{
switch (arg.m_type)
{
case ArgType::Register: regs[regIdx] = arg.m_register; regIdx++; break;
case ArgType::Label: label = arg.m_label; break;
case ArgType::Immediate: addr = arg.m_immediate; break;
case ArgType::Dereference: regs[regIdx] = arg.m_register; regIdx++; break;
case ArgType::Define: addr = arg.m_immediate; break;
}
}
if (containsLabel)
{
uint32_t pos_x86 = a_state.m_program.m_text_x86.size();
// Add dummy placeholder instruction until we do the fixup
// uint32_t pos = AddInstruction(a_state.m_program, OpCode::Nop);
// Because the x86 instructions are not all a fixed length, we need to put in a placeholder instruction of the same length:
uint32_t pos = AddInstruction(a_state.m_program, opCode, regs[0], regs[1], 0); // location/address of zero for now until fixup is done
// save the fixup to be done with all required information to apply it (when we find the forward declared label's actual code position)
Fixup fixup{ pos, pos_x86, opCode, regs[0], regs[1], label };
a_state.m_fixups.emplace_back(fixup);
}
else
{
AddInstruction(a_state.m_program, opCode, regs[0], regs[1], addr);
}
//printf("--%s-%s--\n", instruction.c_str(), rest.c_str());
}
}
}
}
// apply fixups
for (const Fixup& fixup : a_state.m_fixups)
{
uint32_t location = 0;
// Do fixup for the VM text
if (a_state.m_textLabelOffsetMap.count(fixup.m_forwardLabel))
location = a_state.m_textLabelOffsetMap[fixup.m_forwardLabel];
else if (a_state.m_dataLabelOffsetMap.count(fixup.m_forwardLabel))
location = a_state.m_dataLabelOffsetMap[fixup.m_forwardLabel];
else
printf("unable to find label: %s\n", fixup.m_forwardLabel.c_str());
a_state.m_program.m_text[fixup.m_position] = CreateInstruction(fixup.m_opCode, fixup.m_regA, fixup.m_regB, location);
// Do fixup again for the x86 text
if (a_state.m_textLabelOffsetMap_x86.count(fixup.m_forwardLabel))
location = a_state.m_textLabelOffsetMap[fixup.m_forwardLabel];
else if (a_state.m_dataLabelOffsetMap.count(fixup.m_forwardLabel))
location = a_state.m_dataLabelOffsetMap[fixup.m_forwardLabel];
else
printf("unable to find label: %s\n", fixup.m_forwardLabel.c_str());
auto bytes = MapInstruction_x86(fixup.m_opCode, fixup.m_regA, fixup.m_regB, location);
size_t pos = fixup.m_position_x86;
for (auto byte : bytes)
{
a_state.m_program.m_text_x86[pos] = byte;
++pos;
}
}
}
int main(int argc, char* argv[])
{
if (argc < 3)
{
printf("bad number of arguments\n");
return -1;
}
else
{
std::ifstream input(argv[2]);
if (input.is_open())
{
printf("output file already exists\n");
return -1;
}
}
std::ifstream input(argv[1]);
if (!input.is_open())
{
printf("couldn't open input file\n");
return -1;
}
std::ofstream output(argv[2]);
if (!output.is_open())
{
printf("output file couldn't be created\n");
return -1;
}
AssemblerState state = InitializedAssembler();
BuildX86CodeMap();
#if TEST
AssembleProgramTest(state.m_program);
#else
ProcessInput(state, input);
#endif
UpdateProgramHeader(state.m_program);
WriteProgram(state.m_program, output);
return 0;
}