阅读视图

发现新文章,点击刷新页面。
🔲 ☆

使用Elasticsearch分析腾讯云EO日志

腾讯云EO可以查看一些指标信息,但是更加详细的信息需要我们下载离线日志自行分析。

获取日志下载链接

腾讯云会将日志打包为.gz格式,解压后文件会包含多行,每一行都是一个JSON格式的数据,对应一条EO的请求日志,日志格式可以参考腾讯云文档

我们可以批量获取最近一个月的日志下载链接

之后复制所有链接并保存到urls.txt文件中。

启动Elasticsearch集群

我们参考官方文档使用docker来启动集群,首先下载.envdocker-compose.yml,之后在.env文件中设置es和kibana的密码都是123456,然后设置STACK_VERSION=9.2.3。考虑到数据量比较大,可以提高容器的内存大小,我这里设置了一台8G。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Password for the 'elastic' user (at least 6 characters)
ELASTIC_PASSWORD=123456

# Password for the 'kibana_system' user (at least 6 characters)
KIBANA_PASSWORD=123456

# Version of Elastic products
STACK_VERSION=9.2.3

# Set the cluster name
CLUSTER_NAME=elasticsearch-cluster

# Set to 'basic' or 'trial' to automatically start the 30-day trial
LICENSE=basic

# Port to expose Elasticsearch HTTP API to the host
ES_PORT=9200

# Port to expose Kibana to the host
KIBANA_PORT=5601

# Increase or decrease based on the available host memory (in bytes)
MEM_LIMIT=8589934592

# Project namespace (defaults to the current folder name if not set)
COMPOSE_PROJECT_NAME=elasticsearch-project

设置好了之后使用命令docker-compose up -d启动ES集群。

之后可以通过http://127.0.0.1:5601访问kibana,用户名elastic,密码123456。

写入日志

使用如下的代码下载解析日志,并保存到ES中

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gzip
import json
import os
from datetime import datetime
from urllib.parse import urlparse

import requests
from elasticsearch import Elasticsearch, helpers

ES_URL = "https://localhost:9200"
ES_USER = "elastic"
ES_PASSWORD = "123456"
INDEX_NAME = "eo_logs"
DOWNLOAD_DIR = "downloaded_logs"

es = Elasticsearch([ES_URL], basic_auth=(ES_USER, ES_PASSWORD), verify_certs=False, ssl_show_warn=False)
os.makedirs(DOWNLOAD_DIR, exist_ok=True)


def download_file(url):
filename = os.path.basename(urlparse(url).path)
filepath = os.path.join(DOWNLOAD_DIR, filename)
if os.path.exists(filepath):
print(f"文件已存在: {filename}")
return filepath
print(f"下载: {filename}")
response = requests.get(url, stream=True, timeout=300)
with open(filepath, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
return filepath


def parse_gz(filepath):
logs = []
print(f"解析: {os.path.basename(filepath)}")
with gzip.open(filepath, 'rt', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
log = json.loads(line)
log['_source_file'] = os.path.basename(filepath)
log['_import_time'] = datetime.utcnow().isoformat()
logs.append(log)

print(f"解析完成: {len(logs)} 条")
return logs


def save_to_es(logs):
if not logs:
return
print(f"保存 {len(logs)} 条到 ES")
actions = [{"_index": INDEX_NAME, "_source": log} for log in logs]
success, _ = helpers.bulk(es, actions, chunk_size=1000, request_timeout=60)
print(f"保存完成: {success} 条")


def process_url(url):
filepath = download_file(url)
logs = parse_gz(filepath)
save_to_es(logs)


def main():
with open("urls.txt", 'r') as f:
urls = [line.strip() for line in f if line.strip()]
print(f"开始处理 {len(urls)} 个文件\n")
for i, url in enumerate(urls, 1):
print(f"\n[{i}/{len(urls)}]")
process_url(url)
print("\n处理完成!")


if __name__ == "__main__":
main()

执行如上代码,就能够下载日志并保存到ES了(这会花费比较多的时间,我这里花费了100多分钟)。

分析日志

数据索引完毕之后,我们可以查看索引信息

1
2
~ curl 'https://127.0.0.1:9200/eo_logs/_count' --header 'Authorization: Basic ZWxhc3RpYzo9dk5Cc0QwSTNZRWFPa2RoZFFhZg==' -k
{"count":31398691,"_shards":{"total":1,"successful":1,"skipped":0,"failed":0}}%

可以看到一共索引了3000多万条数据,我们还可以查看索引的mapping和详细信息如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
{
"eo_logs": {
"aliases": {},
"mappings": {
"properties": {
"ClientIP": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"ClientISP": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"ClientRegion": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"ClientState": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"ContentID": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"EdgeCacheStatus": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"EdgeFunctionSubrequest": {
"type": "long"
},
"EdgeInternalTime": {
"type": "long"
},
"EdgeResponseBodyBytes": {
"type": "long"
},
"EdgeResponseBytes": {
"type": "long"
},
"EdgeResponseStatusCode": {
"type": "long"
},
"EdgeResponseTime": {
"type": "long"
},
"EdgeServerID": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"EdgeServerIP": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"ParentRequestID": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"RemotePort": {
"type": "long"
},
"RequestBytes": {
"type": "long"
},
"RequestHost": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"RequestID": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"RequestMethod": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"RequestProtocol": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"RequestRange": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"RequestReferer": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"RequestStatus": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"RequestTime": {
"type": "date"
},
"RequestUA": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"RequestUrl": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"RequestUrlQueryString": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
},
"_import_time": {
"type": "date"
},
"_source_file": {
"type": "text",
"fields": {
"keyword": {
"type": "keyword",
"ignore_above": 256
}
}
}
}
},
"settings": {
"index": {
"routing": {
"allocation": {
"include": {
"_tier_preference": "data_content"
}
}
},
"number_of_shards": "1",
"provided_name": "eo_logs",
"creation_date": "1766816305347",
"number_of_replicas": "1",
"uuid": "wi9l88cjRh-Kq7lgl4NReg",
"version": {
"created": "9039003"
}
}
}
}
}

具体每个字段的含义如下

字段名含义说明
ClientIP客户端 IP访问 EdgeOne 边缘节点的真实用户 IP
ClientISP客户端运营商用户网络所属运营商,如电信、联通、移动
ClientRegion客户端地区用户所在国家或地区
ClientState客户端省份/州用户所在省份或州级行政区
ContentID内容标识EO 内部用于标识访问资源的唯一 ID
EdgeCacheStatus缓存状态边缘节点缓存命中情况:Hit / Miss / RefreshHit / Bypass
EdgeFunctionSubrequest子请求数量边缘函数触发的内部子请求次数
EdgeInternalTime内部处理耗时边缘节点内部处理请求所消耗的时间(毫秒)
EdgeResponseBodyBytes响应体大小返回给客户端的响应 Body 字节数
EdgeResponseBytes响应总大小返回给客户端的总字节数(Header + Body)
EdgeResponseStatusCode响应状态码边缘节点返回的 HTTP 状态码
EdgeResponseTime总响应耗时从边缘节点接收请求到完成响应的总耗时(毫秒)
EdgeServerID边缘节点 ID实际处理请求的 EdgeOne 节点标识
EdgeServerIP边缘节点 IP实际处理请求的边缘节点 IP 地址
ParentRequestID父请求 ID关联内部转发或子请求的父级请求标识
RemotePort客户端端口客户端发起连接时使用的端口
RequestBytes请求大小客户端请求报文大小(字节)
RequestHost请求域名客户端请求的 Host 域名
RequestID请求 IDEdgeOne 为请求生成的唯一标识
RequestMethod请求方法HTTP 请求方法,如 GET、POST
RequestProtocol请求协议使用的 HTTP 协议版本(HTTP/1.1、HTTP/2、HTTP/3)
RequestRangeRange 请求请求头中的 Range 字段,用于分段或断点下载
RequestReferer来源页面请求头中的 Referer 信息
RequestStatus请求状态EdgeOne 定义的请求处理状态
RequestTime请求时间请求到达 EdgeOne 的时间
RequestUAUser-Agent客户端 User-Agent 信息
RequestUrl请求路径请求的 URL 路径(不包含查询参数)
RequestUrlQueryString查询参数请求 URL 中的 Query String
_import_time导入时间日志被导入 Elasticsearch 的时间
_source_file日志来源生成该日志的原始文件或对象标识

然后我们想看指定域名的请求耗时情况(从EdgeOne接收到客户端发起的请求开始,到响应给客户端最后一个字节,整个过程的耗时,对应字段EdgeResponseTime),可以使用如下DSL

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
POST /eo_logs/_search
{
"size": 0,
"query": {
"bool": {
"filter": [
{
"term": {
"RequestHost.keyword": "static.example.com"
}
}
]
}
},
"aggs": {
"edge_response_stats": {
"stats": {
"field": "EdgeResponseTime"
}
},
"edge_response_percentiles": {
"percentiles": {
"field": "EdgeResponseTime",
"percents": [
50,
90,
95,
99
]
}
},
"edge_response_hist": {
"histogram": {
"field": "EdgeResponseTime",
"interval": 50,
"min_doc_count": 1
}
}
}
}

得到结果如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
{
"took": 3128,
"timed_out": false,
"_shards": {
"total": 1,
"successful": 1,
"skipped": 0,
"failed": 0
},
"hits": {
"total": {
"value": 10000,
"relation": "gte"
},
"max_score": null,
"hits": []
},
"aggregations": {
"edge_response_percentiles": {
"values": {
"50.0": 5.014287434842656,
"90.0": 25.778307762642324,
"95.0": 73.78316545752277,
"99.0": 593.9728031414846
}
},
"edge_response_hist": {
"buckets": [
{
"key": 0.0,
"doc_count": 25997272
},
{
"key": 50.0,
"doc_count": 841843
},
{
"key": 100.0,
"doc_count": 377168
},
{
"key": 150.0,
"doc_count": 109181
},
{
"key": 200.0,
"doc_count": 53672
},
{
"key": 250.0,
"doc_count": 37425
},
{
"key": 300.0,
"doc_count": 32744
},
{
"key": 350.0,
"doc_count": 36445
},
{
"key": 400.0,
"doc_count": 26137
},
{
"key": 450.0,
"doc_count": 22807
},
{
"key": 500.0,
"doc_count": 21111
},
{
"key": 550.0,
"doc_count": 16784
},
{
"key": 600.0,
"doc_count": 13214
},
{
"key": 650.0,
"doc_count": 11211
},
{
"key": 700.0,
"doc_count": 11760
},
{
"key": 750.0,
"doc_count": 11911
},
{
"key": 800.0,
"doc_count": 10381
},
{
"key": 850.0,
"doc_count": 9158
},
{
"key": 900.0,
"doc_count": 6851
},
{
"key": 950.0,
"doc_count": 5822
},
{
"key": 1000.0,
"doc_count": 5195
},
...
]
},
"edge_response_stats": {
"count": 27840645,
"min": 1.0,
"max": 707706.0,
"avg": 46.91420216737076,
"sum": 1.306121648E9
}
}
}

我们重点关注百分比:

百分位含义解读
p505 ms一半请求 5ms 内完成(极快)
p9025 ms90% 的请求很健康
p9574 ms95% 的请求 < 100ms(优秀)
p99594 ms1% 请求接近 / 超过 0.5s

可以看到这个域名的请求速度还是很快的。

此外,我们还可以分析哪些资源的下载比较慢

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
POST /eo_logs/_search
{
"size": 0,
"query": {
"bool": {
"filter": [
{
"term": {
"RequestHost.keyword": "static.example.com"
}
},
{
"exists": {
"field": "RequestUrl.keyword"
}
},
{
"exists": {
"field": "EdgeResponseTime"
}
}
]
}
},
"aggs": {
"by_url": {
"terms": {
"field": "RequestUrl.keyword",
"size": 200,
"order": {
"p95_edge_response[95.0]": "desc"
}
},
"aggs": {
"p95_edge_response": {
"percentiles": {
"field": "EdgeResponseTime",
"percents": [
95
]
}
},
"avg_edge_response": {
"avg": {
"field": "EdgeResponseTime"
}
},
"count_requests": {
"value_count": {
"field": "EdgeResponseTime"
}
}
}
}
}
}

我们可以针对上面查询到的慢速URL去做特定的优化和缓存预热。只是,上面的这个DSL不够严谨,因为单纯使用请求时间来判断速度快慢是不足够的,请求时间也会受到资源大小的影响。因此,我们使用资源的大小比上请求耗时,这个就代表这个资源的下载速度,之后我们从小到大排序,就可以知道哪些资源可能会下载比较慢了。具体DSL如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
POST /eo_logs/_search
{
"size": 0,
"query": {
"bool": {
"filter": [
{
"term": {
"RequestHost.keyword": "static.example.com"
}
},
{
"exists": {
"field": "RequestUrl.keyword"
}
},
{
"exists": {
"field": "EdgeResponseTime"
}
},
{
"exists": {
"field": "EdgeResponseBodyBytes"
}
},
{
"range": {
"EdgeResponseBodyBytes": {
"gt": 0
}
}
},
{
"range": {
"EdgeResponseTime": {
"gt": 0
}
}
}
]
}
},
"aggs": {
"by_url": {
"terms": {
"field": "RequestUrl.keyword",
"size": 2000,
"order": {
"avg_kbps": "asc"
}
},
"aggs": {
"avg_kbps": {
"avg": {
"script": {
"lang": "painless",
"source": "double b = doc['EdgeResponseBodyBytes'].value; double t = doc['EdgeResponseTime'].value; return (b / t) * (1000.0 / 1024.0);"
}
}
},
"p95_kbps": {
"percentiles": {
"script": {
"lang": "painless",
"source": "double b = doc['EdgeResponseBodyBytes'].value; double t = doc['EdgeResponseTime'].value; return (b / t) * (1000.0 / 1024.0);"
},
"percents": [
95
]
}
},
"avg_time_ms": {
"avg": {
"field": "EdgeResponseTime"
}
},
"avg_body_bytes": {
"avg": {
"field": "EdgeResponseBodyBytes"
}
},
"req_count": {
"value_count": {
"field": "EdgeResponseTime"
}
}
}
}
}
}

根据上面的查询结果,我们就可以知道哪些资源的下载速度可能比较慢,之后就可以针对这些URL对应的资源去做专门的优化了。

🔲 ☆

使用NGINX的auth_request进行统一jwt鉴权

NGINX的auth_request模块提供了一种统一的认证机制,可以在NGINX层面进行JWT鉴权,而不需要在每个后端服务中重复实现认证逻辑。

首先我们定义一下nginx的配置,它的配置如下

flat
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
server {
listen 8965;

# 鉴权接口,仅供 Nginx 内部 auth_request 使用
location = /auth {
internal; # 该接口只能被 Nginx 内部请求,防止外部访问
# 转发到实际的认证服务
proxy_pass http://localhost:5001/api/auth/verify;
# 不转发请求体,提升效率
proxy_pass_request_body off;
# 防止后端因 Content-Length 不确定而报错
proxy_set_header Content-Length "";
# 将客户端传来的 Authorization 头(JWT Token)传给认证服务
proxy_set_header Authorization $http_authorization;
# 传入原始请求路径,供认证服务判断路径是否需要鉴权
proxy_set_header X-Original-URI $request_uri;
}

# 所有 / 路径下的请求都进行认证
location / {
# 认证请求会先调用上面的 /auth 接口
auth_request /auth;
# 如果认证失败(如返回 401),跳转到自定义处理逻辑
error_page 401 = @unauthorized;
# 从 /auth 的响应头中提取用户信息
auth_request_set $user_id $upstream_http_x_user_id;
# 把用户信息注入到请求头中,转发给后端业务服务
proxy_set_header X-User-ID $user_id;
# 转发到后端服务
proxy_pass http://localhost:5001;
}

# 自定义未授权响应(认证失败时返回)
location @unauthorized {
# 返回 401 状态码 + 文本内容
return 401 "Unauthorized";
}
}

有了这个nginx的配置之后,我们就可以实现鉴权的逻辑了,具体逻辑如下

flat
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import jwt
from flask import Flask, request, make_response, jsonify
from jwt import InvalidTokenError

app = Flask(__name__)

SECRET_KEY = "a-string-secret-at-least-256-bits-long"

# 路径白名单,不需要鉴权的接口
PUBLIC_PATHS = [
"/api/public/hello",
"/api/login",
"/api/register"
]


@app.route("/api/auth/verify", methods=["GET"])
def verify_token():
original_uri = request.headers.get("X-Original-URI", "")
# 不需要鉴权的接口,直接返回200
if original_uri in PUBLIC_PATHS:
return "", 200

# 否则继续验证 JWT
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
return "Missing or invalid Authorization header", 401

# 解析得到token
token = auth_header.split(" ", 1)[1]
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
response = make_response("", 200)
response.headers["X-User-ID"] = str(payload.get("user_id", ""))
return response
except InvalidTokenError:
return "Invalid token", 401


@app.route("/api/hello")
def hello():
return jsonify({"message": "hello from auth", "user_id": request.headers.get("X-User-ID")})


@app.route("/api/public/hello")
def ping():
return {"msg": "hello without auth"}


if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)

