diff --git a/libpathod/rparse.py b/libpathod/rparse.py index d083e335a..98254abbc 100644 --- a/libpathod/rparse.py +++ b/libpathod/rparse.py @@ -31,7 +31,7 @@ def ready_actions(length, lst): itms = list(i) if i[0] == "r": itms[0] = random.randrange(length) - if i[0] == "a": + elif i[0] == "a": itms[0] = length+1 ret.append(tuple(itms)) ret.sort() @@ -68,10 +68,10 @@ def write_values(fp, vals, actions, sofar=0, skip=0, blocksize=BLOCKSIZE): offset += send_chunk(fp, v, blocksize, offset, a[0]-sofar-offset) if a[1] == "pause": time.sleep(a[2]) - elif a[1] == "inject": - send_chunk(fp, a[2], blocksize, 0, len(a[2])) elif a[1] == "disconnect": return True + elif a[1] == "inject": + send_chunk(fp, a[2], blocksize, 0, len(a[2])) send_chunk(fp, v, blocksize, offset, len(v)) sofar += len(v) except tcp.NetLibDisconnect: @@ -501,6 +501,9 @@ class Code: class Message: version = "HTTP/1.1" def length(self): + """ + Calculate the length of the base message without any applied actions. + """ l = sum(len(x) for x in self.preamble()) l += 2 for i in self.headers: @@ -510,6 +513,20 @@ class Message: l += len(self.body) return l + def effective_length(self, actions): + """ + Calculate the length of the base message with all applied actions. + """ + # Order matters here, and must match the order of application in + # write_values. + l = self.length() + for i in reversed(actions): + if i[1] == "disconnect": + return i[0] + elif i[1] == "inject": + l += len(i[2]) + return l + def serve(self, fp): started = time.time() if self.body and not utils.get_header("Content-Length", self.headers): diff --git a/test/test_rparse.py b/test/test_rparse.py index 23c9b3e66..0cb3e3731 100644 --- a/test/test_rparse.py +++ b/test/test_rparse.py @@ -433,6 +433,36 @@ class TestResponse: testlen(rparse.parse_response({}, "400'msg':h'foo'='bar'")) testlen(rparse.parse_response({}, "400'msg':h'foo'='bar':b@100b")) + def test_effective_length(self): + def testlen(x, actions): + s = cStringIO.StringIO() + x.serve(s) + assert x.effective_length(actions) == len(s.getvalue()) + actions = [ + + ] + r = rparse.parse_response({}, "400'msg':b@100") + + actions = [ + (0, "disconnect"), + ] + r.actions = actions + testlen(r, actions) + + actions = [ + (0, "disconnect"), + (0, "inject", "foo") + ] + r.actions = actions + testlen(r, actions) + + actions = [ + (0, "inject", "foo") + ] + r.actions = actions + testlen(r, actions) + + def test_read_file(): tutils.raises(rparse.FileAccessDenied, rparse.read_file, {}, "=/foo")