Add a hooks mechanism, based on filter expressions.

This commit is contained in:
Aldo Cortesi 2012-03-16 17:13:11 +13:00
parent d138af7217
commit 08f410cacc
4 changed files with 89 additions and 8 deletions

View File

@ -87,6 +87,7 @@ class _Rex(_Action):
except:
raise ValueError, "Cannot compile expression."
def _check_content_type(expr, o):
val = o.headers["content-type"]
if val and re.search(expr, val[0]):

View File

@ -26,6 +26,48 @@ import controller, version
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
class Hooks:
def __init__(self):
self.lst = []
def add(self, patt, func):
"""
Add a hook.
patt: A string specifying a filter pattern.
func: A callable taking the matching flow as argument.
Returns True if hook was added, False if the pattern could not be
parsed.
"""
cpatt = filt.parse(patt)
if not cpatt:
return False
self.lst.append((patt, func, cpatt))
return True
def remove(self, patt, func=None):
"""
Remove a hook.
patt: A string specifying a filter pattern.
func: Optional callable. If not specified, all hooks matching patt are removed.
"""
for i in range(len(self.lst)-1, -1, -1):
if func and (patt, func) == self.lst[i][:2]:
del self.lst[i]
elif not func and patt == self.lst[i][0]:
del self.lst[i]
def run(self, f):
for _, func, cpatt in self.lst:
if cpatt(f):
func(f)
def clear(self):
self.lst = []
class ScriptContext:
def __init__(self, master):
self._master = master
@ -75,7 +117,6 @@ class ODict:
def __getitem__(self, k):
"""
Returns a list of values matching key.
"""
ret = []
k = self._kconv(k)
@ -1229,6 +1270,7 @@ class FlowMaster(controller.Master):
self.anticache = False
self.anticomp = False
self.refresh_server_playback = False
self.hooks = Hooks()
def add_event(self, e, level="info"):
"""
@ -1438,6 +1480,7 @@ class FlowMaster(controller.Master):
def handle_error(self, r):
f = self.state.add_error(r)
self.hooks.run(f)
if f:
self.run_script_hook("error", f)
if self.client_playback:
@ -1447,12 +1490,14 @@ class FlowMaster(controller.Master):
def handle_request(self, r):
f = self.state.add_request(r)
self.hooks.run(f)
self.run_script_hook("request", f)
self.process_new_request(f)
return f
def handle_response(self, r):
f = self.state.add_response(r)
self.hooks.run(f)
if f:
self.run_script_hook("response", f)
if self.client_playback:

View File

@ -57,12 +57,12 @@ functions. They're only here so you can use load() to read precisely one
item from a file or socket without consuming any extra data.
By default tnetstrings work only with byte strings, not unicode. If you want
unicode strings then pass an optional encoding to the various functions,
unicode strings then pass an optional encoding to the various functions,
like so::
>>> print repr(tnetstring.loads("2:\\xce\\xb1,"))
'\\xce\\xb1'
>>>
>>>
>>> print repr(tnetstring.loads("2:\\xce\\xb1,","utf8"))
u'\u03b1'
@ -129,7 +129,7 @@ def _rdumpq(q,size,value,encoding=None):
write("5:false!")
return size + 8
if isinstance(value,(int,long)):
data = str(value)
data = str(value)
ldata = len(data)
span = str(ldata)
write("#")
@ -142,7 +142,7 @@ def _rdumpq(q,size,value,encoding=None):
# It round-trips more accurately.
# Probably unnecessary in later python versions that
# use David Gay's ftoa routines.
data = repr(value)
data = repr(value)
ldata = len(data)
span = str(ldata)
write("^")
@ -207,13 +207,13 @@ def _gdumps(value,encoding):
elif value is False:
yield "5:false!"
elif isinstance(value,(int,long)):
data = str(value)
data = str(value)
yield str(len(data))
yield ":"
yield data
yield "#"
elif isinstance(value,(float,)):
data = repr(value)
data = repr(value)
yield str(len(data))
yield ":"
yield data
@ -334,7 +334,7 @@ def load(file,encoding=None):
d[key] = val
return d
raise ValueError("unknown type tag")
def pop(string,encoding=None):

View File

@ -1036,9 +1036,44 @@ class udecoded(libpry.AutoTree):
assert r.content == "foo"
class uHooks(libpry.AutoTree):
def test_add_remove(self):
f = lambda(x): None
h = flow.Hooks()
h.add("~q", f)
assert h.lst
h.remove("~q", f)
assert not h.lst
h.add("~q", f)
h.add("~s", f)
assert len(h.lst) == 2
h.remove("~q", f)
assert len(h.lst) == 1
h.remove("~q")
assert len(h.lst) == 1
h.remove("~s")
assert len(h.lst) == 0
track = []
def func(x):
track.append(x)
h.add("~s", func)
f = tutils.tflow()
h.run(f)
assert not track
f = tutils.tflow_full()
h.run(f)
assert len(track) == 1
tests = [
uHooks(),
uStickyCookieState(),
uStickyAuthState(),
uServerPlaybackState(),