启动nginx和如上python服务,之后我们使用如下payload和header以及密钥生成token

payload

{    "alg": "HS256",    "typ": "JWT"}

header

{    "user_id": 656670838050885}

密钥

a-string-secret-at-least-256-bits-long

生成得到token

eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjo2NTY2NzA4MzgwNTA4ODV9.caQ6cp-BA-OMxXu4zTUjV0OiZo1iygvdi7GPQNjNVHM

之后我们就可以使用token进行测试了,具体测试结果如下

~ AUTH_HEADER="Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjo2NTY2NzA4MzgwNTA4ODV9.caQ6cp-BA-OMxXu4zTUjV0OiZo1iygvdi7GPQNjNVHM"~ curl -H "$AUTH_HEADER" http://localhost:8965/api/hello{"message":"hello from auth","user_id":"656670838050885"}~ curl -H "$AUTH_HEADER" http://localhost:8965/api/public/hello{"msg":"hello without auth"}~ curl -H "$AUTH_HEADER" http://localhost:8965/api/login<!doctype html><html lang=en><title>404 Not Found</title><h1>Not Found</h1><p>The requested URL was not found on the server. If you entered the URL manually please check your spelling and try again.</p>~ curl http://localhost:8965/api/public/hello{"msg":"hello without auth"}~ curl http://localhost:8965/api/helloUnauthorized%

如上我们正确设置了Authorization之后就可以正常访问需要鉴权的接口了,但是去掉了Authorization之后需要鉴权的接口就会返回Unauthorized。此外还可以看到,不需要鉴权的接口,即使不添加鉴权配置也是可以正常访问的。

🔲 ☆

使用APISIX解析jwt并获取payload信息

APISIX支持获取jwt的信息,并且将这个信息进行解码并转发给后端服务。

1. 启动服务

首先我们根据官方脚本来启动APISIX服务

~ curl -sL "https://run.api7.ai/apisix/quickstart" | shDestroying existing apisix-quickstart container, if any.Installing APISIX with the quickstart options.Creating bridge network apisix-quickstart-net.77e35df073894075ad77facd9d1c7d2a35b280213732c1b631052caede079bab✔ network apisix-quickstart-net createdStarting the container etcd-quickstart.d123605c8b7658b130be97e5f44e7a160aa85858db008032ecf594266225e342✔ etcd is listening on etcd-quickstart:2379Starting the container apisix-quickstart.38434806c63b3a72f53fb6ad849cb4c11781eebaff79c8db04510226593fcf46⚠ WARNING: The Admin API key is currently disabled. You should turn on admin_key_required and set a strong Admin API key in production for security.✔ APISIX is ready!

2. 配置插件

启动了APISIX之后,我们首先创建一个插件配置。在这个插件中我们定义了一个Lua方法,这个方法的目的是从请求的header中获取authorization信息,并进行解码,之后将解码的信息放到HTTP header中传给后端

curl --location --request PUT 'http://127.0.0.1:9180/apisix/admin/plugin_configs/1001' \--header 'Content-Type: application/json' \--header 'Accept: */*' \--header 'Host: 127.0.0.1:9180' \--header 'Connection: keep-alive' \--data-raw '{    "plugins": {        "serverless-pre-function": {            "phase": "access",            "functions": [                "return function(_, ctx) local core = require(\"apisix.core\") local jwt = require(\"resty.jwt\") local auth_header = ctx.var.http_authorization if not auth_header then return end local token = auth_header:match(\"Bearer%s+(.+)\") if not token then return end local obj = jwt:load_jwt(token) if obj and obj.valid and obj.payload then if obj.payload.user_id then core.request.set_header(\"X-User-Id\", obj.payload.user_id) end if obj.payload.role then core.request.set_header(\"X-User-Role\", obj.payload.role) end end end"            ]        }    }}'

如上的fucntions属性中添加了一个Lua方法,格式化之后的Lua代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
return function(_, ctx)
local core = require("apisix.core")
local jwt = require("resty.jwt")

local auth_header = ctx.var.http_authorization
if not auth_header then
return
end

local token = auth_header:match("Bearer%s+(.+)")
if not token then
return
end

local obj = jwt:load_jwt(token)
if obj and obj.valid and obj.payload then
if obj.payload.user_id then
core.request.set_header("X-User-Id", obj.payload.user_id)
end
if obj.payload.role then
core.request.set_header("X-User-Role", obj.payload.role)
end
end
end

这段代码实现了如下几个功能:

  1. 从 Authorization: Bearer 中提取 JWT
  2. 使用 resty.jwt 解码
  3. 如果合法,提取 user_id 和 role
  4. 注入到 header(X-User-Id, X-User-Role)中供后端读取

3. 配置consumer

创建了这个插件之后,我们再新建一个consumer。在APISIX中,consumer代表了一类客户端,比如APP。我们可以针对这类客户端添加一些配置,多种不同类型的客户端(比如APP、网页、开放平台,等等)可以分别设置成不同的consumer以方便管理

curl --location --request PUT 'http://127.0.0.1:9180/apisix/admin/consumers/app' \--header 'Content-Type: application/json' \--header 'Accept: */*' \--header 'Host: 127.0.0.1:9180' \--header 'Connection: keep-alive' \--data-raw '{    "username": "app",    "plugins": {        "jwt-auth": {            "key": "app-key",            "secret": "a-string-secret-at-least-256-bits-long",            "algorithm": "HS256"        }    }}'

如上添加了一个名为app的consumer,它的key是app-key,加密方式是HS256,密钥是a-string-secret-at-least-256-bits-long。有了解析插件和consumer之后,我们就可以创建路由了。

4. 配置路由

如下请求会创建一个ID为1的路由,使用了ID为1001插件,并且添加了jwt-auth的配置,路由的后端是https://httpbin.org,这个网站会把我们请求的信息返回给我们。

curl --location --request PUT 'http://127.0.0.1:9180/apisix/admin/routes/1' \--header 'Content-Type: application/json' \--header 'Accept: */*' \--header 'Host: 127.0.0.1:9180' \--header 'Connection: keep-alive' \--data-raw '{    "uri": "/headers",    "plugin_config_id": 1001,    "plugins": {        "jwt-auth": {}    },    "upstream": {        "type": "roundrobin",        "nodes": {            "httpbin.org:80": 1        }    }}'

5. 发起请求

在创建好了plugin_config、consumer和route之后,我们就可以测试请求了。首先我们构建如下payload

1
2
3
4
5
6
{
"key": "app-key",
"user_id": 100001,
"role": "admin",
"exp": 1900000000
}

这个payload包含了user_idrole两个业务属性,exp代表这个jwt的过期时间戳,key是APISIX用于识别匹配哪个consumer的,这里我们选择匹配app-key这个consumer。之后我们将该payload和密钥a-string-secret-at-least-256-bits-long一起在https://jwt.io/进行编码,得到编码jwt信息如下

eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJrZXkiOiJhcHAta2V5IiwidXNlcl9pZCI6MTAwMDAxLCJyb2xlIjoiYWRtaW4iLCJleHAiOjE5MDAwMDAwMDB9.qG7PNPz2XlatmjrhNW_xf6SmI8T9JSIx2lJVJcAox0I

之后我们执行HTTP请求,将这个jwt放到Authorization header中

curl --location --request GET 'http://127.0.0.1:9080/headers' \--header 'Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJrZXkiOiJhcHAta2V5IiwidXNlcl9pZCI6MTAwMDAxLCJyb2xlIjoiYWRtaW4iLCJleHAiOjE5MDAwMDAwMDB9.qG7PNPz2XlatmjrhNW_xf6SmI8T9JSIx2lJVJcAox0I' \--header 'Accept: */*' \--header 'Host: httpbin.org:80' \--header 'Connection: keep-alive'

请求得到的响应如下,可以看到user_id和role属性已经成功的传给后端服务了

1
2
3
4
5
6
7
8
9
10
11
12
13
{
"headers": {
"Accept": "*/*",
"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJrZXkiOiJhcHAta2V5IiwidXNlcl9pZCI6MTAwMDAxLCJyb2xlIjoiYWRtaW4iLCJleHAiOjE5MDAwMDAwMDB9.qG7PNPz2XlatmjrhNW_xf6SmI8T9JSIx2lJVJcAox0I",
"Host": "httpbin.org",
"User-Agent": "curl/7.81.0",
"X-Amzn-Trace-Id": "Root=1-68709903-309891501348175943af3223",
"X-Consumer-Username": "app",
"X-Forwarded-Host": "httpbin.org",
"X-User-Id": "100001",
"X-User-Role": "admin"
}
}
🔲 ☆

《推荐系统实践》

从某种意义上说,推荐系统和搜索引擎对于用户来说是两个互补的工具。搜索引擎满足了用户有明确目的时的主动查找需求,而推荐系统能够在用户没有明确目的的时候帮助他们发现感兴趣的新内容。

基于用户行为分析的推荐算法是个性化推荐系统的重要算法,学术界一般将这种类型的算法称为协同过滤(Collaborative filtering)算法。顾名思义,协同过滤就是指用户可以齐心协力,通过不断地和网站互动,使自己的推荐列表能够不断过滤掉自己不感兴趣的物品,从而越来越满足自己的需求。

用户行为分类

用户行为在个性化推荐系统中一般分两种——显性反馈行为(explicit feedback)和隐性反馈行为(implicit feedback)。显示反馈行为是用户主动做的,比如给视频点赞、给书籍打分等等;隐式反馈行为的代表就是用户浏览页面,这种行为显示出来的用户偏好不是那么明显,但是数据量更大。

常用算法

基于邻域的算法

  • 基于用户的协同过滤算法 这种算法给用户推荐和他兴趣相似的其他用户喜欢的物品。
    1. 找到和目标用户兴趣相似的用户集合(P45)。
    2. 找到这个集合中的用户喜欢的,且目标用户没有听说过的物品推荐给目标用户。
  • 基于物品的协同过滤算法 这种算法给用户推荐和他之前喜欢的物品相似的物品。
    1. 计算物品之间的相似度(P53)。
    2. 根据物品的相似度和用户的历史行为给用户生成推荐列表。

基于用户的协同过滤算法

计算两个用户的兴趣相似程度:给定用户u和用户v,N(u)表示用户u曾经有过正反馈的物品集合,N(v)表示用户v曾经有过正反馈的物品集合。可以使用Jaccard公式计算两个用户的兴趣相似程度

wuv=|N(u)N(v)||N(u)N(v)|

或者使用余弦相似公式计算相似程度

wuv=|N(u)N(v)||N(u)||N(v)|

以余弦相似公式为例,假设有用户ABCD,物品abcde,用户喜欢的物品如下

用户物品 a物品 b物品 c物品 d物品 e
A☑️☑️☑️
B☑️☑️
C☑️☑️
D☑️☑️☑️

那么我们可以得到用户A和BCD的相似度

wAB=|{a,b,d}{a,c}||{a,b,d}||{a,c}|=16

wAC=|{a,b,d}{b,e}||{a,b,d}||{b,e}|=16

wAD=|{a,b,d}{c,d,e}||{a,b,d}||{c,d,e}|=13

具体计算过程以AD的相似度计算为例

  1. 分子为交集并且交集为 {d}|{d}| = 1,所以分子为1
  2. 分母为并集,3 x 3 = 9,开根号为3
    最终值为 1 / 3

以上逻辑可以用代码进行实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def similarity(users):
w = defaultdict(dict)
for u, v in combinations(users.keys(), 2):
r1 = len(users[u] & users[v])
r2 = math.sqrt(len(users[u]) * len(users[v]) * 1.0)
r = r1 / r2
w[u][v], w[v][u] = r, r # 保存两次,方便后面使用
return w

def main():
users = {
'A': {'a', 'b', 'd'},
'B': {'a', 'c'},
'C': {'b', 'e'},
'D': {'c', 'd', 'e'}
}
for k, v in similarity(users).items():
print(f'{k}: {json.dumps(v)}')

执行后得到结果如下

A: {"B": 0.4082482904638631, "C": 0.4082482904638631, "D": 0.3333333333333333}B: {"A": 0.4082482904638631, "C": 0.0, "D": 0.4082482904638631}C: {"A": 0.4082482904638631, "B": 0.0, "D": 0.4082482904638631}D: {"A": 0.3333333333333333, "B": 0.4082482904638631, "C": 0.4082482904638631}

据此我们就可以得到各个用户之间的兴趣相似度了。有了用户兴趣的相似度之后,我们可以给用户推荐和他兴趣最相似的K个用户喜欢的物品。我们可以使用如下公式计算用户u对物品i的感兴趣程度

p(u,i)=vS(u,K)N(i)wuvrvi

其中,S(u, K)包含和用户u兴趣最接近的K个用户,N(i)是对物品i有过行为的用户集合,wuv是用户u和用户v的兴趣相似度,rvi代表用户v对物品i的兴趣,因为使用的是单一行为的隐反馈数据,所以所有的rvi=1。

具体的逻辑实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def recommend(user, users, w, k):
"""
:param user: 计算指定用户的物品推荐程度
:param users: 数据集
:param w: 前一步计算得到的用户兴趣相似度
:param k: 取k个兴趣最相似的用户
:return:
"""
rank = defaultdict(float)
# 获取指定用户和其它用户的兴趣相似度,并按照相似度从大到小排序,取前k个数据
for v, wuv in sorted(w[user].items(), key=lambda item: item[1], reverse=True)[:k]:
# 取出指定用户的数据集
for i in users[v]:
# 如果这个数据已经在当前用户的数据集里面,跳过,因为已经感兴趣的数据不需要再次推荐
if i in users[user]:
continue
rank[i] += wuv
return rank

通过这个代码我们就可以计算得到指定用户的物品推荐程度了。完整的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import json
import math
from collections import defaultdict
from itertools import combinations


def similarity(users):
w = defaultdict(dict)
for u, v in combinations(users.keys(), 2):
r1 = len(users[u] & users[v])
r2 = math.sqrt(len(users[u]) * len(users[v]) * 1.0)
r = r1 / r2
w[u][v], w[v][u] = r, r # 保存两次,方便后面使用
return w


def recommend(user, users, w, k):
"""
:param user: 计算指定用户的物品推荐程度
:param users: 数据集
:param w: 前一步计算得到的用户兴趣相似度
:param k: 取k个兴趣最相似的用户
:return:
"""
rank = defaultdict(float)
# 获取指定用户和其它用户的兴趣相似度,并按照相似度从大到小排序,取前k个数据
for v, wuv in sorted(w[user].items(), key=lambda item: item[1], reverse=True)[:k]:
# 取出指定用户的数据集
for i in users[v]:
# 如果这个数据已经在当前用户的数据集里面,跳过,因为已经感兴趣的数据不需要再次推荐
if i in users[user]:
continue
rank[i] += wuv
return rank


def main():
users = {
'A': {'a', 'b', 'd'},
'B': {'a', 'c'},
'C': {'b', 'e'},
'D': {'c', 'd', 'e'}
}

w = similarity(users)
for k, v in w.items():
print(f'{k}: {json.dumps(v)}')

rank = recommend('C', users, w, 3)
for k, v in sorted(rank.items(), key=lambda item: item[1], reverse=True):
print(f'{k}: {v}')


if __name__ == '__main__':
main()

执行代码得到的结果如下

A: {"B": 0.4082482904638631, "C": 0.4082482904638631, "D": 0.3333333333333333}B: {"A": 0.4082482904638631, "C": 0.0, "D": 0.4082482904638631}C: {"A": 0.4082482904638631, "B": 0.0, "D": 0.4082482904638631}D: {"A": 0.3333333333333333, "B": 0.4082482904638631, "C": 0.4082482904638631}d: 0.8164965809277261a: 0.4082482904638631c: 0.4082482904638631

