trafficserver-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From a..@apache.org
Subject [trafficserver] branch master updated: Test: Update micro-server to have hooks for observers.
Date Fri, 07 Apr 2017 14:20:21 GMT
This is an automated email from the ASF dual-hosted git repository.

amc pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/trafficserver.git

The following commit(s) were added to refs/heads/master by this push:
       new  db75dee   Test: Update micro-server to have hooks for observers.
db75dee is described below

commit db75deed5c82ed4c6ca56917f5e38f97119697ae
Author: Alan M. Carroll <solidwallofcode@yahoo-inc.com>
AuthorDate: Tue Apr 4 15:33:52 2017 -0500

    Test: Update micro-server to have hooks for observers.
---
 tests/gold_tests/autest-site/microserver.test.ext |  16 ++-
 tests/tools/microServer/uWServer.py               | 156 +++++++++++++++-------
 2 files changed, 120 insertions(+), 52 deletions(-)

diff --git a/tests/gold_tests/autest-site/microserver.test.ext b/tests/gold_tests/autest-site/microserver.test.ext
index a0e8fdc..9aa777e 100644
--- a/tests/gold_tests/autest-site/microserver.test.ext
+++ b/tests/gold_tests/autest-site/microserver.test.ext
@@ -20,7 +20,7 @@ from ports import get_port
 import json
 
 def addMethod(self,testName, request_header, functionName):
-    return 
+    return
 
 # creates the full request or response block using headers and message data
 def httpObject(self,header,data):
@@ -74,7 +74,7 @@ def addTransactionToSession(txn,JFile):
     if os.path.exists(JFile):
         jf = open(JFile,'r')
         jsondata = json.load(jf)
-    
+
     if jsondata == None:
         jsondata = dict()
         jsondata["version"]='0.1'
@@ -86,9 +86,9 @@ def addTransactionToSession(txn,JFile):
         jsondata["txns"].append(txn)
     with open(JFile,'w+') as jf:
         jf.write(json.dumps(jsondata))
-        
 
-#make headers with the key and values provided        
+
+#make headers with the key and values provided
 def makeHeader(self,requestString, **kwargs):
     headerStr = requestString+'\r\n'
     for k,v in kwargs.iteritems():
@@ -97,21 +97,23 @@ def makeHeader(self,requestString, **kwargs):
     return headerStr
 
 
-def MakeOriginServer(obj, name,public_ip=False):
+def MakeOriginServer(obj, name,public_ip=False,options={}):
     server_path= os.path.join(obj.Variables.AtsTestToolsDir,'microServer/uWServer.py')
     data_dir = os.path.join(obj.RunDirectory, name)
     # create Process
     p = obj.Processes.Process(name)
     port=get_port(p,"Port")
     command = "python3 {0} --data-dir {1} --port {2} --public {3} -m test".format(server_path,
data_dir, port, public_ip)
-    
+    for flag,value in options.items() :
+        command += " {} {}".format(flag,value)
+
     # create process
     p.Command = command
     p.Setup.MakeDir(data_dir)
     p.Variables.DataDir = data_dir
     AddMethodToInstance(p,addResponse)
     AddMethodToInstance(p,addTransactionToSession)
-    
+
     return p
 
 AddTestRunSet(MakeOriginServer,name="MakeOriginServer")
diff --git a/tests/tools/microServer/uWServer.py b/tests/tools/microServer/uWServer.py
index c889f5f..09e1d88 100644
--- a/tests/tools/microServer/uWServer.py
+++ b/tests/tools/microServer/uWServer.py
@@ -32,6 +32,7 @@ from http import HTTPStatus
 import argparse
 import ssl
 import socket
+import importlib.util
 
 test_mode_enabled = True
 __version__="1.0"
@@ -54,9 +55,62 @@ HTTP_VERSION = 'HTTP/1.1'
 G_replay_dict = {}
 
 count = 0
