本系列从第 01 篇开始逐步构建一个能读取真实 SQLite 文件的 Rust 程序:第 01–02 篇解析文件头和页头,第 03 篇遍历 B-Tree,第 04 篇解码 varint 和 Record,第 05 篇读取 sqlite_schema,第 06 篇实现带列名的 SELECT *,第 07 篇实现 SELECT count(*)。这一篇在第 06 篇的基础上加入 WHERE 条件过滤,让我们能执行:
SELECT * FROM users WHERE name = 'Alice'
SELECT * FROM users WHERE age = 30
策略是最朴素的:全表扫描后过滤——先读出所有行,再逐行判断目标列的值是否与过滤值匹配,符合条件的行才输出。
一、实现目标
本篇要支持的 SQL 形式是等值过滤,语法约定如下:
SELECT * FROM <table> WHERE <col> = '<string_value>'— 字符串值用单引号包裹SELECT * FROM <table> WHERE <col> = <numeric_value>— 数值直接写,不加引号
过滤逻辑:全表扫描每一个叶页 Cell,解码每行数据,找到目标列对应的值,将其转为字符串后与过滤值做字符串比较,相等则保留该行。
整体流程如下图:
SQL 字符串
↓ parse_sql()
↓ 发现含 WHERE 关键字
↓ parse_where_clause() → WhereClause { column, value }
↓
select_where(pager, table, where_clause)
↓ 读 sqlite_schema → 列定义 + root_page
↓ 找 where_clause.column 对应的列下标 col_idx
↓ collect_leaf_cells() → 所有叶页 Cell
↓ 对每个 Cell:
parse_leaf_cell() → (rowid, values)
组装完整行(处理 rowid 别名)
values[col_idx] 转字符串 == where_clause.value ?
是 → 加入结果集
↓
QueryResult { columns, rows }
二、SQL 解析
在 main.rs(或独立的 parser.rs)中扩展解析逻辑,识别 SQL 字符串里的 WHERE 子句,提取出 列名 和 过滤值 两个字符串。
// src/main.rs(解析部分)
/// 从 WHERE 子句中提取 (列名, 值字符串) 元组
/// 支持:
/// WHERE name = 'Alice' → ("name", "Alice")
/// WHERE age = 30 → ("age", "30")
fn parse_where_clause(where_str: &str) -> Option<(String, String)> {
// where_str 是 "WHERE name = 'Alice'" 或 "WHERE age = 30"
// 先去掉前缀 "WHERE"(大小写不敏感)
let body = where_str.trim();
let body = if body.to_uppercase().starts_with("WHERE") {
body[5..].trim()
} else {
body
};
// 按 '=' 切分,左边是列名,右边是值
let eq_pos = body.find('=')?;
let col_name = body[..eq_pos].trim().to_string();
let val_raw = body[eq_pos + 1..].trim();
// 处理字符串值:去掉单引号
let value = if val_raw.starts_with('\'') && val_raw.ends_with('\'') && val_raw.len() >= 2 {
val_raw[1..val_raw.len() - 1].to_string()
} else {
// 数值或无引号值:直接使用
val_raw.to_string()
};
Some((col_name, value))
}
/// 检测 SQL 是否含 WHERE 子句,返回 (table_name, WhereClause) 或 None
fn parse_select_where(sql: &str) -> Option<(String, crate::sql::WhereClause)> {
let upper = sql.to_uppercase();
// 必须以 SELECT * FROM 开头
let rest = upper.strip_prefix("SELECT")?
.trim_start()
.strip_prefix('*')?
.trim_start()
.strip_prefix("FROM")?
.trim_start();
// 找 WHERE 关键字
let where_upper_pos = rest.find("WHERE")?;
let table_upper = rest[..where_upper_pos].trim();
// 从原始 sql 还原表名(保持大小写)
let offset = sql.len() - rest.len();
let table_name = sql[offset..offset + table_upper.len()].trim().to_string();
// 提取 WHERE 子句(从原始 sql 中取,保持值的大小写)
let where_offset = offset + where_upper_pos;
let where_str = &sql[where_offset..];
let (col, val) = parse_where_clause(where_str)?;
Some((table_name, crate::sql::WhereClause { column: col, value: val }))
}
关键点:
- 用
to_uppercase()做关键字匹配,但从原始sql字符串中取表名和值,保留大小写 - 字符串值通过检测前后单引号来识别,去掉引号后存入
WhereClause.value - 数值(如
30)直接以字符串形式存储,过滤时再做字符串比较
三、过滤实现
在 src/sql.rs 中新增 WhereClause 结构体和 select_where() 函数。
// src/sql.rs
use crate::btree::collect_leaf_cells;
use crate::pager::Pager;
use crate::record::{parse_leaf_cell, Value};
use crate::schema::{read_schema, parse_column_defs, ColumnDef};
pub struct QueryResult {
pub columns: Vec<String>,
pub rows: Vec<Vec<String>>,
}
/// WHERE 等值过滤条件
pub struct WhereClause {
pub column: String, // 列名,如 "name"
pub value: String, // 过滤值(字符串形式),如 "Alice" 或 "30"
}
/// 执行 SELECT * FROM <table> WHERE <col> = <val>
pub fn select_where(
pager: &mut Pager,
table_name: &str,
clause: &WhereClause,
) -> Result<QueryResult, String> {
// 1. 从 sqlite_schema 拿到表的 root_page 和列定义
let entries = read_schema(pager);
let entry = entries.iter()
.find(|e| e.object_type == "table" && e.name == table_name)
.ok_or_else(|| format!("table '{}' not found", table_name))?;
let root_page = entry.root_page;
let col_defs: Vec<ColumnDef> = entry.sql.as_deref()
.map(parse_column_defs)
.unwrap_or_default();
let col_names: Vec<String> = col_defs.iter().map(|c| c.name.clone()).collect();
// 2. 找到 WHERE 列对应的下标(在 col_names 中的位置)
let filter_col_idx = col_names.iter()
.position(|n| n.eq_ignore_ascii_case(&clause.column))
.ok_or_else(|| format!("column '{}' not found in table '{}'", clause.column, table_name))?;
// 3. 遍历所有叶页 Cell,解码每行,按条件过滤
let cells = collect_leaf_cells(pager, root_page);
let mut rows = Vec::new();
for cell_addr in &cells {
let page_data = pager.read_page(cell_addr.page_num).unwrap();
let (rowid, values) = parse_leaf_cell(&page_data, cell_addr.offset as usize);
// 4. 组装完整行(处理 INTEGER PRIMARY KEY rowid 别名)
let mut row = Vec::new();
let mut val_idx = 0usize;
for col_def in &col_defs {
if col_def.is_rowid_alias {
row.push(rowid.to_string());
} else {
let v = values.get(val_idx)
.map(value_to_string)
.unwrap_or_else(|| "NULL".to_string());
row.push(v);
val_idx += 1;
}
}
// 5. 取目标列的值,与过滤值做字符串比较
if let Some(cell_val) = row.get(filter_col_idx) {
if cell_val == &clause.value {
rows.push(row);
}
}
}
Ok(QueryResult { columns: col_names, rows })
}
fn value_to_string(v: &Value) -> String {
match v {
Value::Null => "NULL".to_string(),
Value::Integer(n) => n.to_string(),
Value::Float(f) => {
if *f == f.trunc() && f.abs() < 1e15 {
format!("{}", *f as i64)
} else {
format!("{}", f)
}
}
Value::Text(s) => s.clone(),
Value::Blob(b) => format!("<blob {}>", b.len()),
}
}
核心步骤说明:
- 获取列定义:从
sqlite_schema.sql解析出col_defs,包含列名和 rowid 别名标记 - 定位列下标:用
eq_ignore_ascii_case大小写不敏感地在列名列表中查找 WHERE 列的位置,返回错误而非 panic - 扫描并组装行:与
select_all()相同的行组装逻辑,正确处理INTEGER PRIMARY KEY列 - 字符串比较过滤:将组装好的行中目标列的值(已通过
value_to_string()转换)与clause.value做==比较
四、更新 main.rs 调度
在 main.rs 的命令路由中,检测 SQL 字符串是否含 WHERE 子句,有则走 select_where(),无则走原来的 select_all()。
// src/main.rs
mod header; mod page; mod pager; mod btree;
mod varint; mod record; mod schema; mod sql;
use header::DbHeader;
use pager::Pager;
use schema::read_schema;
use sql::{select_all, select_where};
fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() < 3 {
eprintln!("Usage: sqlite-rs <db> <.tables | SQL>");
std::process::exit(1);
}
let db_path = &args[1];
let command = args[2..].join(" ");
let db_header = DbHeader::read_from_file(db_path).unwrap();
let mut pager = Pager::open(db_path, db_header.page_size).unwrap();
if command == ".tables" {
let entries = read_schema(&mut pager);
let names: Vec<&str> = entries.iter()
.filter(|e| e.object_type == "table" && !e.name.starts_with("sqlite_"))
.map(|e| e.name.as_str())
.collect();
println!("{}", names.join(" "));
return;
}
// 检测 SELECT count(*)
if let Some(table_name) = parse_select_count(&command) {
// ... count(*) 逻辑(第 07 篇)
let _ = table_name;
return;
}
// 检测 SELECT * FROM table WHERE col = val
if let Some((table_name, where_clause)) = parse_select_where(&command) {
match select_where(&mut pager, &table_name, &where_clause) {
Ok(result) => print_result(&result),
Err(e) => { eprintln!("Error: {}", e); std::process::exit(1); }
}
return;
}
// 检测 SELECT * FROM table(无 WHERE)
if let Some(table_name) = parse_select_star(&command) {
match select_all(&mut pager, table_name) {
Ok(result) => print_result(&result),
Err(e) => { eprintln!("Error: {}", e); std::process::exit(1); }
}
return;
}
eprintln!("Unsupported command: {}", command);
eprintln!("Supported: .tables | SELECT * FROM <table> [WHERE col = val] | SELECT count(*) FROM <table>");
std::process::exit(1);
}
fn print_result(result: &sql::QueryResult) {
println!("{}", result.columns.join("|"));
for row in &result.rows {
println!("{}", row.join("|"));
}
}
/// 从 "SELECT * FROM users" 解析表名(忽略 WHERE 及后续内容)
fn parse_select_star(sql: &str) -> Option<&str> {
let upper = sql.to_uppercase();
let rest = upper.strip_prefix("SELECT")?.trim_start()
.strip_prefix('*')?.trim_start()
.strip_prefix("FROM")?
.trim_start();
let offset = sql.len() - rest.len();
// 表名到下一个空格(WHERE 前)
let name = sql[offset..].split_whitespace().next()?;
Some(name)
}
fn parse_select_count(sql: &str) -> Option<&str> {
let upper = sql.to_uppercase();
let rest = upper.strip_prefix("SELECT")?.trim_start()
.strip_prefix("COUNT(*)")?.trim_start()
.strip_prefix("FROM")?
.trim_start();
let offset = sql.len() - rest.len();
Some(sql[offset..].split_whitespace().next()?)
}
调度顺序很重要:先检测含 WHERE 的 SELECT *,再检测不含 WHERE 的 SELECT *,避免误匹配。
五、测试
字符串字段过滤
cargo run -- test.db "SELECT * FROM users WHERE name = 'Alice'"
id|name|age
1|Alice|30
全表有 3 行(Alice/Bob/Charlie),只返回 name = 'Alice' 的那一行,id 列正确填入 rowid 值 1。
整数字段过滤
cargo run -- test.db "SELECT * FROM users WHERE age = 30"
id|name|age
1|Alice|30
数值过滤通过字符串比较实现:存储值 30(Integer)经 value_to_string() 转为 "30",与过滤值 "30" 相等,匹配成功。
大文件,多行匹配
cargo run -- big.db "SELECT * FROM users WHERE age = 42"
id|name|age
42|User42|42
142|User142|42
242|User242|42
342|User342|42
442|User442|42
big.db 有 500 行,age 字段循环分布(age = (id % 100) + 20),age = 42 的有 5 行,全部正确命中。跨越多个叶页的扫描工作正常。
无匹配行
cargo run -- test.db "SELECT * FROM users WHERE name = 'Dave'"
id|name|age
没有匹配行时只输出列名,行为符合预期。
列名不存在
cargo run -- test.db "SELECT * FROM users WHERE score = 100"
Error: column 'score' not found in table 'users'
明确报错,不会 panic 或静默返回空集。
六、这种方式的局限
当前实现有三个明显局限,了解它们有助于理解后续篇章要解决的问题:
全表扫描,不使用索引
每次过滤都要遍历所有叶页的所有 Cell。对于 500 行的 big.db 这没问题,但如果表有 500 万行,即使结果只有 1 行,也要扫描 500 万条记录。SQLite 真正的查询优化器会在有索引的情况下走 B-Tree 索引查找,只需 O(log N) 次页访问。第 09 篇将探索索引 B-Tree 的结构。
纯字符串比较
过滤时把所有值都转成字符串再做 == 比较。这对整数和短字符串是正确的,但对浮点数可能有精度问题(如 3.14000000000001 vs 3.14),对 BLOB 类型无法工作,对需要大小写不敏感匹配的场景也不对。真实的 SQL 引擎会保留类型信息,按类型做比较。
仅支持等值过滤
当前只支持 col = val。范围查询(age > 25、age BETWEEN 20 AND 40)、多条件组合(AND/OR)、LIKE 模糊匹配等都尚未实现。这些需要更完整的表达式求值框架。
七、代码结构
sqlite-rs/
└── src/
├── main.rs ← 命令路由:.tables / SELECT * / SELECT * WHERE / count(*)(本篇新增 WHERE 分支)
├── header.rs ← 文件头解析(第 01 篇)
├── page.rs ← 页头 + Cell 指针(第 02 篇)
├── pager.rs ← 页读取器(第 02 篇)
├── btree.rs ← B-Tree 叶页遍历(第 03 篇)
├── varint.rs ← varint 解码(第 04 篇)
├── record.rs ← Record 解析(第 04 篇)
├── schema.rs ← sqlite_schema + 列定义解析(第 05、06 篇)
└── sql.rs ← select_all / select_where / WhereClause(本篇新增)
本篇改动集中在两处:
sql.rs:新增WhereClause结构体和select_where()函数main.rs:新增parse_select_where()和parse_where_clause()解析函数,在调度逻辑中插入 WHERE 分支
原有的 select_all()、parse_column_defs()、B-Tree 遍历等代码没有任何改动,过滤逻辑完全是在已有基础上的增量。
八、关键点总结
- WHERE 过滤的核心是:先用已有的
select_all逻辑组装完整行,再按列下标取值做字符串比较 - 解析 WHERE 子句时要区分单引号字符串值和无引号数值,统一去掉引号后存为字符串
- 列名查找用
eq_ignore_ascii_case做大小写不敏感匹配,找不到时返回Err,不 panic - 调度顺序:含 WHERE 的分支必须在不含 WHERE 的
SELECT *分支之前检测,否则parse_select_star会先匹配成功并截断表名 - 全表扫描过滤对小表完全够用;对大表需要索引——这正是下一篇(第 09 篇)的主题:解析索引 B-Tree 结构,实现基于索引的等值查找,将扫描代价从 O(N) 降到 O(log N)