由上面的结果我们可以知道,针对用户C,最推荐的物品是物品d

根据上面的例子我们已经简单了解了基于用户的协同过滤算法,不过这种算法存在问题,主要是

  1. 随着网站的用户数目越来越大,计算用户兴趣相似度矩阵将越来越困难,其运算时间复杂度和空间复杂度的增长和用户数的增长近似于平方关系
  2. 基于用户的协同过滤很难对推荐结果作出解释

因此,在实际的使用中,更常见的是基于物品的协同过滤算法

基于物品的协同过滤算法

为了挖掘长尾信息,避免热门物品对推荐产生影响,减小二八定律的出现。可以用如下公式计算物品之间的相似度

wij=|N(i)N(j)||N(i)||N(j)|

分子是同时喜欢物品i和物品j的用户数,分母是喜欢两个物品用户数的并集。为了减小计算量,我们可以构建一个矩阵来存储某个用户喜欢的物品集合。

举个例子,比如用户A喜欢物品 {a, b, d},那我们可以构建如下矩阵

    |  a  |  b  |  c  |  d  |  e  |----|-----|-----|-----|-----|-----|a   |  0  |  1  |  0  |  1  |  0  |b   |  1  |  0  |  0  |  1  |  0  |c   |  0  |  0  |  0  |  0  |  0  |d   |  1  |  1  |  0  |  0  |  0  |e   |  0  |  0  |  0  |  0  |  0  |

因为a、b、d可以组成ab、ad、bd,所以将矩阵中的对应位置都填上1。这是一个用户的物品信息,对于多个用户,只需要把这些矩阵相加即可。例如有5个用户,他们的物品信息和生成的对应物品矩阵如下

用户 1: {a, c, d}

    |  a  |  b  |  c  |  d  |  e  |----|-----|-----|-----|-----|-----|a   |  0  |  0  |  1  |  1  |  0  |b   |  0  |  0  |  0  |  0  |  0  |c   |  1  |  0  |  0  |  1  |  0  |d   |  1  |  0  |  1  |  0  |  0  |e   |  0  |  0  |  0  |  0  |  0  |

用户 2: {b, c, e}

    |  a  |  b  |  c  |  d  |  e  |----|-----|-----|-----|-----|-----|a   |  0  |  0  |  0  |  0  |  0  |b   |  0  |  0  |  1  |  0  |  1  |c   |  0  |  1  |  0  |  0  |  1  |d   |  0  |  0  |  0  |  0  |  0  |e   |  0  |  1  |  1  |  0  |  0  |

用户 3: {a, d, e}

    |  a  |  b  |  c  |  d  |  e  |----|-----|-----|-----|-----|-----|a   |  0  |  0  |  0  |  1  |  1  |b   |  0  |  0  |  0  |  0  |  0  |c   |  0  |  0  |  0  |  0  |  0  |d   |  1  |  0  |  0  |  0  |  1  |e   |  1  |  0  |  0  |  1  |  0  |

用户 4: {b, d}

    |  a  |  b  |  c  |  d  |  e  |----|-----|-----|-----|-----|-----|a   |  0  |  0  |  0  |  0  |  0  |b   |  0  |  0  |  0  |  1  |  0  |c   |  0  |  0  |  0  |  0  |  0  |d   |  0  |  1  |  0  |  0  |  0  |e   |  0  |  0  |  0  |  0  |  0  |

用户 5: {a, b, c, e}

    |  a  |  b  |  c  |  d  |  e  |----|-----|-----|-----|-----|-----|a   |  0  |  1  |  1  |  0  |  1  |b   |  1  |  0  |  1  |  0  |  1  |c   |  1  |  1  |  0  |  0  |  1  |d   |  0  |  0  |  0  |  0  |  0  |e   |  1  |  1  |  1  |  0  |  0  |

将这5个用户的物品信息相加,得到矩阵

    |  a  |  b  |  c  |  d  |  e  |----|-----|-----|-----|-----|-----|a   |  0  |  1  |  2  |  3  |  2  |b   |  1  |  0  |  3  |  2  |  3  |c   |  2  |  3  |  0  |  2  |  3  |d   |  3  |  2  |  2  |  0  |  2  |e   |  2  |  3  |  3  |  2  |  0  |

在这个矩阵中值越高,代表物品的相关度越高。接下来我们将这个规则用代码进行实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
number = 'number'
def item_similarity(train):
# c[i][number]表示使用物品i的用户数量
# c[i][j]表示同时交互物品i和j的用户数
c = defaultdict(lambda: defaultdict(int))
for user, items in train.items():
for i in items:
# 统计每个物品被交互的总次数
c[i][number] += 1
# 统计物品i与其他物品的共现次数
for j in items:
if i == j:
continue
c[i][j] += 1
# 计算最终的相似度矩阵 w
w = defaultdict(dict)
for i, related_items in c.items():
for j, cij in related_items.items():
if j == number: continue
# 余弦相似度公式
similarity = cij / math.sqrt(c[i][number] * c[j][number])
w[i][j] = similarity
return w

如上我们先计算每个物品各自被用户喜欢的次数,再计算每个物品和其它物品同时被某个用户喜欢的次数,之后根据物品相似度公式即可计算出物品之间的相关性。为了简单起见,如上代码只使用了一个字典变量,物品自己被喜欢的次数被保存在key为number的字段中,物品和其它物品同时被喜欢的次数则保存在key为其它物品ID的字段中。

有了如上逻辑之后,我们就可以计算物品相似度了,假设有用户如下

{  'A': {'a', 'b', 'd'},  'B': {'a', 'c'},  'C': {'b', 'e', 'a'},  'D': {'c', 'd', 'e'}}

计算得到的物品相似度

b: {'a': 0.8164965809277261, 'd': 0.5, 'e': 0.5}a: {'b': 0.8164965809277261, 'd': 0.4082482904638631, 'c': 0.4082482904638631, 'e': 0.4082482904638631}d: {'b': 0.5, 'c': 0.5, 'e': 0.5, 'a': 0.4082482904638631}c: {'e': 0.5, 'd': 0.5, 'a': 0.4082482904638631}e: {'b': 0.5, 'c': 0.5, 'd': 0.5, 'a': 0.4082482904638631}

可以看到物品a和物品b的相关度是最高的。在得到物品的相关度之后,我们可以使用如下公式计算用户u对一个物品j的兴趣

puj=iN(u)S(j,K)wjirui

N(u)是用户喜欢的物品集合,S(j, K)是物品j最相似的K个物品的集合,wji是物品j和i的相似度,rui是用户对物品i的兴趣,可令rui为1。我们可以把这个逻辑使用代码进行实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def recommend(interacted_items: Union[set, dict], w, k):
"""
:param interacted_items: 指定用户交互过的物品
:param w: 物品的相似度
:param k: 取最相似的k个物品
:return:
"""
if isinstance(interacted_items, set): # 如果只有物品,没有评分,那么将评分统一设置为1
interacted_items = {k: 1 for k in interacted_items}
rank = defaultdict(float)
# 用户交互过的物品,和用户对这个物品的评分
for item, score in interacted_items.items():
# 物品的相似度信息,得到related_item和item的相似度similarity,按照相似度的值从大到小排序,取k个值
for related_item, similarity in sorted(w[item].items(), key=lambda x: x[1], reverse=True)[:k]:
# 如果这个物品已经被用户交互过了,跳过
if related_item in interacted_items:
continue
# 计算相关的物品的相似度评分
rank[related_item] += score * similarity
return rank

有了计算用户和物品相关度的代码,我们就可以把逻辑结合起来,实现向用户推荐物品了。完整的代码实现如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import math
from collections import defaultdict
from typing import Union

number = 'number'


def item_similarity(train):
# c[i][number]表示使用物品i的用户数量
# c[i][j]表示同时交互物品i和j的用户数
c = defaultdict(lambda: defaultdict(int))
for user, items in train.items():
for i in items:
# 统计每个物品被交互的总次数
c[i][number] += 1
# 统计物品i与其他物品的共现次数
for j in items:
if i == j:
continue
c[i][j] += 1

# 计算最终的相似度矩阵 w
w = defaultdict(dict)
for i, related_items in c.items():
for j, cij in related_items.items():
if j == number: continue
# 余弦相似度公式
similarity = cij / math.sqrt(c[i][number] * c[j][number])
w[i][j] = similarity

return w


def recommend(interacted_items: Union[set, dict], w, k):
"""
:param interacted_items: 指定用户交互过的物品
:param w: 物品的相似度
:param k: 取最相似的k个物品
:return:
"""
if isinstance(interacted_items, set): # 如果只有物品,没有评分,那么将评分统一设置为1
interacted_items = {k: 1 for k in interacted_items}
rank = defaultdict(float)
# 用户交互过的物品,和用户对这个物品的评分
for item, score in interacted_items.items():
# 物品的相似度信息,得到related_item和item的相似度similarity,按照相似度的值从大到小排序,取k个值
for related_item, similarity in sorted(w[item].items(), key=lambda x: x[1], reverse=True)[:k]:
# 如果这个物品已经被用户交互过了,跳过
if related_item in interacted_items:
continue
# 计算相关的物品的相似度评分
rank[related_item] += score * similarity
return rank


def main():
users = {
'A': {'a', 'b', 'd'},
'B': {'a', 'c'},
'C': {'b', 'e', 'a'},
'D': {'c', 'd', 'e'}
}
w = item_similarity(users)
for k, v in w.items():
print(f'{k}: {dict(sorted(v.items(), key=lambda item: item[1], reverse=True))}')

rank = recommend(users['B'], w, 3)
for k, v in sorted(rank.items(), key=lambda item: item[1], reverse=True):
print(f'{k}: {v}')


if __name__ == '__main__':
main()

以上代码的执行结果如下,可见用户B和物品d的相关度最高

b: {'a': 0.8164965809277261, 'd': 0.5, 'e': 0.5}d: {'b': 0.5, 'c': 0.5, 'e': 0.5, 'a': 0.4082482904638631}a: {'b': 0.8164965809277261, 'd': 0.4082482904638631, 'c': 0.4082482904638631, 'e': 0.4082482904638631}c: {'d': 0.5, 'e': 0.5, 'a': 0.4082482904638631}e: {'b': 0.5, 'c': 0.5, 'd': 0.5, 'a': 0.4082482904638631}d: 0.9082482904638631b: 0.8164965809277261e: 0.5

基于物品的推荐在工程中使用的比基于用户的推荐要多,因为UserCF(User Collaborative Filtering)的推荐更社会化,反映了用户所在的小型兴趣群体中物品的热门程度,而ItemCF(Item Collaborative Filtering)的推荐更加个性化,反映了用户自己的兴趣传承。

LFM(latent factor model)隐语义模型

隐语义模型核心思想是通过隐含特征(latent factor)联系用户兴趣和物品,它可以通过对数据进行分类来实现推荐。这种基于用户对数据的兴趣分类的方式,需要解决如下三个问题:

  1. 如何给物品分类
  2. 如何确定用户对哪些分类感兴趣,以及感兴趣的程度
  3. 对于一个分类,选择哪些物品推荐给用户,以及这些物品的权重如何

隐含语义分析技术(latent variable analysis)采取基于用户行为统计的自动聚类,来实现数据自动分类。

评测指标

一个推荐系统好不好,可以从用户满意度、预测准确度、覆盖率、多样性、新颖性、惊喜度、信任度、实时性、健壮性、商业目标等多个角度来进行评测

准确度

我们可以使用TopN推荐的方式来计算准确度,TopN的准确度一般通过准确率(precision)/召回率(recall)来进行度量。令R(u)是根据用户在训练集上的行为给用户作出的推荐列表,而T(u)是用户在测试集上的行为列表。那么,推荐结果的召回率定义为:

Recall=uU|R(u)T(u)|uU|T(u)|

推荐结果的准确率定义为:

Precision=uU|R(u)T(u)|uU|R(u)|

简单来说,R(u)代表系统推荐给用户u的Top-N列表(预测值),T(u)代表用户实际喜欢或点击过的项目(真实值),召回率和准确率公式的分子都是同时存在于推荐列表和用户喜欢列表的物品数。召回率的分母是用户喜欢的物品数,召回率是看系统有没有把用户喜欢的物品推荐出来。准确率的分母是系统推荐的物品总数,目的是看推荐有多少是对的。

指标含义关注点
Recall你真正喜欢的内容中,被系统找出来了多少不漏掉好东西
Precision系统推荐的内容中,有多少真的是你喜欢的不乱推荐垃圾

我们可以把召回率和准确率的计算通过如下代码实现

flat
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def precision_recall(test_data, train_data, n, recommend_func):
"""
计算推荐系统在测试集上的准确率和召回率

:param test_data: dict,用户 -> 测试集中真实交互物品列表
:param train_data: dict,用户 -> 训练集中交互物品列表(用于生成推荐)
:param n: int,每个用户推荐的物品数量
:param recommend_func: function(user, n, train_data),返回推荐物品列表
:return: [recall, precision]
"""
hit = 0 # 交集
total_actual = 0 # 所有用户的真实物品总数
total_recommend = 0 # 所有推荐物品总数

for user, actual_items in test_data.items():
# 计算推荐物品
recommended_items = recommend_func(user, n, train_data)
# 计算交集
hit += len(set(recommended_items) & set(actual_items))
# 真实物品数
total_actual += len(actual_items)
# 推荐物品数
total_recommend += len(recommended_items)

recall = hit / total_actual if total_actual else 0
precision = hit / total_recommend if total_recommend else 0
return [recall, precision]

参考

推荐系统实践

🔲 ☆

《上瘾:让用户养成使用习惯的四大产品逻辑》

Hooked: How to Build Habit-Forming Products

如何卖出更多的产品:产能 -> 营销/渠道 -> 产品设计

上瘾如何设计产品:触发 -> 行动 -> 多变的酬赏 -> 投入

习惯是大脑借以掌握复杂举动的途径之一。神经系统科学家指出,人脑中存在一个负责无意识行为的基底神经节,那些无意中产生的条件反射会以习惯的形式存储在基底神经节中,从而使人们腾出精力来关注其他的事物。当大脑试图走捷径而不再主动思考接下来该做些什么时,习惯就养成了。为解决当下面临的问题,大脑会在极短的时间内从行为存储库里提取出相宜的对策。
(就是基底核,有点像缓存的作用)

我们所要描述的体验更接近于“痒”,它是潜伏于我们内心的一种渴求,当这种渴求得不到满足时,不适感就会出现。那些让我们养成某种习惯的产品正好可以缓解这种不适感。比起听之任之的做法,利用技术或产品来”挠痒痒”能够更快地满足我们的渴求。一旦我们对某种技术或产品产生依赖,那它就是唯一的灵丹妙药了。

福格行为模型可以用公式来呈现,即B=MAT。B代表行为,M代表动机,A代表能力,T代表触发。要想使人们完成特定的行为,动机、能力、触发这三样缺一不可。1否则,人们将无法跨过”行动线”,也就是说,不会实施某种行为。

  • 稀缺效应:物以稀为贵
  • 环境效应:环境会影响人们的价值判断
  • 锚定效应
  • 赠券效应

多变的酬赏主要表现为三种形式:社交酬赏,猎物酬赏,自我酬赏

沉没成本:通过用户对产品的投入程度,留住用户

总体评价:很薄的一本书,有部分的观点有参考意义,但是大部分的论调都是老生常谈。大多数的观点在很多心理学的书籍里面已经讲过了,本书主要是讲怎么依赖于这些原理来进行实操,有一定的参考意义。

https://book.douban.com/subject/27030507/

🔲 ☆

Kotlin与Java对照手册

1. 基本类型

类型Kotlin 写法Java 写法简要说明
数字Int, Long, Float, Double, Short, Byteint, long, float, double, short, byteKotlin 数值类型映射到相应的原生/包装类型。
布尔Booleanboolean只能取 true/false,与数字不互通。
字符Charchar单个 Unicode 字符,支持转义序列。
字符串StringString不可变;支持多行文本块 """..."""
数组Array<T>, IntArrayT[]提供原始类型专用数组如 IntArrayByteArray
无符号整型UInt, ULong, UShort, UByte编译时检查范围,运行时越界抛 IllegalArgumentException

2. 语法对照

