diff options
Diffstat (limited to 'script/doctor.lua')
-rw-r--r-- | script/doctor.lua | 380 |
1 files changed, 380 insertions, 0 deletions
diff --git a/script/doctor.lua b/script/doctor.lua new file mode 100644 index 00000000..08ec69cf --- /dev/null +++ b/script/doctor.lua @@ -0,0 +1,380 @@ +local type = type +local next = next +local ipairs = ipairs +local rawget = rawget +local pcall = pcall +local getregistry = debug.getregistry +local getmetatable = debug.getmetatable +local getupvalue = debug.getupvalue +local getuservalue = debug.getuservalue +local getlocal = debug.getlocal +local getinfo = debug.getinfo +local maxinterger = math.maxinteger +local mathType = math.type +local tableConcat = table.concat +local _G = _G +local registry = getregistry() +local tableSort = table.sort + +_ENV = nil + +local m = {} + +local function getTostring(obj) + local mt = getmetatable(obj) + if not mt then + return nil + end + local toString = rawget(mt, '__tostring') + if not toString then + return nil + end + local suc, str = pcall(toString, obj) + if not suc then + return nil + end + if type(str) ~= 'string' then + return nil + end + return str +end + +local function formatName(obj) + local tp = type(obj) + if tp == 'nil' then + return 'nil:nil' + elseif tp == 'boolean' then + if obj == true then + return 'boolean:true' + else + return 'boolean:false' + end + elseif tp == 'number' then + if mathType(obj) == 'integer' then + return ('number:%d'):format(obj) + else + -- 如果浮点数可以完全表示为整数,那么就转换为整数 + local str = ('%.10f'):format(obj):gsub('%.?[0]+$', '') + if str:find('.', 1, true) then + -- 如果浮点数不能表示为整数,那么再加上它的精确表示法 + str = ('%s(%q)'):format(str, obj) + end + return 'number:' .. str + end + elseif tp == 'string' then + local str = ('%q'):format(obj) + if #str > 100 then + local new = ('%s...(len=%d)'):format(str:sub(1, 100), #str) + if #new < #str then + str = new + end + end + return 'string:' .. str + elseif tp == 'function' then + local info = getinfo(obj, 'S') + if info.what == 'c' then + return ('function:%p(C)'):format(obj) + elseif info.what == 'main' then + return ('function:%p(main)'):format(obj) + else + return ('function:%p(%s:%d-%d)'):format(obj, info.source, info.linedefined, info.lastlinedefined) + end + elseif tp == 'table' then + local id = getTostring(obj) + if not id then + if obj == _G then + id = '_G' + elseif obj == registry then + id = 'registry' + end + end + if id then + return ('table:%p(%s)'):format(obj, id) + else + return ('table:%p'):format(obj) + end + elseif tp == 'userdata' then + local id = getTostring(obj) + if id then + return ('userdata:%p(%s)'):format(obj, id) + else + return ('userdata:%p'):format(obj) + end + else + return ('%s:%p'):format(tp, obj) + end +end + +--- 内存快照 +---@return table +function m.snapshot() + local mark = {} + local find + + local function findTable(t, result) + result = result or {} + local mt = getmetatable(t) + local wk, wv + if mt then + local mode = rawget(mt, '__mode') + if type(mode) == 'string' then + if mode:find('k', 1, true) then + wk = true + end + if mode:find('v', 1, true) then + wv = true + end + end + end + for k, v in next, t do + if not wk then + local keyInfo = find(k) + if keyInfo then + result[#result+1] = { + type = 'key', + name = formatName(k), + info = keyInfo, + } + end + end + if not wv then + local valueInfo = find(v) + if valueInfo then + result[#result+1] = { + type = 'field', + name = formatName(k) .. '|' .. formatName(v), + info = valueInfo, + } + end + end + end + local MTInfo = find(getmetatable(t)) + if MTInfo then + result[#result+1] = { + type = 'metatable', + name = '', + info = MTInfo, + } + end + if #result == 0 then + return nil + end + return result + end + + local function findFunction(f, result, trd, stack) + result = result or {} + for i = 1, maxinterger do + local n, v = getupvalue(f, i) + if not n then + break + end + local valueInfo = find(v) + if valueInfo then + result[#result+1] = { + type = 'upvalue', + name = n, + info = valueInfo, + } + end + end + if trd then + for i = 1, maxinterger do + local n, l = getlocal(trd, stack, i) + if not n then + break + end + local valueInfo = find(l) + if valueInfo then + result[#result+1] = { + type = 'local', + name = n, + info = valueInfo, + } + end + end + end + if #result == 0 then + return nil + end + return result + end + + local function findUserData(u, result) + result = result or {} + for i = 1, maxinterger do + local v, b = getuservalue(u, i) + if not b then + break + end + local valueInfo = find(v) + if valueInfo then + result[#result+1] = { + type = 'uservalue', + name = formatName(i), + info = valueInfo, + } + end + end + local MTInfo = find(getmetatable(u)) + if MTInfo then + result[#result+1] = { + type = 'metatable', + name = '', + info = MTInfo, + } + end + if #result == 0 then + return nil + end + return result + end + + local function findThread(trd, result) + -- 不查找主线程,主线程一定是临时的(视为弱引用) + if trd == registry[1] then + return nil + end + result = result or {} + + for i = 1, maxinterger do + local info = getinfo(trd, i, 'Sf') + if not info then + break + end + local funcInfo = find(info.func, trd, i) + if funcInfo then + result[#result+1] = { + type = 'stack', + name = i .. '@' .. formatName(info.func), + info = funcInfo, + } + end + end + + if #result == 0 then + return nil + end + return result + end + + function find(obj, trd, stack) + if mark[obj] then + return mark[obj] + end + local tp = type(obj) + if tp == 'table' then + mark[obj] = {} + mark[obj] = findTable(obj, mark[obj]) + elseif tp == 'function' then + mark[obj] = {} + mark[obj] = findFunction(obj, mark[obj], trd, stack) + elseif tp == 'userdata' then + mark[obj] = {} + mark[obj] = findUserData(obj, mark[obj]) + elseif tp == 'thread' then + mark[obj] = {} + mark[obj] = findThread(obj, mark[obj]) + else + return nil + end + if mark[obj] then + mark[obj].object = obj + end + return mark[obj] + end + + return { + name = formatName(registry), + type = 'root', + info = find(registry), + } +end + +--- 寻找对象的引用 +---@return string +function m.catch(...) + local targets = {} + for _, target in ipairs {...} do + targets[target] = true + end + local report = m.snapshot() + local path = {} + local result = {} + local mark = {} + + local function push() + result[#result+1] = tableConcat(path, ' => ') + end + + local function search(t) + path[#path+1] = ('(%s)%s'):format(t.type, t.name) + local addTarget + if targets[t.info.object] then + targets[t.info.object] = nil + addTarget = t.info.object + push(t) + end + if not mark[t.info] then + mark[t.info] = true + for _, obj in ipairs(t.info) do + search(obj) + end + end + path[#path] = nil + if addTarget then + targets[addTarget] = true + end + end + + search(report) + + return result +end + +--- 生成一个报告 +---@return string +function m.report() + local snapshot = m.snapshot() + local cache = {} + local mark = {} + + local function scan(t) + local obj = t.info.object + local tp = type(obj) + if tp == 'table' + or tp == 'userdata' + or tp == 'function' + or tp == 'string' + or tp == 'thread' then + local point = ('%p'):format(obj) + if not cache[point] then + cache[point] = { + point = point, + count = 0, + name = formatName(obj), + } + end + cache[point].count = cache[point].count + 1 + end + if not mark[t.info] then + mark[t.info] = true + for _, child in ipairs(t.info) do + scan(child) + end + end + end + + scan(snapshot) + + local list = {} + for _, info in next, cache do + list[#list+1] = info + end + tableSort(list, function (a, b) + return a.name < b.name + end) + return list +end + +return m |