手写 SQLite 08:实现 WHERE 条件过滤

本系列从第 01 篇开始逐步构建一个能读取真实 SQLite 文件的 Rust 程序:第 01–02 篇解析文件头和页头,第 03 篇遍历 B-Tree,第 04 篇解码 varint 和 Record,第 05 篇读取 sqlite_schema,第 06 篇实现带列名的 SELECT *,第 07 篇实现 SELECT count(*)。这一篇在第 06 篇的基础上加入 WHERE 条件过滤,让我们能执行:

SELECT * FROM users WHERE name = 'Alice'
SELECT * FROM users WHERE age = 30

策略是最朴素的:全表扫描后过滤——先读出所有行,再逐行判断目标列的值是否与过滤值匹配,符合条件的行才输出。

一、实现目标

本篇要支持的 SQL 形式是等值过滤,语法约定如下:

  • SELECT * FROM <table> WHERE <col> = '<string_value>' — 字符串值用单引号包裹
  • SELECT * FROM <table> WHERE <col> = <numeric_value> — 数值直接写,不加引号

过滤逻辑:全表扫描每一个叶页 Cell,解码每行数据,找到目标列对应的值,将其转为字符串后与过滤值做字符串比较,相等则保留该行。

整体流程如下图:

SQL 字符串
    ↓ parse_sql()
    ↓ 发现含 WHERE 关键字
    ↓ parse_where_clause()  →  WhereClause { column, value }
    ↓
select_where(pager, table, where_clause)
    ↓ 读 sqlite_schema → 列定义 + root_page
    ↓ 找 where_clause.column 对应的列下标 col_idx
    ↓ collect_leaf_cells() → 所有叶页 Cell
    ↓ 对每个 Cell:
        parse_leaf_cell() → (rowid, values)
        组装完整行(处理 rowid 别名)
        values[col_idx] 转字符串 == where_clause.value ?
        是 → 加入结果集
    ↓
QueryResult { columns, rows }

二、SQL 解析

main.rs(或独立的 parser.rs)中扩展解析逻辑,识别 SQL 字符串里的 WHERE 子句,提取出 列名过滤值 两个字符串。

// src/main.rs(解析部分)

/// 从 WHERE 子句中提取 (列名, 值字符串) 元组
/// 支持:
///   WHERE name = 'Alice'   →  ("name", "Alice")
///   WHERE age = 30         →  ("age", "30")
fn parse_where_clause(where_str: &str) -> Option<(String, String)> {
    // where_str 是 "WHERE name = 'Alice'" 或 "WHERE age = 30"
    // 先去掉前缀 "WHERE"(大小写不敏感)
    let body = where_str.trim();
    let body = if body.to_uppercase().starts_with("WHERE") {
        body[5..].trim()
    } else {
        body
    };

    // 按 '=' 切分,左边是列名,右边是值
    let eq_pos = body.find('=')?;
    let col_name = body[..eq_pos].trim().to_string();
    let val_raw  = body[eq_pos + 1..].trim();

    // 处理字符串值:去掉单引号
    let value = if val_raw.starts_with('\'') && val_raw.ends_with('\'') && val_raw.len() >= 2 {
        val_raw[1..val_raw.len() - 1].to_string()
    } else {
        // 数值或无引号值:直接使用
        val_raw.to_string()
    };

    Some((col_name, value))
}

/// 检测 SQL 是否含 WHERE 子句,返回 (table_name, WhereClause) 或 None
fn parse_select_where(sql: &str) -> Option<(String, crate::sql::WhereClause)> {
    let upper = sql.to_uppercase();

    // 必须以 SELECT * FROM 开头
    let rest = upper.strip_prefix("SELECT")?
        .trim_start()
        .strip_prefix('*')?
        .trim_start()
        .strip_prefix("FROM")?
        .trim_start();

    // 找 WHERE 关键字
    let where_upper_pos = rest.find("WHERE")?;
    let table_upper = rest[..where_upper_pos].trim();

    // 从原始 sql 还原表名(保持大小写)
    let offset = sql.len() - rest.len();
    let table_name = sql[offset..offset + table_upper.len()].trim().to_string();

    // 提取 WHERE 子句(从原始 sql 中取,保持值的大小写)
    let where_offset = offset + where_upper_pos;
    let where_str = &sql[where_offset..];
    let (col, val) = parse_where_clause(where_str)?;

    Some((table_name, crate::sql::WhereClause { column: col, value: val }))
}