功能Java 写法Kotlin 写法简要说明
变量定义int x = 10; final String name = "Tom";var x = 10 val name = "Tom"var 可变,val 只读;类型由编译器推断。
类 + 构造public class P { P(String n) { ... } }class P(val name: String)主构造中声明属性,自动生成字段 & 访问器。
数据类手动写字段/构造/equals/toStringdata class User(val id: Int, val n: String)data 自动生成常用方法 & 解构组件。
函数定义public int sum(int a, int b) { return a + b; }fun sum(a: Int, b: Int) = a + b表达式函数可省略大括号和 return
空安全if (s != null) len = s.length(); else len = 0;val len = s?.length ?: 0String? 可空,?.?: 插入编译期空检查。
分支匹配switch(x) { case 1: ... }when(x) { 1 -> ...; else -> ... }when 是表达式,支持范围 & 任意对象比较。
循环 & 集合for(int i=0;i<10;i++)``list.stream().filter()for(i in 0 until 10)``list.filter{}0 until 生成 IntRange;集合链式调用基于扩展函数。
单例class S { private static S i=new S(); … }object S { fun foo() {} }object 编译时生成线程安全单例,无需额外样板。

3. 独有亮点

特性示例简要说明
默认 & 命名参数fun g(msg: String = "Hi", name: String = "You") g(name="Tom")编译器生成默认方法,命名参数避免重载歧义。
扩展函数fun String.ex() = uppercase()编译后为静态方法,第一个参数是接收者,调用如成员方法。
解构声明val (x, y) = Point(1, 2)data class 自动生成 componentN(),一行取多值。
密封类sealed class R; data class Ok(val d: String): R(); object Err: R()限定子类范围,when 可做穷尽检查。
内联函数inline fun <T> m(b: ()->T): T { … }在调用处展开函数体,减少高阶函数的运行时开销。
集合构造器listOf(1, 2), mutableListOf("A"), mapOf("a" to 1)内建集合工厂函数,语法简洁;to 表示键值对。
数组构造器arrayOf(1, 2), intArrayOf(1, 2)支持泛型与原始类型数组,避免装箱。
表达式返回值val max = if (a > b) a else b val result = try { … } catch { … }ifwhentry 都是表达式,可直接赋值。
区间语法 & 步进for (i in 1..5), for (j in 1 until 5 step 2).. 表闭区间,until 表半开,step 控制步长。
字符串模板"Hello, $name" "Length: ${s.length}"$变量 可直接拼接,复杂表达式用 ${}
Lambda 尾随语法list.filter { it > 0 }.map { it * 2 }大括号可直接跟随函数调用,链式语法自然、简洁。

4. 常用标准库函数

函数用法示例简要说明
letuser?.let { print(it.name) }非空时执行块,it 引用原对象。
applyUser().apply { age = 18 }在对象上执行块并返回该对象,常用于初始化。
alsolist.also { println("init") }执行副作用并返回对象,常用于日志 / 调试。
runval r = run { compute(); result }无接收者的作用域块,返回最后一行结果。
withwith(cfg) { load(); validate() }对象上下文块,this 指向接收者,返回结果。
takeIfstr.takeIf { it.isNotBlank() }条件为真返回对象,否则返回 null
sequencesequenceOf(1,2,3).map { … }惰性集合处理,适合大规模数据管道。

5. 类型系统对比

功能Java 写法Kotlin 写法简要说明
泛型List<String>List<String>支持协变 / 逆变(out / in)和 reified 泛型函数。
类型别名typealias Name = String简化复杂类型声明。
枚举类enum Color { RED, GREEN }enum class Color { RED, GREEN }支持在枚举中定义属性 & 方法。
内联类@JvmInline value class USD(val amount: Int)编译时包装或展开,零开销封装。

6. 类型检测与转换

功能Java 写法Kotlin 写法简要说明
类型检查if (obj instanceof String)if (obj is String)is 后自动智能转换,无需显式强转。
安全转换(String) objobj as String / obj as? Stringas? 安全转换失败返回 null
基本转换Integer.parseInt(str)str.toInt(), toDouble(), toLong()通过扩展函数提供常见类型转换。

7. 控制流程 & 异常

功能Java 写法Kotlin 写法简要说明
条件 & 循环if, switch, for, while, do-whileif, when, for, while, do-whilewhen 可做表达式,替代 switch
返回 & 跳转return, break, continue, throw同 Java支持在 lambda 中局部返回,如 return@label
异常处理try-catch-finally, checked exceptiontry-catch-finally,无 checked exceptionKotlin 不区分受检异常,简化错误处理。

8. 包与导入

功能Java 写法Kotlin 写法简要说明
包声明package com.example;package com.example不需要分号。
导入import java.util.List;import java.util.List支持导入顶层函数和属性。
别名导入import foo.Bar as Baz解决命名冲突或简化引用。

9. 面向对象相关

功能Java 写法Kotlin 写法简要说明
接口默认实现default void f() {}接口中可直接写方法体接口内方法可有实现,无需关键字。
抽象类abstract class Shape { … }abstract class Shape { … }抽象成员不需再加 abstract 前缀。
继承 & 覆写class A extends B { @Override … }class A : B() { override fun … }: 表示继承,override 必显式标注。
可见性修饰符public/protected/privatepublic/protected/private/internalinternal 表示同模块内可见。
内部类class Outer { class Inner {} }class Outer { inner class Inner {} }默认是静态嵌套,加 inner 变为非静态内部类。

10. 协程 vs 多线程

场景Java 写法(线程/异步)Kotlin 写法(协程)简要说明
启动任务new Thread(() -> work()).start();GlobalScope.launch { work() }协程更轻量、省资源,适合大规模并发。
异步返回值Future<Integer> f = exec.submit(...);val result = async { compute() }.await()内建 async/await,语义更清晰。
延迟执行Thread.sleep(1000)delay(1000)非阻塞挂起,不占用线程。
结构化并发手动管理线程池和生命周期coroutineScope { … }协程作用域自动管理生命周期,避免泄漏。

📦 11. 集合操作对比

功能Java 写法(Stream)Kotlin 写法(扩展函数)简要说明
过滤list.stream().filter(x -> x > 0).collect(...)list.filter { it > 0 }语法简洁,链式调用更直观。
映射list.stream().map(x -> x * 2).collect(...)list.map { it * 2 }Lambda 简洁,扩展函数无额外依赖。
分组Collectors.groupingBy(...)list.groupBy { it.key }直接返回 Map<K, List<V>>,更易读。
排序list.sort(Comparator.comparing(...))list.sortedBy { it.prop }函数式排序,链式可读性好。
聚合reduce, sum, collectreduce, sumOf, fold内建多种聚合函数,常用时无需额外导入。

参考:

Kotlin内核编程
Kotlin 语言参考文档
深入理解Kotlin协程
有没有 Kotlin 讲协程比较好的书籍或博客连载
Kotlin 官方文档 中文版

🔲 ⭐

利用whisper为视频自动生成字幕

whisper是一个由openai开发的通用语言识别模型,我们可以使用它来为视频自动创建字幕。

环境安装

为了加速,我们需要使用GPU来进行计算,因此需要安装基于CUDA的pytorch。首先我们需要安装Miniconda,这里安装的时候直接点击下一步即可。

安装完毕之后,我们需要创建一个新的环境,这里我们创建一个名为whisper的环境:

conda create -n whisper python=3.8conda activate whisper

1. 安装CUDA

安装好了Miniconda之后,我们需要安装CUDA,执行nvidia-smi

$ nvidia-smiThu Jan  2 11:49:53 2025+-----------------------------------------------------------------------------------------+| NVIDIA-SMI 560.94                 Driver Version: 560.94         CUDA Version: 12.6     ||-----------------------------------------+------------------------+----------------------+| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC || Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. ||                                         |                        |               MIG M. ||=========================================+========================+======================||   0  NVIDIA GeForce GTX 1060 6GB  WDDM  |   00000000:01:00.0  On |                  N/A ||  0%   39C    P8             10W /  120W |     505MiB /   6144MiB |      0%      Default ||                                         |                        |                  N/A |+-----------------------------------------+------------------------+----------------------+

通过这个命令可以看到Driver Version: 560.94CUDA Version: 12.6,因此我们需要安装12.6版本的CUDA,更加详细的版本对照表在这里。在安装的时候可以选择自定义安装选项,一般来说只要勾选CUDA下的 Development和Runtime即可。

安装完毕之后执行命令nvcc -V查看CUDA版本:

$ nvcc -Vnvcc: NVIDIA (R) Cuda compiler driverCopyright (c) 2005-2024 NVIDIA CorporationBuilt on Thu_Sep_12_02:55:00_Pacific_Daylight_Time_2024Cuda compilation tools, release 12.6, V12.6.77Build cuda_12.6.r12.6/compiler.34841621_0

2. 安装cuDNN

根据自己下载的CUDA来选择对应版本的cuDNN,下载地址在这里。下载完毕之后解压到CUDA的安装目录下,一般来说是C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA{版本号},如果有重名的文件直接替换即可。

之后进入extras\demo_suite目录,执行如下命令:

bandwidthTest.exedeviceQuery.exe

如果出现了PASS的字样,说明安装成功。

3. 安装pytorch

切换到我们之前创建的whisper环境,使用如下命令安装CUDA版本的pytorch:

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

安装之后执行python命令进入python环境,执行如下代码:

1
2
import torch
torch.cuda.is_available()

如果显示True则说明CUDA版本的pytorch安装成功。

4. 安装whisper

切换到我们之前创建的whisper环境,执行如下命令安装whisper:

pip install -U openai-whisperpip install setuptools-rust

安装完毕之后执行如下命令就可以使用whisper了:

whisper 'C:/Users/raymond/Desktop/voice.aac' --language zh --model turbo

如上命令表示对C:/Users/raymond/Desktop/voice.aac文件进行中文语言的识别,使用turbo模型。第一次执行该命令会下载模型文件,模型文件较大,下载时请确保网络通畅。执行结果如下

[00:00.000 --> 00:03.060] 提到肉毒毒素[00:03.060 --> 00:04.540] 你会想到什么[00:04.540 --> 00:10.820] 你真的了解它吗[00:10.820 --> 00:12.540] 2017年[00:12.540 --> 00:14.180] 肉毒毒素以万能药标签[00:14.180 --> 00:15.500] 登上时代周刊方面[00:15.500 --> 00:17.280] 目前它在全球[00:17.280 --> 00:18.960] 已被应用于几十种适应症[00:18.960 --> 00:20.560] 仅在2019年[00:20.560 --> 00:23.000] 接受注射的就已超过620万例[00:23.000 --> 00:24.880] 但不要忘了[00:24.880 --> 00:26.780] 肉毒毒素更是一种神经毒素[00:26.780 --> 00:29.000] 还曾被当作生化武器使用... 省略 ...

生成字幕

我们可以使用ffmpeg将音频从视频中提取出来,然后使用whisper生成字幕,最后使用ffmpeg将字幕添加到视频中。

使用如下命令提取音频:

ffmpeg -i input.mp4 -vn -acodec copy output.aac

然后使用whisper生成字幕,我们先在pycharm中创建一个test-whisper项目,并且把python解释器设置为Miniconda创建的whisper环境。创建一个main.py文件,写入如下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import whisper
from whisper.utils import get_writer

root = 'E:/'

# 使用turbo模型
model = whisper.load_model('turbo')
prompt = '如果使用了中文,请使用简体中文来表示文本内容'

# 选择声音文件,识别中文,并且打印详细信息
result = model.transcribe(root + 'output.aac', language='zh', initial_prompt=prompt, verbose=True)
print(result['text'])

# 保存字幕文件
writer = get_writer('srt', root)
writer(result, 'output.srt')

如上代码表示使用turbo模型,识别中文,打印详细信息,并且保存字幕文件。执行完毕之后我们可以在E:/目录下看到生成的字幕文件。

最后我们使用ffmpeg将字幕添加到视频中:

ffmpeg -i input.mp4 -i output.srt -c:s mov_text -c:v copy -c:a copy output.mp4

之后我们在播放这个视频的时候就会有字幕了。

参考

video-subtitle-generator
基于Anaconda的pytorch-cuda
CUDA与cuDNN的安装与配置
ffmpeg视频合并、格式转换、截图

🔲 ⭐

ffmpeg笔记

合并一个文件夹内的所有视频
1
2
3
4
5
find *.mp4 | sed 's:\ :\\\ :g'| sed 's/^/file /' > fl.txt
ffmpeg -f concat -i fl.txt -c copy output.mp4
// 忽略错误信息
ffmpeg -safe 0 -f concat -i fl.txt -c copy output.mp4
rm fl.txt

参考资源

视频压缩
1
2
3
4
5
6
// 视频使用h.264编码,声音使用aac编码
ffmpeg -i input.mp4 -vcodec h264 -acodec aac output.mp4
// 视频使用h.265编码,压缩到更小文档
ffmpeg -i input.mp4 -vcodec libx265 -crf 28 output.mp4
// 视频使用h.264编码,保留更好的质量
ffmpeg -i input.mp4 -vcodec libx264 -crf 20 output.mp4

crf越小,视频质量越高;crf越大,视频文件越小

编码参数也可以简写,从-vcodec-acodec改为-c:v-c:a

1
2
3
ffmpeg -i input.mp4 -c:v libx264 -crf 23 output.mp4
ffmpeg -i input.mp4 -c:v libx265 -crf 28 output.mp4
ffmpeg -i input.mp4 -c:v libvpx-vp9 -crf 31 -b:v 0 output.mkv

参考资源

其中AVC/H264HEVC/H265都是软件编码,速度很慢。可以选择英伟达的硬件编码:hevc_nvenc与h264_nvenc,它们使用硬件加速,速度很快。

参考资源

使用英伟达显卡进行编码:

1
ffmpeg -i video.mp4 -c:v hevc_nvenc -crf 28 output.mp4

将视频从H.264转码到H.265,花了55分钟,视频体积从3.8GB减小到430MB,效果立竿见影。转码命令:ffmpeg -i 1.mp4 -c:v libx265 -vtag hvc1 -c:a copy 1_hevc.mp4

在win10可以用scoop安装ffmpeg,更新Windows上面通过scoop安装的所有程序
scoop list | foreach { scoop update $_.Name }

将视频以同样的编码,按照指定时间进行裁剪

1
ffmpeg -ss 00:05 -to 08:53.500 -i ./input.mp4 -c copy video.mp4

利用ffmpeg快速剪辑视频

1
ffmpeg -ss 07:18 -to 13:45 -i ./aaa.mkv -c copy bbb.mkv
  • -ss表示开始时间
  • -to表示结束时间
  • -i是输入文档
  • -c表示使用被剪辑视频一样的编码
  • bbb是输出文档的名称

合并视频和声音,视频使用原始编码,声音改为aac编码

1
ffmpeg -i 1.mp4 -i 1.opus -c:v copy -c:a aac output.mp4

将PNG格式图片转为JPG格式图片

1
ffmpeg -i image.png -preset ultrafast image.jpg

修改图片的尺寸

1
2
ffmpeg -i image.jpeg -vf scale=413:626 2寸.jpeg
ffmpeg -i image.jpeg -vf scale=390:567 1寸.jpeg

将一个音频重复10次

1
ffmpeg -stream_loop 10 -i input.m4a -c copy output.m4a
🔲 ⭐

自己动手实现一个可以运行在JVM上的编程语言

众所周知,JVM虚拟机被设计为可以执行栈式指令的机器。因此任何一个语言只要编译之后得到的字节码符合JVM的标准,就可以在JVM上执行,例如Kotlin、Groovy、Scala、Clojure。

我们自己设计一款语言,并命名为Jinx,它支持类定义、变量定义、变量打印。它的语法解析逻辑如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
grammar Jinx;

@header {
package com.nosuchfield.jinx.code;
}

jinx: CLASS ID LEFT_BR classBody RIGHT_BR EOF;
classBody: (variable | print)*;
variable: VARIABLE ID EQUALS value;
print: PRINT ID;
value: STRING | INT | DOUBLE;

LEFT_BR: '{';
RIGHT_BR: '}';
CLASS: 'class';
VARIABLE: 'var';
PRINT: 'print';
EQUALS: '=';
STRING: '"' ('\\"' | ~'"')+ '"';
DOUBLE: [0-9]+ '.' [0-9]+;
INT: [0-9]+;
// 这个ID不能放在前面,不然会被提前解析,导致print等字符串被解析为ID
ID: [a-zA-Z] [a-zA-Z0-9]*;

WS: [\n\r\t ]+ -> skip;

