!27 feat: 新增trace中间件强化日志链路追踪和响应头

* refactor: trace_log重命名为trace_middleware
* refactor: 日志处理器重构为类式写法
* perf: 移除无用文件
* perf: 优化trace中间件部分写法
* style: 格式化代码
* Merge branch 'master' into develop
* feature: 1.日志添加traceId链路追踪 2.response-header默认添加request-id与traceId对应
This commit is contained in:
py1ren
2025-01-24 01:34:33 +00:00
committed by insistence
parent 1cfd85f9de
commit 00011f8419
6 changed files with 197 additions and 6 deletions

View File

@@ -1,6 +1,7 @@
from fastapi import FastAPI
from middlewares.cors_middleware import add_cors_middleware
from middlewares.gzip_middleware import add_gzip_middleware
from middlewares.trace_middleware import add_trace_middleware
def handle_middleware(app: FastAPI):
@@ -11,3 +12,5 @@ def handle_middleware(app: FastAPI):
add_cors_middleware(app)
# 加载gzip压缩中间件
add_gzip_middleware(app)
# 加载trace中间件
add_trace_middleware(app)

View File

@@ -0,0 +1,17 @@
from fastapi import FastAPI
from .ctx import TraceCtx
from .middle import TraceASGIMiddleware
__all__ = ('TraceASGIMiddleware', 'TraceCtx')
__version__ = '0.1.0'
def add_trace_middleware(app: FastAPI):
"""
添加trace中间件
:param app: FastAPI对象
:return:
"""
app.add_middleware(TraceASGIMiddleware)

View File

@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
"""
@author: peng
@file: ctx.py
@time: 2025/1/17 16:57
"""
import contextvars
from uuid import uuid4
CTX_REQUEST_ID: contextvars.ContextVar[str] = contextvars.ContextVar('request-id', default='')
class TraceCtx:
@staticmethod
def set_id():
_id = uuid4().hex
CTX_REQUEST_ID.set(_id)
return _id
@staticmethod
def get_id():
return CTX_REQUEST_ID.get()

View File

@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
"""
@author: peng
@file: middle.py
@time: 2025/1/17 16:57
"""
from functools import wraps
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from .span import get_current_span, Span
class TraceASGIMiddleware:
"""
fastapi-example:
app = FastAPI()
app.add_middleware(TraceASGIMiddleware)
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
@staticmethod
async def my_receive(receive: Receive, span: Span):
await span.request_before()
@wraps(receive)
async def my_receive():
message = await receive()
await span.request_after(message)
return message
return my_receive
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope['type'] != 'http':
await self.app(scope, receive, send)
return
async with get_current_span(scope) as span:
handle_outgoing_receive = await self.my_receive(receive, span)
async def handle_outgoing_request(message: 'Message') -> None:
await span.response(message)
await send(message)
await self.app(scope, handle_outgoing_receive, handle_outgoing_request)

View File

@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
"""
@author: peng
@file: span.py
@time: 2025/1/17 16:57
"""
from contextlib import asynccontextmanager
from starlette.types import Scope, Message
from .ctx import TraceCtx
class Span:
"""
整个http生命周期
request(before) --> request(after) --> response(before) --> response(after)
"""
def __init__(self, scope: Scope):
self.scope = scope
async def request_before(self):
"""
request_before: 处理header信息等, 如记录请求体信息
"""
TraceCtx.set_id()
async def request_after(self, message: Message):
"""
request_after: 处理请求bytes 如记录请求参数
example:
message: {'type': 'http.request', 'body': b'{\r\n "name": "\xe8\x8b\x8f\xe8\x8b\x8f\xe8\x8b\x8f"\r\n}', 'more_body': False}
"""
return message
async def response(self, message: Message):
"""
if message['type'] == "http.response.start": -----> request-before
pass
if message['type'] == "http.response.body": -----> request-after
message.get('body', b'')
pass
"""
if message['type'] == 'http.response.start':
message['headers'].append((b'request-id', TraceCtx.get_id().encode()))
return message
@asynccontextmanager
async def get_current_span(scope: Scope):
yield Span(scope)

View File

@@ -1,11 +1,60 @@
import os
import sys
import time
from loguru import logger
from loguru import logger as _logger
from typing import Dict
from middlewares.trace_middleware import TraceCtx
log_path = os.path.join(os.getcwd(), 'logs')
if not os.path.exists(log_path):
os.mkdir(log_path)
log_path_error = os.path.join(log_path, f'{time.strftime("%Y-%m-%d")}_error.log')
class LoggerInitializer:
def __init__(self):
self.log_path = os.path.join(os.getcwd(), 'logs')
self.__ensure_log_directory_exists()
self.log_path_error = os.path.join(self.log_path, f'{time.strftime("%Y-%m-%d")}_error.log')
logger.add(log_path_error, rotation='50MB', encoding='utf-8', enqueue=True, compression='zip')
def __ensure_log_directory_exists(self):
"""
确保日志目录存在,如果不存在则创建
"""
if not os.path.exists(self.log_path):
os.mkdir(self.log_path)
@staticmethod
def __filter(log: Dict):
"""
自定义日志过滤器添加trace_id
"""
log['trace_id'] = TraceCtx.get_id()
return log
def init_log(self):
"""
初始化日志配置
"""
# 自定义日志格式
format_str = (
'<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | '
'<cyan>{trace_id}</cyan> | '
'<level>{level: <8}</level> | '
'<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - '
'<level>{message}</level>'
)
_logger.remove()
# 移除后重新添加sys.stderr, 目的: 控制台输出与文件日志内容和结构一致
_logger.add(sys.stderr, filter=self.__filter, format=format_str, enqueue=True)
_logger.add(
self.log_path_error,
filter=self.__filter,
format=format_str,
rotation='50MB',
encoding='utf-8',
enqueue=True,
compression='zip',
)
return _logger
# 初始化日志处理器
log_initializer = LoggerInitializer()
logger = log_initializer.init_log()