关键点:

  • to_uppercase() 做关键字匹配,但从原始 sql 字符串中取表名和值,保留大小写
  • 字符串值通过检测前后单引号来识别,去掉引号后存入 WhereClause.value
  • 数值(如 30)直接以字符串形式存储,过滤时再做字符串比较

三、过滤实现

src/sql.rs 中新增 WhereClause 结构体和 select_where() 函数。

// src/sql.rs

use crate::btree::collect_leaf_cells;
use crate::pager::Pager;
use crate::record::{parse_leaf_cell, Value};
use crate::schema::{read_schema, parse_column_defs, ColumnDef};

pub struct QueryResult {
    pub columns: Vec<String>,
    pub rows:    Vec<Vec<String>>,
}

/// WHERE 等值过滤条件
pub struct WhereClause {
    pub column: String,   // 列名,如 "name"
    pub value:  String,   // 过滤值(字符串形式),如 "Alice" 或 "30"
}

/// 执行 SELECT * FROM <table> WHERE <col> = <val>
pub fn select_where(
    pager:  &mut Pager,
    table_name: &str,
    clause: &WhereClause,
) -> Result<QueryResult, String> {
    // 1. 从 sqlite_schema 拿到表的 root_page 和列定义
    let entries = read_schema(pager);
    let entry = entries.iter()
        .find(|e| e.object_type == "table" && e.name == table_name)
        .ok_or_else(|| format!("table '{}' not found", table_name))?;

    let root_page = entry.root_page;
    let col_defs: Vec<ColumnDef> = entry.sql.as_deref()
        .map(parse_column_defs)
        .unwrap_or_default();

    let col_names: Vec<String> = col_defs.iter().map(|c| c.name.clone()).collect();

    // 2. 找到 WHERE 列对应的下标(在 col_names 中的位置)
    let filter_col_idx = col_names.iter()
        .position(|n| n.eq_ignore_ascii_case(&clause.column))
        .ok_or_else(|| format!("column '{}' not found in table '{}'", clause.column, table_name))?;

    // 3. 遍历所有叶页 Cell,解码每行,按条件过滤
    let cells = collect_leaf_cells(pager, root_page);
    let mut rows = Vec::new();

    for cell_addr in &cells {
        let page_data = pager.read_page(cell_addr.page_num).unwrap();
        let (rowid, values) = parse_leaf_cell(&page_data, cell_addr.offset as usize);

        // 4. 组装完整行(处理 INTEGER PRIMARY KEY rowid 别名)
        let mut row = Vec::new();
        let mut val_idx = 0usize;

        for col_def in &col_defs {
            if col_def.is_rowid_alias {
                row.push(rowid.to_string());
            } else {
                let v = values.get(val_idx)
                    .map(value_to_string)
                    .unwrap_or_else(|| "NULL".to_string());
                row.push(v);
                val_idx += 1;
            }
        }

        // 5. 取目标列的值,与过滤值做字符串比较
        if let Some(cell_val) = row.get(filter_col_idx) {
            if cell_val == &clause.value {
                rows.push(row);
            }
        }
    }

    Ok(QueryResult { columns: col_names, rows })
}

fn value_to_string(v: &Value) -> String {
    match v {
        Value::Null       => "NULL".to_string(),
        Value::Integer(n) => n.to_string(),
        Value::Float(f)   => {
            if *f == f.trunc() && f.abs() < 1e15 {
                format!("{}", *f as i64)
            } else {
                format!("{}", f)
            }
        }
        Value::Text(s)    => s.clone(),
        Value::Blob(b)    => format!("<blob {}>", b.len()),
    }
}

