diff options
Diffstat (limited to 'devtools/tasmrecover/tasm/proc.py')
-rw-r--r-- | devtools/tasmrecover/tasm/proc.py | 171 |
1 files changed, 171 insertions, 0 deletions
diff --git a/devtools/tasmrecover/tasm/proc.py b/devtools/tasmrecover/tasm/proc.py new file mode 100644 index 0000000000..ed7053df89 --- /dev/null +++ b/devtools/tasmrecover/tasm/proc.py @@ -0,0 +1,171 @@ +import re +import op + +class proc: + last_addr = 0xc000 + + def __init__(self, name): + self.name = name + self.calls = [] + self.stmts = [] + self.labels = set() + self.retlabels = set() + self.__label_re = re.compile(r'^(\S+):(.*)$') + self.offset = proc.last_addr + proc.last_addr += 4 + + def add_label(self, label): + self.stmts.append(op.label(label)) + self.labels.add(label) + + def remove_label(self, label): + try: + self.labels.remove(label) + except: + pass + for i in xrange(len(self.stmts)): + if isinstance(self.stmts[i], op.label) and self.stmts[i].name == label: + self.stmts[i] = op._nop(None) + return + + def optimize_sequence(self, cls): + i = 0 + stmts = self.stmts + while i < len(stmts): + if not isinstance(stmts[i], cls): + i += 1 + continue + if i > 0 and isinstance(stmts[i - 1], op._rep): #skip rep prefixed instructions for now + i += 1 + continue + j = i + 1 + + while j < len(stmts): + if not isinstance(stmts[j], cls): + break + j = j + 1 + + n = j - i + if n > 1: + print "Eliminate consequtive storage instructions at %u-%u" %(i, j) + for k in range(i+1,j): + stmts[k] = op._nop(None) + stmts[i].repeat = n + else: + i = j + + i = 0 + while i < len(stmts): + if not isinstance(stmts[i], op._rep): + i += 1 + continue + if i + 1 >= len(stmts): + break + if isinstance(stmts[i + 1], cls): + stmts[i + 1].repeat = 'cx' + stmts[i + 1].clear_cx = True + stmts[i] = op._nop(None) + i += 1 + return + + def optimize(self, keep_labels=[]): + print "optimizing..." + #trivial simplifications + while len(self.stmts) and isinstance(self.stmts[-1], op.label): + print "stripping last label" + self.stmts.pop() + #mark labels that directly precede a ret + for i in range(len(self.stmts)): + if not isinstance(self.stmts[i], op.label): + continue + j = i + while j < len(self.stmts) and isinstance(self.stmts[j], (op.label, op._nop)): + j += 1 + if j == len(self.stmts) or isinstance(self.stmts[j], op._ret): + print "Return label: %s" % (self.stmts[i].name,) + self.retlabels.add(self.stmts[i].name) + #merging push ax pop bx constructs + i = 0 + while i + 1 < len(self.stmts): + a, b = self.stmts[i], self.stmts[i + 1] + if isinstance(a, op._push) and isinstance(b, op._pop): + ar, br = a.regs, b.regs + movs = [] + while len(ar) and len(br): + src = ar.pop() + dst = br.pop(0) + movs.append(op._mov2(dst, src)) + if len(br) == 0: + self.stmts.pop(i + 1) + print "merging %d push-pops into movs" %(len(movs)) + for m in movs: + print "\t%s <- %s" %(m.dst, m.src) + self.stmts[i + 1:i + 1] = movs + if len(ar) == 0: + self.stmts.pop(i) + else: + i += 1 + + #eliminating unused labels + for s in list(self.stmts): + if not isinstance(s, op.label): + continue + print "checking label %s..." %s.name + used = s.name in keep_labels + if s.name not in self.retlabels: + for j in self.stmts: + if isinstance(j, op.basejmp) and j.label == s.name: + print "used" + used = True + break + if not used: + print self.labels + self.remove_label(s.name) + + #removing duplicate rets and rets at end + for i in xrange(len(self.stmts)): + if isinstance(self.stmts[i], op._ret): + j = i+1 + while j < len(self.stmts) and isinstance(self.stmts[j], op._nop): + j += 1 + if j == len(self.stmts) or isinstance(self.stmts[j], op._ret): + self.stmts[i] = op._nop(None) + + self.optimize_sequence(op._stosb); + self.optimize_sequence(op._stosw); + self.optimize_sequence(op._movsb); + self.optimize_sequence(op._movsw); + + def add(self, stmt): + #print stmt + comment = stmt.rfind(';') + if comment >= 0: + stmt = stmt[:comment] + stmt = stmt.strip() + + r = self.__label_re.search(stmt) + if r is not None: + #label + self.add_label(r.group(1).lower()) + #print "remains: %s" %r.group(2) + stmt = r.group(2).strip() + + if len(stmt) == 0: + return + + s = stmt.split(None) + cmd = s[0] + cl = getattr(op, '_' + cmd) + arg = " ".join(s[1:]) if len(s) > 1 else str() + o = cl(arg) + self.stmts.append(o) + + def __str__(self): + r = [] + for i in self.stmts: + r.append(i.__str__()) + return "\n".join(r) + + def visit(self, visitor, skip = 0): + for i in xrange(skip, len(self.stmts)): + self.stmts[i].visit(visitor) |