Skip to content

知识库检索工程技术详解

📋 概述

检索工程是知识库系统的核心能力,负责将用户的自然语言查询转换为精准的知识检索结果。本文档详细介绍多路召回、重排序、缓存等关键技术的实现原理和代码架构。

🏗️ 检索架构设计

检索流程概览

mermaid
graph TD
    A[用户查询] --> B[查询预处理]
    B --> C[查询重写]
    C --> D[多路并行召回]
    D --> E[向量检索]
    D --> F[全文检索]
    D --> G[结构化检索]
    E --> H[结果融合]
    F --> H
    G --> H
    H --> I[重排序]
    I --> J[结果后处理]
    J --> K[返回结果]

核心技术框架

本项目使用 字节跳动开源的 Eino 框架 构建检索链,这是一个专为 AI 应用设计的组合式编程框架。

Eino 框架优势:

  • 组合式编程: 链式和并行组合
  • 类型安全: 强类型检查
  • 高性能: 支持并发执行
  • 可观测性: 内置链路追踪

🔧 检索链实现

检索链构建

位置: backend/domain/knowledge/service/retrieve.go:55

go
func (k *knowledgeSVC) Retrieve(ctx context.Context, request *RetrieveRequest) (*RetrieveResponse, error) {
    // 1. 构建检索上下文
    retrieveContext, err := k.newRetrieveContext(ctx, request)
    if err != nil {
        return nil, err
    }
    
    // 2. 构建 Eino 检索链
    chain := compose.NewChain[*RetrieveContext, []*knowledgeModel.RetrieveSlice]()
    
    // 3. 定义处理节点
    rewriteNode := compose.InvokableLambda(k.queryRewriteNode)           // 查询重写
    vectorRetrieveNode := compose.InvokableLambda(k.vectorRetrieveNode)  // 向量检索
    esRetrieveNode := compose.InvokableLambda(k.esRetrieveNode)         // ES检索
    nl2SqlRetrieveNode := compose.InvokableLambda(k.nl2SqlRetrieveNode) // NL2SQL检索
    passRequestContextNode := compose.InvokableLambda(k.passRequestContext) // 上下文传递
    reRankNode := compose.InvokableLambda(k.reRankNode)                 // 重排序
    packResult := compose.InvokableLambda(k.packResults)                // 结果打包
    
    // 4. 构建并行召回节点
    parallelNode := compose.NewParallel().
        AddLambda("vectorRetrieveNode", vectorRetrieveNode).
        AddLambda("esRetrieveNode", esRetrieveNode).
        AddLambda("nl2SqlRetrieveNode", nl2SqlRetrieveNode).
        AddLambda("passRequestContext", passRequestContextNode)
    
    // 5. 组装完整检索链
    r, err := chain.
        AppendLambda(rewriteNode).          // 查询重写
        AppendParallel(parallelNode).       // 并行召回
        AppendLambda(reRankNode).           // 重排序
        AppendLambda(packResult).           // 结果打包
        Compile(ctx)
    
    if err != nil {
        return nil, errorx.New(errno.ErrKnowledgeBuildRetrieveChainFailCode, errorx.KV("msg", err.Error()))
    }
    
    // 6. 执行检索链
    output, err := r.Invoke(ctx, retrieveContext)
    return &RetrieveResponse{Slices: output}, err
}

检索上下文设计

go
type RetrieveContext struct {
    // 原始查询信息
    Query               string
    KnowledgeIDs        []int64
    TopK                int
    ScoreThreshold      float64
    
    // 重写后的查询
    RewrittenQueries    []string
    
    // 检索结果
    VectorResults       []*VectorRetrieveResult
    ESResults          []*ESRetrieveResult
    NL2SQLResults      []*NL2SQLRetrieveResult
    
    // 上下文信息
    Documents          []*entity.Document
    SearchStoreMap     map[int64]searchstore.SearchStore
    
    // 配置参数
    EnableRewrite      bool
    EnableRerank       bool
    RerankTopK         int
}

🚀 多路召回实现

1. 向量检索

核心思想: 基于语义相似度的密集检索

