summaryrefslogtreecommitdiff
path: root/script
diff options
context:
space:
mode:
author最萌小汐 <sumneko@hotmail.com>2021-07-14 15:03:01 +0800
committer最萌小汐 <sumneko@hotmail.com>2021-07-14 15:03:01 +0800
commitf6cc2f276a6113404722e0a9dae952ed38d78c75 (patch)
treef3969d54e07da507af2dbbd6d877f4237eb9e505 /script
parent18ae98ec2c658b6ea13046132d314604a8ad7917 (diff)
downloadlua-language-server-f6cc2f276a6113404722e0a9dae952ed38d78c75.zip
update doctor
Diffstat (limited to 'script')
-rw-r--r--script/doctor.lua365
1 files changed, 283 insertions, 82 deletions
diff --git a/script/doctor.lua b/script/doctor.lua
index 08ec69cf..91a7e4b8 100644
--- a/script/doctor.lua
+++ b/script/doctor.lua
@@ -1,24 +1,66 @@
local type = type
local next = next
+local pairs = pairs
local ipairs = ipairs
local rawget = rawget
+local rawset = rawset
local pcall = pcall
+local tostring = tostring
+local select = select
+local stderr = io.stderr
+local sformat = string.format
local getregistry = debug.getregistry
local getmetatable = debug.getmetatable
local getupvalue = debug.getupvalue
-local getuservalue = debug.getuservalue
+---@diagnostic disable-next-line: deprecated
+local getuservalue = debug.getuservalue or debug.getfenv
local getlocal = debug.getlocal
local getinfo = debug.getinfo
-local maxinterger = math.maxinteger
+local maxinterger = 10000
local mathType = math.type
-local tableConcat = table.concat
local _G = _G
local registry = getregistry()
-local tableSort = table.sort
_ENV = nil
-local m = {}
+local hasPoint = pcall(sformat, '%p', _G)
+local multiUserValue = not pcall(getuservalue, stderr, '')
+
+local function getPoint(obj)
+ if hasPoint then
+ return ('%p'):format(obj)
+ else
+ local mt = getmetatable(obj)
+ local ts
+ if mt then
+ ts = rawget(mt, '__tostring')
+ if ts then
+ rawset(mt, '__tostring', nil)
+ end
+ end
+ local name = tostring(obj)
+ if ts then
+ rawset(mt, '__tostring', ts)
+ end
+ return name:match(': (.+)')
+ end
+end
+
+local function formatObject(obj, tp, ext)
+ local text = ('%s:%s'):format(tp, getPoint(obj))
+ if ext then
+ text = ('%s(%s)'):format(text, ext)
+ end
+ return text
+end
+
+local function isInteger(obj)
+ if mathType then
+ return mathType(obj) == 'integer'
+ else
+ return obj % 1 == 0
+ end
+end
local function getTostring(obj)
local mt = getmetatable(obj)
@@ -50,7 +92,7 @@ local function formatName(obj)
return 'boolean:false'
end
elseif tp == 'number' then
- if mathType(obj) == 'integer' then
+ if isInteger(obj) then
return ('number:%d'):format(obj)
else
-- 如果浮点数可以完全表示为整数,那么就转换为整数
@@ -73,11 +115,11 @@ local function formatName(obj)
elseif tp == 'function' then
local info = getinfo(obj, 'S')
if info.what == 'c' then
- return ('function:%p(C)'):format(obj)
+ return formatObject(obj, 'function', 'C')
elseif info.what == 'main' then
- return ('function:%p(main)'):format(obj)
+ return formatObject(obj, 'function', 'main')
else
- return ('function:%p(%s:%d-%d)'):format(obj, info.source, info.linedefined, info.lastlinedefined)
+ return formatObject(obj, 'function', ('%s:%d-%d'):format(info.source, info.linedefined, info.lastlinedefined))
end
elseif tp == 'table' then
local id = getTostring(obj)
@@ -89,27 +131,63 @@ local function formatName(obj)
end
end
if id then
- return ('table:%p(%s)'):format(obj, id)
+ return formatObject(obj, 'table', id)
else
- return ('table:%p'):format(obj)
+ return formatObject(obj, 'table')
end
elseif tp == 'userdata' then
local id = getTostring(obj)
if id then
- return ('userdata:%p(%s)'):format(obj, id)
+ return formatObject(obj, 'userdata', id)
else
- return ('userdata:%p'):format(obj)
+ return formatObject(obj, 'userdata')
end
else
- return ('%s:%p'):format(tp, obj)
+ return formatObject(obj, tp)
+ end
+end
+
+local _private = {}
+
+---@generic T
+---@param o T
+---@return T
+local function private(o)
+ if not o then
+ return nil
end
+ _private[o] = true
+ return o
end
---- 内存快照
+local m = private {}
+--- 获取内存快照,生成一个内部数据结构。
+--- 一般不用这个API,改用 report 或 catch。
---@return table
-function m.snapshot()
- local mark = {}
+m.snapshot = private(function ()
+ if m._lastCache then
+ return m._lastCache
+ end
+
+ local exclude = {}
+ if m._exclude then
+ for _, o in ipairs(m._exclude) do
+ exclude[o] = true
+ end
+ end
+ local function private(o)
+ if not o then
+ return nil
+ end
+ exclude[o] = true
+ return o
+ end
+
+ private(exclude)
+
local find
+ local mark = private {}
+
local function findTable(t, result)
result = result or {}
@@ -130,7 +208,7 @@ function m.snapshot()
if not wk then
local keyInfo = find(k)
if keyInfo then
- result[#result+1] = {
+ result[#result+1] = private {
type = 'key',
name = formatName(k),
info = keyInfo,
@@ -140,7 +218,7 @@ function m.snapshot()
if not wv then
local valueInfo = find(v)
if valueInfo then
- result[#result+1] = {
+ result[#result+1] = private {
type = 'field',
name = formatName(k) .. '|' .. formatName(v),
info = valueInfo,
@@ -150,19 +228,16 @@ function m.snapshot()
end
local MTInfo = find(getmetatable(t))
if MTInfo then
- result[#result+1] = {
+ result[#result+1] = private {
type = 'metatable',
name = '',
info = MTInfo,
}
end
- if #result == 0 then
- return nil
- end
return result
end
- local function findFunction(f, result, trd, stack)
+ local function findFunction(f, result)
result = result or {}
for i = 1, maxinterger do
local n, v = getupvalue(f, i)
@@ -171,45 +246,27 @@ function m.snapshot()
end
local valueInfo = find(v)
if valueInfo then
- result[#result+1] = {
+ result[#result+1] = private {
type = 'upvalue',
name = n,
info = valueInfo,
}
end
end
- if trd then
- for i = 1, maxinterger do
- local n, l = getlocal(trd, stack, i)
- if not n then
- break
- end
- local valueInfo = find(l)
- if valueInfo then
- result[#result+1] = {
- type = 'local',
- name = n,
- info = valueInfo,
- }
- end
- end
- end
- if #result == 0 then
- return nil
- end
return result
end
local function findUserData(u, result)
result = result or {}
- for i = 1, maxinterger do
+ local maxUserValue = multiUserValue and maxinterger or 1
+ for i = 1, maxUserValue do
local v, b = getuservalue(u, i)
if not b then
break
end
local valueInfo = find(v)
if valueInfo then
- result[#result+1] = {
+ result[#result+1] = private {
type = 'uservalue',
name = formatName(i),
info = valueInfo,
@@ -218,7 +275,7 @@ function m.snapshot()
end
local MTInfo = find(getmetatable(u))
if MTInfo then
- result[#result+1] = {
+ result[#result+1] = private {
type = 'metatable',
name = '',
info = MTInfo,
@@ -232,19 +289,75 @@ function m.snapshot()
local function findThread(trd, result)
-- 不查找主线程,主线程一定是临时的(视为弱引用)
- if trd == registry[1] then
+ if m._ignoreMainThread and trd == registry[1] then
return nil
end
- result = result or {}
+ result = result or private {}
for i = 1, maxinterger do
local info = getinfo(trd, i, 'Sf')
if not info then
break
end
- local funcInfo = find(info.func, trd, i)
+ local funcInfo = find(info.func)
+ if funcInfo then
+ for ln = 1, maxinterger do
+ local n, l = getlocal(trd, i, ln)
+ if not n then
+ break
+ end
+ local valueInfo = find(l)
+ if valueInfo then
+ funcInfo[#funcInfo+1] = private {
+ type = 'local',
+ name = n,
+ info = valueInfo,
+ }
+ end
+ end
+ result[#result+1] = private {
+ type = 'stack',
+ name = i .. '@' .. formatName(info.func),
+ info = funcInfo,
+ }
+ end
+ end
+
+ if #result == 0 then
+ return nil
+ end
+ return result
+ end
+
+ local function findMainThread()
+ -- 不查找主线程,主线程一定是临时的(视为弱引用)
+ if m._ignoreMainThread then
+ return nil
+ end
+ local result = private {}
+
+ for i = 1, maxinterger do
+ local info = getinfo(i, 'Sf')
+ if not info then
+ break
+ end
+ local funcInfo = find(info.func)
if funcInfo then
- result[#result+1] = {
+ for ln = 1, maxinterger do
+ local n, l = getlocal(i, ln)
+ if not n then
+ break
+ end
+ local valueInfo = find(l)
+ if valueInfo then
+ funcInfo[#funcInfo+1] = private {
+ type = 'local',
+ name = n,
+ info = valueInfo,
+ }
+ end
+ end
+ result[#result+1] = private {
type = 'stack',
name = i .. '@' .. formatName(info.func),
info = funcInfo,
@@ -258,22 +371,25 @@ function m.snapshot()
return result
end
- function find(obj, trd, stack)
+ function find(obj)
if mark[obj] then
return mark[obj]
end
+ if exclude[obj] or _private[obj] then
+ return nil
+ end
local tp = type(obj)
if tp == 'table' then
- mark[obj] = {}
+ mark[obj] = private {}
mark[obj] = findTable(obj, mark[obj])
elseif tp == 'function' then
- mark[obj] = {}
- mark[obj] = findFunction(obj, mark[obj], trd, stack)
+ mark[obj] = private {}
+ mark[obj] = findFunction(obj, mark[obj])
elseif tp == 'userdata' then
- mark[obj] = {}
+ mark[obj] = private {}
mark[obj] = findUserData(obj, mark[obj])
elseif tp == 'thread' then
- mark[obj] = {}
+ mark[obj] = private {}
mark[obj] = findThread(obj, mark[obj])
else
return nil
@@ -284,36 +400,71 @@ function m.snapshot()
return mark[obj]
end
- return {
+ -- TODO: Lua 5.1中,主线程与_G都不在注册表里
+ local result = private {
name = formatName(registry),
type = 'root',
info = find(registry),
}
-end
+ if not registry[1] then
+ result.info[#result.info+1] = private {
+ type = 'thread',
+ name = 'main',
+ info = findMainThread(),
+ }
+ end
+ if not registry[2] then
+ result.info[#result.info+1] = private {
+ type = '_G',
+ name = '_G',
+ info = find(_G),
+ }
+ end
+ if m._cache then
+ m._lastCache = result
+ end
+ return result
+end)
---- 寻找对象的引用
----@return string
-function m.catch(...)
+--- 遍历虚拟机,寻找对象的引用。
+--- 输入既可以是对象实体,也可以是对象的描述(从其他接口的返回值中复制过来)。
+--- 返回字符串数组的数组,每个字符串描述了如何从根节点引用到指定的对象。
+--- 可以同时查找多个对象。
+---@return string[][]
+m.catch = private(function (...)
local targets = {}
- for _, target in ipairs {...} do
- targets[target] = true
+ for i = 1, select('#', ...) do
+ local target = select(i, ...)
+ if target ~= nil then
+ targets[target] = true
+ end
end
local report = m.snapshot()
- local path = {}
+ local path = {}
local result = {}
- local mark = {}
+ local mark = {}
local function push()
- result[#result+1] = tableConcat(path, ' => ')
+ local resultPath = {}
+ for i = 1, #path do
+ resultPath[i] = path[i]
+ end
+ result[#result+1] = resultPath
end
local function search(t)
path[#path+1] = ('(%s)%s'):format(t.type, t.name)
local addTarget
+ local point = getPoint(t.info.object)
if targets[t.info.object] then
targets[t.info.object] = nil
addTarget = t.info.object
- push(t)
+ push()
+ end
+ if targets[point] then
+ targets[point] = nil
+ addTarget = point
+ push()
end
if not mark[t.info] then
mark[t.info] = true
@@ -330,11 +481,14 @@ function m.catch(...)
search(report)
return result
-end
+end)
---- 生成一个报告
----@return string
-function m.report()
+---@alias report {point: string, count: integer, name: string, childs: integer}
+
+--- 生成一个内存快照的报告。
+--- 你应当将其输出到一个文件里再查看。
+---@return report[]
+m.report = private(function ()
local snapshot = m.snapshot()
local cache = {}
local mark = {}
@@ -347,12 +501,13 @@ function m.report()
or tp == 'function'
or tp == 'string'
or tp == 'thread' then
- local point = ('%p'):format(obj)
+ local point = getPoint(obj)
if not cache[point] then
cache[point] = {
- point = point,
- count = 0,
- name = formatName(obj),
+ point = point,
+ count = 0,
+ name = formatName(obj),
+ childs = #t.info,
}
end
cache[point].count = cache[point].count + 1
@@ -366,15 +521,61 @@ function m.report()
end
scan(snapshot)
-
local list = {}
- for _, info in next, cache do
+ for _, info in pairs(cache) do
list[#list+1] = info
end
- tableSort(list, function (a, b)
- return a.name < b.name
- end)
return list
-end
+end)
+
+--- 在进行快照相关操作时排除掉的对象。
+--- 你可以用这个功能排除掉一些数据表。
+m.exclude = private(function (...)
+ m._exclude = {...}
+end)
+
+--- 比较2个报告
+---@return string
+m.compare = private(function (old, new)
+ local newHash = {}
+ local ret = {}
+ for _, info in ipairs(new) do
+ newHash[info.point] = info
+ end
+ for _, info in ipairs(old) do
+ if newHash[info.point] then
+ ret[#ret + 1] = {
+ old = info,
+ new = newHash[info.point]
+ }
+ end
+ end
+ return ret
+end)
+
+--- 是否忽略主线程的栈
+---@param flag boolean
+m.ignoreMainThread = private(function (flag)
+ m._ignoreMainThread = flag
+end)
+
+--- 是否启用缓存,启用后会始终使用第一次查找的结果,
+--- 适用于连续查找引用。如果想要查找新的引用需要先关闭缓存。
+---@param flag boolean
+m.enableCache = private(function (flag)
+ if flag then
+ m._cache = true
+ else
+ m._cache = false
+ m._lastCache = nil
+ end
+end)
+
+--- 立即清除缓存
+m.flushCache = private(function ()
+ m._lastCache = nil
+end)
+
+private(getinfo(1, 'f').func)
return m