Jinx的最外层是类class,class的内部可以包含变量的定义和打印,变量的值支持字符串、整数和小数。有了ANTLR4的解析逻辑之后,我们就可以处理程序的语法树了,语法树的解析如下

flat
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
public class Loader extends JinxBaseListener {

/**
* 变量表,以变量名为key,包括:变量索引idx、变量类型
*/
private final Map<String, ImmutablePair<Integer, Integer>> variables = new HashMap<>();

/**
* 指令列表
*/
private final List<Instruction> instructions = new ArrayList<>();

private String className;

@Override
public void enterJinx(JinxParser.JinxContext ctx) {
className = ctx.ID().getText();
}

@Override
public void exitVariable(JinxParser.VariableContext ctx) {
// 变量名
String name = ctx.ID().getText();
JinxParser.ValueContext variable = ctx.value();
// 变量值
String text = variable.getText();
// 变量类型
int type = variable.getStart().getType();
// 变量索引(在局部变量表中这是第几个变量)
int idx = variables.size();

// 把这个变量保存在内存,方便后面知道这个变量的索引和类型
variables.put(name, ImmutablePair.of(idx, type));
// 创建保存这个变量的指令
instructions.add(new VariableInstruction(idx, type, text));
}

@Override
public void exitPrint(JinxParser.PrintContext ctx) {
String name = ctx.ID().getText();
if (!variables.containsKey(name)) {
System.err.printf("variable %s not exist\n", name);
System.exit(1);
}
int idx = variables.get(name).getLeft();
int type = variables.get(name).getRight();
// 创建打印的指令
instructions.add(new PrintInstruction(idx, type));
}

public List<Instruction> getInstructions() {
return instructions;
}

public String getClassName() {
return className;
}

}

在上面的语法树解析中,我们会解析每一个变量的定义语法和打印语法。

变量定义

我们会在定义每个变量的时候记录下变量的类型和索引,并把记录的数据关联到这个变量的名字上。此外,我们还会针对这个变量的类型、索引和值生成JVM保存变量的指令。

变量打印

在打印程序的解析中,我们会先通过变量的名称从关联表中取出变量的类型和索引(如果不存在就报错),之后根据变量的类型和索引创建JVM打印的指令。

上面的语法树解析最终生成了一个指令列表instructions,我们接下来根据这个指令列表生成JVM所需要的字节码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
private byte[] generateBytecode(List<Instruction> instructions, String className) {
ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
classWriter.visit(V1_8, ACC_PUBLIC + ACC_SUPER, className, null, "java/lang/Object", null);
// main方法
MethodVisitor methodVisitor = classWriter.visitMethod(ACC_PUBLIC + ACC_STATIC, "main",
"([Ljava/lang/String;)V", null, null);
for (Instruction instruction : instructions) {
instruction.apply(methodVisitor);
}
methodVisitor.visitInsn(RETURN);
methodVisitor.visitMaxs(0, 0); // 设置COMPUTE_FRAMES后会自动计算,但是此处设置不能省略
methodVisitor.visitEnd();
classWriter.visitEnd();
return classWriter.toByteArray();
}

如上我们根据指令和类名使用ASM生成了字节码数据,它生成了一个包含main方法的类,并且把我们的指令放在main方法中。每个指令都调用了其apply方法,接下来我们具体看一下变量定义和变量打印的apply方法是如何实现的。