go
func (k *knowledgeSVC) vectorRetrieveNode(ctx context.Context, retrieveCtx *RetrieveContext) (*RetrieveContext, error) {
    var allResults []*VectorRetrieveResult
    
    // 对每个知识库并行检索
    for knowledgeID, searchStore := range retrieveCtx.SearchStoreMap {
        // 构建检索请求
        retrieveRequest := &retriever.RetrieveRequest{
            Query:  retrieveCtx.Query,
            TopK:   retrieveCtx.TopK,
            Option: map[string]interface{}{
                "score_threshold": retrieveCtx.ScoreThreshold,
                "collection_name": getCollectionName(knowledgeID),
            },
        }
        
        // 执行向量检索
        retrieveResp, err := searchStore.Retrieve(ctx, retrieveRequest)
        if err != nil {
            logs.CtxErrorf(ctx, "vector retrieve failed: %v", err)
            continue
        }
        
        // 转换检索结果
        for _, doc := range retrieveResp.Documents {
            result := &VectorRetrieveResult{
                SliceID:     extractSliceID(doc.ID),
                Score:       doc.Score,
                Content:     doc.Content,
                KnowledgeID: knowledgeID,
                Source:      "vector",
            }
            allResults = append(allResults, result)
        }
    }
    
    // 按分数排序
    sort.Slice(allResults, func(i, j int) bool {
        return allResults[i].Score > allResults[j].Score
    })
    
    retrieveCtx.VectorResults = allResults
    return retrieveCtx, nil
}

2. 全文检索 (Elasticsearch)

核心思想: 基于关键词匹配的稀疏检索

go
func (k *knowledgeSVC) esRetrieveNode(ctx context.Context, retrieveCtx *RetrieveContext) (*RetrieveContext, error) {
    var allResults []*ESRetrieveResult
    
    for knowledgeID, searchStore := range retrieveCtx.SearchStoreMap {
        // 构建ES查询
        query := buildESQuery(retrieveCtx.Query)
        
        retrieveRequest := &retriever.RetrieveRequest{
            Query:  query,
            TopK:   retrieveCtx.TopK,
            Option: map[string]interface{}{
                "index_name": getIndexName(knowledgeID),
                "highlight": map[string]interface{}{
                    "fields": map[string]interface{}{
                        "content": map[string]interface{}{},
                    },
                },
            },
        }
        
        // 执行ES检索
        retrieveResp, err := searchStore.Retrieve(ctx, retrieveRequest)
        if err != nil {
            logs.CtxErrorf(ctx, "es retrieve failed: %v", err)
            continue
        }
        
        // 处理检索结果
        for _, doc := range retrieveResp.Documents {
            result := &ESRetrieveResult{
                SliceID:     extractSliceID(doc.ID),
                Score:       doc.Score,
                Content:     doc.Content,
                Highlight:   extractHighlight(doc.Metadata),
                KnowledgeID: knowledgeID,
                Source:      "elasticsearch",
            }
            allResults = append(allResults, result)
        }
    }
    
    retrieveCtx.ESResults = allResults
    return retrieveCtx, nil
}

// 构建ES查询DSL
func buildESQuery(userQuery string) string {
    queryDSL := map[string]interface{}{
        "query": map[string]interface{}{
            "bool": map[string]interface{}{
                "should": []interface{}{
                    // 精确匹配
                    map[string]interface{}{
                        "match_phrase": map[string]interface{}{
                            "content": map[string]interface{}{
                                "query": userQuery,
                                "boost": 3.0,
                            },
                        },
                    },
                    // 模糊匹配
                    map[string]interface{}{
                        "match": map[string]interface{}{
                            "content": map[string]interface{}{
                                "query":     userQuery,
                                "boost":     2.0,
                                "fuzziness": "AUTO",
                            },
                        },
                    },
                    // 多字段匹配
                    map[string]interface{}{
                        "multi_match": map[string]interface{}{
                            "query":  userQuery,
                            "fields": []string{"title^2", "content"},
                            "boost":  1.0,
                        },
                    },
                },
                "minimum_should_match": 1,
            },
        },
    }
    
    jsonBytes, _ := json.Marshal(queryDSL)
    return string(jsonBytes)
}

3. 结构化检索 (NL2SQL)

核心思想: 自然语言转SQL查询表格数据

