手写 SQLite 06:实现 SELECT * FROM table 带列名输出

上一篇我们实现了全表扫描,但输出是 Alice|30 这样无列名的格式。这一篇做两件事:

  1. sqlite_schema.sql 解析列定义,拿到每列的名字
  2. 正确处理 INTEGER PRIMARY KEY——它是 rowid 的别名,不存在 Record 里,需要特殊填充

完成后,SELECT * FROM users 能输出:

id|name|age
1|Alice|30
2|Bob|25
3|Charlie|35

一、列定义从哪来

sqlite_schemasql 列存储了建表 SQL,例如:

CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)

从这段 SQL 中解析出列名列表,是本篇的核心工作。我们不需要实现完整的 SQL 解析器——只需要提取括号内的列定义,然后按逗号切分,取每段的第一个 token 作为列名。

二、解析列定义

// src/schema.rs (新增函数)

/// 列定义:列名 + 是否是 INTEGER PRIMARY KEY(rowid 别名)
#[derive(Debug, Clone)]
pub struct ColumnDef {
    pub name:          String,
    pub is_rowid_alias: bool,  // INTEGER PRIMARY KEY 列
}

/// 从 CREATE TABLE SQL 中提取列定义列表
/// 输入示例:
///   "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)"
pub fn parse_column_defs(create_sql: &str) -> Vec<ColumnDef> {
    // 1. 找到第一个 '(' 和最后一个 ')',取括号内的内容
    let start = match create_sql.find('(') {
        Some(i) => i + 1,
        None    => return vec![],
    };
    let end = match create_sql.rfind(')') {
        Some(i) => i,
        None    => return vec![],
    };
    let cols_str = &create_sql[start..end];

    // 2. 按顶层逗号切分(需要处理括号嵌套,如 CHECK(..., ...) 中的逗号)
    let col_segments = split_top_level_commas(cols_str);

    // 3. 解析每个列段
    let mut defs = Vec::new();
    for seg in col_segments {
        let seg = seg.trim();
        if seg.is_empty() { continue; }

        // 跳过表级约束(PRIMARY KEY(...), UNIQUE(...), FOREIGN KEY(...)等)
        let upper = seg.to_uppercase();
        if upper.starts_with("PRIMARY KEY")
            || upper.starts_with("UNIQUE")
            || upper.starts_with("CHECK")
            || upper.starts_with("FOREIGN KEY")
            || upper.starts_with("CONSTRAINT")
        {
            continue;
        }

        // 列名是第一个 token(可能被引号包裹)
        let name = extract_identifier(seg);
        if name.is_empty() { continue; }

        // 判断是否是 INTEGER PRIMARY KEY(rowid 别名)
        // 条件:类型包含 INTEGER,且有 PRIMARY KEY 约束
        let is_rowid_alias = upper.contains("INTEGER")
            && upper.contains("PRIMARY")
            && upper.contains("KEY")
            && !upper.contains("FOREIGN");

        defs.push(ColumnDef { name, is_rowid_alias });
    }
    defs
}

/// 按顶层逗号切分(括号内的逗号不切)
fn split_top_level_commas(s: &str) -> Vec<&str> {
    let mut result = Vec::new();
    let mut depth  = 0usize;
    let mut start  = 0usize;

    for (i, ch) in s.char_indices() {
        match ch {
            '(' => depth += 1,
            ')' => depth = depth.saturating_sub(1),
            ',' if depth == 0 => {
                result.push(&s[start..i]);
                start = i + 1;
            }
            _ => {}
        }
    }
    result.push(&s[start..]);
    result
}

/// 从列定义段提取标识符(处理引号和反引号)
fn extract_identifier(s: &str) -> String {
    let s = s.trim();
    // 处理 `name`、"name"、[name] 三种引号形式
    if let Some(stripped) = s.strip_prefix('`') {
        return stripped.split('`').next().unwrap_or("").to_string();
    }
    if let Some(stripped) = s.strip_prefix('"') {
        return stripped.split('"').next().unwrap_or("").to_string();
    }
    if let Some(stripped) = s.strip_prefix('[') {
        return stripped.split(']').next().unwrap_or("").to_string();
    }
    // 无引号:取第一个空白前的内容
    s.split_whitespace().next().unwrap_or("").to_string()
}

三、INTEGER PRIMARY KEY 的特殊处理