变量定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public void apply(MethodVisitor mv) {
switch (type) {
case JinxLexer.DOUBLE -> {
double val = Double.parseDouble(value);
// 常量池的数据推到栈顶
mv.visitLdcInsn(val);
// 栈顶double值存入本地局部变量,idx代表索引
mv.visitVarInsn(DSTORE, idx);
}
case JinxLexer.INT -> {
int val = Integer.parseInt(value);
mv.visitLdcInsn(val);
mv.visitVarInsn(ISTORE, idx);
}
case JinxLexer.STRING -> {
mv.visitLdcInsn(Utils.removeFirstAndLastChar(value));
mv.visitVarInsn(ASTORE, idx);
}
}

变量的定义很简单,都是先把变量的值从常量池取出,然后推到操作数栈的顶部。之后从操作数栈顶取数据,根据变量的idx把变量保存到局部变量表的指定索引位置。区别在于浮点型的保存指令是DSTORE,整型是ISTORE,字符串是ASTORE

变量打印

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public void apply(MethodVisitor mv) {
mv.visitFieldInsn(GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
switch (type) {
case JinxLexer.INT -> {
mv.visitVarInsn(ILOAD, idx);
mv.visitMethodInsn(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(I)V", false);
}
case JinxLexer.DOUBLE -> {
mv.visitVarInsn(DLOAD, idx);
mv.visitMethodInsn(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(D)V", false);
}
case JinxLexer.STRING -> {
mv.visitVarInsn(ALOAD, idx);
mv.visitMethodInsn(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
}
}
}

变量的打印会先使用System.out变量,之后从局部变量表中根据变量的idx取出变量的值,然后执行println方法,入参分别为整型、浮点型和字符串。

有了以上这些指令,我们就可以正常生成字节码了,我们进行语法分析生成instructions,并使用instructions最终生成字节码文件。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public void compile0(String file) throws IOException {
// 词法分析
JinxLexer lexer = new JinxLexer(CharStreams.fromFileName(file));
CommonTokenStream tokens = new CommonTokenStream(lexer);
// 语法分析
JinxParser parser = new JinxParser(tokens);
parser.removeErrorListeners();
parser.addErrorListener(new ErrorHandler()); // 语法分析错误处理
ParseTree tree = parser.jinx();
// 语法树遍历
ParseTreeWalker parseTreeWalker = new ParseTreeWalker();
Loader loader = new Loader();
parseTreeWalker.walk(loader, tree);
// 遍历语法树生成Java指令
List<Instruction> instructions = loader.getInstructions();
// 生成Java.class文件
String className = loader.getClassName();
String classFile = Paths.get(new File(file).getParent(), className + ".class").toString();
writeByteArrayToFile(classFile, generateBytecode(instructions, className));
}

上面代码的最后一行就是根据指令列表和类名生成字节码,并把字节码保存到文件中。我们创建一个源代码

class Test {    var name = "Mike"    var salary = 2370    print name    print salary    var number = 1.1    print number}

使用编译器解析如上代码并最终生成一个字节码文件Test.class,运行这个字节码文件可以打印出变量的值

$ java TestMike23701.1

我们也可以查看字节码的信息如下

$ javap -verbose TestClassfile /src/main/resources/jinx/Test.classLast modified Jan 3, 2023; size 342 bytesMD5 checksum fff7d9ac9c044299ffd5a6194c452502public class Testminor version: 0major version: 52flags: ACC_PUBLIC, ACC_SUPERConstant pool:#1 = Utf8               Test#2 = Class              #1             // Test#3 = Utf8               java/lang/Object#4 = Class              #3             // java/lang/Object#5 = Utf8               main#6 = Utf8               ([Ljava/lang/String;)V#7 = Utf8               Mike#8 = String             #7             // Mike#9 = Integer            2370#10 = Utf8               java/lang/System#11 = Class              #10            // java/lang/System#12 = Utf8               out#13 = Utf8               Ljava/io/PrintStream;#14 = NameAndType        #12:#13        // out:Ljava/io/PrintStream;#15 = Fieldref           #11.#14        // java/lang/System.out:Ljava/io/PrintStream;#16 = Utf8               java/io/PrintStream#17 = Class              #16            // java/io/PrintStream#18 = Utf8               println#19 = Utf8               (Ljava/lang/String;)V#20 = NameAndType        #18:#19        // println:(Ljava/lang/String;)V#21 = Methodref          #17.#20        // java/io/PrintStream.println:(Ljava/lang/String;)V#22 = Utf8               (I)V#23 = NameAndType        #18:#22        // println:(I)V#24 = Methodref          #17.#23        // java/io/PrintStream.println:(I)V#25 = Double             1.1d#27 = Utf8               (D)V#28 = NameAndType        #18:#27        // println:(D)V#29 = Methodref          #17.#28        // java/io/PrintStream.println:(D)V#30 = Utf8               Code{public static void main(java.lang.String[]);    descriptor: ([Ljava/lang/String;)V    flags: ACC_PUBLIC, ACC_STATIC    Code:    stack=3, locals=4, args_size=1        0: ldc           #8                  // String Mike        2: astore_0        3: ldc           #9                  // int 2370        5: istore_1        6: getstatic     #15                 // Field java/lang/System.out:Ljava/io/PrintStream;        9: aload_0        10: invokevirtual #21                 // Method java/io/PrintStream.println:(Ljava/lang/String;)V        13: getstatic     #15                 // Field java/lang/System.out:Ljava/io/PrintStream;        16: iload_1        17: invokevirtual #24                 // Method java/io/PrintStream.println:(I)V        20: ldc2_w        #25                 // double 1.1d        23: dstore_2        24: getstatic     #15                 // Field java/lang/System.out:Ljava/io/PrintStream;        27: dload_2        28: invokevirtual #29                 // Method java/io/PrintStream.println:(D)V        31: return}

参考

ANTLR4表达式
Java代码
Java ASM系列
Enkel-JVM-language

🔲 ⭐

ANTLR4从入门到实践

ANTLR(ANother Tool for Language Recognition)是一个强大的解析器生成器,用于读取、处理、执行或翻译结构化文本或二进制文档。它被广泛用于构建语言、工具和框架。ANTLR根据语法定义生成解析器,解析器可以构建和遍历解析树。

安装

以Linux系统为例,我们首先安装Java17

~ java -versionjava version "17.0.6" 2023-01-17 LTSJava(TM) SE Runtime Environment (build 17.0.6+9-LTS-190)Java HotSpot(TM) 64-Bit Server VM (build 17.0.6+9-LTS-190, mixed mode, sharing)

随后我们下载antlr4的完整依赖包

wget https://www.antlr.org/download/antlr-4.13.0-complete.jar

并把依赖包添加到Java的CLASSPATH中,将以下命令添加到~/.zshrc文件中

export CLASSPATH="/home/raymond/Desktop/antlr4/antlr-4.13.0-complete.jar:$CLASSPATH"

之后我们就可以使用antlr4的Tool和TestRig了

~ java org.antlr.v4.ToolANTLR Parser Generator  Version 4.13.0-o ___              specify output directory where all output is generated-lib ___            specify location of grammars, tokens files-atn                generate rule augmented transition network diagrams-encoding ___       specify grammar file encoding; e.g., euc-jp-message-format ___ specify output style for messages in antlr, gnu, vs2005-long-messages      show exception details when available for errors and warnings-listener           generate parse tree listener (default)-no-listener        don't generate parse tree listener-visitor            generate parse tree visitor-no-visitor         don't generate parse tree visitor (default)-package ___        specify a package/namespace for the generated code-depend             generate file dependencies-D<option>=value    set/override a grammar-level option-Werror             treat warnings as errors-XdbgST             launch StringTemplate visualizer on generated code-XdbgSTWait         wait for STViz to close before continuing-Xforce-atn         use the ATN simulator for all predictions-Xlog               dump lots of logging info to antlr-timestamp.log-Xexact-output-dir  all output goes into -o dir regardless of paths/package~ java org.antlr.v4.gui.TestRigjava org.antlr.v4.gui.TestRig GrammarName startRuleName[-tokens] [-tree] [-gui] [-ps file.ps] [-encoding encodingname][-trace] [-diagnostics] [-SLL][input-filename(s)]Use startRuleName='tokens' if GrammarName is a lexer grammar.Omitting input-filename makes rig read from stdin.

可以在~/.zshrc中添加如下别名

alias antlr4='java org.antlr.v4.Tool'alias grun='java org.antlr.v4.gui.TestRig'

后面就可以直接使用antlr4和grun命令了

一个简单的例子

我们从一个最简单的例子来看antlr4,创建一个名为Hello.g4的文件并输入如下内容

1
2
3
4
5
grammar Hello;              // 语法名称,必须要和文件名称一样

r : 'hello' ID ; // 表示匹配字符串hello和ID这个token,语法名称用小写字母定义
ID : [a-z]+ ; // ID这个token的定义只允许小写字母,词法名称用大写字母定义
WS : [ \t\r\n]+ -> skip ; // 忽略一些字符

随后执行antlr4 Hello.g4 -o code命令将语法文件转化为Java的代码,具体生成的文件如下

HelloBaseListener.javaHello.interpHelloLexer.interpHelloLexer.javaHelloLexer.tokensHelloListener.javaHelloParser.javaHello.tokens

之后执行命令javac *.java将所有的Java代码进行编译,编译完了之后执行命令grun Hello r -tree并输入相关文本内容,之后输入EOF(Linux上面是Ctrl + D)可以得到解析结果

➜  grun Hello r -treehello antlr<EOF>(r hello antlr)

其中Hello是语法文件的名称,r则是语法的名称,-tree表示以lisp语法展示语法,我们也可以使用-gui选项展示语法树。

Visual Studio Code提供了antlr4的插件,可以方便的进行语法高亮和格式化等操作。IntelliJ Idea也提供了插件,具有快速生成代码、设置生成代码的参数以及查看语法树等功能。

使用antlr4构建一个计算器

首先我们创建一个Calc.g4文件,具体内容如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
grammar Calc;       // 语法的名称,要和文件名称一致

calc: (expr)* EOF; // 一个或多个表达式

expr:
BRACKET_L expr BRACKET_R // 圆括号
| (ADD | SUB)? (NUMBER | PERCENT_NUMBER) // 正负数字和百分数
| expr (MUL | DIV) expr // 乘除法
| expr (ADD | SUB) expr; // 加减法

PERCENT_NUMBER: NUMBER PERCENT; // 百分数
NUMBER: DIGIT (POINT DIGIT)?; // 小数

DIGIT: [0-9]+; // 数字
BRACKET_L: '('; // 左括号
BRACKET_R: ')'; // 右括号
ADD: '+';
SUB: '-';
MUL: '*';
DIV: '/';
PERCENT: '%';
POINT: '.';

WS: [ \t\r\n]+ -> skip; // 跳过空格换行等字符

执行命令antlr4 Calc.g4 -o code来生成代码,并将生成的代码放到code文件夹中。进入code文件夹,执行javac *.java命令编译代码。编译完代码之后,就可以执行测试程序了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
➜ grun Calc calc -tree
1 + 2 * (3 + 4) - 5 / 6
(calc (expr (expr (expr 1) + (expr (expr 2) * (expr ( (expr (expr 3) + (expr 4)) )))) - (expr (expr 5) / (expr 6))) <EOF>)

➜ grun Calc calc -tokens
1 + 2 * (3 + 4) - 5 / 6
[@0,0:0='1',<NUMBER>,1:0]
[@1,2:2='+',<'+'>,1:2]
[@2,4:4='2',<NUMBER>,1:4]
[@3,6:6='*',<'*'>,1:6]
[@4,8:8='(',<'('>,1:8]
[@5,9:9='3',<NUMBER>,1:9]
[@6,11:11='+',<'+'>,1:11]
[@7,13:13='4',<NUMBER>,1:13]
[@8,14:14=')',<')'>,1:14]
[@9,16:16='-',<'-'>,1:16]
[@10,18:18='5',<NUMBER>,1:18]
[@11,20:20='/',<'/'>,1:20]
[@12,22:22='6',<NUMBER>,1:22]
[@13,31:30='<EOF>',<EOF>,2:0]

➜ grun Calc calc -gui
1 + 2 * (3 + 4) - 5 / 6

第一个命令是生成Lisp风格的语法树,第二个命令是查看相应的token,第三个命令生成的语法树如下所示

通过Java代码调用生成的Lexer和Parser

还是以上面的例子为例,这次我们把词法分析和语法分析的内容分开来,分别创建CalcLexerRules.g4Calc.g4文件,它们的内容分别如下

CalcLexerRules.g4

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
lexer grammar CalcLexerRules;

PERCENT_NUMBER: NUMBER PERCENT;
NUMBER: DIGIT (POINT DIGIT)?;

DIGIT: [0-9]+;
BRACKET_L: '(';
BRACKET_R: ')';
ADD: '+';
SUB: '-';
MUL: '*';
DIV: '/';
PERCENT: '%';
POINT: '.';

WS: [ \t\r\n]+ -> skip;

Calc.g4

1
2
3
4
5
6
7
8
9
10
grammar Calc;
import CalcLexerRules; // 引入CalcLexerRules的词法规则

calc: (expr)* EOF;

expr:
BRACKET_L expr BRACKET_R
| (ADD | SUB)? (NUMBER | PERCENT_NUMBER)
| expr (MUL | DIV) expr
| expr (ADD | SUB) expr;

创建这两个文件之后,执行命令antlr4 Calc.g4 -o code生成代码,antlr会自动把CalcLexerRules.g4的内容引入进来。在生成代码的code文件夹下创建Java文件CalcTest.java,并使用Java代码调用生成的Lexer和Parser类中的方法

CalcTest.java

1
2
3
4
5
6
7
8
9
10
11
12
13
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.tree.ParseTree;

public class CalcTest {
public static void main(String[] args) throws Exception {
CalcLexer lexer = new CalcLexer(CharStreams.fromString("1 + 2 * (3 + 4) - 5 / 6"));
CommonTokenStream tokens = new CommonTokenStream(lexer);
CalcParser parser = new CalcParser(tokens);
ParseTree tree = parser.calc();
System.out.println(tree.toStringTree(parser));
}
}

添加了如上的类之后,执行命令javac *.java编译源码文件,之后执行命令java CalcTest来运行Java代码,得到结果如下

(calc (expr (expr (expr 1) + (expr (expr 2) * (expr ( (expr (expr 3) + (expr 4)) )))) - (expr (expr 5) / (expr 6))) <EOF>)

运行结果和上面的grun的测试结果是一致的。

通过Visitor访问代码

上面我们使用Java代码调用了CalcLexer和CalcParser类,接下来我们实现一个Visitor,通过Visitor来访问我们所需要访问的AST节点,并执行计算器的计算功能。

这里我们使用Idea的ANTLR v4插件来生成代码,上面的词法文件CalcLexerRules.g4不需要任何改变,而语法文件Calc.g4修改如下

1
2
3
4
5
6
7
8
9
10
11
12
13
grammar Calc;
@header {
package com.nosuchfield.calc.code;
}
import CalcLexerRules; // 引入词法分析文件

calc: (expr)* EOF # calculationBlock;

expr:
BRACKET_L expr BRACKET_R # expressionWithBr
| sign = (ADD | SUB)? num = (NUMBER | PERCENT_NUMBER) # expressionNumeric
| expr op = (MUL | DIV) expr # expressionMulOrDiv
| expr op = (ADD | SUB) expr # expressionAddOrSub;

这里我们添加了@header标记,表示在生成代码的时候在代码头部生成我们所需要的内容,如上就是在代码头部放上了类的package声明。

我们还在每个语法后面使用井号#设置了一个标记名称,这个名称在生成Visitor代码的时候会生成相应名称的方法。此外我们还给表达式的参数设置了名称,例如sign、num和op,这样当生成代码的时候,我们就可以用参数num取到NUMBER或者PERCENT_NUMBER的值。

我们在Calc.g4文件上右击并选择Configure ANTLR选项

之后设置代码的生成目录为src/main/java/com/nosuchfield/calc/code,并且去掉生成listener的选项,同时选择生成visitor的选项

设置好了之后我们右击Calc.g4文件并右击选择Generate ANTLR Recognizer选项,即可在com/nosuchfield/calc/code文件夹下生成相关的代码

接下来我们自定义一个继承自CalcBaseVisitor类的CalculateVisitor,具体如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package com.nosuchfield.calc;

import com.nosuchfield.calc.code.CalcBaseVisitor;
import com.nosuchfield.calc.code.CalcLexer;
import com.nosuchfield.calc.code.CalcParser;
import org.antlr.v4.runtime.Token;

import java.math.BigDecimal;
import java.math.MathContext;
import java.util.Objects;

public class CalculateVisitor extends CalcBaseVisitor<BigDecimal> {

/**
* 用于设置BigDecimal的计算精度
*/
private static final MathContext MATH_CONTEXT = MathContext.DECIMAL128;

/**
* calc语法,包含了多个expr,返回最后一个expr的结果
*/
@Override
public BigDecimal visitCalculationBlock(CalcParser.CalculationBlockContext ctx) {
BigDecimal calcResult = null;
for (CalcParser.ExprContext expr : ctx.expr()) {
calcResult = visit(expr);
}
return calcResult;
}

/**
* 左右括号,取出括号中的表达式
*/
@Override
public BigDecimal visitExpressionWithBr(CalcParser.ExpressionWithBrContext ctx) {
return visit(ctx.expr());
}

/**
* 乘除法,返回左右两个元素的计算结果
* 其中op属性是在语法文件中自定义的
*/
@Override
public BigDecimal visitExpressionMulOrDiv(CalcParser.ExpressionMulOrDivContext ctx) {
BigDecimal left = visit(ctx.expr(0));
BigDecimal right = visit(ctx.expr(1));
switch (ctx.op.getType()) {
case CalcParser.MUL:
return left.multiply(right, MATH_CONTEXT);
case CalcParser.DIV:
return left.divide(right, MATH_CONTEXT);
default:
throw new RuntimeException("unsupported operator type");
}
}

/**
* 加减法,返回左右两个元素的计算结果
* 其中op属性是在语法文件中自定义的
*/
@Override
public BigDecimal visitExpressionAddOrSub(CalcParser.ExpressionAddOrSubContext ctx) {
BigDecimal left = visit(ctx.expr(0));
BigDecimal right = visit(ctx.expr(1));
switch (ctx.op.getType()) {
case CalcParser.ADD:
return left.add(right, MATH_CONTEXT);
case CalcParser.SUB:
return left.subtract(right, MATH_CONTEXT);
default:
throw new RuntimeException("unsupported operator type");
}
}

/**
* 获取数值,num属性是在语法文件中定义的
* 如果数值前有负号就取负值
*/
@Override
public BigDecimal visitExpressionNumeric(CalcParser.ExpressionNumericContext ctx) {
BigDecimal numeric = numberOrPercent(ctx.num);
if (Objects.nonNull(ctx.sign) && ctx.sign.getType() == CalcLexer.SUB) {
return numeric.negate();
}
return numeric;
}

/**
* 将文本内容转化为BigDecimal,包含数字和百分数
*/
private BigDecimal numberOrPercent(Token num) {
String numberStr = num.getText();
switch (num.getType()) {
case CalcLexer.NUMBER:
return new BigDecimal(numberStr);
case CalcLexer.PERCENT_NUMBER:
return new BigDecimal(numberStr.substring(0, numberStr.length() - 1).trim())
.divide(BigDecimal.valueOf(100), MATH_CONTEXT);
default:
throw new RuntimeException("unsupported number type");
}
}

}

在自定义的Visitor中我们实现了计算逻辑,可以看到,这里重写了类CalcBaseVisitor的5个方法,分别对应了语法文件中的5个标记以及它们定义的名称,而属性的定义如num则对应了方法中入参的属性。以expressionNumeric语法为例,它对应的了方法visitExpressionNumeric,我们可以通过方法入参ExpressionNumericContext取到sign和num属性,之后通过这两个属性来定义数字的值。而expressionMulOrDiv语法就是通过op取到运算符,之后对两边的数字根据运算符来进行相应的计算。

有了visitor之后,我们用一个测试类来测试计算结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
package com.nosuchfield.calc;

import com.nosuchfield.calc.code.CalcLexer;
import com.nosuchfield.calc.code.CalcParser;
import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.junit.Test;

import java.math.BigDecimal;

import static junit.framework.TestCase.assertEquals;

public class TestCalculate {

@Test
public void testCalculate() {
String[][] sources = new String[][]{
{"1 + 2", "3"},
{"3 - 2", "1"},
{"2 * 3", "6"},
{"6 / 3", "2"},
{"6 / (1 + 2)", "2"},
{"50%", "0.5"},
{"100 * 30%", "30.0"},
{"1 + 2 * (3 - 4) / 5", "0.6"},
{"-8 + 8 * 2 - 8", "0"}
};
for (String[] source : sources) {
String input = source[0].trim();
BigDecimal result = new BigDecimal(source[1].trim());
assertEquals(calculate(input), result);
}
}

/**
* 计算表达式
*
* @param expression 表达式
* @return 计算的结果
*/
private BigDecimal calculate(String expression) {
CharStream cs = CharStreams.fromString(expression);
CalcLexer lexer = new CalcLexer(cs);
CommonTokenStream tokens = new CommonTokenStream(lexer);
CalcParser parser = new CalcParser(tokens);
CalcParser.CalcContext context = parser.calc();
CalculateVisitor visitor = new CalculateVisitor();
return visitor.visit(context);
}

}

可以看到,我们构建的计算器已经成功的计算出了正确结果。

ANTLR4的工作流程

在上面的例子中我们已经了解到,使用antlr4的一般流程如下

  1. 书写antlr4的词法和文法规则
  2. 使用antlr4的生成工具处理写好的规则,以生成指定语言的Lexer和Parser代码
  3. 调用生成的Lexer和Parser类,书写相应的逻辑代码,将原始输入文本转化为一个抽象语法树
  4. 使用antlr4的visitor来解析语法树,实现各种功能

实际上,除了visitor之外,antlr4还提供了另一种解析语法树方式,叫做Listener。Listener是antlr4默认解析语法树的方式,它和visitor一样都可以实现对ParseTree的解析。如果开启了visitor或listener,那么antlr4除了会生成Lexer和Parser代码,还会生成相应的Visitor和Listener代码。

Listener和Visitor区别如下

ListenerVisitor
是否访问所有节点访问所有节点只访问手动指定的节点
访问节点方式通过enter和exit方法通过visit方法
方法是否有返回值没有返回值有返回值

了解了Listener和Visitor的区别之后,我们可以总结出antlr4的大致工作流程如下

如上左边的点线流程代表了通过ANTLR4,将原始的.g4规则转化为Lexer、Parser、Listener和Visitor。右边的虚线流程代表了将原始的输入流通过Lexer转化为Tokens,再将Tokens通过Parser转化为语法树,最后通过Listener或Visitor遍历ParseTree得到最终结果。

解析CSV文件

我们已经使用Visitor构建过一个计算器,接下来我们使用Listener实现对CSV的解析。
Comma-separated values (CSV)文件是一种使用英文逗号 , 来分割字段的文件格式。文件分为多行,每行又被逗号分割为多列,第一行的内容可以当作字段的名称。下面是一个例子

省份,城市,区县,描述江苏,南京,雨花台,外包大道浙江,杭州,西湖,太美丽啦!西湖,上海,黄浦,"as it says: ""hello, shanghai"""

分析这个格式,首先是一行头部,之后跟着多行数据,因此可以很容易的得出如下的语法规则

csv: hdr row*;

而头部也是一样的数据格式,因此有如下规则

hdr: row;

数据是一些由逗号分割的字段,因此可以定义数据如下。其中\n是Mac和Linux的换行符,\r\n则是Windows下的换行符,因此\r是可选的

row: field (',' field)* '\r'? '\n';

接下来只需要定义field的词法即可,因为换行和逗号都是CSV中的格式符号,不允许在字符中存在。因此可以很容易的得到

field: ~[\n,\r]+;

~代表取反,也就是除了换行和逗号之外的其它多个字符。

有了上面这个规则还不够,因为CSV标准规定了,如果有特殊字符,可以用双引号包起来。例如一个逗号如果被包含在双引号里面,那么就是一个字段的组成部分而不是字段的分隔符。如果双引号包裹的内容中又有双引号,那么需要将这个字段内部的双引号用两个双引号进行替代。
因此我们还需要一个规则

field: '"' ('""' | ~'"')* '"';

如上规则表示用双引号包裹的内容,可以是两个双引号或者除了单个双引号之外的其它任意内容。

CSV还允许空字段

field: ;

整理如上规则,并添加包配置和相关的标记

1
2
3
4
5
6
7
8
9
10
11
12
13
grammar Csv;

@header {
package com.nosuchfield.csv.code;
}

csv: hdr row*;
hdr: row;
row: field (',' field)* '\r'? '\n';
field: TEXT # text | STRING # string | # empty;

TEXT: ~[\n,\r]+;
STRING: '"' ('""' | ~'"')* '"';

之后配置代码生成目录为com/nosuchfield/csv/code,并去掉生成Visitor的选项,勾选生成Listener的选项,使用antlr4生成代码,生成的Java代码如下

CsvBaseListener.javaCsvLexer.javaCsvListener.javaCsvParser.java

可以看到除了Lexer和Parser,还生成了相应的Listener代码。我们创建一个继承自CsvBaseListener的类如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
package com.nosuchfield.csv;

import com.nosuchfield.csv.code.CsvBaseListener;
import com.nosuchfield.csv.code.CsvParser;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class CsvListener extends CsvBaseListener {

/**
* CSV的多行数据
*/
private final List<Map<String, String>> rows = new ArrayList<>();

/**
* CSV的头部
*/
private List<String> header;

/**
* 一行CSV数据
*/
private List<String> row;

/**
* 进入一行
*/
@Override
public void enterRow(CsvParser.RowContext ctx) {
// 创建一个list用来保存这一行的数据
row = new ArrayList<>();
}

/**
* 离开TEXT
*/
@Override
public void exitText(CsvParser.TextContext ctx) {
// 添加这一列的数据
row.add(ctx.TEXT().getText());
}

@Override
public void exitString(CsvParser.StringContext ctx) {
// 获取字符
String field = ctx.STRING().getText();
// 移除头部和尾部的双引号
field = field.substring(1, field.length() - 2);
// 因为CSV在双引号中用两个双引号代表单引号,这里转回来
field = field.replaceAll("\"\"", "\"");
row.add(field);
}

@Override
public void exitEmpty(CsvParser.EmptyContext ctx) {
// 添加空字符串
row.add("");
}

/**
* 离开某一行
*/
@Override
public void exitRow(CsvParser.RowContext ctx) {
if (ctx.getParent() instanceof CsvParser.HdrContext) {
// 如果某一行的父节点是header头部
// 那么就把header的值设置成这一行的数据
header = row;
return;
}
Map<String, String> data = new HashMap<>();
// 某一行已经遍历完毕,将这一行的数据和header组合起来,构成一个map
for (int i = 0; i < row.size(); i++) {
data.put(header.get(i), row.get(i));
}
// 将这一行数据添加到数据集中
rows.add(data);
}

public List<Map<String, String>> getRows() {
return rows;
}

}

如上的Listener在进入行的时候初始化容器,在退出字段的时候将字段的数据保存到容器中,并在退出行的时候最终保存所有的字段。我们通过Lexer和Parser来解析上面的CSV数据,最终生成一个ParseTree,并调用Listener遍历ParseTree来解析生成的数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
public void testCsv() throws IOException {
// 从文件中取得CSV数据流,并生成lexer
CsvLexer lexer = new CsvLexer(CharStreams.fromFileName("src/main/resources/csv/city.csv"));
// 根据lexer生成token
CommonTokenStream tokens = new CommonTokenStream(lexer);
// 将token交给parser
CsvParser parser = new CsvParser(tokens);
// 生成语法树
ParseTree tree = parser.csv();
// 打印语法树
System.out.println(tree.toStringTree(parser));

// 构建语法树遍历器
ParseTreeWalker parseTreeWalker = new ParseTreeWalker();
// 语法树监听器
CsvListener listener = new CsvListener();
// 遍历语法树
parseTreeWalker.walk(listener, tree);
// 打印生成的结果
System.out.println(listener.getRows());
}

执行上面的代码得到结果如下,可以看到完整的打印出了CSV的数据

(csv (hdr (row (field 省份) , (field 城市) , (field 区县) , (field 描述) \r \n)) (row (field 江苏) , (field 南京) , (field 雨花台) , (field 外包大道) \r \n) (row (field 浙江) , (field 杭州) , (field 西湖) , (field 太美丽啦!西湖) \r \n) (row field , (field 上海) , (field 黄浦) , (field "as it says: ""hello, shanghai""") \r \n))[{省份=江苏, 描述=外包大道, 城市=南京, 区县=雨花台}, {省份=浙江, 描述=太美丽啦!西湖, 城市=杭州, 区县=西湖}, {省份=, 描述=as it says: "hello, shanghai", 城市=上海, 区县=黄浦}]

通过Listener构建一个计算器

在上面的例子中,我们已经使用了Visitor实现了一个计算器,实际上通过Listener也可以实现相同的功能。在Visitor中我们通过方法的返回值来存储计算结果,在Listener中方法没有返回值,那我们就需要通过另一种方式来进行计算并存储计算结果 —— 栈。

还是使用上面的词法分析和语法分析规则,这次我们勾选生成Listener选项,之后再次生成代码,这次会生成CalcListener接口和CalcBaseListener类,我们实现一个继承自CalcBaseListener的类。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
public class CalculateListener extends CalcBaseListener {

private static final MathContext MATH_CONTEXT = MathContext.DECIMAL128;

private Stack<BigDecimal> stack;

private BigDecimal result;

@Override
public void enterCalculationBlock(CalcParser.CalculationBlockContext ctx) {
// 创建新的栈
stack = new Stack<>();
}

@Override
public void exitCalculationBlock(CalcParser.CalculationBlockContext ctx) {
// 取出栈顶元素作为结果
result = stack.pop();
}

@Override
public void exitExpressionMulOrDiv(CalcParser.ExpressionMulOrDivContext ctx) {
// 将栈顶的两个元素取出来做乘除法,将结果压回栈
BigDecimal x = stack.pop();
BigDecimal y = stack.pop();
BigDecimal z;
switch (ctx.op.getType()) {
case CalcLexer.MUL:
z = y.multiply(x, MATH_CONTEXT);
break;
case CalcLexer.DIV:
z = y.divide(x, MATH_CONTEXT);
break;
default:
throw new RuntimeException("unsupported operator type");
}
stack.push(z);
}

@Override
public void exitExpressionAddOrSub(CalcParser.ExpressionAddOrSubContext ctx) {
// 将栈顶两个元素取出来做加减法,将结果压回栈
BigDecimal x = stack.pop();
BigDecimal y = stack.pop();
BigDecimal z;
switch (ctx.op.getType()) {
case CalcLexer.ADD:
z = y.add(x, MATH_CONTEXT);
break;
case CalcLexer.SUB:
z = y.subtract(x, MATH_CONTEXT);
break;
default:
throw new RuntimeException("unsupported operator type");
}
stack.push(z);
}

@Override
public void exitExpressionNumeric(CalcParser.ExpressionNumericContext ctx) {
// 计算数字
BigDecimal numeric = numberOrPercent(ctx.num);
if (Objects.nonNull(ctx.sign) && ctx.sign.getType() == CalcLexer.SUB) {
numeric = numeric.negate();
}
stack.push(numeric);
}

private BigDecimal numberOrPercent(Token num) {
String numberStr = num.getText();
switch (num.getType()) {
case CalcLexer.NUMBER:
return new BigDecimal(numberStr);
case CalcLexer.PERCENT_NUMBER:
return new BigDecimal(numberStr.substring(0, numberStr.length() - 1).trim())
.divide(BigDecimal.valueOf(100), MATH_CONTEXT);
default:
throw new RuntimeException("unsupported number type");
}
}

/**
* 获取计算结果
*
* @return 计算结果
*/
public BigDecimal getResult() {
return result;
}

}

上面的代码和Visitor非常相似,区别在于针对加减法和乘除法的计算,Visitor是直接拿方法参数计算,并将结果作为返回值返回。而Listener是从栈的顶部取出两个元素进行计算,并将计算结果压回栈。

如果你了解方法调用的一般方式,就应该知道其实方法调用的一般方式也是通过栈来存储方法的入参和出参的。在方法调用前,将方法入参的值压入栈中,之后运行方法,如果方法中还有方法调用,继续将入参压入栈。当方法开始执行时,将方法的入参弹出,等到方法执行完毕,将执行完毕的方法返回值压入栈,如此往复就形成了方法调用。

因此我们可以知道,计算Visitor和Listener的逻辑基本一致,都是使用栈来存储计算的数值和计算的结果。区别在于Visitor的值是存储在当前运行线程的栈上的,如果值过多,可能因为栈空间不够导致StackOverflow错误。而Listener的值是保存在我们自定义的位于堆内存的栈数据结构上的,可以存储更多的数据内容。

完整的代码位于https://github.com/RitterHou/test-antlr4

参考

ANTLR 4权威指南
语法解析器ANTLR4从入门到实践
从一个小例子理解Antlr4
Antlr4系列(二):实现一个计算器
ANTLR 使用——以表达式语法为例
Antlr4教程

🔲 ☆

ShardingSphere-JDBC介绍

ShardingSphere-JDBC是一款可以将JDBC操作进行封装,然后实现数据分片、分布式事务、读写分离、高可用、数据加密和数据脱敏等功能的模块。它的原理是实现JDBC的接口,随后将收到的JDBC操作进行改写和处理,再将操作命中到真正的数据库之上。因为它实现了JDBC接口,因此现有的Java项目都可以100%兼容使用,只需要依赖ShardingSphere-JDBC并提供相关的配置即可。

JDBC数据分片的简单使用

我们看一个简单的JDBC数据分片的例子,首先我们需要添加相关的maven依赖

1
2
3
4
5
6
7
8
9
10
 <dependency>
<groupId>org.apache.shardingsphere</groupId>
<artifactId>shardingsphere-jdbc-core</artifactId>
<version>5.3.2</version>
</dependency>
<dependency>
<groupId>com.mysql</groupId>
<artifactId>mysql-connector-j</artifactId>
<version>8.0.33</version>
</dependency>

如上添加了shardingsphere-jdbc和mysql的依赖,shardingsphere-jdbc是项目的核心依赖,而mysql则是jdbc操作需要用到的依赖。添加了maven依赖之后我们可以先创建相关的数据库和表,创建数据库和表的sql如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;
/*!40101 SET NAMES utf8 */;
/*!50503 SET NAMES utf8mb4 */;
/*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */;
/*!40103 SET TIME_ZONE='+08:00' */;
/*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */;
/*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */;
/*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */;

-- 导出 ds_0 的数据库结构
DROP DATABASE IF EXISTS `ds_0`;
CREATE DATABASE IF NOT EXISTS `ds_0` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci */ /*!80016 DEFAULT ENCRYPTION='N' */;
USE `ds_0`;

-- 导出 表 ds_0.t_order 结构
DROP TABLE IF EXISTS `t_order`;
CREATE TABLE IF NOT EXISTS `t_order` (
`id` varchar(50) COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` varchar(50) COLLATE utf8mb4_general_ci DEFAULT NULL COMMENT '用户id',
`order_id` varchar(50) COLLATE utf8mb4_general_ci DEFAULT NULL COMMENT '订单id'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';

-- 导出 ds_1 的数据库结构
DROP DATABASE IF EXISTS `ds_1`;
CREATE DATABASE IF NOT EXISTS `ds_1` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci */ /*!80016 DEFAULT ENCRYPTION='N' */;
USE `ds_1`;

-- 导出 表 ds_1.t_order 结构
DROP TABLE IF EXISTS `t_order`;
CREATE TABLE IF NOT EXISTS `t_order` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL COMMENT '用户id',
`order_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL COMMENT '订单id'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';

-- 导出 ds_2 的数据库结构
DROP DATABASE IF EXISTS `ds_2`;
CREATE DATABASE IF NOT EXISTS `ds_2` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci */ /*!80016 DEFAULT ENCRYPTION='N' */;
USE `ds_2`;

-- 导出 表 ds_2.t_order 结构
DROP TABLE IF EXISTS `t_order`;
CREATE TABLE IF NOT EXISTS `t_order` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL COMMENT '用户id',
`order_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL COMMENT '订单id'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';

-- 导出 ds_3 的数据库结构
DROP DATABASE IF EXISTS `ds_3`;
CREATE DATABASE IF NOT EXISTS `ds_3` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci */ /*!80016 DEFAULT ENCRYPTION='N' */;
USE `ds_3`;

-- 导出 表 ds_3.t_order 结构
DROP TABLE IF EXISTS `t_order`;
CREATE TABLE IF NOT EXISTS `t_order` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL COMMENT '用户id',
`order_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL COMMENT '订单id'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';

-- 导出 ds_4 的数据库结构
DROP DATABASE IF EXISTS `ds_4`;
CREATE DATABASE IF NOT EXISTS `ds_4` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci */ /*!80016 DEFAULT ENCRYPTION='N' */;
USE `ds_4`;

-- 导出 表 ds_4.t_order 结构
DROP TABLE IF EXISTS `t_order`;
CREATE TABLE IF NOT EXISTS `t_order` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL COMMENT '用户id',
`order_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL COMMENT '订单id'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';

-- 导出 ds_5 的数据库结构
DROP DATABASE IF EXISTS `ds_5`;
CREATE DATABASE IF NOT EXISTS `ds_5` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci */ /*!80016 DEFAULT ENCRYPTION='N' */;
USE `ds_5`;

-- 导出 表 ds_5.t_order 结构
DROP TABLE IF EXISTS `t_order`;
CREATE TABLE IF NOT EXISTS `t_order` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL COMMENT '用户id',
`order_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL COMMENT '订单id'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';

/*!40103 SET TIME_ZONE=IFNULL(@OLD_TIME_ZONE, 'system') */;
/*!40101 SET SQL_MODE=IFNULL(@OLD_SQL_MODE, '') */;
/*!40014 SET FOREIGN_KEY_CHECKS=IFNULL(@OLD_FOREIGN_KEY_CHECKS, 1) */;
/*!40101 SET CHARACTER_SET_CLIENT=@OLD_CHARACTER_SET_CLIENT */;
/*!40111 SET SQL_NOTES=IFNULL(@OLD_SQL_NOTES, 1) */;

我们会创建6个数据库,分别为ds_0ds_5,并且会在每个数据库里面创建一个名叫t_order的表。

为了使用shardingsphere-jdbc,我们需要创建相应的jdbc连接和配置,因为shardingsphere-jdbc实现了jdbc的接口,所以我们可以像使用普通的jdbc一样使用shardingsphere-jdbc。创建shardingsphere-jdbc连接的代码如下

1
2
Class.forName("org.apache.shardingsphere.driver.ShardingSphereDriver");
Connection conn = DriverManager.getConnection("jdbc:shardingsphere:classpath:shardingsphere-config.yaml");

如上我们创建了一个shardingsphere-jdbc的连接,可以看到就是一个创建JDBC的过程。其中使用的SPI类是org.apache.shardingsphere.driver.ShardingSphereDriver,而具体的jdbcUrl则是一个文件地址shardingsphere-config.yaml,shardingsphere-jdbc的配置就保存在这个文件中。根据shardingsphere-jdbc的官方文档,其配置包含五大类:

  1. JDBC逻辑数据库名称
  2. 运行模式配置
  3. 数据源集合配置
  4. 规则集合配置
  5. 属性配置

shardingsphere-jdbc的配置支持Java代码和yaml文件,这里我们只介绍yaml文件,下面是一个简单的例子shardingsphere-config.yaml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
dataSources:
ds_0:
dataSourceClassName: com.zaxxer.hikari.HikariDataSource
driverClassName: com.mysql.jdbc.Driver
jdbcUrl: jdbc:mysql://127.0.0.1:3306/ds_0
username: root
password: 1234
ds_1:
dataSourceClassName: com.zaxxer.hikari.HikariDataSource
driverClassName: com.mysql.jdbc.Driver
jdbcUrl: jdbc:mysql://localhost:3306/ds_1
username: root
password: 1234
ds_2:
dataSourceClassName: com.zaxxer.hikari.HikariDataSource
driverClassName: com.mysql.jdbc.Driver
jdbcUrl: jdbc:mysql://127.0.0.1:3306/ds_2
username: root
password: 1234
ds_3:
dataSourceClassName: com.zaxxer.hikari.HikariDataSource
driverClassName: com.mysql.jdbc.Driver
jdbcUrl: jdbc:mysql://localhost:3306/ds_3
username: root
password: 1234
ds_4:
dataSourceClassName: com.zaxxer.hikari.HikariDataSource
driverClassName: com.mysql.jdbc.Driver
jdbcUrl: jdbc:mysql://127.0.0.1:3306/ds_4
username: root
password: 1234
ds_5:
dataSourceClassName: com.zaxxer.hikari.HikariDataSource
driverClassName: com.mysql.jdbc.Driver
jdbcUrl: jdbc:mysql://localhost:3306/ds_5
username: root
password: 1234
rules:
- !SHARDING
tables:
t_order:
actualDataNodes: ds_$->{0..5}.t_order
databaseStrategy:
standard:
shardingColumn: id
shardingAlgorithmName: testInline
keyGenerateStrategy:
column: id
keyGeneratorName: snowflake
shardingAlgorithms:
testInline:
type: INLINE
props:
algorithm-expression: ds_$->{id % 6}
keyGenerators:
snowflake:
type: SNOWFLAKE
props:
sql-show: true

如上配置了6个数据源分别是数据库ds_0ds_5props设置了打印sql语句,rules包含了表、分片算法和主键生成算法的配置。表设置中创建了一个逻辑表t_order,对应的真正数据库表是ds_0.t_orderds_5.t_order,数据库的使用策略是通过id进行分片,分片算法是testInline,表的id字段的生成算法为snowflake。分片算法中定义了算法testInline,它使用INLINE内置方式来对id取模并和ds_进行拼接,构成数据库名。字段生成算法中定义了类型为SNOWFLAKE的字段生成算法。

有了如上配置之后,我们就可以使用shardingsphere-jdbc了。以一个数据插入操作为例,在引入了maven依赖、创建了相关的数据库和表、定义了相关的shardingsphere-jdbc配置之后,我们就可以使用上面创建的conn字段实现数据插入了。

1
2
3
4
5
6
7
String sql = "INSERT INTO t_order (`user_id`, `order_id`) VALUES (?, ?)";
PreparedStatement ps = conn.prepareStatement(sql))
for (int i = 0; i < 20; i++) {
ps.setString(1, "userId");
ps.setString(2, "orderId");
ps.executeUpdate();
}

如上代码会创建一条数据并且随机根据snowflake算法生成一个id字段,并根据id字段的取模结果将数据保存到真正的数据库中去。更多的增删改查操作可在如下代码中看到:https://github.com/RitterHou/test-shardingsphere/tree/basic/src/main/java/com/nosuchfield/shardingsphere/data

SpringBoot集成MyBatis使用shardingsphere-jdbc

根据官方issue,目前shardingsphere-jdbc已经不再使用spring-boot-starter,而是直接使用jdbc实现相关功能。这种方式可以完美兼容JDBC的相关接口,因此可以简化很多已有项目的使用

在SpringBoot中使用ShardingSphere需要设置如下的pom配置,在这里我们使用MyBatis作为ORM框架。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.0.0</version>
</parent>

<groupId>com.nosuchfield</groupId>
<artifactId>test-shardingsphere</artifactId>
<version>1.0.0-SNAPSHOT</version>

<properties>
<java.version>17</java.version>
<maven.compiler.source>${java.version}</maven.compiler.source>
<maven.compiler.target>${java.version}</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>

<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.mybatis.spring.boot</groupId>
<artifactId>mybatis-spring-boot-starter</artifactId>
<version>3.0.2</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.28</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.12</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-jdbc</artifactId>
</dependency>
<dependency>
<groupId>org.apache.shardingsphere</groupId>
<artifactId>shardingsphere-jdbc-core</artifactId>
<version>5.3.2</version>
</dependency>
<dependency>
<groupId>com.mysql</groupId>
<artifactId>mysql-connector-j</artifactId>
<version>8.0.33</version>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.1.2</version>
<configuration>
<skipTests>true</skipTests>
</configuration>
</plugin>
</plugins>
</build>

</project>

SpringBoot的application.yml配置如下,这里配置的数据源驱动为ShardingSphereDriver,而url就是我们配置ShardingSphere属性的地方。除此之外,我们还配置了mybatis的SQL语句所对应xml文件的路径信息。

1
2
3
4
5
6
7
8
9
10
11
spring:
datasource:
url: jdbc:shardingsphere:classpath:shardingsphere/config.yaml
driver-class-name: org.apache.shardingsphere.driver.ShardingSphereDriver
application:
name: ShardingSphere

mybatis:
mapper-locations: classpath:mybatis/mapper/*.xml
configuration:
map-underscore-to-camel-case: true

接着我们配置ShardingSphere的配置信息config.yaml,这里的配置和上面简单使用的配置差不多,不再赘述了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
dataSources:
ds_0:
dataSourceClassName: com.zaxxer.hikari.HikariDataSource
driverClassName: com.mysql.jdbc.Driver
jdbcUrl: jdbc:mysql://127.0.0.1:3306/ds_0
username: root
password: 1234
ds_1:
dataSourceClassName: com.zaxxer.hikari.HikariDataSource
driverClassName: com.mysql.jdbc.Driver
jdbcUrl: jdbc:mysql://localhost:3306/ds_1
username: root
password: 1234
ds_2:
dataSourceClassName: com.zaxxer.hikari.HikariDataSource
driverClassName: com.mysql.jdbc.Driver
jdbcUrl: jdbc:mysql://127.0.0.1:3306/ds_2
username: root
password: 1234
ds_3:
dataSourceClassName: com.zaxxer.hikari.HikariDataSource
driverClassName: com.mysql.jdbc.Driver
jdbcUrl: jdbc:mysql://localhost:3306/ds_3
username: root
password: 1234
ds_4:
dataSourceClassName: com.zaxxer.hikari.HikariDataSource
driverClassName: com.mysql.jdbc.Driver
jdbcUrl: jdbc:mysql://127.0.0.1:3306/ds_4
username: root
password: 1234
ds_5:
dataSourceClassName: com.zaxxer.hikari.HikariDataSource
driverClassName: com.mysql.jdbc.Driver
jdbcUrl: jdbc:mysql://localhost:3306/ds_5
username: root
password: 1234
rules:
- !SHARDING
tables:
t_order:
actualDataNodes: ds_$->{0..5}.t_order_$->{1..2}
databaseStrategy:
standard:
shardingColumn: id
shardingAlgorithmName: databaseInline
tableStrategy:
standard:
shardingColumn: id
shardingAlgorithmName: tableInline
keyGenerateStrategy:
column: id
keyGeneratorName: snowflake
shardingAlgorithms:
databaseInline:
type: INLINE
props:
algorithm-expression: ds_$->{id % 6}
tableInline:
type: INLINE
props:
algorithm-expression: t_order_$->{id % 2 + 1}
keyGenerators:
snowflake:
type: SNOWFLAKE
props:
sql-show: true

接着我们定义一个订单模型Order,订单包含了一些属性信息

1
2
3
4
5
6
7
8
public class Order {
private String id;
private String orderId;
private Long userId;
private BigDecimal totalPrice;
private LocalDateTime createTime;
private LocalDateTime updateTime;
}

我们根据这个模型可以定义个MyBatis的Mapper,它包含了插入、查询的操作

1
2
3
4
5
6
7
@Mapper
public interface OrderMapper {
void insert(Order order);
List<Order> selectListByIds(@Param("idList") List<Long> idList);
@Select("SELECT * FROM t_order")
List<Order> getAllOrders();
}

其中getAllOrders方法通过注解实现了SQL的定义,而另外两个方法的SQL则在xml文件中进行实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.nosuchfield.shardingsphere.mapper.OrderMapper">
<insert id="insert" parameterType="com.nosuchfield.shardingsphere.model.Order">
INSERT INTO t_order(user_id, order_id, total_price, create_time, update_time)
VALUES(#{userId}, #{orderId}, #{totalPrice}, #{createTime}, #{updateTime})
</insert>
<select id="selectListByIds" resultType="com.nosuchfield.shardingsphere.model.Order">
SELECT order_id, user_id, total_price, state FROM t_order WHERE order_id IN
<foreach collection="idList" item="id" open="(" separator="," close=")">
#{id}
</foreach>
</select>
</mapper>

构建了如上的ShardingSphere和MyBatis的配置之后,我们可以创建相关的数据库和表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;
/*!40101 SET NAMES utf8 */;
/*!50503 SET NAMES utf8mb4 */;
/*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */;
/*!40103 SET TIME_ZONE='+08:00' */;
/*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */;
/*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */;
/*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */;


-- 导出 ds_0 的数据库结构
DROP DATABASE IF EXISTS `ds_0`;
CREATE DATABASE IF NOT EXISTS `ds_0` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci */ /*!80016 DEFAULT ENCRYPTION='N' */;
USE `ds_0`;

-- 导出 表 ds_0.t_order_1 结构
DROP TABLE IF EXISTS `t_order_1`;
CREATE TABLE IF NOT EXISTS `t_order_1` (
`id` varchar(50) COLLATE utf8mb4_general_ci DEFAULT NULL,
`order_id` varchar(50) COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` bigint DEFAULT NULL,
`total_price` decimal(20,6) DEFAULT NULL,
`create_time` datetime DEFAULT NULL,
`update_time` datetime DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';

-- 导出 表 ds_0.t_order_2 结构
DROP TABLE IF EXISTS `t_order_2`;
CREATE TABLE IF NOT EXISTS `t_order_2` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`order_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` bigint DEFAULT NULL,
`total_price` decimal(20,6) DEFAULT NULL,
`create_time` datetime DEFAULT NULL,
`update_time` datetime DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';


-- 导出 ds_1 的数据库结构
DROP DATABASE IF EXISTS `ds_1`;
CREATE DATABASE IF NOT EXISTS `ds_1` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci */ /*!80016 DEFAULT ENCRYPTION='N' */;
USE `ds_1`;

-- 导出 表 ds_1.t_order_1 结构
DROP TABLE IF EXISTS `t_order_1`;
CREATE TABLE IF NOT EXISTS `t_order_1` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`order_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` bigint DEFAULT NULL,
`total_price` decimal(20,6) DEFAULT NULL,
`create_time` datetime DEFAULT NULL,
`update_time` datetime DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';

-- 导出 表 ds_1.t_order_2 结构
DROP TABLE IF EXISTS `t_order_2`;
CREATE TABLE IF NOT EXISTS `t_order_2` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`order_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` bigint DEFAULT NULL,
`total_price` decimal(20,6) DEFAULT NULL,
`create_time` datetime DEFAULT NULL,
`update_time` datetime DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';


-- 导出 ds_2 的数据库结构
DROP DATABASE IF EXISTS `ds_2`;
CREATE DATABASE IF NOT EXISTS `ds_2` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci */ /*!80016 DEFAULT ENCRYPTION='N' */;
USE `ds_2`;

-- 导出 表 ds_2.t_order_1 结构
DROP TABLE IF EXISTS `t_order_1`;
CREATE TABLE IF NOT EXISTS `t_order_1` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`order_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` bigint DEFAULT NULL,
`total_price` decimal(20,6) DEFAULT NULL,
`create_time` datetime DEFAULT NULL,
`update_time` datetime DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';

-- 导出 表 ds_2.t_order_2 结构
DROP TABLE IF EXISTS `t_order_2`;
CREATE TABLE IF NOT EXISTS `t_order_2` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`order_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` bigint DEFAULT NULL,
`total_price` decimal(20,6) DEFAULT NULL,
`create_time` datetime DEFAULT NULL,
`update_time` datetime DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';


-- 导出 ds_3 的数据库结构
DROP DATABASE IF EXISTS `ds_3`;
CREATE DATABASE IF NOT EXISTS `ds_3` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci */ /*!80016 DEFAULT ENCRYPTION='N' */;
USE `ds_3`;

-- 导出 表 ds_3.t_order_1 结构
DROP TABLE IF EXISTS `t_order_1`;
CREATE TABLE IF NOT EXISTS `t_order_1` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`order_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` bigint DEFAULT NULL,
`total_price` decimal(20,6) DEFAULT NULL,
`create_time` datetime DEFAULT NULL,
`update_time` datetime DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';

-- 导出 表 ds_3.t_order_2 结构
DROP TABLE IF EXISTS `t_order_2`;
CREATE TABLE IF NOT EXISTS `t_order_2` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`order_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` bigint DEFAULT NULL,
`total_price` decimal(20,6) DEFAULT NULL,
`create_time` datetime DEFAULT NULL,
`update_time` datetime DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';


-- 导出 ds_4 的数据库结构
DROP DATABASE IF EXISTS `ds_4`;
CREATE DATABASE IF NOT EXISTS `ds_4` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci */ /*!80016 DEFAULT ENCRYPTION='N' */;
USE `ds_4`;

-- 导出 表 ds_4.t_order_1 结构
DROP TABLE IF EXISTS `t_order_1`;
CREATE TABLE IF NOT EXISTS `t_order_1` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`order_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` bigint DEFAULT NULL,
`total_price` decimal(20,6) DEFAULT NULL,
`create_time` datetime DEFAULT NULL,
`update_time` datetime DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';

-- 导出 表 ds_4.t_order_2 结构
DROP TABLE IF EXISTS `t_order_2`;
CREATE TABLE IF NOT EXISTS `t_order_2` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`order_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` bigint DEFAULT NULL,
`total_price` decimal(20,6) DEFAULT NULL,
`create_time` datetime DEFAULT NULL,
`update_time` datetime DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';


-- 导出 ds_5 的数据库结构
DROP DATABASE IF EXISTS `ds_5`;
CREATE DATABASE IF NOT EXISTS `ds_5` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci */ /*!80016 DEFAULT ENCRYPTION='N' */;
USE `ds_5`;

-- 导出 表 ds_5.t_order_1 结构
DROP TABLE IF EXISTS `t_order_1`;
CREATE TABLE IF NOT EXISTS `t_order_1` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`order_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` bigint DEFAULT NULL,
`total_price` decimal(20,6) DEFAULT NULL,
`create_time` datetime DEFAULT NULL,
`update_time` datetime DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';

-- 导出 表 ds_5.t_order_2 结构
DROP TABLE IF EXISTS `t_order_2`;
CREATE TABLE IF NOT EXISTS `t_order_2` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`order_id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`user_id` bigint DEFAULT NULL,
`total_price` decimal(20,6) DEFAULT NULL,
`create_time` datetime DEFAULT NULL,
`update_time` datetime DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci COMMENT='订单表';

/*!40103 SET TIME_ZONE=IFNULL(@OLD_TIME_ZONE, 'system') */;
/*!40101 SET SQL_MODE=IFNULL(@OLD_SQL_MODE, '') */;
/*!40014 SET FOREIGN_KEY_CHECKS=IFNULL(@OLD_FOREIGN_KEY_CHECKS, 1) */;
/*!40101 SET CHARACTER_SET_CLIENT=@OLD_CHARACTER_SET_CLIENT */;
/*!40111 SET SQL_NOTES=IFNULL(@OLD_SQL_NOTES, 1) */;

有了上面的数据库和表之后,我们就可以测试ShardingSphere的数据插入和查询了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
@Slf4j
@RunWith(SpringRunner.class)
@SpringBootTest(classes = Application.class)
public class OrderMapperTest {
@Autowired
private OrderMapper orderMapper;
@Test
public void testInsert() {
for (int i = 0; i < 30; i++) {
Order order = new Order();
order.setOrderId("66666666666");
order.setUserId(1L);
order.setTotalPrice(new BigDecimal((i + 1) * 5));
order.setCreateTime(LocalDateTime.now());
order.setUpdateTime(order.getCreateTime());
this.orderMapper.insert(order);
}
}
@Test
public void testQueryAll() {
List<Order> orders = orderMapper.getAllOrders();
orders.forEach(order -> log.info(order.toString()));
}
}

读写分离和数据脱敏

上面我们测试了ShardingSphere的数据分片功能,下面我们了解一下它的读写分离和数据脱敏。我们先在ds_0ds_1ds_2数据库中创建表t_user

1
2
3
4
5
6
CREATE TABLE `t_user` (
`id` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`name` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`phone` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL,
`address` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;

之后我们在ShardingSphere的rules属性下添加如下配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
- !READWRITE_SPLITTING
dataSources:
readwrite_ds:
staticStrategy:
writeDataSourceName: ds_0
readDataSourceNames:
- ds_1
- ds_2
loadBalancerName: random
loadBalancers:
random:
type: RANDOM
- !MASK
tables:
t_user:
columns:
id:
maskAlgorithm: md5_mask
phone:
maskAlgorithm: keep_first_n_last_m_mask
maskAlgorithms:
md5_mask:
type: MD5
keep_first_n_last_m_mask:
type: KEEP_FIRST_N_LAST_M
props:
first-n: 3
last-m: 4
replace-char: '*'

配置包含了写库ds_0和读库ds_1ds_2的配置,读库的负载均衡策略为随机(这里需要先设置ds_1ds_2自动同步ds_0的数据,详细过程可查看文章MySQL实现双服务器主从同步)。数据脱敏策略为对t_user的id字段进行md5脱敏,对phone字段保留前3位和后4位,剩下的部分用*替换。创建好了表和配置之后,我们设置User的model

1
2
3
4
5
6
public class User {
private String id;
private String name;
private String phone;
private String address;
}

以及mapper

1
2
3
4
5
6
7
@Mapper
public interface UserMapper {
@Insert("INSERT INTO t_user(id, name, phone, address) VALUES (#{id}, #{name}, #{phone}, #{address})")
void save(User user);
@Select("SELECT * FROM t_user")
List<User> query();
}

之后我们测试上面的操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
@Slf4j
public class UserMapperTest extends BaseTest {
@Autowired
private UserMapper userMapper;
@Test
public void testInsert() {
userMapper.save(User.builder()
.id("888")
.name("小明")
.phone("13866688888")
.address("江苏省南京市").build());
}
@Test
public void testQuery() {
List<User> users = userMapper.query();
log.info(users.toString());
}
}

我们先插入数据,随后到从库中查询数据,得到结果如下

ShardingSphere-SQL:73 Logic SQL: SELECT * FROM t_userShardingSphere-SQL:73 Actual SQL: ds_1 ::: SELECT * FROM t_usercom.nosuchfield.shardingsphere.UserMapperTest:30 [User(id=0a113ef6b61820daa5611c870ed8d5ee, name=小明, phone=138****8888, address=江苏省南京市)]

可以看到数据插入到了主库中,随后从从库ds_1中查询出了相关的数据,并且对id和phone字段的数据进行了脱敏操作,id字段被转化为了MD5的结果,而phone的中间4位被星号替代了。

数据加密

数据加密可以保证我们存到数据库中的数据都是经过加密的,和数据脱敏刚好反过来。首先我们创建表t_member

1
2
3
4
CREATE TABLE `t_member` (
`name` varchar(50) COLLATE utf8mb4_general_ci DEFAULT NULL,
`password` varchar(50) COLLATE utf8mb4_general_ci DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;

随后我们配置ShardingSphere的数据加密配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
- !ENCRYPT
tables:
t_member:
columns:
name:
cipherColumn: name
encryptorName: name_encryptor
password:
cipherColumn: password
encryptorName: pwd_encryptor
queryWithCipherColumn: true
encryptors:
name_encryptor:
type: AES
props:
aes-key-value: 123abc
pwd_encryptor:
type: MD5
props:
salt: nosuchfield

我们将表t_member的name字段使用name_encryptor的加密方式进行加密,加密之后的字段名仍然叫做name,name_encryptor的配置在encryptors中可以看到,使用了AES加密算法并设置key为123abc。类似的,password的加密方式为MD5,在计算MD5的时候加盐nosuchfield

随后我们创建model

1
2
3
4
public class Member {
private String name;
private String password;
}

和mapper

1
2
3
4
5
6
7
@Mapper
public interface MemberMapper {
@Insert("INSERT INTO t_member(name, password) VALUES (#{name}, #{password})")
void save(Member member);
@Select("SELECT * FROM t_member")
List<Member> query();
}

并测试写入和读取

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
@Slf4j
public class MemberMapperTest extends BaseTest {
@Autowired
private MemberMapper memberMapper;
@Test
public void testSave() {
memberMapper.save(Member.builder()
.name("张三")
.password("123456").build());
}
@Test
public void testQuery() {
List<Member> members = memberMapper.query();
log.info(members.toString());
}
}

在插入了数据{"name": "张三", "password": "123456"}之后,可以到数据库中查看插入的数据如下

PS C:\Program Files\MySQL\MySQL Server 8.0\bin> ./mysql -u root -pmysql> select * from t_member;+--------------------------+----------------------------------+| name                     | password                         |+--------------------------+----------------------------------+| Fod6ouOanqNvHlTdBsx1Lw== | 47514eed77109a04ce4c9f9931d0c5ec |+--------------------------+----------------------------------+1 row in set (0.00 sec)

可以看到name和password在存储到数据库的时候都加密了。随后我们执行测试代码中的查询逻辑,可以看到结果如下,name又通过AES算法解密成功,而password因为使用的是MD5算法就无法解密了

com.nosuchfield.shardingsphere.MemberMapperTest:27 [Member(name=张三, password=47514eed77109a04ce4c9f9931d0c5ec)]

本节使用到的代码:https://github.com/RitterHou/test-shardingsphere

参考

官方文档

❌