使用 Go 实现简化的 Lua 解释器

Lua 是一种轻量、嵌入式的脚本语言,广泛用于游戏开发和嵌入式系统。实现一个 Lua 解释器是学习编译原理和 Go 语言的绝佳实践。本文将展示如何使用 Go 开发一个简化的 Lua 解释器,支持变量赋值、基本算术运算和函数调用。文章将涵盖词法分析、语法解析、执行逻辑。

背景与动机

Lua 解释器以其简单高效的设计而闻名,核心功能包括词法分析(Lexing)、语法解析(Parsing)和执行(Execution)。本项目旨在实现一个最小化的 Lua 解释器,专注于以下特性:

  • 变量赋值:支持 a = 42 形式的赋值。
  • 算术运算:支持加减乘除(如 a = 1 + 2 * 3)。
  • 函数调用:支持简单内置函数(如 print)。
  • 错误处理:捕获语法错误和运行时错误。

🛠️ 项目设计

1. 架构概述

解释器分为以下模块:

  • 词法分析器(Lexer):将 Lua 代码分解为 Token(如数字、运算符、标识符)。
  • 语法解析器(Parser):构建抽象语法树(AST),表示代码结构。
  • 执行器(Evaluator):遍历 AST,执行计算和函数调用。
  • 环境(Environment):管理变量和内置函数。

2. 支持的 Lua 语法

为简化实现,我们支持以下 Lua 子集:

1
2
3
4
5
6
-- 变量赋值
x = 42
y = x + 10 * 2

-- 函数调用
print(x + y)

3. 技术栈

  • Go 1.22+:使用标准库处理字符串和正则表达式。
  • 算法
    • 词法分析:基于有限状态机(FSM)解析 Token。
    • 语法解析:使用递归下降解析(Recursive Descent Parsing)。
    • 执行:基于 AST 的树遍历,时间复杂度 O(n)。

🔧 实现代码

以下是完整的 Go 实现,包含词法分析、语法解析和执行逻辑。

main.go

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
package main

import (
	"fmt"
	"strconv"
	"strings"
)

// TokenType 定义 Token 类型
type TokenType int

const (
	TokenEOF TokenType = iota
	TokenNumber
	TokenIdentifier
	TokenEqual
	TokenPlus
	TokenMinus
	TokenMultiply
	TokenDivide
	TokenLParen
	TokenRParen
	TokenPrint
)

// Token 表示词法单元
type Token struct {
	Type    TokenType
	Literal string
}

// Lexer 词法分析器
type Lexer struct {
	input   string
	pos     int
	tokens  []Token
}

func NewLexer(input string) *Lexer {
	return &Lexer{input: input}
}

func (l *Lexer) NextToken() Token {
	for l.pos < len(l.input) {
		ch := l.input[l.pos]
		switch {
		case ch == ' ' || ch == '\t' || ch == '\n':
			l.pos++
			continue
		case ch == '=':
			l.pos++
			return Token{Type: TokenEqual, Literal: "="}
		case ch == '+':
			l.pos++
			return Token{Type: TokenPlus, Literal: "+"}
		case ch == '-':
			l.pos++
			return Token{Type: TokenMinus, Literal: "-"}
		case ch == '*':
			l.pos++
			return Token{Type: TokenMultiply, Literal: "*"}
		case ch == '/':
			l.pos++
			return Token{Type: TokenDivide, Literal: "/"}
		case ch == '(':
			l.pos++
			return Token{Type: TokenLParen, Literal: "("}
		case ch == ')':
			l.pos++
			return Token{Type: TokenRParen, Literal: ")"}
		case isLetter(ch):
			start := l.pos
			for l.pos < len(l.input) && (isLetter(l.input[l.pos]) || isDigit(l.input[l.pos])) {
				l.pos++
			}
			literal := l.input[start:l.pos]
			if literal == "print" {
				return Token{Type: TokenPrint, Literal: literal}
			}
			return Token{Type: TokenIdentifier, Literal: literal}
		case isDigit(ch):
			start := l.pos
			for l.pos < len(l.input) && isDigit(l.input[l.pos]) {
				l.pos++
			}
			return Token{Type: TokenNumber, Literal: l.input[start:l.pos]}
		default:
			l.pos++
			return Token{Type: TokenEOF, Literal: ""}
		}
	}
	return Token{Type: TokenEOF, Literal: ""}
}

