aboutsummaryrefslogtreecommitdiff
path: root/devtools/tasmrecover/tasm/proc.py
diff options
context:
space:
mode:
Diffstat (limited to 'devtools/tasmrecover/tasm/proc.py')
-rw-r--r--devtools/tasmrecover/tasm/proc.py47
1 files changed, 35 insertions, 12 deletions
diff --git a/devtools/tasmrecover/tasm/proc.py b/devtools/tasmrecover/tasm/proc.py
index 1350ea1e0b..c127c406f7 100644
--- a/devtools/tasmrecover/tasm/proc.py
+++ b/devtools/tasmrecover/tasm/proc.py
@@ -9,6 +9,7 @@ class proc:
self.calls = []
self.stmts = []
self.labels = set()
+ self.retlabels = set()
self.__label_re = re.compile(r'^(\S+):(.*)$')
self.offset = proc.last_addr
proc.last_addr += 4
@@ -60,20 +61,28 @@ class proc:
if i + 1 >= len(stmts):
break
if isinstance(stmts[i + 1], cls):
- stmts[i + 1].repeat = 'context.cx'
+ stmts[i + 1].repeat = 'cx'
+ stmts[i + 1].clear_cx = True
del stmts[i]
i += 1
return
- def optimize(self):
+ def optimize(self, keep_labels=[]):
print "optimizing..."
- #trivial simplifications, removing last ret
+ #trivial simplifications
while len(self.stmts) and isinstance(self.stmts[-1], op.label):
print "stripping last label"
self.stmts.pop()
- if isinstance(self.stmts[-1], op._ret) and (len(self.stmts) < 2 or not isinstance(self.stmts[-2], op.label)):
- print "stripping last ret"
- self.stmts.pop()
+ #mark labels that directly precede a ret
+ for i in range(len(self.stmts)):
+ if not isinstance(self.stmts[i], op.label):
+ continue
+ j = i
+ while j < len(self.stmts) and isinstance(self.stmts[j], op.label):
+ j += 1
+ if j == len(self.stmts) or isinstance(self.stmts[j], op._ret):
+ print "Return label: %s" % (self.stmts[i].name,)
+ self.retlabels.add(self.stmts[i].name)
#merging push ax pop bx constructs
i = 0
while i + 1 < len(self.stmts):
@@ -101,16 +110,30 @@ class proc:
if not isinstance(s, op.label):
continue
print "checking label %s..." %s.name
- used = False
- for j in self.stmts:
- if isinstance(j, op.basejmp) and j.label == s.name:
- print "used"
- used = True
- break
+ used = s.name in keep_labels
+ if s.name not in self.retlabels:
+ for j in self.stmts:
+ if isinstance(j, op.basejmp) and j.label == s.name:
+ print "used"
+ used = True
+ break
if not used:
print self.labels
self.remove_label(s.name)
+ #removing duplicate rets
+ i = 0
+ while i < len(self.stmts)-1:
+ if isinstance(self.stmts[i], op._ret) and isinstance(self.stmts[i+1], op._ret):
+ del self.stmts[i]
+ else:
+ i += 1
+
+ #removing last ret
+ while len(self.stmts) > 0 and isinstance(self.stmts[-1], op._ret) and (len(self.stmts) < 2 or not isinstance(self.stmts[-2], op.label)):
+ print "stripping last ret"
+ self.stmts.pop()
+
self.optimize_sequence(op._stosb);
self.optimize_sequence(op._stosw);
self.optimize_sequence(op._movsb);