import op, traceback, re, proc from copy import copy proc_module = proc class CrossJump(Exception): pass def parse_bin(s): b = s.group(1) v = hex(int(b, 2)) #print "BINARY: %s -> %s" %(b, v) return v class cpp: def __init__(self, context, namespace, skip_first = 0, blacklist = []): self.namespace = namespace fname = namespace.lower() + ".cpp" header = namespace.lower() + ".h" banner = "/* PLEASE DO NOT MODIFY THIS FILE. ALL CHANGES WILL BE LOST! LOOK FOR README FOR DETAILS */" self.fd = open(fname, "wt") self.hd = open(header, "wt") hid = "TASMRECOVER_%s_STUBS_H__" %namespace.upper() self.hd.write("""#ifndef %s #define %s %s """ %(hid, hid, banner)) self.context = context self.data_seg = context.binary_data self.procs = context.proc_list self.skip_first = skip_first self.proc_queue = [] self.proc_done = [] self.blacklist = blacklist self.failed = list(blacklist) self.translated = [] self.proc_addr = [] self.methods = [] self.fd.write("""%s #include \"%s\" namespace %s { """ %(banner, header, namespace)) def expand_cb(self, match): name = match.group(0).lower() if len(name) == 2 and \ ((name[0] in ['a', 'b', 'c', 'd'] and name[1] in ['h', 'x', 'l']) or name in ['si', 'di', 'es', 'ds', 'cs']): return "%s" %name if self.indirection == -1: try: offset,p,p = self.context.get_offset(name) print "OFFSET = %d" %offset self.indirection = 0 return str(offset) except: pass g = self.context.get_global(name) if isinstance(g, op.const): value = self.expand_equ(g.value) print "equ: %s -> %s" %(name, value) elif isinstance(g, proc.proc): if self.indirection != -1: raise Exception("invalid proc label usage") value = str(g.offset) self.indirection = 0 else: size = g.size if size == 0: raise Exception("invalid var '%s' size %u" %(name, size)) if self.indirection == 0: value = "data.%s(k%s)" %("byte" if size == 1 else "word", name.capitalize()) elif self.indirection == -1: value = "%s" %g.offset self.indirection = 0 else: raise Exception("invalid indirection %d" %self.indirection) return value def get_size(self, expr): #print 'get_size("%s")' %expr try: v = self.context.parse_int(expr) return 1 if v < 256 else 2 except: pass if re.match(r'byte\s+ptr\s', expr) is not None: return 1 if re.match(r'word\s+ptr\s', expr) is not None: return 2 if len(expr) == 2 and expr[0] in ['a', 'b', 'c', 'd'] and expr[1] in ['h', 'l']: return 1 if expr in ['ax', 'bx', 'cx', 'dx', 'si', 'di', 'sp', 'bp', 'ds', 'cs', 'es', 'fs']: return 2 m = re.match(r'[a-zA-Z_]\w*', expr) if m is not None: name = m.group(0) try: g = self.context.get_global(name) return g.size except: pass return 0 def expand_equ_cb(self, match): name = match.group(0).lower() g = self.context.get_global(name) if isinstance(g, op.const): return g.value return str(g.offset) def expand_equ(self, expr): n = 1 while n > 0: expr, n = re.subn(r'\b[a-zA-Z_][a-zA-Z0-9_]+\b', self.expand_equ_cb, expr) expr = re.sub(r'\b([0-9][a-fA-F0-9]*)h', '0x\\1', expr) return "(%s)" %expr def expand(self, expr, def_size = 0): #print "EXPAND \"%s\"" %expr size = self.get_size(expr) if def_size == 0 else def_size indirection = 0 seg = None reg = True m = re.match(r'seg\s+(.*?)$', expr) if m is not None: return "data" match_id = True m = re.match(r'offset\s+(.*?)$', expr) if m is not None: indirection -= 1 expr = m.group(1).strip() m = re.match(r'byte\s+ptr\s+(.*?)$', expr) if m is not None: expr = m.group(1).strip() m = re.match(r'word\s+ptr\s+(.*?)$', expr) if m is not None: expr = m.group(1).strip() m = re.match(r'\[(.*)\]$', expr) if m is not None: indirection += 1 expr = m.group(1).strip() m = re.match(r'(\w{2,2}):(.*)$', expr) if m is not None: seg_prefix = m.group(1) expr = m.group(2).strip() print "SEGMENT %s, remains: %s" %(seg_prefix, expr) else: seg_prefix = "ds" m = re.match(r'(([abcd][xhl])|si|di|bp|sp)([\+-].*)?$', expr) if m is not None: reg = m.group(1) plus = m.group(3) if plus is not None: plus = self.expand(plus) else: plus = "" match_id = False #print "COMMON_REG: ", reg, plus expr = "%s%s" %(reg, plus) expr = re.sub(r'\b([0-9][a-fA-F0-9]*)h', '0x\\1', expr) expr = re.sub(r'\b([0-1]+)b', parse_bin, expr) expr = re.sub(r'"(.)"', '\'\\1\'', expr) if match_id: #print "BEFORE: %d" %indirection self.indirection = indirection expr = re.sub(r'\b[a-zA-Z_][a-zA-Z0-9_]+\b', self.expand_cb, expr) indirection = self.indirection #print "AFTER: %d" %indirection if indirection == 1: if size == 1: expr = "%s.byte(%s)" %(seg_prefix, expr) elif size == 2: expr = "%s.word(%s)" %(seg_prefix, expr) else: expr = "@invalid size 0" elif indirection == 0: pass elif indirection == -1: expr = "&%s" %expr else: raise Exception("invalid indirection %d" %indirection) return expr def mangle_label(self, name): name = name.lower() return re.sub(r'\$', '_tmp', name) def resolve_label(self, name): name = name.lower() if not name in self.proc.labels: try: offset, proc, pos = self.context.get_offset(name) except: print "no label %s, trying procedure" %name proc = self.context.get_global(name) pos = 0 if not isinstance(proc, proc_module.proc): raise CrossJump("cross-procedure jump to non label and non procedure %s" %(name)) self.proc.labels.add(name) for i in xrange(0, len(self.unbounded)): u = self.unbounded[i] if u[1] == proc: if pos < u[2]: self.unbounded[i] = (name, proc, pos) return self.mangle_label(name) self.unbounded.append((name, proc, pos)) return self.mangle_label(name) def jump_to_label(self, name): jump_proc = False if name in self.blacklist: jump_proc = True if self.context.has_global(name) : g = self.context.get_global(name) if isinstance(g, proc_module.proc): jump_proc = True if jump_proc: return "{ %s(); return; }" %name else: # TODO: name or self.resolve_label(name) or self.mangle_label(name)?? if name in self.proc.retlabels: return "return /* (%s) */" % (name) return "goto %s" %self.resolve_label(name) def _label(self, name): self.body += "%s:\n" %self.mangle_label(name) def schedule(self, name): name = name.lower() if name in self.proc_queue or name in self.proc_done or name in self.failed: return print "+scheduling function %s..." %name self.proc_queue.append(name) def _call(self, name): name = name.lower() if name == 'ax': self.body += "\t__dispatch_call(%s);\n" %self.expand('ax', 2) return self.body += "\t%s();\n" %name self.schedule(name) def _ret(self): self.body += "\treturn;\n" def parse2(self, dst, src): dst_size, src_size = self.get_size(dst), self.get_size(src) if dst_size == 0: if src_size == 0: raise Exception("both sizes are 0") dst_size = src_size if src_size == 0: src_size = dst_size dst = self.expand(dst, dst_size) src = self.expand(src, src_size) return dst, src def _mov(self, dst, src): self.body += "\t%s = %s;\n" %self.parse2(dst, src) def _add(self, dst, src): self.body += "\t_add(%s, %s);\n" %self.parse2(dst, src) def _sub(self, dst, src): self.body += "\t_sub(%s, %s);\n" %self.parse2(dst, src) def _and(self, dst, src): self.body += "\t_and(%s, %s);\n" %self.parse2(dst, src) def _or(self, dst, src): self.body += "\t_or(%s, %s);\n" %self.parse2(dst, src) def _xor(self, dst, src): self.body += "\t_xor(%s, %s);\n" %self.parse2(dst, src) def _neg(self, dst): dst = self.expand(dst) self.body += "\t_neg(%s);\n" %(dst) def _cbw(self): self.body += "\tax.cbw();\n" def _shr(self, dst, src): self.body += "\t_shr(%s, %s);\n" %self.parse2(dst, src) def _shl(self, dst, src): self.body += "\t_shl(%s, %s);\n" %self.parse2(dst, src) #def _sar(self, dst, src): # self.body += "\t_sar(%s%s);\n" %self.parse2(dst, src) #def _sal(self, dst, src): # self.body += "\t_sal(%s, %s);\n" %self.parse2(dst, src) #def _rcl(self, dst, src): # self.body += "\t_rcl(%s, %s);\n" %self.parse2(dst, src) #def _rcr(self, dst, src): # self.body += "\t_rcr(%s, %s);\n" %self.parse2(dst, src) def _mul(self, src): src = self.expand(src) self.body += "\t_mul(%s);\n" %(src) def _div(self, src): src = self.expand(src) self.body += "\t_div(%s);\n" %(src) def _inc(self, dst): dst = self.expand(dst) self.body += "\t_inc(%s);\n" %(dst) def _dec(self, dst): dst = self.expand(dst) self.body += "\t_dec(%s);\n" %(dst) def _cmp(self, a, b): self.body += "\t_cmp(%s, %s);\n" %self.parse2(a, b) def _test(self, a, b): self.body += "\t_test(%s, %s);\n" %self.parse2(a, b) def _js(self, label): self.body += "\tif (flags.s())\n\t\t%s;\n" %(self.jump_to_label(label)) def _jns(self, label): self.body += "\tif (!flags.s())\n\t\t%s;\n" %(self.jump_to_label(label)) def _jz(self, label): self.body += "\tif (flags.z())\n\t\t%s;\n" %(self.jump_to_label(label)) def _jnz(self, label): self.body += "\tif (!flags.z())\n\t\t%s;\n" %(self.jump_to_label(label)) def _jl(self, label): self.body += "\tif (flags.l())\n\t\t%s;\n" %(self.jump_to_label(label)) def _jg(self, label): self.body += "\tif (!flags.le())\n\t\t%s;\n" %(self.jump_to_label(label)) def _jle(self, label): self.body += "\tif (flags.le())\n\t\t%s;\n" %(self.jump_to_label(label)) def _jge(self, label): self.body += "\tif (!flags.l())\n\t\t%s;\n" %(self.jump_to_label(label)) def _jc(self, label): self.body += "\tif (flags.c())\n\t\t%s;\n" %(self.jump_to_label(label)) def _jnc(self, label): self.body += "\tif (!flags.c())\n\t\t%s;\n" %(self.jump_to_label(label)) def _xchg(self, dst, src): self.body += "\t_xchg(%s, %s);\n" %self.parse2(dst, src) def _jmp(self, label): self.body += "\t%s;\n" %(self.jump_to_label(label)) def _loop(self, label): self.body += "\tif (--cx)\n\t\t%s;\n" %self.jump_to_label(label) def _push(self, regs): p = str(); for r in regs: r = self.expand(r) p += "\tpush(%s);\n" %(r) self.body += p def _pop(self, regs): p = str(); for r in regs: self.temps_count -= 1 i = self.temps_count r = self.expand(r) p += "\t%s = pop();\n" %r self.body += p def _rep(self): self.body += "\twhile(cx--)\n\t" def _lodsb(self): self.body += "\t_lodsb();\n" def _lodsw(self): self.body += "\t_lodsw();\n" def _stosb(self, n, clear_cx): self.body += "\t_stosb(%s%s);\n" %("" if n == 1 else n, ", true" if clear_cx else "") def _stosw(self, n, clear_cx): self.body += "\t_stosw(%s%s);\n" %("" if n == 1 else n, ", true" if clear_cx else "") def _movsb(self, n, clear_cx): self.body += "\t_movsb(%s%s);\n" %("" if n == 1 else n, ", true" if clear_cx else "") def _movsw(self, n, clear_cx): self.body += "\t_movsw(%s%s);\n" %("" if n == 1 else n, ", true" if clear_cx else "") def _stc(self): self.body += "\tflags._c = true;\n " def _clc(self): self.body += "\tflags._c = false;\n " def __proc(self, name, def_skip = 0): try: skip = def_skip self.temps_count = 0 self.temps_max = 0 if self.context.has_global(name): self.proc = self.context.get_global(name) else: print "No procedure named %s, trying label" %name off, src_proc, skip = self.context.get_offset(name) self.proc = proc_module.proc(name) self.proc.stmts = copy(src_proc.stmts) self.proc.labels = copy(src_proc.labels) self.proc.retlabels = copy(src_proc.retlabels) #for p in xrange(skip, len(self.proc.stmts)): # s = self.proc.stmts[p] # if isinstance(s, op.basejmp): # o, p, s = self.context.get_offset(s.label) # if p == src_proc and s < skip: # skip = s self.proc_addr.append((name, self.proc.offset)) self.body = str() self.body += "void %sContext::%s() {\n\tSTACK_CHECK;\n" %(self.namespace, name); self.proc.optimize() self.unbounded = [] self.proc.visit(self, skip) #adding remaining labels: for i in xrange(0, len(self.unbounded)): u = self.unbounded[i] print "UNBOUNDED: ", u proc = u[1] for p in xrange(u[2], len(proc.stmts)): s = proc.stmts[p] if isinstance(s, op.basejmp): self.resolve_label(s.label) #adding statements #BIG FIXME: this is quite ugly to handle code analysis from the code generation. rewrite me! for label, proc, offset in self.unbounded: self.body += "\treturn;\n" #we need to return before calling code from the other proc self.body += "/*continuing to unbounded code: %s from %s:%d-%d*/\n" %(label, proc.name, offset, len(proc.stmts)) start = len(self.proc.stmts) self.proc.add_label(label) for s in proc.stmts[offset:]: if isinstance(s, op.label): self.proc.labels.add(s.name) self.proc.stmts.append(s) self.proc.add("ret") print "skipping %d instructions, todo: %d" %(start, len(self.proc.stmts) - start) print "re-optimizing..." self.proc.optimize(keep_labels=[label]) self.proc.visit(self, start) self.body += "}\n"; self.translated.insert(0, self.body) self.proc = None if self.temps_count > 0: raise Exception("temps count == %d at the exit of proc" %self.temps_count); return True except (CrossJump, op.Unsupported) as e: print "%s: ERROR: %s" %(name, e) self.failed.append(name) except: raise def get_type(self, width): return "uint%d_t" %(width * 8) def write_stubs(self, fname, procs): fd = open(fname, "wt") fd.write("namespace %s {\n" %self.namespace) for p in procs: fd.write("void %sContext::%s() {\n\t::error(\"%s\");\n}\n\n" %(self.namespace, p, p)) fd.write("} /*namespace %s */\n" %self.namespace) fd.close() def generate(self, start): #print self.prologue() #print context self.proc_queue.append(start) while len(self.proc_queue): name = self.proc_queue.pop() if name in self.failed or name in self.proc_done: continue if len(self.proc_queue) == 0 and len(self.procs) > 0: print "queue's empty, adding remaining procs:" for p in self.procs: self.schedule(p) self.procs = [] print "continuing on %s" %name self.proc_done.append(name) self.__proc(name) self.methods.append(name) self.write_stubs("_stubs.cpp", self.failed) self.methods += self.failed done, failed = len(self.proc_done), len(self.failed) self.fd.write("\n") self.fd.write("\n".join(self.translated)) self.fd.write("\n\n") print "%d ok, %d failed of %d, %.02g%% translated" %(done, failed, done + failed, 100.0 * done / (done + failed)) print "\n".join(self.failed) data_bin = self.data_seg data_impl = "\n\tstatic const uint8 src[] = {\n\t\t" n = 0 for v in data_bin: data_impl += "0x%02x, " %v n += 1 if (n & 0xf) == 0: data_impl += "\n\t\t" data_impl += "};\n\tds.assign(src, src + sizeof(src));\n" self.hd.write( """\n#include "dreamweb/runtime.h" namespace %s { class %sContext : public Context { public: void __start(); void __dispatch_call(uint16 addr); """ %(self.namespace, self.namespace)) offsets = [] for k, v in self.context.get_globals().items(): if isinstance(v, op.var): offsets.append((k.capitalize(), v.offset)) elif isinstance(v, op.const): offsets.append((k.capitalize(), self.expand_equ(v.value))) #fixme: try to save all constants here offsets = sorted(offsets, key=lambda t: t[1]) for o in offsets: self.hd.write("\tconst static uint16 k%s = %s;\n" %o) self.hd.write("\n") for p in set(self.methods): self.hd.write("\tvoid %s();\n" %p) self.hd.write("};\n}\n\n#endif\n") self.hd.close() self.fd.write("\nvoid %sContext::__start() { %s%s(); \n}\n" %(self.namespace, data_impl, start)) self.fd.write("\nvoid %sContext::__dispatch_call(uint16 addr) {\n\tswitch(addr) {\n" %self.namespace) self.proc_addr.sort(cmp = lambda x, y: x[1] - y[1]) for name,addr in self.proc_addr: self.fd.write("\t\tcase 0x%04x: %s(); break;\n" %(addr, name)) self.fd.write("\t\tdefault: ::error(\"invalid call to %04x dispatched\", (uint16)ax);") self.fd.write("\n\t}\n}\n\n} /*namespace*/\n") self.fd.close()