这是 SQLite 最重要的特殊规则之一:

  • 如果某列声明为 INTEGER PRIMARY KEY,它成为 rowid 的别名
  • 该列的值不存储在 Record 里,Record 中对应位置的 Serial Type 为 0(NULL)
  • 读取时需要用 rowid 填充这一列

实际验证:

python3 -c "
data = open('test.db','rb').read()
# Page 1 Cell[0] 偏移
cell_off = int.from_bytes(data[108:110], 'big')
b = data[cell_off:]
print('payload_size:', b[0])
print('rowid:', b[1])
header_size = b[2]
print('header_size:', header_size)
# serial types: b[3] ~ b[3+header_size-1]
sts = list(b[3:3+header_size-1])
print('serial_types:', sts)
# 第一个 serial type 是 0 (NULL),对应 id 列(rowid 别名)
# 第二个是 23(TEXT 5字节),对应 name='Alice'
# 第三个是 4(INT32),对应 age=30
"
serial_types: [0, 23, 4]
              ↑ id 列是 NULL(用 rowid 填充)

四、完整实现

// src/sql.rs  — SQL 查询执行器

use crate::btree::collect_leaf_cells;
use crate::pager::Pager;
use crate::record::{parse_leaf_cell, Value};
use crate::schema::{read_schema, parse_column_defs, ColumnDef};

pub struct QueryResult {
    pub columns: Vec<String>,
    pub rows:    Vec<Vec<String>>,
}

/// 执行 SELECT * FROM <table_name>
pub fn select_all(pager: &mut Pager, table_name: &str) -> Result<QueryResult, String> {
    // 1. 从 sqlite_schema 找到表定义
    let entries = read_schema(pager);
    let entry = entries.iter()
        .find(|e| e.object_type == "table" && e.name == table_name)
        .ok_or_else(|| format!("table '{}' not found", table_name))?;

    let root_page = entry.root_page;
    let col_defs: Vec<ColumnDef> = entry.sql.as_deref()
        .map(parse_column_defs)
        .unwrap_or_default();

    let col_names: Vec<String> = col_defs.iter().map(|c| c.name.clone()).collect();

    // 2. 遍历 B-Tree,解析每一行
    let cells = collect_leaf_cells(pager, root_page);
    let mut rows = Vec::new();

    for cell_addr in &cells {
        let page_data = pager.read_page(cell_addr.page_num).unwrap();
        let (rowid, values) = parse_leaf_cell(&page_data, cell_addr.offset as usize);

        // 3. 按列定义顺序组装行
        //    values 对应 Record 里的列,可能比 col_defs 少一列(rowid 别名列)
        let mut row = Vec::new();
        let mut val_idx = 0usize;

        for col_def in &col_defs {
            if col_def.is_rowid_alias {
                // INTEGER PRIMARY KEY:用 rowid 填充,不消耗 values
                row.push(rowid.to_string());
            } else {
                // 普通列:从 values 取下一个值
                let v = values.get(val_idx)
                    .map(value_to_string)
                    .unwrap_or_else(|| "NULL".to_string());
                row.push(v);
                val_idx += 1;
            }
        }
        rows.push(row);
    }

    Ok(QueryResult { columns: col_names, rows })
}

fn value_to_string(v: &Value) -> String {
    match v {
        Value::Null       => "NULL".to_string(),
        Value::Integer(n) => n.to_string(),
        Value::Float(f)   => {
            // 去掉多余的小数零:500000.0 → "500000"
            if *f == f.trunc() && f.abs() < 1e15 {
                format!("{}", *f as i64)
            } else {
                format!("{}", f)
            }
        }
        Value::Text(s)    => s.clone(),
        Value::Blob(b)    => format!("<blob {}>", b.len()),
    }
}

更新 main.rs

// src/main.rs

mod header; mod page; mod pager; mod btree;
mod varint; mod record; mod schema; mod sql;

use header::DbHeader;
use pager::Pager;
use schema::read_schema;
use sql::select_all;

