diff options
-rw-r--r-- | script/core/generic.lua | 10 | ||||
-rw-r--r-- | script/core/infer.lua | 61 | ||||
-rw-r--r-- | script/core/linker.lua | 15 | ||||
-rw-r--r-- | script/core/searcher.lua | 6 | ||||
-rw-r--r-- | script/parser/guide.lua | 19 | ||||
-rw-r--r-- | script/parser/luadoc.lua | 4 | ||||
-rw-r--r-- | test/definition/luadoc.lua | 4 | ||||
-rw-r--r-- | test/type_inference/init.lua | 24 |
8 files changed, 95 insertions, 48 deletions
diff --git a/script/core/generic.lua b/script/core/generic.lua index 53ced59c..97f3d635 100644 --- a/script/core/generic.lua +++ b/script/core/generic.lua @@ -108,19 +108,21 @@ local function createValue(closure, proto, callback, road) end local value = instantValue(closure, proto) value.node = node + linker.compileLink(value) return value end if proto.type == 'doc.type.table' then - local tkey = createValue(closure, proto.key, callback, road) + local tkey = createValue(closure, proto.tkey, callback, road) road[#road+1] = linker.SPLIT_CHAR - local tvalue = createValue(closure, proto.value, callback, road) + local tvalue = createValue(closure, proto.tvalue, callback, road) road[#road] = nil if not tkey and not tvalue then return nil end local value = instantValue(closure, proto) - value.key = tkey or proto.key - value.value = tvalue or proto.value + value.tkey = tkey or proto.tkey + value.tvalue = tvalue or proto.tvalue + linker.compileLink(value) return value end end diff --git a/script/core/infer.lua b/script/core/infer.lua index 14ec6be2..713e94b5 100644 --- a/script/core/infer.lua +++ b/script/core/infer.lua @@ -4,6 +4,7 @@ local linker = require 'core.linker' local BE_LEN = {'#'} local CLASS = {'CLASS'} +local TABLE = {'TABLE'} local m = {} @@ -183,6 +184,7 @@ local function cleanInfers(infers) infers['integer'] = nil infers['number'] = true end + -- 如果是通过 # 来推测的,且结果里没有其他的 table 与 string,则加入这2个类型 if infers[BE_LEN] then infers[BE_LEN] = nil if not infers['table'] and not infers['string'] then @@ -190,10 +192,16 @@ local function cleanInfers(infers) infers['string'] = true end end + -- 如果有doc标记,则先移除table类型 if infers[CLASS] then infers[CLASS] = nil infers['table'] = nil end + -- 用doc标记的table,加入table类型 + if infers[TABLE] then + infers[TABLE] = nil + infers['table'] = true + end end ---合并对象的推断类型 @@ -224,6 +232,34 @@ function m.viewInfers(infers) return infers[0] end +local function getDocName(doc) + if not doc then + return nil + end + if doc.type == 'doc.class.name' + or doc.type == 'doc.type.name' + or doc.type == 'doc.alias.name' then + local name = doc[1] or '?' + return name + end + if doc.type == 'doc.type.array' then + local nodeName = getDocName(doc.node) or '?' + return nodeName .. '[]' + end + if doc.type == 'doc.type.table' then + local key = getDocName(doc.tkey) or '?' + local value = getDocName(doc.tvalue) or '?' + return ('<%s, %s>'):format(key, value) + end + if doc.type == 'doc.type.function' then + return 'function' + end + if doc.type == 'doc.type.enum' then + local value = doc[1] or '?' + return value + end +end + ---显示对象的推断类型 ---@param source parser.guide.object ---@return string @@ -239,13 +275,14 @@ local function searchInfer(source, infers) searchInferOfValue(value, infers) return end - if source.type == 'doc.class.name' then - local name = source[1] - if name then - infers[name] = true - infers[CLASS] = true + -- check LuaDoc + local docName = getDocName(source) + if docName then + infers[docName] = true + infers[CLASS] = true + if docName == 'table' then + infers[TABLE] = true end - return end -- X.a -> table if source.next and source.next.node == source then @@ -319,16 +356,24 @@ function m.searchInfers(source) end local defs = searcher.requestDefinition(source) local infers = {} + local mark = {} + mark[source] = true searchInfer(source, infers) for _, def in ipairs(defs) do - searchInfer(def, infers) + if not mark[def] then + mark[def] = true + searchInfer(def, infers) + end end local id = linker.getID(source) if id then local link = linker.getLinkByID(source, id) if link and link.sources then for _, src in ipairs(link.sources) do - searchInfer(src, infers) + if not mark[src] then + mark[src] = true + searchInfer(src, infers) + end end end end diff --git a/script/core/linker.lua b/script/core/linker.lua index d9f3630a..e57cbaa0 100644 --- a/script/core/linker.lua +++ b/script/core/linker.lua @@ -142,6 +142,7 @@ local function getKey(source) or source.type == 'doc.param' or source.type == 'doc.vararg' or source.type == 'doc.field.name' + or source.type == 'doc.type.enum' or source.type == 'doc.type.table' or source.type == 'doc.type.array' or source.type == 'doc.type.function' then @@ -217,6 +218,9 @@ local function checkMode(source) if source.type == 'doc.vararg' then return 'dv:' end + if source.type == 'doc.type.enum' then + return 'de:' + end if source.type == 'generic.closure' then return 'gc:' end @@ -388,6 +392,9 @@ function m.compileLink(source) pushForward(id, getID(typeUnit)) pushBackward(getID(typeUnit), id) end + for _, enumUnit in ipairs(source.enums) do + pushForward(id, getID(enumUnit)) + end end -- 分解 @class if source.type == 'doc.class' then @@ -532,12 +539,12 @@ function m.compileLink(source) end end if source.type == 'doc.type.table' then - if source.value then + if source.tvalue then local valueID = ('%s%s'):format( id, SPLIT_CHAR ) - pushForward(valueID, getID(source.value)) + pushForward(valueID, getID(source.tvalue)) end end if source.type == 'doc.type.array' then @@ -657,12 +664,12 @@ function m.compileLink(source) pushForward(nodeID, getID(source.node)) end if proto.type == 'doc.type.table' then - if source.value then + if source.tvalue then local valueID = ('%s%s'):format( id, SPLIT_CHAR ) - pushForward(valueID, getID(source.value)) + pushForward(valueID, getID(source.tvalue)) end end end diff --git a/script/core/searcher.lua b/script/core/searcher.lua index c869f456..b64b21df 100644 --- a/script/core/searcher.lua +++ b/script/core/searcher.lua @@ -48,6 +48,9 @@ function m.pushResult(status, mode, source) or source.type == 'doc.class.name' or source.type == 'doc.alias.name' or source.type == 'doc.field.name' + or source.type == 'doc.type.enum' + or source.type == 'doc.type.array' + or source.type == 'doc.type.table' or source.type == 'doc.type.function' then results[#results+1] = source return @@ -85,6 +88,9 @@ function m.pushResult(status, mode, source) or source.type == 'doc.alias.name' or source.type == 'doc.extends.name' or source.type == 'doc.field.name' + or source.type == 'doc.type.enum' + or source.type == 'doc.type.array' + or source.type == 'doc.type.table' or source.type == 'doc.type.function' then results[#results+1] = source return diff --git a/script/parser/guide.lua b/script/parser/guide.lua index a838a42e..5f56be44 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -74,7 +74,7 @@ m.childMap = { ['doc.generic.object'] = {'generic', 'extends', 'comment'}, ['doc.vararg'] = {'vararg', 'comment'}, ['doc.type.array'] = {'node'}, - ['doc.type.table'] = {'node', 'key', 'value', 'comment'}, + ['doc.type.table'] = {'node', 'tkey', 'tvalue', 'comment'}, ['doc.type.function'] = {'#args', '#returns', 'comment'}, ['doc.type.literal'] = {'node'}, ['doc.type.arg'] = {'extends'}, @@ -142,23 +142,6 @@ function m.getParentFunction(obj) return nil end ---- 寻找父的table类型 doc.type.table ----@param obj parser.guide.object ----@return parser.guide.object -function m.getParentDocTypeTable(obj) - for _ = 1, 1000 do - local parent = obj.parent - if not parent then - return nil - end - if parent.type == 'doc.type.table' then - return obj - end - obj = parent - end - error('guide.getParentDocTypeTable overstack') -end - --- 寻找所在区块 ---@param obj parser.guide.object ---@return parser.guide.object diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index 5f70a9e5..e2630446 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -300,8 +300,8 @@ local function parseTypeUnitTable(parent, node) node.parent = result; result.finish = getFinish() - result.key = key - result.value = value + result.tkey = key + result.tvalue = value return result end diff --git a/test/definition/luadoc.lua b/test/definition/luadoc.lua index d0d95847..86366752 100644 --- a/test/definition/luadoc.lua +++ b/test/definition/luadoc.lua @@ -547,8 +547,8 @@ local v1 local function pairs(t) end for k, v in pairs(v1) do - print(k.<?bar1?>) - print(v.bar1) + print(k.bar1) + print(v.<?bar1?>) end ]] diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua index 95933b8d..3de36e5e 100644 --- a/test/type_inference/init.lua +++ b/test/type_inference/init.lua @@ -243,8 +243,10 @@ local function f(<?a?>, b) end ]] -TEST 'string' [[ ----@return string +TEST 'A' [[ +---@class A + +---@return A local function f2() end local function f() @@ -266,14 +268,6 @@ local <?x?> = f() --setmetatable(<?b?>) --]] -TEST 'function' [[ -string.<?sub?>() -]] - -TEST 'function' [[ -(''):<?sub?>() -]] - -- 不根据对方函数内的使用情况来推测 TEST 'any' [[ local function x(a) @@ -325,16 +319,23 @@ local <?x?> ]] TEST 'string' [[ +---@class string + ---@type string local <?x?> ]] TEST 'string[]' [[ +---@class string + ---@type string[] local <?x?> ]] TEST 'string|table' [[ +---@class string +---@class table + ---@type string | table local <?x?> ]] @@ -350,6 +351,9 @@ local <?x?> ]] TEST 'table<string, number>' [[ +---@class string +---@class number + ---@type table<string, number> local <?x?> ]] |