核心步骤说明:

  1. 获取列定义:从 sqlite_schema.sql 解析出 col_defs,包含列名和 rowid 别名标记
  2. 定位列下标:用 eq_ignore_ascii_case 大小写不敏感地在列名列表中查找 WHERE 列的位置,返回错误而非 panic
  3. 扫描并组装行:与 select_all() 相同的行组装逻辑,正确处理 INTEGER PRIMARY KEY
  4. 字符串比较过滤:将组装好的行中目标列的值(已通过 value_to_string() 转换)与 clause.value== 比较

四、更新 main.rs 调度

main.rs 的命令路由中,检测 SQL 字符串是否含 WHERE 子句,有则走 select_where(),无则走原来的 select_all()

// src/main.rs

mod header; mod page; mod pager; mod btree;
mod varint; mod record; mod schema; mod sql;

use header::DbHeader;
use pager::Pager;
use schema::read_schema;
use sql::{select_all, select_where};

fn main() {
    let args: Vec<String> = std::env::args().collect();
    if args.len() < 3 {
        eprintln!("Usage: sqlite-rs <db> <.tables | SQL>");
        std::process::exit(1);
    }

    let db_path = &args[1];
    let command = args[2..].join(" ");

    let db_header = DbHeader::read_from_file(db_path).unwrap();
    let mut pager = Pager::open(db_path, db_header.page_size).unwrap();

    if command == ".tables" {
        let entries = read_schema(&mut pager);
        let names: Vec<&str> = entries.iter()
            .filter(|e| e.object_type == "table" && !e.name.starts_with("sqlite_"))
            .map(|e| e.name.as_str())
            .collect();
        println!("{}", names.join(" "));
        return;
    }

    // 检测 SELECT count(*)
    if let Some(table_name) = parse_select_count(&command) {
        // ... count(*) 逻辑(第 07 篇)
        let _ = table_name;
        return;
    }

    // 检测 SELECT * FROM table WHERE col = val
    if let Some((table_name, where_clause)) = parse_select_where(&command) {
        match select_where(&mut pager, &table_name, &where_clause) {
            Ok(result) => print_result(&result),
            Err(e) => { eprintln!("Error: {}", e); std::process::exit(1); }
        }
        return;
    }

    // 检测 SELECT * FROM table(无 WHERE)
    if let Some(table_name) = parse_select_star(&command) {
        match select_all(&mut pager, table_name) {
            Ok(result) => print_result(&result),
            Err(e) => { eprintln!("Error: {}", e); std::process::exit(1); }
        }
        return;
    }

    eprintln!("Unsupported command: {}", command);
    eprintln!("Supported: .tables | SELECT * FROM <table> [WHERE col = val] | SELECT count(*) FROM <table>");
    std::process::exit(1);
}

fn print_result(result: &sql::QueryResult) {
    println!("{}", result.columns.join("|"));
    for row in &result.rows {
        println!("{}", row.join("|"));
    }
}

/// 从 "SELECT * FROM users" 解析表名(忽略 WHERE 及后续内容)
fn parse_select_star(sql: &str) -> Option<&str> {
    let upper = sql.to_uppercase();
    let rest = upper.strip_prefix("SELECT")?.trim_start()
        .strip_prefix('*')?.trim_start()
        .strip_prefix("FROM")?
        .trim_start();
    let offset = sql.len() - rest.len();
    // 表名到下一个空格(WHERE 前)
    let name = sql[offset..].split_whitespace().next()?;
    Some(name)
}

fn parse_select_count(sql: &str) -> Option<&str> {
    let upper = sql.to_uppercase();
    let rest = upper.strip_prefix("SELECT")?.trim_start()
        .strip_prefix("COUNT(*)")?.trim_start()
        .strip_prefix("FROM")?
        .trim_start();
    let offset = sql.len() - rest.len();
    Some(sql[offset..].split_whitespace().next()?)
}

调度顺序很重要:先检测含 WHERESELECT *,再检测不含 WHERESELECT *,避免误匹配。

五、测试

字符串字段过滤

cargo run -- test.db "SELECT * FROM users WHERE name = 'Alice'"
id|name|age
1|Alice|30

全表有 3 行(Alice/Bob/Charlie),只返回 name = 'Alice' 的那一行,id 列正确填入 rowid 值 1。

