瀏覽代碼

implemented server shutdown function and added taskgroup for client handers

digital 7 年之前
父節點
當前提交
ebaf1101d6
共有 1 個文件被更改,包括 29 次插入22 次删除
  1. 29 22
      network/__init__.py

+ 29 - 22
network/__init__.py

@@ -16,6 +16,7 @@
 # You should have received a copy of the GNU General Public License
 # along with DigiLib.  If not, see <http://www.gnu.org/licenses/>.
 
+import atexit
 import logging
 import logging.handlers
 import os
@@ -119,25 +120,24 @@ class Server(object):
             ):
         super(Server, self).__init__()
         self.exit_event =  False
-        self.host=host
-        self.port=port
+        self.host = host
+        self.port = port
         self.af_family = af_family
         self.log_ip = log_ip
         self.handler_kwargs = handler_kwargs
         self.handler = handler
         self.max_allowed_clients = max_allowed_clients
         self.socket = self.make_socket()
+        self.handle_tasks = curio.TaskGroup(name="tg_handle_clients")
         self.connection_handler = []
         self.conn_to_addr = {}
         self.addr_to_conn = {}
         self.conn_to_handler ={}
         self.handler_to_conn = {}
         self.read_sockets_expected = [self.socket]
+        atexit.register(self.shutdown)
 
-    def cleanup(self):
-        pass
-
-    def make_socket(self):
+     def make_socket(self):
         lserver.debug("making a {} socket".format(self.af_family))
         if self.af_family == "AF_INET":
             s = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
@@ -193,6 +193,23 @@ class Server(object):
         self.socket.listen(self.max_allowed_clients)
         # self.socket.settimeout(1)
 
+    def shutdown(self):
+        def error_handler(func,*args,log_text="error",**kwargs):
+            try:
+                func(*args,**kwargs)
+            except Exception as exc:
+                lserver.debug("error occured during "+log_text,exc_info=exc)
+        # this function can be called by the user and we don't need to clean
+        # up a second time when the program exits
+        atexit.unregister(self.shutdown)
+        lserver.info("shutting down server")
+        error_handler(self.handle_tasks.cancel,log_text="handler cancel")
+        if self.socket:
+            error_handler(self.socket.shutdown,log_text="socket shutdown")
+            error_handler(self.socket.close,log_text="socket close")
+        # del self.socket
+        # self.socket = None
+
     def start(self):
         # lserver.debug(dir(self))
         curio.run(self.run)
@@ -201,10 +218,6 @@ class Server(object):
         self.setup()
         lserver.debug("entering main loop")
         while ( not self.exit_event ):
-            # lserver.debug(self.read_sockets_expected)
-            # lserver.debug(self.write_sockets_expected)
-            # lserver.debug(self.exc_sockets_expected)
-            # read_sockets = select.select(self.read_sockets_expected,[],[])
             lserver.debug("waiting for client to connect")
             conn,addr = await self.socket.accept()
             if self.log_ip:
@@ -215,31 +228,25 @@ class Server(object):
             handler = self.make_handler(conn, addr)
             self.register_conn(conn, addr)
             self.register_handler(handler, conn)
-            await curio.spawn(self.wait_for_client(conn,handler))
+            await self.handle_task.spawn(self.wait_for_client(conn,handler))
 
-    async def wait_for_client(self,socket,handler):
+    async def handle_client(self,socket,handler):
         _append_task()
-        # cur_task = curio.current_task()
-        # _tasks.append(cur_task)
         # while True:
         for i in range(1):
             try:
                 if self.log_ip:
-                    lserver.debug(
-                        "waiting for {} to send something"
+                    lserver.debug("waiting for {} to send something"
                         .format(socket.getsockname()))
                 else:
-                    lserver.debug(
-                        "waiting for the client to send something")
+                    lserver.debug("waiting for the client to send something")
                 data = await handler.recv()
                 if not data:
                     if self.log_ip:
-                        lserver.info(
-                            "the connection to {} was closed"
+                        lserver.info("the connection to {} was closed"
                             .format(socket.getsockname()))
                     else:
-                        lserver.info(
-                            "the connection to the client was closed")
+                        lserver.info("the connection to the client was closed")
                     self.unregister_handler(handler, socket)
                     self.unregister_conn(socket)
                     handler.close()