func isLetter(ch byte) bool {
	return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')
}

func isDigit(ch byte) bool {
	return ch >= '0' && ch <= '9'
}

// NodeType 定义 AST 节点类型
type NodeType int

const (
	NodeNumber NodeType = iota
	NodeIdentifier
	NodeBinaryOp
	NodeAssign
	NodePrint
)

// Node 表示 AST 节点
type Node struct {
	Type     NodeType
	Value    string
	Left     *Node
	Right    *Node
	Operator TokenType
}

// Parser 语法解析器
type Parser struct {
	lexer  *Lexer
	tokens []Token
	pos    int
}

func NewParser(lexer *Lexer) *Parser {
	p := &Parser{lexer: lexer}
	for {
		token := lexer.NextToken()
		p.tokens = append(p.tokens, token)
		if token.Type == TokenEOF {
			break
		}
	}
	return p
}

func (p *Parser) Parse() ([]*Node, error) {
	var statements []*Node
	for p.pos < len(p.tokens) && p.tokens[p.pos].Type != TokenEOF {
		stmt, err := p.parseStatement()
		if err != nil {
			return nil, err
		}
		statements = append(statements, stmt)
	}
	return statements, nil
}

func (p *Parser) parseStatement() (*Node, error) {
	if p.pos < len(p.tokens) && p.tokens[p.pos].Type == TokenPrint {
		return p.parsePrint()
	}
	return p.parseAssign()
}

func (p *Parser) parsePrint() (*Node, error) {
	if p.pos >= len(p.tokens) || p.tokens[p.pos].Type != TokenPrint {
		return nil, fmt.Errorf("expected print")
	}
	p.pos++
	if p.pos >= len(p.tokens) || p.tokens[p.pos].Type != TokenLParen {
		return nil, fmt.Errorf("expected (")
	}
	p.pos++
	expr, err := p.parseExpression()
	if err != nil {
		return nil, err
	}
	if p.pos >= len(p.tokens) || p.tokens[p.pos].Type != TokenRParen {
		return nil, fmt.Errorf("expected )")
	}
	p.pos++
	return &Node{Type: NodePrint, Left: expr}, nil
}

func (p *Parser) parseAssign() (*Node, error) {
	if p.pos >= len(p.tokens) || p.tokens[p.pos].Type != TokenIdentifier {
		return nil, fmt.Errorf("expected identifier")
	}
	ident := p.tokens[p.pos]
	p.pos++
	if p.pos >= len(p.tokens) || p.tokens[p.pos].Type != TokenEqual {
		return nil, fmt.Errorf("expected =")
	}
	p.pos++
	expr, err := p.parseExpression()
	if err != nil {
		return nil, err
	}
	return &Node{Type: NodeAssign, Value: ident.Literal, Right: expr}, nil
}

func (p *Parser) parseExpression() (*Node, error) {
	node, err := p.parseTerm()
	if err != nil {
		return nil, err
	}
	for p.pos < len(p.tokens) && (p.tokens[p.pos].Type == TokenPlus || p.tokens[p.pos].Type == TokenMinus) {
		op := p.tokens[p.pos]
		p.pos++
		right, err := p.parseTerm()
		if err != nil {
			return nil, err
		}
		node = &Node{Type: NodeBinaryOp, Operator: op.Type, Left: node, Right: right}
	}
	return node, nil
}

func (p *Parser) parseTerm() (*Node, error) {
	node, err := p.parseFactor()
	if err != nil {
		return nil, err
	}
	for p.pos < len(p.tokens) && (p.tokens[p.pos].Type == TokenMultiply || p.tokens[p.pos].Type == TokenDivide) {
		op := p.tokens[p.pos]
		p.pos++
		right, err := p.parseFactor()
		if err != nil {
			return nil, err
		}
		node = &Node{Type: NodeBinaryOp, Operator: op.Type, Left: node, Right: right}
	}
	return node, nil
}