go
func (k *knowledgeSVC) nl2SqlRetrieveNode(ctx context.Context, retrieveCtx *RetrieveContext) (*RetrieveContext, error) {
    var allResults []*NL2SQLRetrieveResult
    
    // 只对表格类型文档执行NL2SQL检索
    tableDocuments := filterTableDocuments(retrieveCtx.Documents)
    if len(tableDocuments) == 0 {
        retrieveCtx.NL2SQLResults = allResults
        return retrieveCtx, nil
    }
    
    for _, doc := range tableDocuments {
        // 构建表结构描述
        schema := buildTableSchema(doc.TableInfo)
        
        // 调用NL2SQL服务
        nl2sqlRequest := &nl2sql.TranslateRequest{
            Query:        retrieveCtx.Query,
            TableSchema:  schema,
            DatabaseType: "mysql",
        }
        
        sqlResponse, err := k.nl2Sql.Translate(ctx, nl2sqlRequest)
        if err != nil {
            logs.CtxErrorf(ctx, "nl2sql translate failed: %v", err)
            continue
        }
        
        // 执行SQL查询
        queryRequest := &rdb.QueryDataRequest{
            TableName: doc.TableInfo.PhysicalTableName,
            SQL:       sqlResponse.SQL,
            Limit:     retrieveCtx.TopK,
        }
        
        queryResponse, err := k.rdb.QueryData(ctx, queryRequest)
        if err != nil {
            logs.CtxErrorf(ctx, "sql query failed: %v", err)
            continue
        }
        
        // 构建结果
        for _, row := range queryResponse.Rows {
            result := &NL2SQLRetrieveResult{
                SliceID:     extractSliceIDFromRow(row),
                Content:     formatRowAsText(row, doc.TableInfo.Columns),
                SQL:         sqlResponse.SQL,
                KnowledgeID: doc.KnowledgeID,
                Source:      "nl2sql",
                Score:       1.0, // NL2SQL结果固定高分
            }
            allResults = append(allResults, result)
        }
    }
    
    retrieveCtx.NL2SQLResults = allResults
    return retrieveCtx, nil
}

// 构建表结构描述
func buildTableSchema(tableInfo *entity.TableInfo) *nl2sql.TableSchema {
    columns := make([]*nl2sql.Column, 0, len(tableInfo.Columns))
    
    for _, col := range tableInfo.Columns {
        columns = append(columns, &nl2sql.Column{
            Name:        col.Name,
            Type:        convertColumnType(col.Type),
            Description: col.Description,
            IsPrimary:   col.Name == "id",
            IsNullable:  !col.Indexing,
        })
    }
    
    return &nl2sql.TableSchema{
        TableName:   tableInfo.VirtualTableName,
        Description: tableInfo.TableDesc,
        Columns:     columns,
    }
}

🎯 查询重写实现

查询重写策略

go
func (k *knowledgeSVC) queryRewriteNode(ctx context.Context, retrieveCtx *RetrieveContext) (*RetrieveContext, error) {
    if !retrieveCtx.EnableRewrite || k.rewriter == nil {
        retrieveCtx.RewrittenQueries = []string{retrieveCtx.Query}
        return retrieveCtx, nil
    }
    
    // 构建重写请求
    rewriteRequest := &messages2query.MessagesToQueryRequest{
        Messages: []*messages2query.Message{
            {
                Role:    "user",
                Content: retrieveCtx.Query,
            },
        },
        Context: buildRewriteContext(retrieveCtx),
    }
    
    // 执行查询重写
    rewriteResponse, err := k.rewriter.Rewrite(ctx, rewriteRequest)
    if err != nil {
        logs.CtxErrorf(ctx, "query rewrite failed: %v", err)
        retrieveCtx.RewrittenQueries = []string{retrieveCtx.Query}
        return retrieveCtx, nil
    }
    
    // 合并原查询和重写查询
    queries := []string{retrieveCtx.Query}
    for _, query := range rewriteResponse.Queries {
        if query != retrieveCtx.Query {
            queries = append(queries, query)
        }
    }
    
    retrieveCtx.RewrittenQueries = queries
    return retrieveCtx, nil
}

// 构建重写上下文
func buildRewriteContext(retrieveCtx *RetrieveContext) map[string]interface{} {
    return map[string]interface{}{
        "knowledge_domains": extractDomains(retrieveCtx.Documents),
        "document_types":    extractDocumentTypes(retrieveCtx.Documents),
        "table_schemas":     extractTableSchemas(retrieveCtx.Documents),
    }
}

🔄 重排序实现

重排序策略

go
func (k *knowledgeSVC) reRankNode(ctx context.Context, retrieveCtx *RetrieveContext) (*RetrieveContext, error) {
    if !retrieveCtx.EnableRerank {
        return retrieveCtx, nil
    }
    
    // 1. 收集所有检索结果
    allResults := collectAllResults(retrieveCtx)
    if len(allResults) == 0 {
        return retrieveCtx, nil
    }
    
    // 2. RRF 算法融合多路结果
    rrfResults := k.applyRRF(allResults)
    
    // 3. 语义重排序
    if k.reranker != nil && len(rrfResults) > 1 {
        rerankResults, err := k.applySemanticRerank(ctx, retrieveCtx.Query, rrfResults)
        if err != nil {
            logs.CtxErrorf(ctx, "semantic rerank failed: %v", err)
        } else {
            rrfResults = rerankResults
        }
    }
    
    // 4. 更新检索上下文
    retrieveCtx.FinalResults = rrfResults[:min(len(rrfResults), retrieveCtx.RerankTopK)]
    return retrieveCtx, nil
}

