aboutsummaryrefslogtreecommitdiff
path: root/devtools/tasmrecover/tasm
diff options
context:
space:
mode:
Diffstat (limited to 'devtools/tasmrecover/tasm')
-rw-r--r--devtools/tasmrecover/tasm/__init__.py0
-rw-r--r--devtools/tasmrecover/tasm/cpp.py581
-rw-r--r--devtools/tasmrecover/tasm/lex.py52
-rw-r--r--devtools/tasmrecover/tasm/op.py410
-rw-r--r--devtools/tasmrecover/tasm/parser.py261
-rw-r--r--devtools/tasmrecover/tasm/proc.py171
6 files changed, 1475 insertions, 0 deletions
diff --git a/devtools/tasmrecover/tasm/__init__.py b/devtools/tasmrecover/tasm/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/devtools/tasmrecover/tasm/__init__.py
diff --git a/devtools/tasmrecover/tasm/cpp.py b/devtools/tasmrecover/tasm/cpp.py
new file mode 100644
index 0000000000..dfdfb239f4
--- /dev/null
+++ b/devtools/tasmrecover/tasm/cpp.py
@@ -0,0 +1,581 @@
+import op, traceback, re, proc
+from copy import copy
+proc_module = proc
+
+class CrossJump(Exception):
+ pass
+
+def parse_bin(s):
+ b = s.group(1)
+ v = hex(int(b, 2))
+ #print "BINARY: %s -> %s" %(b, v)
+ return v
+
+class cpp:
+ def __init__(self, context, namespace, skip_first = 0, blacklist = []):
+ self.namespace = namespace
+ fname = namespace.lower() + ".cpp"
+ header = namespace.lower() + ".h"
+ banner = "/* PLEASE DO NOT MODIFY THIS FILE. ALL CHANGES WILL BE LOST! LOOK FOR README FOR DETAILS */"
+ self.fd = open(fname, "wt")
+ self.hd = open(header, "wt")
+ hid = "TASMRECOVER_%s_STUBS_H__" %namespace.upper()
+ self.hd.write("""#ifndef %s
+#define %s
+
+%s
+
+""" %(hid, hid, banner))
+ self.context = context
+ self.data_seg = context.binary_data
+ self.procs = context.proc_list
+ self.skip_first = skip_first
+ self.proc_queue = []
+ self.proc_done = []
+ self.blacklist = blacklist
+ self.failed = list(blacklist)
+ self.translated = []
+ self.proc_addr = []
+ self.methods = []
+ self.fd.write("""%s
+
+#include \"%s\"
+
+namespace %s {
+""" %(banner, header, namespace))
+
+ def expand_cb(self, match):
+ name = match.group(0).lower()
+ if len(name) == 2 and \
+ ((name[0] in ['a', 'b', 'c', 'd'] and name[1] in ['h', 'x', 'l']) or name in ['si', 'di', 'es', 'ds', 'cs']):
+ return "%s" %name
+
+ if self.indirection == -1:
+ try:
+ offset,p,p = self.context.get_offset(name)
+ print "OFFSET = %d" %offset
+ self.indirection = 0
+ return str(offset)
+ except:
+ pass
+
+ g = self.context.get_global(name)
+ if isinstance(g, op.const):
+ value = self.expand_equ(g.value)
+ print "equ: %s -> %s" %(name, value)
+ elif isinstance(g, proc.proc):
+ if self.indirection != -1:
+ raise Exception("invalid proc label usage")
+ value = str(g.offset)
+ self.indirection = 0
+ else:
+ size = g.size
+ if size == 0:
+ raise Exception("invalid var '%s' size %u" %(name, size))
+ if self.indirection == 0:
+ value = "data.%s(k%s)" %("byte" if size == 1 else "word", name.capitalize())
+ elif self.indirection == -1:
+ value = "%s" %g.offset
+ self.indirection = 0
+ else:
+ raise Exception("invalid indirection %d" %self.indirection)
+ return value
+
+ def get_size(self, expr):
+ #print 'get_size("%s")' %expr
+ try:
+ v = self.context.parse_int(expr)
+ return 1 if v < 256 else 2
+ except:
+ pass
+
+ if re.match(r'byte\s+ptr\s', expr) is not None:
+ return 1
+
+ if re.match(r'word\s+ptr\s', expr) is not None:
+ return 2
+
+ if len(expr) == 2 and expr[0] in ['a', 'b', 'c', 'd'] and expr[1] in ['h', 'l']:
+ return 1
+ if expr in ['ax', 'bx', 'cx', 'dx', 'si', 'di', 'sp', 'bp', 'ds', 'cs', 'es', 'fs']:
+ return 2
+
+ m = re.match(r'[a-zA-Z_]\w*', expr)
+ if m is not None:
+ name = m.group(0)
+ try:
+ g = self.context.get_global(name)
+ return g.size
+ except:
+ pass
+
+ return 0
+
+ def expand_equ_cb(self, match):
+ name = match.group(0).lower()
+ g = self.context.get_global(name)
+ if isinstance(g, op.const):
+ return g.value
+ return str(g.offset)
+
+ def expand_equ(self, expr):
+ n = 1
+ while n > 0:
+ expr, n = re.subn(r'\b[a-zA-Z_][a-zA-Z0-9_]+\b', self.expand_equ_cb, expr)
+ expr = re.sub(r'\b([0-9][a-fA-F0-9]*)h', '0x\\1', expr)
+ return "(%s)" %expr
+
+ def expand(self, expr, def_size = 0):
+ #print "EXPAND \"%s\"" %expr
+ size = self.get_size(expr) if def_size == 0 else def_size
+ indirection = 0
+ seg = None
+ reg = True
+
+ m = re.match(r'seg\s+(.*?)$', expr)
+ if m is not None:
+ return "data"
+
+ match_id = True
+ m = re.match(r'offset\s+(.*?)$', expr)
+ if m is not None:
+ indirection -= 1
+ expr = m.group(1).strip()
+
+ m = re.match(r'byte\s+ptr\s+(.*?)$', expr)
+ if m is not None:
+ expr = m.group(1).strip()
+
+ m = re.match(r'word\s+ptr\s+(.*?)$', expr)
+ if m is not None:
+ expr = m.group(1).strip()
+
+ m = re.match(r'\[(.*)\]$', expr)
+ if m is not None:
+ indirection += 1
+ expr = m.group(1).strip()
+
+ m = re.match(r'(\w{2,2}):(.*)$', expr)
+ if m is not None:
+ seg_prefix = m.group(1)
+ expr = m.group(2).strip()
+ print "SEGMENT %s, remains: %s" %(seg_prefix, expr)
+ else:
+ seg_prefix = "ds"
+
+ m = re.match(r'(([abcd][xhl])|si|di|bp|sp)([\+-].*)?$', expr)
+ if m is not None:
+ reg = m.group(1)
+ plus = m.group(3)
+ if plus is not None:
+ plus = self.expand(plus)
+ else:
+ plus = ""
+ match_id = False
+ #print "COMMON_REG: ", reg, plus
+ expr = "%s%s" %(reg, plus)
+
+ expr = re.sub(r'\b([0-9][a-fA-F0-9]*)h', '0x\\1', expr)
+ expr = re.sub(r'\b([0-1]+)b', parse_bin, expr)
+ expr = re.sub(r'"(.)"', '\'\\1\'', expr)
+ if match_id:
+ #print "BEFORE: %d" %indirection
+ self.indirection = indirection
+ expr = re.sub(r'\b[a-zA-Z_][a-zA-Z0-9_]+\b', self.expand_cb, expr)
+ indirection = self.indirection
+ #print "AFTER: %d" %indirection
+
+ if indirection == 1:
+ if size == 1:
+ expr = "%s.byte(%s)" %(seg_prefix, expr)
+ elif size == 2:
+ expr = "%s.word(%s)" %(seg_prefix, expr)
+ else:
+ expr = "@invalid size 0"
+ elif indirection == 0:
+ pass
+ elif indirection == -1:
+ expr = "&%s" %expr
+ else:
+ raise Exception("invalid indirection %d" %indirection)
+ return expr
+
+ def mangle_label(self, name):
+ name = name.lower()
+ return re.sub(r'\$', '_tmp', name)
+
+ def resolve_label(self, name):
+ name = name.lower()
+ if not name in self.proc.labels:
+ try:
+ offset, proc, pos = self.context.get_offset(name)
+ except:
+ print "no label %s, trying procedure" %name
+ proc = self.context.get_global(name)
+ pos = 0
+ if not isinstance(proc, proc_module.proc):
+ raise CrossJump("cross-procedure jump to non label and non procedure %s" %(name))
+ self.proc.labels.add(name)
+ for i in xrange(0, len(self.unbounded)):
+ u = self.unbounded[i]
+ if u[1] == proc:
+ if pos < u[2]:
+ self.unbounded[i] = (name, proc, pos)
+ return self.mangle_label(name)
+ self.unbounded.append((name, proc, pos))
+
+ return self.mangle_label(name)
+
+ def jump_to_label(self, name):
+ jump_proc = False
+ if name in self.blacklist:
+ jump_proc = True
+
+ if self.context.has_global(name) :
+ g = self.context.get_global(name)
+ if isinstance(g, proc_module.proc):
+ jump_proc = True
+
+ if jump_proc:
+ return "{ %s(); return; }" %name
+ else:
+ # TODO: name or self.resolve_label(name) or self.mangle_label(name)??
+ if name in self.proc.retlabels:
+ return "return /* (%s) */" % (name)
+ return "goto %s" %self.resolve_label(name)
+
+ def _label(self, name):
+ self.body += "%s:\n" %self.mangle_label(name)
+
+ def schedule(self, name):
+ name = name.lower()
+ if name in self.proc_queue or name in self.proc_done or name in self.failed:
+ return
+ print "+scheduling function %s..." %name
+ self.proc_queue.append(name)
+
+ def _call(self, name):
+ name = name.lower()
+ if name == 'ax':
+ self.body += "\t__dispatch_call(%s);\n" %self.expand('ax', 2)
+ return
+ self.body += "\t%s();\n" %name
+ self.schedule(name)
+
+ def _ret(self):
+ self.body += "\treturn;\n"
+
+ def parse2(self, dst, src):
+ dst_size, src_size = self.get_size(dst), self.get_size(src)
+ if dst_size == 0:
+ if src_size == 0:
+ raise Exception("both sizes are 0")
+ dst_size = src_size
+ if src_size == 0:
+ src_size = dst_size
+
+ dst = self.expand(dst, dst_size)
+ src = self.expand(src, src_size)
+ return dst, src
+
+ def _mov(self, dst, src):
+ self.body += "\t%s = %s;\n" %self.parse2(dst, src)
+
+ def _add(self, dst, src):
+ self.body += "\t_add(%s, %s);\n" %self.parse2(dst, src)
+
+ def _sub(self, dst, src):
+ self.body += "\t_sub(%s, %s);\n" %self.parse2(dst, src)
+
+ def _and(self, dst, src):
+ self.body += "\t_and(%s, %s);\n" %self.parse2(dst, src)
+
+ def _or(self, dst, src):
+ self.body += "\t_or(%s, %s);\n" %self.parse2(dst, src)
+
+ def _xor(self, dst, src):
+ self.body += "\t_xor(%s, %s);\n" %self.parse2(dst, src)
+
+ def _neg(self, dst):
+ dst = self.expand(dst)
+ self.body += "\t_neg(%s);\n" %(dst)
+
+ def _cbw(self):
+ self.body += "\tax.cbw();\n"
+
+ def _shr(self, dst, src):
+ self.body += "\t_shr(%s, %s);\n" %self.parse2(dst, src)
+
+ def _shl(self, dst, src):
+ self.body += "\t_shl(%s, %s);\n" %self.parse2(dst, src)
+
+ #def _sar(self, dst, src):
+ # self.body += "\t_sar(%s%s);\n" %self.parse2(dst, src)
+
+ #def _sal(self, dst, src):
+ # self.body += "\t_sal(%s, %s);\n" %self.parse2(dst, src)
+
+ #def _rcl(self, dst, src):
+ # self.body += "\t_rcl(%s, %s);\n" %self.parse2(dst, src)
+
+ #def _rcr(self, dst, src):
+ # self.body += "\t_rcr(%s, %s);\n" %self.parse2(dst, src)
+
+ def _mul(self, src):
+ src = self.expand(src)
+ self.body += "\t_mul(%s);\n" %(src)
+
+ def _div(self, src):
+ src = self.expand(src)
+ self.body += "\t_div(%s);\n" %(src)
+
+ def _inc(self, dst):
+ dst = self.expand(dst)
+ self.body += "\t_inc(%s);\n" %(dst)
+
+ def _dec(self, dst):
+ dst = self.expand(dst)
+ self.body += "\t_dec(%s);\n" %(dst)
+
+ def _cmp(self, a, b):
+ self.body += "\t_cmp(%s, %s);\n" %self.parse2(a, b)
+
+ def _test(self, a, b):
+ self.body += "\t_test(%s, %s);\n" %self.parse2(a, b)
+
+ def _js(self, label):
+ self.body += "\tif (flags.s())\n\t\t%s;\n" %(self.jump_to_label(label))
+
+ def _jns(self, label):
+ self.body += "\tif (!flags.s())\n\t\t%s;\n" %(self.jump_to_label(label))
+
+ def _jz(self, label):
+ self.body += "\tif (flags.z())\n\t\t%s;\n" %(self.jump_to_label(label))
+
+ def _jnz(self, label):
+ self.body += "\tif (!flags.z())\n\t\t%s;\n" %(self.jump_to_label(label))
+
+ def _jl(self, label):
+ self.body += "\tif (flags.l())\n\t\t%s;\n" %(self.jump_to_label(label))
+
+ def _jg(self, label):
+ self.body += "\tif (!flags.le())\n\t\t%s;\n" %(self.jump_to_label(label))
+
+ def _jle(self, label):
+ self.body += "\tif (flags.le())\n\t\t%s;\n" %(self.jump_to_label(label))
+
+ def _jge(self, label):
+ self.body += "\tif (!flags.l())\n\t\t%s;\n" %(self.jump_to_label(label))
+
+ def _jc(self, label):
+ self.body += "\tif (flags.c())\n\t\t%s;\n" %(self.jump_to_label(label))
+
+ def _jnc(self, label):
+ self.body += "\tif (!flags.c())\n\t\t%s;\n" %(self.jump_to_label(label))
+
+ def _xchg(self, dst, src):
+ self.body += "\t_xchg(%s, %s);\n" %self.parse2(dst, src)
+
+ def _jmp(self, label):
+ self.body += "\t%s;\n" %(self.jump_to_label(label))
+
+ def _loop(self, label):
+ self.body += "\tif (--cx)\n\t\t%s;\n" %self.jump_to_label(label)
+
+ def _push(self, regs):
+ p = str();
+ for r in regs:
+ r = self.expand(r)
+ p += "\tpush(%s);\n" %(r)
+ self.body += p
+
+ def _pop(self, regs):
+ p = str();
+ for r in regs:
+ self.temps_count -= 1
+ i = self.temps_count
+ r = self.expand(r)
+ p += "\t%s = pop();\n" %r
+ self.body += p
+
+ def _rep(self):
+ self.body += "\twhile(cx--)\n\t"
+
+ def _lodsb(self):
+ self.body += "\t_lodsb();\n"
+
+ def _lodsw(self):
+ self.body += "\t_lodsw();\n"
+
+ def _stosb(self, n, clear_cx):
+ self.body += "\t_stosb(%s%s);\n" %("" if n == 1 else n, ", true" if clear_cx else "")
+
+ def _stosw(self, n, clear_cx):
+ self.body += "\t_stosw(%s%s);\n" %("" if n == 1 else n, ", true" if clear_cx else "")
+
+ def _movsb(self, n, clear_cx):
+ self.body += "\t_movsb(%s%s);\n" %("" if n == 1 else n, ", true" if clear_cx else "")
+
+ def _movsw(self, n, clear_cx):
+ self.body += "\t_movsw(%s%s);\n" %("" if n == 1 else n, ", true" if clear_cx else "")
+
+ def _stc(self):
+ self.body += "\tflags._c = true;\n "
+
+ def _clc(self):
+ self.body += "\tflags._c = false;\n "
+
+ def __proc(self, name, def_skip = 0):
+ try:
+ skip = def_skip
+ self.temps_count = 0
+ self.temps_max = 0
+ if self.context.has_global(name):
+ self.proc = self.context.get_global(name)
+ else:
+ print "No procedure named %s, trying label" %name
+ off, src_proc, skip = self.context.get_offset(name)
+
+ self.proc = proc_module.proc(name)
+ self.proc.stmts = copy(src_proc.stmts)
+ self.proc.labels = copy(src_proc.labels)
+ self.proc.retlabels = copy(src_proc.retlabels)
+ #for p in xrange(skip, len(self.proc.stmts)):
+ # s = self.proc.stmts[p]
+ # if isinstance(s, op.basejmp):
+ # o, p, s = self.context.get_offset(s.label)
+ # if p == src_proc and s < skip:
+ # skip = s
+
+
+ self.proc_addr.append((name, self.proc.offset))
+ self.body = str()
+ self.body += "void %sContext::%s() {\n\tSTACK_CHECK;\n" %(self.namespace, name);
+ self.proc.optimize()
+ self.unbounded = []
+ self.proc.visit(self, skip)
+
+ #adding remaining labels:
+ for i in xrange(0, len(self.unbounded)):
+ u = self.unbounded[i]
+ print "UNBOUNDED: ", u
+ proc = u[1]
+ for p in xrange(u[2], len(proc.stmts)):
+ s = proc.stmts[p]
+ if isinstance(s, op.basejmp):
+ self.resolve_label(s.label)
+
+ #adding statements
+ #BIG FIXME: this is quite ugly to handle code analysis from the code generation. rewrite me!
+ for label, proc, offset in self.unbounded:
+ self.body += "\treturn;\n" #we need to return before calling code from the other proc
+ self.body += "/*continuing to unbounded code: %s from %s:%d-%d*/\n" %(label, proc.name, offset, len(proc.stmts))
+ start = len(self.proc.stmts)
+ self.proc.add_label(label)
+ for s in proc.stmts[offset:]:
+ if isinstance(s, op.label):
+ self.proc.labels.add(s.name)
+ self.proc.stmts.append(s)
+ self.proc.add("ret")
+ print "skipping %d instructions, todo: %d" %(start, len(self.proc.stmts) - start)
+ print "re-optimizing..."
+ self.proc.optimize(keep_labels=[label])
+ self.proc.visit(self, start)
+ self.body += "}\n";
+ self.translated.insert(0, self.body)
+ self.proc = None
+ if self.temps_count > 0:
+ raise Exception("temps count == %d at the exit of proc" %self.temps_count);
+ return True
+ except (CrossJump, op.Unsupported) as e:
+ print "%s: ERROR: %s" %(name, e)
+ self.failed.append(name)
+ except:
+ raise
+
+ def get_type(self, width):
+ return "uint%d_t" %(width * 8)
+
+ def write_stubs(self, fname, procs):
+ fd = open(fname, "wt")
+ fd.write("namespace %s {\n" %self.namespace)
+ for p in procs:
+ fd.write("void %sContext::%s() {\n\t::error(\"%s\");\n}\n\n" %(self.namespace, p, p))
+ fd.write("} /*namespace %s */\n" %self.namespace)
+ fd.close()
+
+
+ def generate(self, start):
+ #print self.prologue()
+ #print context
+ self.proc_queue.append(start)
+ while len(self.proc_queue):
+ name = self.proc_queue.pop()
+ if name in self.failed or name in self.proc_done:
+ continue
+ if len(self.proc_queue) == 0 and len(self.procs) > 0:
+ print "queue's empty, adding remaining procs:"
+ for p in self.procs:
+ self.schedule(p)
+ self.procs = []
+ print "continuing on %s" %name
+ self.proc_done.append(name)
+ self.__proc(name)
+ self.methods.append(name)
+ self.write_stubs("_stubs.cpp", self.failed)
+ self.methods += self.failed
+ done, failed = len(self.proc_done), len(self.failed)
+
+ self.fd.write("\n")
+ self.fd.write("\n".join(self.translated))
+ self.fd.write("\n\n")
+ print "%d ok, %d failed of %d, %.02g%% translated" %(done, failed, done + failed, 100.0 * done / (done + failed))
+ print "\n".join(self.failed)
+ data_bin = self.data_seg
+ data_impl = "\n\tstatic const uint8 src[] = {\n\t\t"
+ n = 0
+ for v in data_bin:
+ data_impl += "0x%02x, " %v
+ n += 1
+ if (n & 0xf) == 0:
+ data_impl += "\n\t\t"
+ data_impl += "};\n\tds.assign(src, src + sizeof(src));\n"
+ self.hd.write(
+"""\n#include "dreamweb/runtime.h"
+
+namespace %s {
+
+class %sContext : public Context {
+public:
+ void __start();
+ void __dispatch_call(uint16 addr);
+
+"""
+%(self.namespace, self.namespace))
+ offsets = []
+ for k, v in self.context.get_globals().items():
+ if isinstance(v, op.var):
+ offsets.append((k.capitalize(), v.offset))
+ elif isinstance(v, op.const):
+ offsets.append((k.capitalize(), self.expand_equ(v.value))) #fixme: try to save all constants here
+
+ offsets = sorted(offsets, key=lambda t: t[1])
+ for o in offsets:
+ self.hd.write("\tconst static uint16 k%s = %s;\n" %o)
+ self.hd.write("\n")
+ for p in set(self.methods):
+ self.hd.write("\tvoid %s();\n" %p)
+
+ self.hd.write("};\n}\n\n#endif\n")
+ self.hd.close()
+
+ self.fd.write("\nvoid %sContext::__start() { %s%s(); \n}\n" %(self.namespace, data_impl, start))
+
+ self.fd.write("\nvoid %sContext::__dispatch_call(uint16 addr) {\n\tswitch(addr) {\n" %self.namespace)
+ self.proc_addr.sort(cmp = lambda x, y: x[1] - y[1])
+ for name,addr in self.proc_addr:
+ self.fd.write("\t\tcase 0x%04x: %s(); break;\n" %(addr, name))
+ self.fd.write("\t\tdefault: ::error(\"invalid call to %04x dispatched\", (uint16)ax);")
+ self.fd.write("\n\t}\n}\n\n} /*namespace*/\n")
+
+ self.fd.close()
diff --git a/devtools/tasmrecover/tasm/lex.py b/devtools/tasmrecover/tasm/lex.py
new file mode 100644
index 0000000000..cf7e6e19bf
--- /dev/null
+++ b/devtools/tasmrecover/tasm/lex.py
@@ -0,0 +1,52 @@
+def parse_args(text):
+ #print "parsing: [%s]" %text
+ escape = False
+ string = False
+ result = []
+ token = str()
+ value = 0;
+ for c in text:
+ #print "[%s]%s: %s: %s" %(token, c, escape, string)
+ if c == '\\':
+ escape = True
+ continue
+
+ if escape:
+ if not string:
+ raise SyntaxError("escape found in no string: %s" %text);
+
+ #print "escaping[%s]" %c
+ escape = False
+ token += c
+ continue
+
+ if string:
+ if c == '\'' or c == '"':
+ string = False
+
+ token += c
+ continue
+
+ if c == '\'' or c == '"':
+ string = True
+ token += c
+ continue
+
+ if c == ',':
+ result.append(token.strip())
+ token = str()
+ continue
+
+ if c == ';': #comment, bailing out
+ break
+
+ token += c
+ #token = token.strip()
+ if len(token):
+ result.append(token)
+ #print result
+ return result
+
+def compile(width, data):
+ print data
+ return data
diff --git a/devtools/tasmrecover/tasm/op.py b/devtools/tasmrecover/tasm/op.py
new file mode 100644
index 0000000000..10fdd8a568
--- /dev/null
+++ b/devtools/tasmrecover/tasm/op.py
@@ -0,0 +1,410 @@
+import re
+import lex
+
+class Unsupported(Exception):
+ pass
+
+class var:
+ def __init__(self, size, offset):
+ self.size = size
+ self.offset = offset
+
+class const:
+ def __init__(self, value):
+ self.value = value
+
+class reg:
+ def __init__(self, name):
+ self.name = name
+ def size(self):
+ return 2 if self.name[1] == 'x' else 1
+ def __str__(self):
+ return "<register %s>" %self.name
+
+class unref:
+ def __init__(self, exp):
+ self.exp = exp
+ def __str__(self):
+ return "<unref %s>" %self.exp
+
+class ref:
+ def __init__(self, name):
+ self.name = name
+ def __str__(self):
+ return "<ref %s>" %self.name
+
+class glob:
+ def __init__(self, name):
+ self.name = name
+ def __str__(self):
+ return "<global %s>" %self.name
+
+class segment:
+ def __init__(self, name):
+ self.name = name
+ def __str__(self):
+ return "<segment %s>" %self.name
+
+class baseop(object):
+ def parse_arg(self, arg):
+ return arg
+
+ def split(self, text):
+ a, b = lex.parse_args(text)
+ return self.parse_arg(a), self.parse_arg(b)
+ def __str__(self):
+ return str(self.__class__)
+
+class basejmp(baseop):
+ pass
+
+class _call(baseop):
+ def __init__(self, arg):
+ self.name = arg
+ def visit(self, visitor):
+ visitor._call(self.name)
+ def __str__(self):
+ return "call(%s)" %self.name
+
+class _rep(baseop):
+ def __init__(self, arg):
+ pass
+ def visit(self, visitor):
+ visitor._rep()
+
+class _mov(baseop):
+ def __init__(self, arg):
+ self.dst, self.src = self.split(arg)
+ def visit(self, visitor):
+ visitor._mov(self.dst, self.src)
+ def __str__(self):
+ return "mov(%s, %s)" %(self.dst, self.src)
+
+class _mov2(baseop):
+ def __init__(self, dst, src):
+ self.dst, self.src = dst, src
+ def visit(self, visitor):
+ visitor._mov(self.dst, self.src)
+
+class _shr(baseop):
+ def __init__(self, arg):
+ self.dst, self.src = self.split(arg)
+ def visit(self, visitor):
+ visitor._shr(self.dst, self.src)
+
+class _shl(baseop):
+ def __init__(self, arg):
+ self.dst, self.src = self.split(arg)
+ def visit(self, visitor):
+ visitor._shl(self.dst, self.src)
+
+class _ror(baseop):
+ def __init__(self, arg):
+ pass
+
+class _rol(baseop):
+ def __init__(self, arg):
+ pass
+
+class _sar(baseop):
+ def __init__(self, arg):
+ self.dst, self.src = self.split(arg)
+ def visit(self, visitor):
+ visitor._sar(self.dst, self.src)
+
+class _sal(baseop):
+ def __init__(self, arg):
+ self.dst, self.src = self.split(arg)
+ def visit(self, visitor):
+ visitor._sal(self.dst, self.src)
+
+class _rcl(baseop):
+ def __init__(self, arg):
+ self.dst, self.src = self.split(arg)
+ def visit(self, visitor):
+ visitor._rcl(self.dst, self.src)
+
+class _rcr(baseop):
+ def __init__(self, arg):
+ self.dst, self.src = self.split(arg)
+ def visit(self, visitor):
+ visitor._rcr(self.dst, self.src)
+
+class _neg(baseop):
+ def __init__(self, arg):
+ self.arg = arg
+ def visit(self, visitor):
+ visitor._neg(self.arg)
+
+class _dec(baseop):
+ def __init__(self, arg):
+ self.dst = arg
+ def visit(self, visitor):
+ visitor._dec(self.dst)
+
+class _inc(baseop):
+ def __init__(self, arg):
+ self.dst = arg
+ def visit(self, visitor):
+ visitor._inc(self.dst)
+
+class _add(baseop):
+ def __init__(self, arg):
+ self.dst, self.src = self.split(arg)
+ def visit(self, visitor):
+ visitor._add(self.dst, self.src)
+
+class _sub(baseop):
+ def __init__(self, arg):
+ self.dst, self.src = self.split(arg)
+ def visit(self, visitor):
+ visitor._sub(self.dst, self.src)
+
+class _mul(baseop):
+ def __init__(self, arg):
+ self.arg = self.parse_arg(arg)
+ def visit(self, visitor):
+ visitor._mul(self.arg)
+
+class _div(baseop):
+ def __init__(self, arg):
+ self.arg = self.parse_arg(arg)
+ def visit(self, visitor):
+ visitor._div(self.arg)
+
+class _and(baseop):
+ def __init__(self, arg):
+ self.dst, self.src = self.split(arg)
+ def visit(self, visitor):
+ visitor._and(self.dst, self.src)
+
+class _xor(baseop):
+ def __init__(self, arg):
+ self.dst, self.src = self.split(arg)
+ def visit(self, visitor):
+ visitor._xor(self.dst, self.src)
+
+class _or(baseop):
+ def __init__(self, arg):
+ self.dst, self.src = self.split(arg)
+ def visit(self, visitor):
+ visitor._or(self.dst, self.src)
+
+class _cmp(baseop):
+ def __init__(self, arg):
+ self.a, self.b = self.split(arg)
+ def visit(self, visitor):
+ visitor._cmp(self.a, self.b)
+
+class _test(baseop):
+ def __init__(self, arg):
+ self.a, self.b = self.split(arg)
+ def visit(self, visitor):
+ visitor._test(self.a, self.b)
+
+class _xchg(baseop):
+ def __init__(self, arg):
+ self.a, self.b = self.split(arg)
+ def visit(self, visitor):
+ visitor._xchg(self.a, self.b)
+
+class _jnz(basejmp):
+ def __init__(self, label):
+ self.label = label
+ def visit(self, visitor):
+ visitor._jnz(self.label)
+
+class _jz(basejmp):
+ def __init__(self, label):
+ self.label = label
+ def visit(self, visitor):
+ visitor._jz(self.label)
+
+class _jc(basejmp):
+ def __init__(self, label):
+ self.label = label
+ def visit(self, visitor):
+ visitor._jc(self.label)
+
+class _jnc(basejmp):
+ def __init__(self, label):
+ self.label = label
+ def visit(self, visitor):
+ visitor._jnc(self.label)
+
+class _js(basejmp):
+ def __init__(self, label):
+ self.label = label
+ def visit(self, visitor):
+ visitor._js(self.label)
+
+class _jns(basejmp):
+ def __init__(self, label):
+ self.label = label
+ def visit(self, visitor):
+ visitor._jns(self.label)
+
+class _jl(basejmp):
+ def __init__(self, label):
+ self.label = label
+ def visit(self, visitor):
+ visitor._jl(self.label)
+
+class _jg(basejmp):
+ def __init__(self, label):
+ self.label = label
+ def visit(self, visitor):
+ visitor._jg(self.label)
+
+class _jle(basejmp):
+ def __init__(self, label):
+ self.label = label
+ def visit(self, visitor):
+ visitor._jle(self.label)
+
+class _jge(basejmp):
+ def __init__(self, label):
+ self.label = label
+ def visit(self, visitor):
+ visitor._jge(self.label)
+
+class _jmp(basejmp):
+ def __init__(self, label):
+ self.label = label
+ def visit(self, visitor):
+ visitor._jmp(self.label)
+
+class _loop(basejmp):
+ def __init__(self, label):
+ self.label = label
+ def visit(self, visitor):
+ visitor._loop(self.label)
+
+class _push(baseop):
+ def __init__(self, arg):
+ self.regs = []
+ for r in arg.split():
+ self.regs.append(self.parse_arg(r))
+ def visit(self, visitor):
+ visitor._push(self.regs)
+
+class _pop(baseop):
+ def __init__(self, arg):
+ self.regs = []
+ for r in arg.split():
+ self.regs.append(self.parse_arg(r))
+ def visit(self, visitor):
+ visitor._pop(self.regs)
+
+class _ret(baseop):
+ def __init__(self, arg):
+ pass
+ def visit(self, visitor):
+ visitor._ret()
+
+class _lodsb(baseop):
+ def __init__(self, arg):
+ pass
+ def visit(self, visitor):
+ visitor._lodsb()
+
+class _lodsw(baseop):
+ def __init__(self, arg):
+ pass
+ def visit(self, visitor):
+ visitor._lodsw()
+
+class _stosw(baseop):
+ def __init__(self, arg):
+ self.repeat = 1
+ self.clear_cx = False
+ def visit(self, visitor):
+ visitor._stosw(self.repeat, self.clear_cx)
+
+class _stosb(baseop):
+ def __init__(self, arg):
+ self.repeat = 1
+ self.clear_cx = False
+ def visit(self, visitor):
+ visitor._stosb(self.repeat, self.clear_cx)
+
+class _movsw(baseop):
+ def __init__(self, arg):
+ self.repeat = 1
+ self.clear_cx = False
+ def visit(self, visitor):
+ visitor._movsw(self.repeat, self.clear_cx)
+
+class _movsb(baseop):
+ def __init__(self, arg):
+ self.repeat = 1
+ self.clear_cx = False
+ def visit(self, visitor):
+ visitor._movsb(self.repeat, self.clear_cx)
+
+class _in(baseop):
+ def __init__(self, arg):
+ self.arg = arg
+ def visit(self, visitor):
+ raise Unsupported("input from port: %s" %self.arg)
+
+class _out(baseop):
+ def __init__(self, arg):
+ self.arg = arg
+ def visit(self, visitor):
+ raise Unsupported("out to port: %s" %self.arg)
+
+class _cli(baseop):
+ def __init__(self, arg):
+ pass
+ def visit(self, visitor):
+ raise Unsupported("cli")
+
+class _sti(baseop):
+ def __init__(self, arg):
+ pass
+ def visit(self, visitor):
+ raise Unsupported("sli")
+
+class _int(baseop):
+ def __init__(self, arg):
+ self.arg = arg
+ def visit(self, visitor):
+ raise Unsupported("interrupt: %s" %self.arg)
+
+class _iret(baseop):
+ def __init__(self, arg):
+ pass
+ def visit(self, visitor):
+ raise Unsupported("interrupt return")
+
+class _cbw(baseop):
+ def __init__(self, arg):
+ pass
+ def visit(self, visitor):
+ visitor._cbw()
+
+class _nop(baseop):
+ def __init__(self, arg):
+ pass
+ def visit(self, visitor):
+ pass
+
+class _stc(baseop):
+ def __init__(self, arg):
+ pass
+ def visit(self, visitor):
+ visitor._stc()
+
+class _clc(baseop):
+ def __init__(self, arg):
+ pass
+ def visit(self, visitor):
+ visitor._clc()
+
+class label(baseop):
+ def __init__(self, name):
+ self.name = name
+ def visit(self, visitor):
+ visitor._label(self.name)
+
diff --git a/devtools/tasmrecover/tasm/parser.py b/devtools/tasmrecover/tasm/parser.py
new file mode 100644
index 0000000000..4cea496722
--- /dev/null
+++ b/devtools/tasmrecover/tasm/parser.py
@@ -0,0 +1,261 @@
+import os, re
+from proc import proc
+import lex
+import op
+
+class parser:
+ def __init__(self):
+ self.strip_path = 0
+ self.__globals = {}
+ self.__offsets = {}
+ self.__stack = []
+ self.proc = None
+ self.proc_list = []
+ self.binary_data = []
+
+ self.symbols = []
+ self.link_later = []
+
+ def visible(self):
+ for i in self.__stack:
+ if not i or i == 0:
+ return False
+ return True
+
+ def push_if(self, text):
+ value = self.eval(text)
+ #print "if %s -> %s" %(text, value)
+ self.__stack.append(value)
+
+ def push_else(self):
+ #print "else"
+ self.__stack[-1] = not self.__stack[-1]
+
+ def pop_if(self):
+ #print "endif"
+ return self.__stack.pop()
+
+ def set_global(self, name, value):
+ if len(name) == 0:
+ raise Exception("empty name is not allowed")
+ name = name.lower()
+ #print "adding global %s -> %s" %(name, value)
+ if self.__globals.has_key(name):
+ raise Exception("global %s was already defined", name)
+ self.__globals[name] = value
+
+ def get_global(self, name):
+ name = name.lower()
+ g = self.__globals[name]
+ g.used = True
+ return g
+
+ def get_globals(self):
+ return self.__globals
+
+ def has_global(self, name):
+ name = name.lower()
+ return self.__globals.has_key(name)
+
+ def set_offset(self, name, value):
+ if len(name) == 0:
+ raise Exception("empty name is not allowed")
+ name = name.lower()
+ #print "adding global %s -> %s" %(name, value)
+ if self.__offsets.has_key(name):
+ raise Exception("global %s was already defined", name)
+ self.__offsets[name] = value
+
+ def get_offset(self, name):
+ name = name.lower()
+ return self.__offsets[name]
+
+ def include(self, basedir, fname):
+ path = fname.split('\\')[self.strip_path:]
+ path = os.path.join(basedir, os.path.pathsep.join(path))
+ #print "including %s" %(path)
+
+ self.parse(path)
+
+ def eval(self, stmt):
+ try:
+ return self.parse_int(stmt)
+ except:
+ pass
+ value = self.__globals[stmt.lower()].value
+ return int(value)
+
+ def expr_callback(self, match):
+ name = match.group(1).lower()
+ g = self.get_global(name)
+ if isinstance(g, op.const):
+ return g.value
+ else:
+ return "0x%04x" %g.offset
+
+ def eval_expr(self, expr):
+ n = 1
+ while n > 0:
+ expr, n = re.subn(r'\b([a-zA-Z_]+[a-zA-Z0-9_]*)', self.expr_callback, expr)
+ return eval(expr)
+
+ def expand_globals(self, text):
+ return text
+
+ def fix_dollar(self, v):
+ print("$ = %d" %len(self.binary_data))
+ return re.sub(r'\$', "%d" %len(self.binary_data), v)
+
+ def parse_int(self, v):
+ if re.match(r'[01]+b$', v):
+ v = int(v[:-1], 2)
+ if re.match(r'[\+-]?[0-9a-f]+h$', v):
+ v = int(v[:-1], 16)
+ return int(v)
+
+ def compact_data(self, width, data):
+ #print "COMPACTING %d %s" %(width, data)
+ r = []
+ base = 0x100 if width == 1 else 0x10000
+ for v in data:
+ if v[0] == '"':
+ if v[-1] != '"':
+ raise Exception("invalid string %s" %v)
+ if width == 2:
+ raise Exception("string with data width more than 1") #we could allow it :)
+ for i in xrange(1, len(v) - 1):
+ r.append(ord(v[i]))
+ continue
+
+ m = re.match(r'(\w+)\s+dup\s+\((\s*\S+\s*)\)', v)
+ if m is not None:
+ #we should parse that
+ n = self.parse_int(m.group(1))
+ if m.group(2) != '?':
+ value = self.parse_int(m.group(2))
+ else:
+ value = 0
+ for i in xrange(0, n):
+ v = value
+ for b in xrange(0, width):
+ r.append(v & 0xff);
+ v >>= 8
+ continue
+
+ try:
+ v = self.parse_int(v)
+ if v < 0:
+ v += base
+ except:
+ #global name
+ print "global/expr: %s" %v
+ try:
+ g = self.get_global(v)
+ v = g.offset
+ except:
+ print "unknown address %s" %(v)
+ self.link_later.append((len(self.binary_data) + len(r), v))
+ v = 0
+
+ for b in xrange(0, width):
+ r.append(v & 0xff);
+ v >>= 8
+ #print r
+ return r
+
+ def parse(self, fname):
+# print "opening file %s..." %(fname, basedir)
+ fd = open(fname, 'rb')
+ for line in fd:
+ line = line.strip()
+ if len(line) == 0 or line[0] == ';' or line[0] == chr(0x1a):
+ continue
+
+ #print line
+ m = re.match('(\w+)\s*?:', line)
+ if m is not None:
+ line = line[len(m.group(0)):].strip()
+ if self.visible():
+ name = m.group(1)
+ if self.proc is not None:
+ self.proc.add_label(name)
+ print "offset %s -> %d" %(name, len(self.binary_data))
+ self.set_offset(name, (len(self.binary_data), self.proc, len(self.proc.stmts) if self.proc is not None else 0))
+ #print line
+
+ cmd = line.split()
+ if len(cmd) == 0:
+ continue
+
+ cmd0 = str(cmd[0])
+ if cmd0 == 'if':
+ self.push_if(cmd[1])
+ continue
+ elif cmd0 == 'else':
+ self.push_else()
+ continue
+ elif cmd0 == 'endif':
+ self.pop_if()
+ continue
+
+ if not self.visible():
+ continue
+
+ if cmd0 == 'db' or cmd0 == 'dw' or cmd0 == 'dd':
+ arg = line[len(cmd0):].strip()
+ print "%d:1: %s" %(len(self.binary_data), arg) #fixme: COPYPASTE
+ binary_width = {'b': 1, 'w': 2, 'd': 4}[cmd0[1]]
+ self.binary_data += self.compact_data(binary_width, lex.parse_args(arg))
+ continue
+ elif cmd0 == 'include':
+ self.include(os.path.dirname(fname), cmd[1])
+ continue
+ elif cmd0 == 'endp':
+ self.proc = None
+ continue
+ elif cmd0 == 'assume':
+ print "skipping: %s" %line
+ continue
+ elif cmd0 == 'rep':
+ self.proc.add(cmd0)
+ self.proc.add(" ".join(cmd[1:]))
+ continue
+
+ if len(cmd) >= 3:
+ cmd1 = cmd[1]
+ if cmd1 == 'equ':
+ v = cmd[2]
+ self.set_global(cmd0, op.const(self.fix_dollar(v)))
+ elif cmd1 == 'db' or cmd1 == 'dw' or cmd1 == 'dd':
+ binary_width = {'b': 1, 'w': 2, 'd': 4}[cmd1[1]]
+ offset = len(self.binary_data)
+ arg = line[len(cmd0):].strip()
+ arg = arg[len(cmd1):].strip()
+ print "%d: %s" %(offset, arg)
+ self.binary_data += self.compact_data(binary_width, lex.parse_args(arg))
+ self.set_global(cmd0.lower(), op.var(binary_width, offset))
+ continue
+ elif cmd1 == 'proc':
+ name = cmd0.lower()
+ self.proc = proc(name)
+ print "procedure %s, #%d" %(name, len(self.proc_list))
+ self.proc_list.append(name)
+ self.set_global(name, self.proc)
+ continue
+ if (self.proc):
+ self.proc.add(line)
+ else:
+ #print line
+ pass
+
+ fd.close()
+ return self
+
+ def link(self):
+ for addr, expr in self.link_later:
+ v = self.eval_expr(expr)
+ print "link: patching %04x -> %04x" %(addr, v)
+ while v != 0:
+ self.binary_data[addr] = v & 0xff
+ addr += 1
+ v >>= 8
diff --git a/devtools/tasmrecover/tasm/proc.py b/devtools/tasmrecover/tasm/proc.py
new file mode 100644
index 0000000000..ed7053df89
--- /dev/null
+++ b/devtools/tasmrecover/tasm/proc.py
@@ -0,0 +1,171 @@
+import re
+import op
+
+class proc:
+ last_addr = 0xc000
+
+ def __init__(self, name):
+ self.name = name
+ self.calls = []
+ self.stmts = []
+ self.labels = set()
+ self.retlabels = set()
+ self.__label_re = re.compile(r'^(\S+):(.*)$')
+ self.offset = proc.last_addr
+ proc.last_addr += 4
+
+ def add_label(self, label):
+ self.stmts.append(op.label(label))
+ self.labels.add(label)
+
+ def remove_label(self, label):
+ try:
+ self.labels.remove(label)
+ except:
+ pass
+ for i in xrange(len(self.stmts)):
+ if isinstance(self.stmts[i], op.label) and self.stmts[i].name == label:
+ self.stmts[i] = op._nop(None)
+ return
+
+ def optimize_sequence(self, cls):
+ i = 0
+ stmts = self.stmts
+ while i < len(stmts):
+ if not isinstance(stmts[i], cls):
+ i += 1
+ continue
+ if i > 0 and isinstance(stmts[i - 1], op._rep): #skip rep prefixed instructions for now
+ i += 1
+ continue
+ j = i + 1
+
+ while j < len(stmts):
+ if not isinstance(stmts[j], cls):
+ break
+ j = j + 1
+
+ n = j - i
+ if n > 1:
+ print "Eliminate consequtive storage instructions at %u-%u" %(i, j)
+ for k in range(i+1,j):
+ stmts[k] = op._nop(None)
+ stmts[i].repeat = n
+ else:
+ i = j
+
+ i = 0
+ while i < len(stmts):
+ if not isinstance(stmts[i], op._rep):
+ i += 1
+ continue
+ if i + 1 >= len(stmts):
+ break
+ if isinstance(stmts[i + 1], cls):
+ stmts[i + 1].repeat = 'cx'
+ stmts[i + 1].clear_cx = True
+ stmts[i] = op._nop(None)
+ i += 1
+ return
+
+ def optimize(self, keep_labels=[]):
+ print "optimizing..."
+ #trivial simplifications
+ while len(self.stmts) and isinstance(self.stmts[-1], op.label):
+ print "stripping last label"
+ self.stmts.pop()
+ #mark labels that directly precede a ret
+ for i in range(len(self.stmts)):
+ if not isinstance(self.stmts[i], op.label):
+ continue
+ j = i
+ while j < len(self.stmts) and isinstance(self.stmts[j], (op.label, op._nop)):
+ j += 1
+ if j == len(self.stmts) or isinstance(self.stmts[j], op._ret):
+ print "Return label: %s" % (self.stmts[i].name,)
+ self.retlabels.add(self.stmts[i].name)
+ #merging push ax pop bx constructs
+ i = 0
+ while i + 1 < len(self.stmts):
+ a, b = self.stmts[i], self.stmts[i + 1]
+ if isinstance(a, op._push) and isinstance(b, op._pop):
+ ar, br = a.regs, b.regs
+ movs = []
+ while len(ar) and len(br):
+ src = ar.pop()
+ dst = br.pop(0)
+ movs.append(op._mov2(dst, src))
+ if len(br) == 0:
+ self.stmts.pop(i + 1)
+ print "merging %d push-pops into movs" %(len(movs))
+ for m in movs:
+ print "\t%s <- %s" %(m.dst, m.src)
+ self.stmts[i + 1:i + 1] = movs
+ if len(ar) == 0:
+ self.stmts.pop(i)
+ else:
+ i += 1
+
+ #eliminating unused labels
+ for s in list(self.stmts):
+ if not isinstance(s, op.label):
+ continue
+ print "checking label %s..." %s.name
+ used = s.name in keep_labels
+ if s.name not in self.retlabels:
+ for j in self.stmts:
+ if isinstance(j, op.basejmp) and j.label == s.name:
+ print "used"
+ used = True
+ break
+ if not used:
+ print self.labels
+ self.remove_label(s.name)
+
+ #removing duplicate rets and rets at end
+ for i in xrange(len(self.stmts)):
+ if isinstance(self.stmts[i], op._ret):
+ j = i+1
+ while j < len(self.stmts) and isinstance(self.stmts[j], op._nop):
+ j += 1
+ if j == len(self.stmts) or isinstance(self.stmts[j], op._ret):
+ self.stmts[i] = op._nop(None)
+
+ self.optimize_sequence(op._stosb);
+ self.optimize_sequence(op._stosw);
+ self.optimize_sequence(op._movsb);
+ self.optimize_sequence(op._movsw);
+
+ def add(self, stmt):
+ #print stmt
+ comment = stmt.rfind(';')
+ if comment >= 0:
+ stmt = stmt[:comment]
+ stmt = stmt.strip()
+
+ r = self.__label_re.search(stmt)
+ if r is not None:
+ #label
+ self.add_label(r.group(1).lower())
+ #print "remains: %s" %r.group(2)
+ stmt = r.group(2).strip()
+
+ if len(stmt) == 0:
+ return
+
+ s = stmt.split(None)
+ cmd = s[0]
+ cl = getattr(op, '_' + cmd)
+ arg = " ".join(s[1:]) if len(s) > 1 else str()
+ o = cl(arg)
+ self.stmts.append(o)
+
+ def __str__(self):
+ r = []
+ for i in self.stmts:
+ r.append(i.__str__())
+ return "\n".join(r)
+
+ def visit(self, visitor, skip = 0):
+ for i in xrange(skip, len(self.stmts)):
+ self.stmts[i].visit(visitor)