From f6cc2f276a6113404722e0a9dae952ed38d78c75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=80=E8=90=8C=E5=B0=8F=E6=B1=90?= Date: Wed, 14 Jul 2021 15:03:01 +0800 Subject: update doctor --- script/doctor.lua | 365 ++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 283 insertions(+), 82 deletions(-) (limited to 'script/doctor.lua') diff --git a/script/doctor.lua b/script/doctor.lua index 08ec69cf..91a7e4b8 100644 --- a/script/doctor.lua +++ b/script/doctor.lua @@ -1,24 +1,66 @@ local type = type local next = next +local pairs = pairs local ipairs = ipairs local rawget = rawget +local rawset = rawset local pcall = pcall +local tostring = tostring +local select = select +local stderr = io.stderr +local sformat = string.format local getregistry = debug.getregistry local getmetatable = debug.getmetatable local getupvalue = debug.getupvalue -local getuservalue = debug.getuservalue +---@diagnostic disable-next-line: deprecated +local getuservalue = debug.getuservalue or debug.getfenv local getlocal = debug.getlocal local getinfo = debug.getinfo -local maxinterger = math.maxinteger +local maxinterger = 10000 local mathType = math.type -local tableConcat = table.concat local _G = _G local registry = getregistry() -local tableSort = table.sort _ENV = nil -local m = {} +local hasPoint = pcall(sformat, '%p', _G) +local multiUserValue = not pcall(getuservalue, stderr, '') + +local function getPoint(obj) + if hasPoint then + return ('%p'):format(obj) + else + local mt = getmetatable(obj) + local ts + if mt then + ts = rawget(mt, '__tostring') + if ts then + rawset(mt, '__tostring', nil) + end + end + local name = tostring(obj) + if ts then + rawset(mt, '__tostring', ts) + end + return name:match(': (.+)') + end +end + +local function formatObject(obj, tp, ext) + local text = ('%s:%s'):format(tp, getPoint(obj)) + if ext then + text = ('%s(%s)'):format(text, ext) + end + return text +end + +local function isInteger(obj) + if mathType then + return mathType(obj) == 'integer' + else + return obj % 1 == 0 + end +end local function getTostring(obj) local mt = getmetatable(obj) @@ -50,7 +92,7 @@ local function formatName(obj) return 'boolean:false' end elseif tp == 'number' then - if mathType(obj) == 'integer' then + if isInteger(obj) then return ('number:%d'):format(obj) else -- 如果浮点数可以完全表示为整数,那么就转换为整数 @@ -73,11 +115,11 @@ local function formatName(obj) elseif tp == 'function' then local info = getinfo(obj, 'S') if info.what == 'c' then - return ('function:%p(C)'):format(obj) + return formatObject(obj, 'function', 'C') elseif info.what == 'main' then - return ('function:%p(main)'):format(obj) + return formatObject(obj, 'function', 'main') else - return ('function:%p(%s:%d-%d)'):format(obj, info.source, info.linedefined, info.lastlinedefined) + return formatObject(obj, 'function', ('%s:%d-%d'):format(info.source, info.linedefined, info.lastlinedefined)) end elseif tp == 'table' then local id = getTostring(obj) @@ -89,27 +131,63 @@ local function formatName(obj) end end if id then - return ('table:%p(%s)'):format(obj, id) + return formatObject(obj, 'table', id) else - return ('table:%p'):format(obj) + return formatObject(obj, 'table') end elseif tp == 'userdata' then local id = getTostring(obj) if id then - return ('userdata:%p(%s)'):format(obj, id) + return formatObject(obj, 'userdata', id) else - return ('userdata:%p'):format(obj) + return formatObject(obj, 'userdata') end else - return ('%s:%p'):format(tp, obj) + return formatObject(obj, tp) + end +end + +local _private = {} + +---@generic T +---@param o T +---@return T +local function private(o) + if not o then + return nil end + _private[o] = true + return o end ---- 内存快照 +local m = private {} +--- 获取内存快照,生成一个内部数据结构。 +--- 一般不用这个API,改用 report 或 catch。 ---@return table -function m.snapshot() - local mark = {} +m.snapshot = private(function () + if m._lastCache then + return m._lastCache + end + + local exclude = {} + if m._exclude then + for _, o in ipairs(m._exclude) do + exclude[o] = true + end + end + local function private(o) + if not o then + return nil + end + exclude[o] = true + return o + end + + private(exclude) + local find + local mark = private {} + local function findTable(t, result) result = result or {} @@ -130,7 +208,7 @@ function m.snapshot() if not wk then local keyInfo = find(k) if keyInfo then - result[#result+1] = { + result[#result+1] = private { type = 'key', name = formatName(k), info = keyInfo, @@ -140,7 +218,7 @@ function m.snapshot() if not wv then local valueInfo = find(v) if valueInfo then - result[#result+1] = { + result[#result+1] = private { type = 'field', name = formatName(k) .. '|' .. formatName(v), info = valueInfo, @@ -150,19 +228,16 @@ function m.snapshot() end local MTInfo = find(getmetatable(t)) if MTInfo then - result[#result+1] = { + result[#result+1] = private { type = 'metatable', name = '', info = MTInfo, } end - if #result == 0 then - return nil - end return result end - local function findFunction(f, result, trd, stack) + local function findFunction(f, result) result = result or {} for i = 1, maxinterger do local n, v = getupvalue(f, i) @@ -171,45 +246,27 @@ function m.snapshot() end local valueInfo = find(v) if valueInfo then - result[#result+1] = { + result[#result+1] = private { type = 'upvalue', name = n, info = valueInfo, } end end - if trd then - for i = 1, maxinterger do - local n, l = getlocal(trd, stack, i) - if not n then - break - end - local valueInfo = find(l) - if valueInfo then - result[#result+1] = { - type = 'local', - name = n, - info = valueInfo, - } - end - end - end - if #result == 0 then - return nil - end return result end local function findUserData(u, result) result = result or {} - for i = 1, maxinterger do + local maxUserValue = multiUserValue and maxinterger or 1 + for i = 1, maxUserValue do local v, b = getuservalue(u, i) if not b then break end local valueInfo = find(v) if valueInfo then - result[#result+1] = { + result[#result+1] = private { type = 'uservalue', name = formatName(i), info = valueInfo, @@ -218,7 +275,7 @@ function m.snapshot() end local MTInfo = find(getmetatable(u)) if MTInfo then - result[#result+1] = { + result[#result+1] = private { type = 'metatable', name = '', info = MTInfo, @@ -232,19 +289,75 @@ function m.snapshot() local function findThread(trd, result) -- 不查找主线程,主线程一定是临时的(视为弱引用) - if trd == registry[1] then + if m._ignoreMainThread and trd == registry[1] then return nil end - result = result or {} + result = result or private {} for i = 1, maxinterger do local info = getinfo(trd, i, 'Sf') if not info then break end - local funcInfo = find(info.func, trd, i) + local funcInfo = find(info.func) + if funcInfo then + for ln = 1, maxinterger do + local n, l = getlocal(trd, i, ln) + if not n then + break + end + local valueInfo = find(l) + if valueInfo then + funcInfo[#funcInfo+1] = private { + type = 'local', + name = n, + info = valueInfo, + } + end + end + result[#result+1] = private { + type = 'stack', + name = i .. '@' .. formatName(info.func), + info = funcInfo, + } + end + end + + if #result == 0 then + return nil + end + return result + end + + local function findMainThread() + -- 不查找主线程,主线程一定是临时的(视为弱引用) + if m._ignoreMainThread then + return nil + end + local result = private {} + + for i = 1, maxinterger do + local info = getinfo(i, 'Sf') + if not info then + break + end + local funcInfo = find(info.func) if funcInfo then - result[#result+1] = { + for ln = 1, maxinterger do + local n, l = getlocal(i, ln) + if not n then + break + end + local valueInfo = find(l) + if valueInfo then + funcInfo[#funcInfo+1] = private { + type = 'local', + name = n, + info = valueInfo, + } + end + end + result[#result+1] = private { type = 'stack', name = i .. '@' .. formatName(info.func), info = funcInfo, @@ -258,22 +371,25 @@ function m.snapshot() return result end - function find(obj, trd, stack) + function find(obj) if mark[obj] then return mark[obj] end + if exclude[obj] or _private[obj] then + return nil + end local tp = type(obj) if tp == 'table' then - mark[obj] = {} + mark[obj] = private {} mark[obj] = findTable(obj, mark[obj]) elseif tp == 'function' then - mark[obj] = {} - mark[obj] = findFunction(obj, mark[obj], trd, stack) + mark[obj] = private {} + mark[obj] = findFunction(obj, mark[obj]) elseif tp == 'userdata' then - mark[obj] = {} + mark[obj] = private {} mark[obj] = findUserData(obj, mark[obj]) elseif tp == 'thread' then - mark[obj] = {} + mark[obj] = private {} mark[obj] = findThread(obj, mark[obj]) else return nil @@ -284,36 +400,71 @@ function m.snapshot() return mark[obj] end - return { + -- TODO: Lua 5.1中,主线程与_G都不在注册表里 + local result = private { name = formatName(registry), type = 'root', info = find(registry), } -end + if not registry[1] then + result.info[#result.info+1] = private { + type = 'thread', + name = 'main', + info = findMainThread(), + } + end + if not registry[2] then + result.info[#result.info+1] = private { + type = '_G', + name = '_G', + info = find(_G), + } + end + if m._cache then + m._lastCache = result + end + return result +end) ---- 寻找对象的引用 ----@return string -function m.catch(...) +--- 遍历虚拟机,寻找对象的引用。 +--- 输入既可以是对象实体,也可以是对象的描述(从其他接口的返回值中复制过来)。 +--- 返回字符串数组的数组,每个字符串描述了如何从根节点引用到指定的对象。 +--- 可以同时查找多个对象。 +---@return string[][] +m.catch = private(function (...) local targets = {} - for _, target in ipairs {...} do - targets[target] = true + for i = 1, select('#', ...) do + local target = select(i, ...) + if target ~= nil then + targets[target] = true + end end local report = m.snapshot() - local path = {} + local path = {} local result = {} - local mark = {} + local mark = {} local function push() - result[#result+1] = tableConcat(path, ' => ') + local resultPath = {} + for i = 1, #path do + resultPath[i] = path[i] + end + result[#result+1] = resultPath end local function search(t) path[#path+1] = ('(%s)%s'):format(t.type, t.name) local addTarget + local point = getPoint(t.info.object) if targets[t.info.object] then targets[t.info.object] = nil addTarget = t.info.object - push(t) + push() + end + if targets[point] then + targets[point] = nil + addTarget = point + push() end if not mark[t.info] then mark[t.info] = true @@ -330,11 +481,14 @@ function m.catch(...) search(report) return result -end +end) ---- 生成一个报告 ----@return string -function m.report() +---@alias report {point: string, count: integer, name: string, childs: integer} + +--- 生成一个内存快照的报告。 +--- 你应当将其输出到一个文件里再查看。 +---@return report[] +m.report = private(function () local snapshot = m.snapshot() local cache = {} local mark = {} @@ -347,12 +501,13 @@ function m.report() or tp == 'function' or tp == 'string' or tp == 'thread' then - local point = ('%p'):format(obj) + local point = getPoint(obj) if not cache[point] then cache[point] = { - point = point, - count = 0, - name = formatName(obj), + point = point, + count = 0, + name = formatName(obj), + childs = #t.info, } end cache[point].count = cache[point].count + 1 @@ -366,15 +521,61 @@ function m.report() end scan(snapshot) - local list = {} - for _, info in next, cache do + for _, info in pairs(cache) do list[#list+1] = info end - tableSort(list, function (a, b) - return a.name < b.name - end) return list -end +end) + +--- 在进行快照相关操作时排除掉的对象。 +--- 你可以用这个功能排除掉一些数据表。 +m.exclude = private(function (...) + m._exclude = {...} +end) + +--- 比较2个报告 +---@return string +m.compare = private(function (old, new) + local newHash = {} + local ret = {} + for _, info in ipairs(new) do + newHash[info.point] = info + end + for _, info in ipairs(old) do + if newHash[info.point] then + ret[#ret + 1] = { + old = info, + new = newHash[info.point] + } + end + end + return ret +end) + +--- 是否忽略主线程的栈 +---@param flag boolean +m.ignoreMainThread = private(function (flag) + m._ignoreMainThread = flag +end) + +--- 是否启用缓存,启用后会始终使用第一次查找的结果, +--- 适用于连续查找引用。如果想要查找新的引用需要先关闭缓存。 +---@param flag boolean +m.enableCache = private(function (flag) + if flag then + m._cache = true + else + m._cache = false + m._lastCache = nil + end +end) + +--- 立即清除缓存 +m.flushCache = private(function () + m._lastCache = nil +end) + +private(getinfo(1, 'f').func) return m -- cgit v1.2.3