// RRF (Reciprocal Rank Fusion) 算法实现
func (k *knowledgeSVC) applyRRF(allResults []*RetrieveResult) []*RetrieveResult {
    const k = 60.0 // RRF 常数
    
    // 按来源分组
    sourceGroups := make(map[string][]*RetrieveResult)
    for _, result := range allResults {
        sourceGroups[result.Source] = append(sourceGroups[result.Source], result)
    }
    
    // 为每个来源的结果排序
    for source, results := range sourceGroups {
        sort.Slice(results, func(i, j int) bool {
            return results[i].Score > results[j].Score
        })
    }
    
    // 计算RRF分数
    rrfScores := make(map[string]float64)
    for source, results := range sourceGroups {
        for rank, result := range results {
            key := fmt.Sprintf("%d_%s", result.SliceID, result.Source)
            rrfScores[key] += 1.0 / (k + float64(rank+1))
        }
    }
    
    // 去重并按RRF分数排序
    uniqueResults := make(map[int64]*RetrieveResult)
    for _, result := range allResults {
        key := fmt.Sprintf("%d_%s", result.SliceID, result.Source)
        if existing, exists := uniqueResults[result.SliceID]; !exists || rrfScores[key] > existing.RRFScore {
            result.RRFScore = rrfScores[key]
            uniqueResults[result.SliceID] = result
        }
    }
    
    // 转换为切片并排序
    finalResults := make([]*RetrieveResult, 0, len(uniqueResults))
    for _, result := range uniqueResults {
        finalResults = append(finalResults, result)
    }
    
    sort.Slice(finalResults, func(i, j int) bool {
        return finalResults[i].RRFScore > finalResults[j].RRFScore
    })
    
    return finalResults
}

// 语义重排序实现
func (k *knowledgeSVC) applySemanticRerank(ctx context.Context, query string, results []*RetrieveResult) ([]*RetrieveResult, error) {
    // 构建重排序文档
    documents := make([]rerank.Document, 0, len(results))
    for i, result := range results {
        documents = append(documents, rerank.Document{
            ID:      strconv.Itoa(i),
            Content: result.Content,
        })
    }
    
    // 执行重排序
    rerankRequest := &rerank.RerankRequest{
        Query:     query,
        Documents: documents,
        TopK:      len(documents),
    }
    
    rerankResponse, err := k.reranker.Rerank(ctx, rerankRequest)
    if err != nil {
        return nil, err
    }
    
    // 重新排序结果
    rerankedResults := make([]*RetrieveResult, 0, len(rerankResponse.Documents))
    for _, doc := range rerankResponse.Documents {
        index, _ := strconv.Atoi(doc.ID)
        if index < len(results) {
            result := results[index]
            result.RerankScore = doc.Score
            rerankedResults = append(rerankedResults, result)
        }
    }
    
    return rerankedResults, nil
}

💾 缓存策略实现

多层缓存架构

go
type CacheManager struct {
    // L1: 内存缓存 (进程内)
    memoryCache *sync.Map
    
    // L2: Redis缓存 (分布式)
    redisCache cache.Cmdable
    
    // 缓存配置
    config CacheConfig
}

type CacheConfig struct {
    MemoryCacheSize    int           // 内存缓存大小
    MemoryExpire      time.Duration // 内存缓存过期时间
    RedisExpire       time.Duration // Redis缓存过期时间
    EnableCompression bool          // 是否启用压缩
}

检索结果缓存

go
func (k *knowledgeSVC) getCachedResults(ctx context.Context, cacheKey string) ([]*RetrieveResult, bool) {
    // L1: 检查内存缓存
    if value, ok := k.memoryCache.Load(cacheKey); ok {
        if cachedResults, ok := value.([]*RetrieveResult); ok {
            return cachedResults, true
        }
    }
    
    // L2: 检查Redis缓存
    cmd := k.redisCache.Get(ctx, cacheKey)
    if cmd.Err() == nil {
        var cachedResults []*RetrieveResult
        if err := json.Unmarshal([]byte(cmd.Val()), &cachedResults); err == nil {
            // 回填内存缓存
            k.memoryCache.Store(cacheKey, cachedResults)
            return cachedResults, true
        }
    }
    
    return nil, false
}