fn main() {
    let args: Vec<String> = std::env::args().collect();
    if args.len() < 3 {
        eprintln!("Usage: sqlite-rs <db> <.tables | SELECT * FROM <table>>");
        std::process::exit(1);
    }

    let db_path = &args[1];
    // 把剩余参数拼成一条"SQL"
    let command = args[2..].join(" ");

    let db_header = DbHeader::read_from_file(db_path).unwrap();
    let mut pager = Pager::open(db_path, db_header.page_size).unwrap();

    if command == ".tables" {
        let entries = read_schema(&mut pager);
        let names: Vec<&str> = entries.iter()
            .filter(|e| e.object_type == "table" && !e.name.starts_with("sqlite_"))
            .map(|e| e.name.as_str())
            .collect();
        println!("{}", names.join(" "));
        return;
    }

    // 简单解析 SELECT * FROM <table>
    if let Some(table_name) = parse_select_star(&command) {
        match select_all(&mut pager, table_name) {
            Ok(result) => {
                println!("{}", result.columns.join("|"));
                for row in &result.rows {
                    println!("{}", row.join("|"));
                }
            }
            Err(e) => {
                eprintln!("Error: {}", e);
                std::process::exit(1);
            }
        }
    } else {
        eprintln!("Unsupported command: {}", command);
        eprintln!("Supported: .tables | SELECT * FROM <table>");
        std::process::exit(1);
    }
}

/// 从 "SELECT * FROM users" 解析出表名 "users"
fn parse_select_star(sql: &str) -> Option<&str> {
    let upper = sql.to_uppercase();
    // 匹配 SELECT * FROM <name> 或 SELECT * FROM <name> (忽略后续条件)
    let rest = upper.strip_prefix("SELECT")?.trim_start()
        .strip_prefix('*')?.trim_start()
        .strip_prefix("FROM")?
        .trim_start();
    // 从原始 sql 里找对应位置(保持大小写)
    let offset = sql.len() - rest.len();
    // 表名到下一个空格或字符串结尾
    let name = sql[offset..].split_whitespace().next()?;
    Some(name)
}

五、测试

cargo run -- test.db "SELECT * FROM users"
id|name|age
1|Alice|30
2|Bob|25
3|Charlie|35

id 列现在正确填入了 rowid 值。

cargo run -- multi.db "SELECT * FROM employees"
id|name|dept
1|Alice|Engineering
2|Bob|Marketing
cargo run -- multi.db "SELECT * FROM departments"
id|name|budget
1|Engineering|500000
2|Marketing|200000

浮点数 500000.0 被正确格式化为 500000(去掉多余小数零)。

cargo run -- big.db "SELECT * FROM users" | head -5
id|name|age
1|User1|21
2|User2|22
3|User3|23
4|User4|24

500 行,全部输出正确,列名也对了。

六、边界情况处理

无 PRIMARY KEY 的表

如果表没有 INTEGER PRIMARY KEY 列,所有列都从 Record 直接读,没有 rowid 填充逻辑,正常工作。

列名带引号

CREATE TABLE t ("my col" TEXT, `other` INTEGER)

extract_identifier() 处理了三种引号形式(`"[),能正确提取列名。

多列主键 / 表级 PRIMARY KEY

CREATE TABLE t (a INTEGER, b TEXT, PRIMARY KEY(a, b))

这种情况下没有 rowid 别名列(SQLite 文档明确:只有单列 INTEGER PRIMARY KEY 才是 rowid 别名),所有列都正常读 Record。split_top_level_commas 会把 PRIMARY KEY(a, b) 作为一个表级约束段,被 starts_with("PRIMARY KEY") 过滤掉,不产生列定义。

七、当前代码结构

sqlite-rs/
└── src/
    ├── main.rs     ← SELECT * / .tables 命令路由(本篇)
    ├── header.rs   ← 文件头(第 01 篇)
    ├── page.rs     ← 页头 + Cell 指针(第 02 篇)
    ├── pager.rs    ← 页读取器(第 02 篇)
    ├── btree.rs    ← B-Tree 遍历(第 03 篇)
    ├── varint.rs   ← varint 解码(第 04 篇)
    ├── record.rs   ← Record 解析(第 04 篇)
    ├── schema.rs   ← sqlite_schema 解析 + 列定义解析(第 05、06 篇)
    └── sql.rs      ← SQL 执行器(本篇)

八、关键点总结

  • sqlite_schema.sql 字段存有 CREATE TABLE SQL,从中可解析列定义
  • 只需要提取括号内容、按顶层逗号切分、取每段第一个 token 作为列名
  • INTEGER PRIMARY KEY 是 rowid 别名:Record 里该列 Serial Type 为 0(NULL),需要用 rowid 值填充
  • 多列主键或表级 PRIMARY KEY 不是 rowid 别名,不需要特殊处理
  • 浮点数格式化需要去掉多余小数零(500000.0500000

下一篇:实现 SELECT count(*) FROM table——只统计行数,不解码每行数据,直接用 B-Tree 遍历收集到的 Cell 数量作为结果,速度极快。