fix: 引入sqlglot修复sql语句解析异常的问题
This commit is contained in:
@@ -3,6 +3,7 @@ import os
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
from functools import lru_cache
|
||||
from pydantic import computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Literal
|
||||
|
||||
@@ -51,6 +52,13 @@ class DataBaseSettings(BaseSettings):
|
||||
db_pool_recycle: int = 3600
|
||||
db_pool_timeout: int = 30
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def sqlglot_parse_dialect(self) -> str:
|
||||
if self.db_type == 'postgresql':
|
||||
return 'postgres'
|
||||
return self.db_type
|
||||
|
||||
|
||||
class RedisSettings(BaseSettings):
|
||||
"""
|
||||
|
@@ -2,6 +2,7 @@ from datetime import datetime, time
|
||||
from sqlalchemy import delete, func, select, text, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlglot.expressions import Expression
|
||||
from typing import List
|
||||
from config.env import DataBaseConfig
|
||||
from module_generator.entity.do.gen_do import GenTable, GenTableColumn
|
||||
@@ -75,15 +76,17 @@ class GenTableDao:
|
||||
return gen_table_all
|
||||
|
||||
@classmethod
|
||||
async def create_table_by_sql_dao(cls, db: AsyncSession, sql: str):
|
||||
async def create_table_by_sql_dao(cls, db: AsyncSession, sql_statements: List[Expression]):
|
||||
"""
|
||||
根据sql语句创建表结构
|
||||
|
||||
:param db: orm对象
|
||||
:param sql: sql语句
|
||||
:param sql_statements: sql语句的ast列表
|
||||
:return:
|
||||
"""
|
||||
await db.execute(text(sql))
|
||||
for sql_statement in sql_statements:
|
||||
sql = sql_statement.sql(dialect=DataBaseConfig.sqlglot_parse_dialect)
|
||||
await db.execute(text(sql))
|
||||
|
||||
@classmethod
|
||||
async def get_gen_table_list(cls, db: AsyncSession, query_object: GenTablePageQueryModel, is_page: bool = False):
|
||||
|
@@ -1,13 +1,14 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlglot import parse as sqlglot_parse
|
||||
from sqlglot.expressions import Add, Alter, Create, Delete, Drop, Expression, Insert, Table, TruncateTable, Update
|
||||
from typing import List
|
||||
from config.constant import GenConstant
|
||||
from config.env import GenConfig
|
||||
from config.env import DataBaseConfig, GenConfig
|
||||
from exceptions.exception import ServiceException
|
||||
from module_admin.entity.vo.common_vo import CrudResponseModel
|
||||
from module_admin.entity.vo.user_vo import CurrentUserModel
|
||||
@@ -197,10 +198,11 @@ class GenTableService:
|
||||
:param current_user: 当前用户信息对象
|
||||
:return: 创建表结构结果
|
||||
"""
|
||||
if cls.__is_valid_create_table(sql):
|
||||
sql_statements = sqlglot_parse(sql, dialect=DataBaseConfig.sqlglot_parse_dialect)
|
||||
if cls.__is_valid_create_table(sql_statements):
|
||||
try:
|
||||
table_names = re.findall(r'create\s+table\s+(\w+)', sql, re.IGNORECASE)
|
||||
await GenTableDao.create_table_by_sql_dao(query_db, sql)
|
||||
table_names = cls.__get_table_names(sql_statements)
|
||||
await GenTableDao.create_table_by_sql_dao(query_db, sql_statements)
|
||||
gen_table_list = await cls.get_gen_db_table_list_by_name_services(query_db, table_names)
|
||||
await cls.import_gen_table_services(query_db, gen_table_list, current_user)
|
||||
|
||||
@@ -211,22 +213,39 @@ class GenTableService:
|
||||
raise ServiceException(message='建表语句不合法')
|
||||
|
||||
@classmethod
|
||||
def __is_valid_create_table(cls, sql: str):
|
||||
def __is_valid_create_table(cls, sql_statements: List[Expression]):
|
||||
"""
|
||||
校验sql语句是否为合法的建表语句
|
||||
|
||||
:param sql: sql语句
|
||||
:param sql_statements: sql语句的ast列表
|
||||
:return: 校验结果
|
||||
"""
|
||||
create_table_pattern = r'^\s*CREATE\s+TABLE\s+'
|
||||
if not re.search(create_table_pattern, sql, re.IGNORECASE):
|
||||
validate_create = [isinstance(sql_statement, Create) for sql_statement in sql_statements]
|
||||
validate_forbidden_keywords = [
|
||||
isinstance(
|
||||
sql_statement,
|
||||
(Add, Alter, Delete, Drop, Insert, TruncateTable, Update),
|
||||
)
|
||||
for sql_statement in sql_statements
|
||||
]
|
||||
if not any(validate_create) or any(validate_forbidden_keywords):
|
||||
return False
|
||||
forbidden_keywords = ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'ALTER', 'TRUNCATE']
|
||||
for keyword in forbidden_keywords:
|
||||
if re.search(rf'\b{keyword}\b', sql, re.IGNORECASE):
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def __get_table_names(cls, sql_statements: List[Expression]):
|
||||
"""
|
||||
获取sql语句中所有的建表表名
|
||||
|
||||
:param sql_statements: sql语句的ast列表
|
||||
:return: 建表表名列表
|
||||
"""
|
||||
table_names = []
|
||||
for sql_statement in sql_statements:
|
||||
if isinstance(sql_statement, Create):
|
||||
table_names.append(sql_statement.find(Table).name)
|
||||
return table_names
|
||||
|
||||
@classmethod
|
||||
async def preview_code_services(cls, query_db: AsyncSession, table_id: int):
|
||||
"""
|
||||
|
@@ -14,4 +14,5 @@ PyMySQL==1.1.1
|
||||
redis==5.2.1
|
||||
requests==2.32.3
|
||||
SQLAlchemy[asyncio]==2.0.38
|
||||
sqlglot[rs]==26.6.0
|
||||
user-agents==2.2.0
|
||||
|
Reference in New Issue
Block a user