+
+# Simple class to hold lists of callbacks associated with a key.
+class HookSet:
+    # Helper class to provide controlled access to the HookSet to the loading module.
+    class Registrar :
+        def __init__(self, hook_set) :
+            self.hooks = hook_set
+
+        def register(self, hook, cb) :
+            self.hooks.register(hook, cb)
+
+    def __init__(self) :
+        # Define all the valid hooks here.
+        self.hooks = { 'ReadRequestHook': [] }
+        self.modules = []
+        self.registrar = HookSet.Registrar(self)
+
+    def load(self, source) :
+        try :
+            spec = importlib.util.spec_from_file_location('Observer', source)
+            mod = importlib.util.module_from_spec(spec)
+            mod.Hooks = self.registrar
+            spec.loader.exec_module(mod)
+        except ImportError:
+            print("Failed to import {}".format(source))
+        else :
+            self.modules.append(mod)
+
+    # Add a callback cb to the hook.
+    # Error if the hook isn't defined.
+    def register(self, hook, cb):
+        if hook in self.hooks :
+            self.hooks[hook].append(cb)
+        else:
+            raise ValueError("{} is not a valid hook name".format(hook))
+
+    # Invoke a hook. Pass on any additional arguments to the callback.
+    def invoke(self, hook, *args, **kwargs):
+        cb_list = self.hooks[hook]
+        if cb_list == None :
+            raise ValueError("{} is not a valid hook name to invoke".format(hook))
+        else :
+            for cb in cb_list :
+                cb(*args, **kwargs)
+
+    # Keep track of modules so the callbacks don't go out of scope.
+    def add_module(self, mod) :
+        self.modules.append(mod)
+
 class ThreadingServer(ThreadingMixIn, HTTPServer):
     '''This class forces the creation of a new thread on each connection'''
-    pass
+    def __init__(self, local_addr, handler_class, options) :
+        HTTPServer.__init__(self, local_addr, handler_class)
+        self.hook_set = HookSet()
+        if (options.load) :
+            self.hook_set.load(options.load)
 
 class ForkingServer(ForkingMixIn, HTTPServer):
     '''This class forces the creation of a new process on each connection'''
@@ -69,9 +123,14 @@ class SSLServer(ThreadingMixIn, HTTPServer):
             keys = os.path.join(pwd,options.key)
             certs = os.path.join(pwd,options.cert)
             self.options = options
+            self.hook_set = HookSet()
 
             self.daemon_threads = True
             self.protocol_version = 'HTTP/1.1'
+
+            if options.load :
+                self.hook_set.load(options.load)
+
             if options.clientverify:
             	self.socket = ssl.wrap_socket(socket.socket(self.address_family, self.socket_type),
                     keyfile=keys, certfile=certs, server_side=True, cert_reqs=ssl.CERT_REQUIRED,
ca_certs='/etc/ssl/certs/ca-certificates.crt')
@@ -110,18 +169,18 @@ class MyHandler(BaseHTTPRequestHandler):
         return key
 
     def parseRequestline(self,requestline):
-        testName=None        
+        testName=None
         return testName
 
     def testMode(self,requestline):
         print(requestline)
         key=self.parseRequestline(requestline)
-        
+
         self.send_response(200)
         self.send_header('Connection', 'close')
         self.end_headers()
 
-        
+
     def get_response_code(self, header):
         # this could totally go wrong
         return int(header.split(' ')[1])
@@ -152,7 +211,7 @@ class MyHandler(BaseHTTPRequestHandler):
 
     def readChunks(self):
         raw_data=b''
-        raw_size = self.rfile.readline(65537)        
+        raw_size = self.rfile.readline(65537)
         size = str(raw_size, 'UTF-8').rstrip('\r\n')
         #print("==========================================>",size)
         size = int(size,16)
@@ -161,7 +220,7 @@ class MyHandler(BaseHTTPRequestHandler):
             chunk = self.rfile.read(size+2) # 2 for reading /r/n
             #print("cuhnk: ",chunk)
             raw_data += chunk
-            raw_size = self.rfile.readline(65537)            
+            raw_size = self.rfile.readline(65537)
             size = str(raw_size, 'UTF-8').rstrip('\r\n')
             size = int(size,16)
         #print("full chunk",raw_data)
