diff --git a/flask_sockets.py b/flask_sockets.py index c0b4b64..de48e3b 100644 --- a/flask_sockets.py +++ b/flask_sockets.py @@ -27,19 +27,23 @@ if 'gevent' in locals(): class SocketMiddleware(object): - def __init__(self, wsgi_app, socket): + def __init__(self, wsgi_app, app, socket): self.ws = socket - self.app = wsgi_app + self.app = app + self.wsgi_app = wsgi_app def __call__(self, environ, start_response): adapter = self.ws.url_map.bind_to_environ(environ) try: handler, values = adapter.match() environment = environ['wsgi.websocket'] - handler(environment, **values) - return [] + + with self.app.app_context(): + with self.app.request_context(environ): + handler(environment, **values) + return [] except (NotFound, KeyError): - return self.app(environ, start_response) + return self.wsgi_app(environ, start_response) class Sockets(object): @@ -50,7 +54,7 @@ class Sockets(object): self.init_app(app) def init_app(self, app): - app.wsgi_app = SocketMiddleware(app.wsgi_app, self) + app.wsgi_app = SocketMiddleware(app.wsgi_app, app, self) def route(self, rule, **options):