dify / api /services /tag_service.py
kikuepi's picture
Upload 4913 files
4304c6d verified
import uuid
from flask_login import current_user
from sqlalchemy import func
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import App, Tag, TagBinding
class TagService:
@staticmethod
def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list:
query = db.session.query(
Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count')
).outerjoin(
TagBinding, Tag.id == TagBinding.tag_id
).filter(
Tag.type == tag_type,
Tag.tenant_id == current_tenant_id
)
if keyword:
query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%')))
query = query.group_by(
Tag.id
)
results = query.order_by(Tag.created_at.desc()).all()
return results
@staticmethod
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
tags = db.session.query(Tag).filter(
Tag.id.in_(tag_ids),
Tag.tenant_id == current_tenant_id,
Tag.type == tag_type
).all()
if not tags:
return []
tag_ids = [tag.id for tag in tags]
tag_bindings = db.session.query(
TagBinding.target_id
).filter(
TagBinding.tag_id.in_(tag_ids),
TagBinding.tenant_id == current_tenant_id
).all()
if not tag_bindings:
return []
results = [tag_binding.target_id for tag_binding in tag_bindings]
return results
@staticmethod
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
tags = db.session.query(Tag).join(
TagBinding,
Tag.id == TagBinding.tag_id
).filter(
TagBinding.target_id == target_id,
TagBinding.tenant_id == current_tenant_id,
Tag.tenant_id == current_tenant_id,
Tag.type == tag_type
).all()
return tags if tags else []
@staticmethod
def save_tags(args: dict) -> Tag:
tag = Tag(
id=str(uuid.uuid4()),
name=args['name'],
type=args['type'],
created_by=current_user.id,
tenant_id=current_user.current_tenant_id
)
db.session.add(tag)
db.session.commit()
return tag
@staticmethod
def update_tags(args: dict, tag_id: str) -> Tag:
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
raise NotFound("Tag not found")
tag.name = args['name']
db.session.commit()
return tag
@staticmethod
def get_tag_binding_count(tag_id: str) -> int:
count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
return count
@staticmethod
def delete_tag(tag_id: str):
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
raise NotFound("Tag not found")
db.session.delete(tag)
# delete tag binding
tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
if tag_bindings:
for tag_binding in tag_bindings:
db.session.delete(tag_binding)
db.session.commit()
@staticmethod
def save_tag_binding(args):
# check if target exists
TagService.check_target_exists(args['type'], args['target_id'])
# save tag binding
for tag_id in args['tag_ids']:
tag_binding = db.session.query(TagBinding).filter(
TagBinding.tag_id == tag_id,
TagBinding.target_id == args['target_id']
).first()
if tag_binding:
continue
new_tag_binding = TagBinding(
tag_id=tag_id,
target_id=args['target_id'],
tenant_id=current_user.current_tenant_id,
created_by=current_user.id
)
db.session.add(new_tag_binding)
db.session.commit()
@staticmethod
def delete_tag_binding(args):
# check if target exists
TagService.check_target_exists(args['type'], args['target_id'])
# delete tag binding
tag_bindings = db.session.query(TagBinding).filter(
TagBinding.target_id == args['target_id'],
TagBinding.tag_id == (args['tag_id'])
).first()
if tag_bindings:
db.session.delete(tag_bindings)
db.session.commit()
@staticmethod
def check_target_exists(type: str, target_id: str):
if type == 'knowledge':
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == current_user.current_tenant_id,
Dataset.id == target_id
).first()
if not dataset:
raise NotFound("Dataset not found")
elif type == 'app':
app = db.session.query(App).filter(
App.tenant_id == current_user.current_tenant_id,
App.id == target_id
).first()
if not app:
raise NotFound("App not found")
else:
raise NotFound("Invalid binding type")