整数字段过滤

cargo run -- test.db "SELECT * FROM users WHERE age = 30"
id|name|age
1|Alice|30

数值过滤通过字符串比较实现:存储值 30(Integer)经 value_to_string() 转为 "30",与过滤值 "30" 相等,匹配成功。

大文件,多行匹配

cargo run -- big.db "SELECT * FROM users WHERE age = 42"
id|name|age
42|User42|42
142|User142|42
242|User242|42
342|User342|42
442|User442|42

big.db 有 500 行,age 字段循环分布(age = (id % 100) + 20),age = 42 的有 5 行,全部正确命中。跨越多个叶页的扫描工作正常。

无匹配行

cargo run -- test.db "SELECT * FROM users WHERE name = 'Dave'"
id|name|age

没有匹配行时只输出列名,行为符合预期。

列名不存在

cargo run -- test.db "SELECT * FROM users WHERE score = 100"
Error: column 'score' not found in table 'users'

明确报错,不会 panic 或静默返回空集。

六、这种方式的局限

当前实现有三个明显局限,了解它们有助于理解后续篇章要解决的问题:

全表扫描,不使用索引

每次过滤都要遍历所有叶页的所有 Cell。对于 500 行的 big.db 这没问题,但如果表有 500 万行,即使结果只有 1 行,也要扫描 500 万条记录。SQLite 真正的查询优化器会在有索引的情况下走 B-Tree 索引查找,只需 O(log N) 次页访问。第 09 篇将探索索引 B-Tree 的结构。

纯字符串比较

过滤时把所有值都转成字符串再做 == 比较。这对整数和短字符串是正确的,但对浮点数可能有精度问题(如 3.14000000000001 vs 3.14),对 BLOB 类型无法工作,对需要大小写不敏感匹配的场景也不对。真实的 SQL 引擎会保留类型信息,按类型做比较。

仅支持等值过滤

当前只支持 col = val。范围查询(age > 25age BETWEEN 20 AND 40)、多条件组合(AND/OR)、LIKE 模糊匹配等都尚未实现。这些需要更完整的表达式求值框架。

七、代码结构

sqlite-rs/
└── src/
    ├── main.rs     ← 命令路由:.tables / SELECT * / SELECT * WHERE / count(*)(本篇新增 WHERE 分支)
    ├── header.rs   ← 文件头解析(第 01 篇)
    ├── page.rs     ← 页头 + Cell 指针(第 02 篇)
    ├── pager.rs    ← 页读取器(第 02 篇)
    ├── btree.rs    ← B-Tree 叶页遍历(第 03 篇)
    ├── varint.rs   ← varint 解码(第 04 篇)
    ├── record.rs   ← Record 解析(第 04 篇)
    ├── schema.rs   ← sqlite_schema + 列定义解析(第 05、06 篇)
    └── sql.rs      ← select_all / select_where / WhereClause(本篇新增)

本篇改动集中在两处:

  • sql.rs:新增 WhereClause 结构体和 select_where() 函数
  • main.rs:新增 parse_select_where()parse_where_clause() 解析函数,在调度逻辑中插入 WHERE 分支

原有的 select_all()parse_column_defs()、B-Tree 遍历等代码没有任何改动,过滤逻辑完全是在已有基础上的增量。

八、关键点总结

  • WHERE 过滤的核心是:先用已有的 select_all 逻辑组装完整行,再按列下标取值做字符串比较
  • 解析 WHERE 子句时要区分单引号字符串值和无引号数值,统一去掉引号后存为字符串
  • 列名查找用 eq_ignore_ascii_case 做大小写不敏感匹配,找不到时返回 Err,不 panic
  • 调度顺序:含 WHERE 的分支必须在不含 WHERE 的 SELECT * 分支之前检测,否则 parse_select_star 会先匹配成功并截断表名
  • 全表扫描过滤对小表完全够用;对大表需要索引——这正是下一篇(第 09 篇)的主题:解析索引 B-Tree 结构,实现基于索引的等值查找,将扫描代价从 O(N) 降到 O(log N)