Remove global should_exit and fix tests

This commit is contained in:
Vyacheslav Bakhmutov 2014-06-13 14:14:55 +07:00
parent 00fd243810
commit b7c1d05782
5 changed files with 16 additions and 19 deletions

View File

@ -577,7 +577,7 @@ class ConsoleMaster(flow.FlowMaster):
self.view_flowlist() self.view_flowlist()
self.server.start_slave(controller.Slave, controller.Channel(self.masterq)) self.server.start_slave(controller.Slave, controller.Channel(self.masterq, self.should_exit))
if self.options.rfile: if self.options.rfile:
ret = self.load_flows(self.options.rfile) ret = self.load_flows(self.options.rfile)
@ -780,7 +780,7 @@ class ConsoleMaster(flow.FlowMaster):
def loop(self): def loop(self):
changed = True changed = True
try: try:
while not controller.should_exit: while not self.should_exit.is_set():
startloop = time.time() startloop = time.time()
if changed: if changed:
self.statusbar.redraw() self.statusbar.redraw()

View File

@ -1,9 +1,6 @@
from __future__ import absolute_import from __future__ import absolute_import
import Queue, threading import Queue, threading
should_exit = False
class DummyReply: class DummyReply:
""" """
A reply object that does nothing. Useful when we need an object to seem A reply object that does nothing. Useful when we need an object to seem
@ -37,8 +34,9 @@ class Reply:
class Channel: class Channel:
def __init__(self, q): def __init__(self, q, should_exit):
self.q = q self.q = q
self.should_exit = should_exit
def ask(self, mtype, m): def ask(self, mtype, m):
""" """
@ -47,7 +45,7 @@ class Channel:
""" """
m.reply = Reply(m) m.reply = Reply(m)
self.q.put((mtype, m)) self.q.put((mtype, m))
while not should_exit: while not self.should_exit.is_set():
try: try:
# The timeout is here so we can handle a should_exit event. # The timeout is here so we can handle a should_exit event.
g = m.reply.q.get(timeout=0.5) g = m.reply.q.get(timeout=0.5)
@ -89,6 +87,7 @@ class Master:
""" """
self.server = server self.server = server
self.masterq = Queue.Queue() self.masterq = Queue.Queue()
self.should_exit = threading.Event()
def tick(self, q): def tick(self, q):
changed = False changed = False
@ -107,10 +106,9 @@ class Master:
return changed return changed
def run(self): def run(self):
global should_exit self.should_exit.clear()
should_exit = False self.server.start_slave(Slave, Channel(self.masterq, self.should_exit))
self.server.start_slave(Slave, Channel(self.masterq)) while not self.should_exit.is_set():
while not should_exit:
self.tick(self.masterq) self.tick(self.masterq)
self.shutdown() self.shutdown()
@ -123,8 +121,7 @@ class Master:
obj.reply() obj.reply()
def shutdown(self): def shutdown(self):
global should_exit if not self.should_exit.is_set():
if not should_exit: self.should_exit.set()
should_exit = True
if self.server: if self.server:
self.server.shutdown() self.server.shutdown()

View File

@ -654,6 +654,7 @@ class FlowMaster(controller.Master):
self.server.config, self.server.config,
f, f,
self.masterq, self.masterq,
self.should_exit
) )
rt.start() # pragma: no cover rt.start() # pragma: no cover
if block: if block:
@ -792,8 +793,8 @@ class FilteredFlowWriter:
class RequestReplayThread(threading.Thread): class RequestReplayThread(threading.Thread):
name="RequestReplayThread" name="RequestReplayThread"
def __init__(self, config, flow, masterq): def __init__(self, config, flow, masterq, should_exit):
self.config, self.flow, self.channel = config, flow, controller.Channel(masterq) self.config, self.flow, self.channel = config, flow, controller.Channel(masterq, should_exit)
threading.Thread.__init__(self) threading.Thread.__init__(self)
def run(self): def run(self):

View File

@ -672,7 +672,6 @@ class TestFlowMaster:
fm.handle_error(f.error) fm.handle_error(f.error)
def test_server_playback(self): def test_server_playback(self):
controller.should_exit = False
s = flow.State() s = flow.State()
f = tutils.tflow() f = tutils.tflow()
@ -695,7 +694,7 @@ class TestFlowMaster:
fm.start_server_playback(pb, False, [], True, False) fm.start_server_playback(pb, False, [], True, False)
q = Queue.Queue() q = Queue.Queue()
fm.tick(q) fm.tick(q)
assert controller.should_exit assert fm.should_exit.is_set()
fm.stop_server_playback() fm.stop_server_playback()
assert not fm.server_playback assert not fm.server_playback

View File

@ -283,7 +283,7 @@ class ChainProxTest(ProxTestBase):
def teardownAll(cls): def teardownAll(cls):
super(ChainProxTest, cls).teardownAll() super(ChainProxTest, cls).teardownAll()
for p in cls.chain: for p in cls.chain:
p.tmaster.server.shutdown() p.tmaster.shutdown()
def setUp(self): def setUp(self):
super(ChainProxTest, self).setUp() super(ChainProxTest, self).setUp()