func (k *knowledgeSVC) setCachedResults(ctx context.Context, cacheKey string, results []*RetrieveResult) {
    // 存储到内存缓存
    k.memoryCache.Store(cacheKey, results)
    
    // 存储到Redis缓存
    if data, err := json.Marshal(results); err == nil {
        k.redisCache.Set(ctx, cacheKey, data, k.config.RedisExpire)
    }
}

// 生成缓存键
func generateCacheKey(request *RetrieveRequest) string {
    hasher := sha256.New()
    hasher.Write([]byte(request.Query))
    hasher.Write([]byte(fmt.Sprintf("%v", request.KnowledgeIDs)))
    hasher.Write([]byte(fmt.Sprintf("%d_%f", request.TopK, request.ScoreThreshold)))
    
    return fmt.Sprintf("retrieve:%x", hasher.Sum(nil))
}

预热策略

go
// 热点查询预热
func (k *knowledgeSVC) warmupCache(ctx context.Context) {
    // 获取热点查询
    hotQueries := k.getHotQueries(ctx)
    
    // 并发预热
    var wg sync.WaitGroup
    semaphore := make(chan struct{}, 10) // 限制并发数
    
    for _, query := range hotQueries {
        wg.Add(1)
        go func(q string) {
            defer wg.Done()
            semaphore <- struct{}{}
            defer func() { <-semaphore }()
            
            // 执行预热查询
            request := &RetrieveRequest{
                Query:        q,
                KnowledgeIDs: k.getAllKnowledgeIDs(ctx),
                TopK:         20,
            }
            
            _, _ = k.Retrieve(ctx, request)
        }(query)
    }
    
    wg.Wait()
}

📊 性能监控

性能指标收集

go
type RetrieveMetrics struct {
    // 延迟指标
    TotalLatency       time.Duration
    RewriteLatency     time.Duration
    VectorLatency      time.Duration
    ESLatency          time.Duration
    NL2SQLLatency      time.Duration
    RerankLatency      time.Duration
    
    // 召回指标
    VectorRecallCount  int
    ESRecallCount      int
    NL2SQLRecallCount  int
    FinalResultCount   int
    
    // 缓存指标
    CacheHitRate       float64
    CacheQueryCount    int
    
    // 错误指标
    ErrorCount         int
    ErrorRate          float64
}

func (k *knowledgeSVC) recordMetrics(ctx context.Context, metrics *RetrieveMetrics) {
    // 记录到Prometheus
    retrieveLatencyHistogram.Observe(metrics.TotalLatency.Seconds())
    cacheHitRateGauge.Set(metrics.CacheHitRate)
    
    // 记录到日志
    logs.CtxInfof(ctx, "retrieve metrics: %+v", metrics)
}

🚨 故障处理

降级策略

go
func (k *knowledgeSVC) handleRetrieveFailure(ctx context.Context, retrieveCtx *RetrieveContext, err error) *RetrieveResponse {
    logs.CtxErrorf(ctx, "retrieve failed, applying fallback strategy: %v", err)
    
    // 1. 尝试缓存降级
    if cachedResults := k.getCachedFallbackResults(ctx, retrieveCtx.Query); len(cachedResults) > 0 {
        return &RetrieveResponse{Slices: cachedResults}
    }
    
    // 2. 简化检索策略
    if simpleResults := k.simpleKeywordSearch(ctx, retrieveCtx); len(simpleResults) > 0 {
        return &RetrieveResponse{Slices: simpleResults}
    }
    
    // 3. 返回空结果
    return &RetrieveResponse{Slices: []*knowledgeModel.RetrieveSlice{}}
}

// 简化关键词搜索
func (k *knowledgeSVC) simpleKeywordSearch(ctx context.Context, retrieveCtx *RetrieveContext) []*knowledgeModel.RetrieveSlice {
    // 使用数据库LIKE查询作为最后的降级方案
    keywords := extractKeywords(retrieveCtx.Query)
    
    var results []*knowledgeModel.RetrieveSlice
    for _, keyword := range keywords {
        slices, err := k.sliceRepo.SearchByKeyword(ctx, keyword, 10)
        if err == nil {
            for _, slice := range slices {
                results = append(results, convertSliceToRetrieveSlice(slice))
            }
        }
    }
    
    return results
}

文档版本: v1.0
最后更新: 2025-10-27
相关文档: 知识库架构总览

飞视数字技术|AI智能技术服务商