summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/generic.lua10
-rw-r--r--script/core/infer.lua61
-rw-r--r--script/core/linker.lua15
-rw-r--r--script/core/searcher.lua6
-rw-r--r--script/parser/guide.lua19
-rw-r--r--script/parser/luadoc.lua4
-rw-r--r--test/definition/luadoc.lua4
-rw-r--r--test/type_inference/init.lua24
8 files changed, 95 insertions, 48 deletions
diff --git a/script/core/generic.lua b/script/core/generic.lua
index 53ced59c..97f3d635 100644
--- a/script/core/generic.lua
+++ b/script/core/generic.lua
@@ -108,19 +108,21 @@ local function createValue(closure, proto, callback, road)
end
local value = instantValue(closure, proto)
value.node = node
+ linker.compileLink(value)
return value
end
if proto.type == 'doc.type.table' then
- local tkey = createValue(closure, proto.key, callback, road)
+ local tkey = createValue(closure, proto.tkey, callback, road)
road[#road+1] = linker.SPLIT_CHAR
- local tvalue = createValue(closure, proto.value, callback, road)
+ local tvalue = createValue(closure, proto.tvalue, callback, road)
road[#road] = nil
if not tkey and not tvalue then
return nil
end
local value = instantValue(closure, proto)
- value.key = tkey or proto.key
- value.value = tvalue or proto.value
+ value.tkey = tkey or proto.tkey
+ value.tvalue = tvalue or proto.tvalue
+ linker.compileLink(value)
return value
end
end
diff --git a/script/core/infer.lua b/script/core/infer.lua
index 14ec6be2..713e94b5 100644
--- a/script/core/infer.lua
+++ b/script/core/infer.lua
@@ -4,6 +4,7 @@ local linker = require 'core.linker'
local BE_LEN = {'#'}
local CLASS = {'CLASS'}
+local TABLE = {'TABLE'}
local m = {}
@@ -183,6 +184,7 @@ local function cleanInfers(infers)
infers['integer'] = nil
infers['number'] = true
end
+ -- 如果是通过 # 来推测的,且结果里没有其他的 table 与 string,则加入这2个类型
if infers[BE_LEN] then
infers[BE_LEN] = nil
if not infers['table'] and not infers['string'] then
@@ -190,10 +192,16 @@ local function cleanInfers(infers)
infers['string'] = true
end
end
+ -- 如果有doc标记,则先移除table类型
if infers[CLASS] then
infers[CLASS] = nil
infers['table'] = nil
end
+ -- 用doc标记的table,加入table类型
+ if infers[TABLE] then
+ infers[TABLE] = nil
+ infers['table'] = true
+ end
end
---合并对象的推断类型
@@ -224,6 +232,34 @@ function m.viewInfers(infers)
return infers[0]
end
+local function getDocName(doc)
+ if not doc then
+ return nil
+ end
+ if doc.type == 'doc.class.name'
+ or doc.type == 'doc.type.name'
+ or doc.type == 'doc.alias.name' then
+ local name = doc[1] or '?'
+ return name
+ end
+ if doc.type == 'doc.type.array' then
+ local nodeName = getDocName(doc.node) or '?'
+ return nodeName .. '[]'
+ end
+ if doc.type == 'doc.type.table' then
+ local key = getDocName(doc.tkey) or '?'
+ local value = getDocName(doc.tvalue) or '?'
+ return ('<%s, %s>'):format(key, value)
+ end
+ if doc.type == 'doc.type.function' then
+ return 'function'
+ end
+ if doc.type == 'doc.type.enum' then
+ local value = doc[1] or '?'
+ return value
+ end
+end
+
---显示对象的推断类型
---@param source parser.guide.object
---@return string
@@ -239,13 +275,14 @@ local function searchInfer(source, infers)
searchInferOfValue(value, infers)
return
end
- if source.type == 'doc.class.name' then
- local name = source[1]
- if name then
- infers[name] = true
- infers[CLASS] = true
+ -- check LuaDoc
+ local docName = getDocName(source)
+ if docName then
+ infers[docName] = true
+ infers[CLASS] = true
+ if docName == 'table' then
+ infers[TABLE] = true
end
- return
end
-- X.a -> table
if source.next and source.next.node == source then
@@ -319,16 +356,24 @@ function m.searchInfers(source)
end
local defs = searcher.requestDefinition(source)
local infers = {}
+ local mark = {}
+ mark[source] = true
searchInfer(source, infers)
for _, def in ipairs(defs) do
- searchInfer(def, infers)
+ if not mark[def] then
+ mark[def] = true
+ searchInfer(def, infers)
+ end
end
local id = linker.getID(source)
if id then
local link = linker.getLinkByID(source, id)
if link and link.sources then
for _, src in ipairs(link.sources) do
- searchInfer(src, infers)
+ if not mark[src] then
+ mark[src] = true
+ searchInfer(src, infers)
+ end
end
end
end
diff --git a/script/core/linker.lua b/script/core/linker.lua
index d9f3630a..e57cbaa0 100644
--- a/script/core/linker.lua
+++ b/script/core/linker.lua
@@ -142,6 +142,7 @@ local function getKey(source)
or source.type == 'doc.param'
or source.type == 'doc.vararg'
or source.type == 'doc.field.name'
+ or source.type == 'doc.type.enum'
or source.type == 'doc.type.table'
or source.type == 'doc.type.array'
or source.type == 'doc.type.function' then
@@ -217,6 +218,9 @@ local function checkMode(source)
if source.type == 'doc.vararg' then
return 'dv:'
end
+ if source.type == 'doc.type.enum' then
+ return 'de:'
+ end
if source.type == 'generic.closure' then
return 'gc:'
end
@@ -388,6 +392,9 @@ function m.compileLink(source)
pushForward(id, getID(typeUnit))
pushBackward(getID(typeUnit), id)
end
+ for _, enumUnit in ipairs(source.enums) do
+ pushForward(id, getID(enumUnit))
+ end
end
-- 分解 @class
if source.type == 'doc.class' then
@@ -532,12 +539,12 @@ function m.compileLink(source)
end
end
if source.type == 'doc.type.table' then
- if source.value then
+ if source.tvalue then
local valueID = ('%s%s'):format(
id,
SPLIT_CHAR
)
- pushForward(valueID, getID(source.value))
+ pushForward(valueID, getID(source.tvalue))
end
end
if source.type == 'doc.type.array' then
@@ -657,12 +664,12 @@ function m.compileLink(source)
pushForward(nodeID, getID(source.node))
end
if proto.type == 'doc.type.table' then
- if source.value then
+ if source.tvalue then
local valueID = ('%s%s'):format(
id,
SPLIT_CHAR
)
- pushForward(valueID, getID(source.value))
+ pushForward(valueID, getID(source.tvalue))
end
end
end
diff --git a/script/core/searcher.lua b/script/core/searcher.lua
index c869f456..b64b21df 100644
--- a/script/core/searcher.lua
+++ b/script/core/searcher.lua
@@ -48,6 +48,9 @@ function m.pushResult(status, mode, source)
or source.type == 'doc.class.name'
or source.type == 'doc.alias.name'
or source.type == 'doc.field.name'
+ or source.type == 'doc.type.enum'
+ or source.type == 'doc.type.array'
+ or source.type == 'doc.type.table'
or source.type == 'doc.type.function' then
results[#results+1] = source
return
@@ -85,6 +88,9 @@ function m.pushResult(status, mode, source)
or source.type == 'doc.alias.name'
or source.type == 'doc.extends.name'
or source.type == 'doc.field.name'
+ or source.type == 'doc.type.enum'
+ or source.type == 'doc.type.array'
+ or source.type == 'doc.type.table'
or source.type == 'doc.type.function' then
results[#results+1] = source
return
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index a838a42e..5f56be44 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -74,7 +74,7 @@ m.childMap = {
['doc.generic.object'] = {'generic', 'extends', 'comment'},
['doc.vararg'] = {'vararg', 'comment'},
['doc.type.array'] = {'node'},
- ['doc.type.table'] = {'node', 'key', 'value', 'comment'},
+ ['doc.type.table'] = {'node', 'tkey', 'tvalue', 'comment'},
['doc.type.function'] = {'#args', '#returns', 'comment'},
['doc.type.literal'] = {'node'},
['doc.type.arg'] = {'extends'},
@@ -142,23 +142,6 @@ function m.getParentFunction(obj)
return nil
end
---- 寻找父的table类型 doc.type.table
----@param obj parser.guide.object
----@return parser.guide.object
-function m.getParentDocTypeTable(obj)
- for _ = 1, 1000 do
- local parent = obj.parent
- if not parent then
- return nil
- end
- if parent.type == 'doc.type.table' then
- return obj
- end
- obj = parent
- end
- error('guide.getParentDocTypeTable overstack')
-end
-
--- 寻找所在区块
---@param obj parser.guide.object
---@return parser.guide.object
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
index 5f70a9e5..e2630446 100644
--- a/script/parser/luadoc.lua
+++ b/script/parser/luadoc.lua
@@ -300,8 +300,8 @@ local function parseTypeUnitTable(parent, node)
node.parent = result;
result.finish = getFinish()
- result.key = key
- result.value = value
+ result.tkey = key
+ result.tvalue = value
return result
end
diff --git a/test/definition/luadoc.lua b/test/definition/luadoc.lua
index d0d95847..86366752 100644
--- a/test/definition/luadoc.lua
+++ b/test/definition/luadoc.lua
@@ -547,8 +547,8 @@ local v1
local function pairs(t) end
for k, v in pairs(v1) do
- print(k.<?bar1?>)
- print(v.bar1)
+ print(k.bar1)
+ print(v.<?bar1?>)
end
]]
diff --git a/test/type_inference/init.lua b/test/type_inference/init.lua
index 95933b8d..3de36e5e 100644
--- a/test/type_inference/init.lua
+++ b/test/type_inference/init.lua
@@ -243,8 +243,10 @@ local function f(<?a?>, b)
end
]]
-TEST 'string' [[
----@return string
+TEST 'A' [[
+---@class A
+
+---@return A
local function f2() end
local function f()
@@ -266,14 +268,6 @@ local <?x?> = f()
--setmetatable(<?b?>)
--]]
-TEST 'function' [[
-string.<?sub?>()
-]]
-
-TEST 'function' [[
-(''):<?sub?>()
-]]
-
-- 不根据对方函数内的使用情况来推测
TEST 'any' [[
local function x(a)
@@ -325,16 +319,23 @@ local <?x?>
]]
TEST 'string' [[
+---@class string
+
---@type string
local <?x?>
]]
TEST 'string[]' [[
+---@class string
+
---@type string[]
local <?x?>
]]
TEST 'string|table' [[
+---@class string
+---@class table
+
---@type string | table
local <?x?>
]]
@@ -350,6 +351,9 @@ local <?x?>
]]
TEST 'table<string, number>' [[
+---@class string
+---@class number
+
---@type table<string, number>
local <?x?>
]]