@@ -191,21 +250,22 @@ class MyHandler(BaseHTTPRequestHandler):
         error is sent back.
 
         """
-        
+
         global count, test_mode_enabled
-        
+
         self.command = None  # set in case of error on the first line
         self.request_version = version = self.default_request_version
         self.close_connection = True
         requestline = str(self.raw_requestline, 'UTF-8')
         #print("request",requestline)
         requestline = requestline.rstrip('\r\n')
-        self.requestline = requestline        
-        
-        # Examine the headers and look for a Connection directive.        
+        self.requestline = requestline
+
+        # Examine the headers and look for a Connection directive.
         try:
             self.headers = http.client.parse_headers(self.rfile,
                                                      _class=self.MessageClass)
+            self.server.hook_set.invoke('ReadRequestHook', self.headers)
 
             # read message body
             if self.headers.get('Content-Length') != None:
@@ -228,8 +288,8 @@ class MyHandler(BaseHTTPRequestHandler):
                 str(err)
             )
             return False
-        
-        
+
+
         words = requestline.split()
         if len(words) == 3:
             command, path, version = words
@@ -337,9 +397,9 @@ class MyHandler(BaseHTTPRequestHandler):
                     if 'Transfer-Encoding' in header:
                         self.send_header('Transfer-Encoding','Chunked')
                         response_string='%X\r\n%s\r\n'%(len('ats'),'ats')
-                        chunkedResponse= True                    
+                        chunkedResponse= True
                         continue
-            
+
                     header_parts = header.split(':', 1)
                     header_field = str(header_parts[0].strip())
                     header_field_val = str(header_parts[1].strip())
@@ -352,8 +412,8 @@ class MyHandler(BaseHTTPRequestHandler):
                         response_string=resp.getBody()
                         self.send_header('Content-Length', str(length))
                 self.end_headers()
-                
-                
+
+
                 if (chunkedResponse):
                     self.writeChunkedData()
                 elif response_string!=None and response_string!='':
@@ -365,9 +425,9 @@ class MyHandler(BaseHTTPRequestHandler):
             self.send_response(400)
             self.send_header('Connection', 'close')
             self.end_headers()
-       
 
-        
+
+
     def do_HEAD(self):
         global G_replay_dict, test_mode_enabled
         if test_mode_enabled:
@@ -395,7 +455,7 @@ class MyHandler(BaseHTTPRequestHandler):
                 elif 'Content-Length' in header:
                     self.send_header('Content-Length', '0')
                     continue
-        
+
                 header_parts = header.split(':', 1)
                 header_field = str(header_parts[0].strip())
                 header_field_val = str(header_parts[1].strip())
@@ -429,7 +489,7 @@ class MyHandler(BaseHTTPRequestHandler):
                 #print("reposen is ",resp_headers)
                 # set headers
                 for header in resp_headers[1:]: # skip first one b/c it's response code
-                    
+
                     if header == '':
                         continue
                     elif 'Content-Length' in header:
@@ -439,7 +499,7 @@ class MyHandler(BaseHTTPRequestHandler):
                             header_field_val = str(header_parts[1].strip())
                             self.send_header(header_field, header_field_val)
                             continue
-                        
+
                         lengthSTR = header.split(':')[1]
                         length = lengthSTR.strip(' ')
                         if test_mode_enabled: # the length of the body is given priority
in test mode rather than the value in Content-Length. But in replay mode Content-Length gets
the priority
@@ -452,9 +512,9 @@ class MyHandler(BaseHTTPRequestHandler):
                     if 'Transfer-Encoding' in header:
                         self.send_header('Transfer-Encoding','Chunked')
                         response_string='%X\r\n%s\r\n'%(len('microserver'),'microserver')
-                        chunkedResponse= True                    
+                        chunkedResponse= True
                         continue
-                    
+
                     header_parts = header.split(':', 1)
                     header_field = str(header_parts[0].strip())
                     header_field_val = str(header_parts[1].strip())
@@ -465,9 +525,9 @@ class MyHandler(BaseHTTPRequestHandler):
                     if resp and resp.getBody():
                         length = len(bytes(resp.getBody(),'UTF-8'))
                         response_string=resp.getBody()
-                        self.send_header('Content-Length', str(length))    
+                        self.send_header('Content-Length', str(length))
                 self.end_headers()
-            
+
             if (chunkedResponse):
                 self.writeChunkedData()
             elif response_string!=None and response_string!='':
@@ -486,13 +546,13 @@ def populate_global_replay_dictionary(sessions):
     for session in sessions:
         for txn in session.getTransactionIter():
             G_replay_dict[txn._uuid] = txn.getResponse()
-    
+
     print("size",len(G_replay_dict))
-    
+
 #tests will add responses to the dictionary where key is the testname
 def addResponseHeader(key,response_header):
     G_replay_dict[key] = response_header
-    
+
 def _path(exists, arg ):
     path = os.path.abspath(arg)
     if not os.path.exists(path) and exists:
@@ -501,7 +561,7 @@ def _path(exists, arg ):
     return path
 
 def _bool(arg):
-        
+
         opt_true_values = set(['y', 'yes', 'true', 't', '1', 'on' , 'all'])
         opt_false_values = set(['n', 'no', 'false', 'f', '0', 'off', 'none'])
 
@@ -526,9 +586,9 @@ def main():
                         help="Directory with data file"
                         )
 
-    parser.add_argument("--public","-P", 
-                        type=_bool, 
-                        default=False,                        
+    parser.add_argument("--public","-P",
+                        type=_bool,
+                        default=False,
                         help="Bind server to public IP 0.0.0.0 vs private IP of 127.0.0.1"
                         )
 
@@ -540,54 +600,60 @@ def main():
 
     parser.add_argument("--port","-p",
                         type=int,
-                        default=SERVER_PORT,                        
+                        default=SERVER_PORT,
                         help="Port to use")
 
-    parser.add_argument("--timeout","-t", 
+    parser.add_argument("--timeout","-t",
                         type=float,
-                        default=None,                        
-                        help="socket time out in seconds")                        
+                        default=None,
+                        help="socket time out in seconds")
 
     parser.add_argument('-V','--version', action='version', version='%(prog)s {0}'.format(__version__))
 
     parser.add_argument("--mode","-m",
                         type=str,
-                        default="test",                        
+                        default="test",
                         help="Mode of operation")
     parser.add_argument("--connection","-c",
                         type=str,
-                        default="nonSSL",                        
+                        default="nonSSL",
                         help="use SSL")
     parser.add_argument("--key","-k",
                         type=str,
-                        default="ssl/server.pem",                        
+                        default="ssl/server.pem",
                         help="key for ssl connnection")
     parser.add_argument("--cert","-cert",
                         type=str,
-                        default="ssl/server.crt",                        
+                        default="ssl/server.crt",
                         help="certificate")
     parser.add_argument("--clientverify","-cverify",
                         type=bool,
                         default=False,
                         help="verify client cert")
+    parser.add_argument("--load",
+                        dest='load',
+                        type=str,
+                        default='',
+                        help="A file which will install observers on hooks")
 
     args=parser.parse_args()
     options = args
+
     # set up global dictionary of {uuid (string): response (Response object)}
     s = sv.SessionValidator(args.data_dir)
     populate_global_replay_dictionary(s.getSessionIter())
     print("Dropped {0} sessions for being malformed".format(len(s.getBadSessionList())))
-    
+
     # start server
     try:
         socket_timeout = args.timeout
         test_mode_enabled = args.mode=="test"
-        
-        MyHandler.protocol_version = HTTP_VERSION        
+
+        MyHandler.protocol_version = HTTP_VERSION
         if options.connection == 'ssl':
             server = SSLServer((options.ip_address,options.port), MyHandler, options)
         else:
-            server = ThreadingServer((options.ip_address, options.port), MyHandler)
+            server = ThreadingServer((options.ip_address, options.port), MyHandler, options)
         server.timeout = 5
         print("started server")
         server_thread = threading.Thread(target=server.serve_forever())

-- 
To stop receiving notification emails like this one, please contact
['"commits@trafficserver.apache.org" <commits@trafficserver.apache.org>'].

Mime
View raw message