func (p *Parser) parseFactor() (*Node, error) {
	if p.pos >= len(p.tokens) {
		return nil, fmt.Errorf("unexpected EOF")
	}
	token := p.tokens[p.pos]
	switch token.Type {
	case TokenNumber:
		p.pos++
		return &Node{Type: Node this, Value: token.Literal}, nil
	case TokenIdentifier:
		p.pos++
		return &Node{Type: NodeIdentifier, Value: token.Literal}, nil
	case TokenLParen:
		p.pos++
		expr, err := p.parseExpression()
		if err != nil {
			return nil, err
		}
		if p.pos >= len(p.tokens) || p.tokens[p.pos].Type != TokenRParen {
			return nil, fmt.Errorf("expected )")
		}
		p.pos++
		return expr, nil
	default:
		return nil, fmt.Errorf("unexpected token: %v", token.Literal)
	}
}

// Environment 存储变量和内置函数
type Environment struct {
	vars map[string]float64
}

func NewEnvironment() *Environment {
	return &Environment{vars: make(map[string]float64)}
}

// Evaluator 执行器
type Evaluator struct {
	env *Environment
}

func NewEvaluator() *Evaluator {
	return &Evaluator{env: NewEnvironment()}
}

func (e *Evaluator) Eval(statements []*Node) error {
	for _, stmt := range statements {
		_, err := e.evalNode(stmt)
		if err != nil {
			return err
		}
	}
	return nil
}

func (e *Evaluator) evalNode(node *Node) (float64, error) {
	switch node.Type {
	case NodeNumber:
		return strconv.ParseFloat(node.Value, 64)
	case NodeIdentifier:
		if val, ok := e.env.vars[node.Value]; ok {
			return val, nil
		}
		return 0, fmt.Errorf("undefined variable: %s", node.Value)
	case NodeBinaryOp:
		left, err := e.evalNode(node.Left)
		if err != nil {
			return 0, err
		}
		right, err := e.evalNode(node.Right)
		if err != nil {
			return 0, err
		}
		switch node.Operator {
		case TokenPlus:
			return left + right, nil
		case TokenMinus:
			return left - right, nil
		case TokenMultiply:
			return left * right, nil
		case TokenDivide:
			if right == 0 {
				return 0, fmt.Errorf("division by zero")
			}
			return left / right, nil
		}
	case NodeAssign:
		value, err := e.evalNode(node.Right)
		if err != nil {
			return 0, err
		}
		e.env.vars[node.Value] = value
		return value, nil
	case NodePrint:
		value, err := e.evalNode(node.Left)
		if err != nil {
			return 0, err
		}
		fmt.Printf("%f\n", value)
		return value, nil
	}
	return 0, fmt.Errorf("unknown node type")
}

func main() {
	script := `
x = 42
y = x + 10 * 2
print(x + y)
`
	lexer := NewLexer(script)
	parser := NewParser(lexer)
	statements, err := parser.Parse()
	if err != nil {
		fmt.Printf("Parse error: %v\n", err)
		return
	}
	evaluator := NewEvaluator()
	if err := evaluator.Eval(statements); err != nil {
		fmt.Printf("Eval error: %v\n", err)
		return
	}
}

代码说明

  • 词法分析器(Lexer)
    • 使用有限状态机(FSM)逐字符扫描输入,生成 Token 序列。
    • 支持数字、标识符、运算符(如 +, -, *, /)、括号和 print 关键字。
    • 时间复杂度:O(n),n 为输入字符串长度。
  • 语法解析器(Parser)
    • 使用递归下降解析(Recursive Descent Parsing),构建 AST。
    • 支持赋值语句(x = expr)、表达式(1 + 2 * 3)和 print 函数调用。
    • 优先级处理:*/ 高于 +-
  • 执行器(Evaluator)
    • 遍历 AST,执行算术运算和变量赋值。
    • 使用 Environment 存储变量,print 函数输出结果。
    • 时间复杂度:O(n),n 为 AST 节点数。
  • 错误处理
    • 捕获语法错误(如缺失括号)和运行时错误(如未定义变量、除零)。

运行示例

1
go run main.go

输入脚本:

1
2
3
x = 42
y = x + 10 * 2
print(x + y)

输出:

1
62

参考资源

使用 Hugo 构建
主题 StackJimmy 设计