summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/parser/compile.lua2
-rw-r--r--script/parser/guide.lua108
-rw-r--r--script/parser/luadoc.lua30
-rw-r--r--script/vm/getDocs.lua5
-rw-r--r--test/definition/luadoc.lua88
-rw-r--r--test/example/guide.txt6
6 files changed, 218 insertions, 21 deletions
diff --git a/script/parser/compile.lua b/script/parser/compile.lua
index d4129ab4..f9992047 100644
--- a/script/parser/compile.lua
+++ b/script/parser/compile.lua
@@ -12,6 +12,8 @@ local specials = {
['loadfile'] = true,
['pcall'] = true,
['xpcall'] = true,
+ ['pairs'] = true,
+ ['ipairs'] = true,
}
_ENV = nil
diff --git a/script/parser/guide.lua b/script/parser/guide.lua
index 6c56155d..5f1bc230 100644
--- a/script/parser/guide.lua
+++ b/script/parser/guide.lua
@@ -97,7 +97,7 @@ m.childMap = {
['doc.generic'] = {'#generics', 'comment'},
['doc.generic.object'] = {'generic', 'extends', 'comment'},
['doc.vararg'] = {'vararg', 'comment'},
- ['doc.type.table'] = {'key', 'value', 'comment'},
+ ['doc.type.table'] = {'node', 'key', 'value', 'comment'},
['doc.type.function'] = {'#args', '#returns', 'comment'},
['doc.type.typeliteral'] = {'node'},
['doc.overload'] = {'overload', 'comment'},
@@ -169,6 +169,21 @@ function m.getParentFunction(obj)
return nil
end
+--- 寻找父的table类型 doc.type.table
+function m.getParentDocTypeTable(obj)
+ for _ = 1, 1000 do
+ local parent = obj.parent
+ if not parent then
+ return nil
+ end
+ if parent.type == 'doc.type.table' then
+ return obj
+ end
+ obj = parent
+ end
+ error('guide.getParentDocTypeTable overstack')
+end
+
--- 寻找所在区块
function m.getBlock(obj)
for _ = 1, 1000 do
@@ -1671,12 +1686,20 @@ function m.checkSameSimpleOfRefByDocSource(status, obj, start, pushQueue, mode)
end
end
-local function getArrayLevel(obj)
+local function getArrayOrTableLevel(obj)
local level = 0
while true do
local parent = obj.parent
if parent.type == 'doc.type.array' then
level = level + 1
+ elseif parent.type == 'doc.type.table' then
+ if obj.type == 'doc.type' then
+ level = level + 1
+ -- else 只存在 obj.type == 'doc.type.name' 的情况,即 table<k,v> 中的 table,这种是不需要再增加层级的
+ end
+ elseif parent.type == 'doc.type' and parent.parent and parent.parent.type == 'doc.type.table' then
+ level = level + 1
+ parent = parent.parent
else
break
end
@@ -1733,9 +1756,10 @@ function m.checkSameSimpleByDoc(status, obj, start, pushQueue, mode)
for _, res in ipairs(pieceResult) do
pushQueue(res, start, true)
end
+
local state = m.getDocState(obj)
if state.type == 'doc.type' and mode == 'ref' then
- m.checkSameSimpleOfRefByDocSource(status, state, start - getArrayLevel(obj), pushQueue, mode)
+ m.checkSameSimpleOfRefByDocSource(status, state, start - getArrayOrTableLevel(obj), pushQueue, mode)
end
return true
elseif obj.type == 'doc.field' then
@@ -1746,6 +1770,10 @@ function m.checkSameSimpleByDoc(status, obj, start, pushQueue, mode)
elseif obj.type == 'doc.type.array' then
pushQueue(obj.node, start + 1, true)
return true
+ elseif obj.type == 'doc.type.table' then
+ pushQueue(obj.node, start, true)
+ pushQueue(obj.value, start + 1, true)
+ return true
end
end
@@ -2211,6 +2239,72 @@ function m.checkSameSimpleAsSetValue(status, ref, start, pushQueue)
end
end
+local function getTableAndIndexIfIsForPairsKeyOrValue(ref)
+ if ref.type ~= 'local' then
+ return
+ end
+
+ if not ref.parent or ref.parent.type ~= 'in' then
+ return
+ end
+
+ if not ref.value or ref.value.type ~= 'select' then
+ return
+ end
+
+ local rootSelectObj = ref.value
+ if rootSelectObj.index ~= 1 and rootSelectObj.index ~= 2 then
+ return
+ end
+
+ if not rootSelectObj.vararg or rootSelectObj.vararg.type ~= 'call' then
+ return
+ end
+ local rootCallObj = rootSelectObj.vararg
+
+ if not rootCallObj.node or rootCallObj.node.type ~= 'call' then
+ return
+ end
+ local pairsCallObj = rootCallObj.node
+
+ if not pairsCallObj.node
+ or (pairsCallObj.node.special ~= 'pairs' and pairsCallObj.node.special ~= 'ipairs') then
+ return
+ end
+
+ if not pairsCallObj.args or not pairsCallObj.args[1] then
+ return
+ end
+ local tableObj = pairsCallObj.args[1]
+
+ return tableObj, rootSelectObj.index
+end
+
+function m.checkSameSimpleAsKeyOrValueInForParis(status, ref, start, pushQueue)
+ local tableObj, index = getTableAndIndexIfIsForPairsKeyOrValue(ref)
+ if not tableObj then
+ return
+ end
+
+ local newStatus = m.status(status)
+ m.searchRefs(newStatus, tableObj, 'def')
+ for _, def in ipairs(newStatus.results) do
+ if def.bindDocs then
+ for _, binddoc in ipairs(def.bindDocs) do
+ if binddoc.type == 'doc.type' then
+ if binddoc.types[1] and binddoc.types[1].type == 'doc.type.table' then
+ if index == 1 then
+ pushQueue(binddoc.types[1].key, start, true)
+ elseif index == 2 then
+ pushQueue(binddoc.types[1].value, start, true)
+ end
+ end
+ end
+ end
+ end
+ end
+end
+
local function hasTypeName(doc, name)
if doc.type ~= 'doc.type' then
return false
@@ -2447,6 +2541,8 @@ function m.checkSameSimple(status, simple, ref, start, force, mode, pushQueue)
m.checkSameSimpleAsReturn(status, ref, i, pushQueue)
-- 检查形如 a = f 的情况
m.checkSameSimpleAsSetValue(status, ref, i, pushQueue)
+ -- 检查形如 for k,v in pairs()/ipairs() do end 的情况
+ m.checkSameSimpleAsKeyOrValueInForParis(status, ref, i, pushQueue)
end
end
if i == #simple then
@@ -2888,7 +2984,7 @@ function m.viewInferType(infers)
or src.type == 'doc.class.name'
or src.type == 'doc.type.name'
or src.type == 'doc.type.array'
- or src.type == 'doc.type.generic' then
+ or src.type == 'doc.type.table' then
if infer.type ~= 'any' then
hasDoc = true
break
@@ -2903,7 +2999,7 @@ function m.viewInferType(infers)
or src.type == 'doc.class.name'
or src.type == 'doc.type.name'
or src.type == 'doc.type.array'
- or src.type == 'doc.type.generic'
+ or src.type == 'doc.type.table'
or src.type == 'doc.type.enum'
or src.type == 'doc.resume' then
local tp = infer.type or 'any'
@@ -3132,7 +3228,7 @@ function m.getDocTypeUnitName(status, unit)
typeName = 'function'
elseif unit.type == 'doc.type.array' then
typeName = m.getDocTypeUnitName(status, unit.node) .. '[]'
- elseif unit.type == 'doc.type.generic' then
+ elseif unit.type == 'doc.type.table' then
typeName = ('%s<%s, %s>'):format(
m.getDocTypeUnitName(status, unit.node),
m.viewInferType(m.getDocTypeNames(status, unit.key)),
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
index 832cdc87..4512bc9e 100644
--- a/script/parser/luadoc.lua
+++ b/script/parser/luadoc.lua
@@ -261,30 +261,36 @@ local function parseTypeUnitArray(node)
return result
end
-local function parseTypeUnitGeneric(node)
+local function parseTypeUnitTable(parent, node)
if not checkToken('symbol', '<', 1) then
return nil
end
if not nextSymbolOrError('<') then
return nil
end
- local key = parseType(node)
+
+ local result = {
+ type = 'doc.type.table',
+ start = node.start,
+ node = node,
+ parent = parent,
+ }
+
+ local key = parseType(result)
if not key or not nextSymbolOrError(',') then
return nil
end
- local value = parseType(node)
+ local value = parseType(result)
if not value then
return nil
end
nextSymbolOrError('>')
- local result = {
- type = 'doc.type.generic',
- start = node.start,
- finish = getFinish(),
- node = node,
- key = key,
- value = value,
- }
+
+ node.parent = result;
+ result.finish = getFinish()
+ result.key = key
+ result.value = value
+
return result
end
@@ -398,7 +404,7 @@ local function parseTypeUnit(parent, content)
result.parent = parent
while true do
local newResult = parseTypeUnitArray(result)
- or parseTypeUnitGeneric(result)
+ or parseTypeUnitTable(parent, result)
if not newResult then
break
end
diff --git a/script/vm/getDocs.lua b/script/vm/getDocs.lua
index 632dd1c2..790a9b50 100644
--- a/script/vm/getDocs.lua
+++ b/script/vm/getDocs.lua
@@ -16,6 +16,11 @@ local function getTypesOfFile(uri)
or src.type == 'doc.class.name'
or src.type == 'doc.extends.name'
or src.type == 'doc.alias.name' then
+ if src.type == 'doc.type.name' then
+ if guide.getParentDocTypeTable(src) then
+ return
+ end
+ end
local name = src[1]
if name then
if not types[name] then
diff --git a/test/definition/luadoc.lua b/test/definition/luadoc.lua
index 1f3dae00..5315b5fd 100644
--- a/test/definition/luadoc.lua
+++ b/test/definition/luadoc.lua
@@ -253,3 +253,91 @@ function Generic(arg1) print(arg1) end
local v1 = Generic("Foo")
print(v1.<?bar1?>)
]]
+
+TEST [[
+---@class Foo
+local Foo = {}
+function Foo:<!bar1!>() end
+
+---@type table<number, Foo>
+local v1
+print(v1[1].<?bar1?>)
+]]
+
+TEST [[
+---@class Foo
+local Foo = {}
+function Foo:<!bar1!>() end
+
+---@class Foo2
+local Foo2 = {}
+function Foo2:bar1() end
+
+---@type Foo2<number, Foo>
+local v1
+print(v1[1].<?bar1?>)
+]]
+
+--TODO 得扩展 simple 的信息才能识别这种情况了
+--TEST [[
+-----@class Foo
+--local Foo = {}
+--function Foo:bar1() end
+--
+-----@class Foo2
+--local Foo2 = {}
+--function Foo2:<!bar1!>() end
+--
+-----@type Foo2<number, Foo>
+--local v1
+--print(v1.<?bar1?>)
+--]]
+
+TEST [[
+---@class Foo
+local Foo = {}
+function Foo:<!bar1!>() end
+
+---@type table<number, Foo>
+local v1
+local ipairs = ipairs
+for i, v in ipairs(v1) do
+ print(v.<?bar1?>)
+end
+]]
+
+TEST [[
+---@class Foo
+local Foo = {}
+function Foo:<!bar1!>() end
+
+---@type table<Foo, Foo>
+local v1
+for k, v in pairs(v1) do
+ print(k.<?bar1?>)
+ print(v.bar1)
+end
+]]
+
+TEST [[
+---@class Foo
+local Foo = {}
+function Foo:<!bar1!>() end
+
+---@type table<number, table<number, Foo>>
+local v1
+for i, v in ipairs(v1) do
+ local v2 = v[1]
+ print(v2.<?bar1?>)
+end
+]]
+
+TEST [[
+---@class Foo
+local Foo = {}
+function Foo:<!bar1!>() end
+
+---@type table<number, table<number, Foo>>
+local v1
+print(v1[1][1].<?bar1?>)
+]]
diff --git a/test/example/guide.txt b/test/example/guide.txt
index 437e37b0..da8d5c32 100644
--- a/test/example/guide.txt
+++ b/test/example/guide.txt
@@ -2702,7 +2702,7 @@ function m.viewInferType(infers)
or src.type == 'doc.class.name'
or src.type == 'doc.type.name'
or src.type == 'doc.type.array'
- or src.type == 'doc.type.generic' then
+ or src.type == 'doc.type.table' then
if infer.type ~= 'any' then
hasDoc = true
break
@@ -2717,7 +2717,7 @@ function m.viewInferType(infers)
or src.type == 'doc.class.name'
or src.type == 'doc.type.name'
or src.type == 'doc.type.array'
- or src.type == 'doc.type.generic'
+ or src.type == 'doc.type.table'
or src.type == 'doc.type.enum'
or src.type == 'doc.resume' then
local tp = infer.type or 'any'
@@ -2946,7 +2946,7 @@ local function getDocTypeUnitName(status, unit)
typeName = 'function'
elseif unit.type == 'doc.type.array' then
typeName = getDocTypeUnitName(status, unit.node) .. '[]'
- elseif unit.type == 'doc.type.generic' then
+ elseif unit.type == 'doc.type.table' then
typeName = ('%s<%s, %s>'):format(
getDocTypeUnitName(status, unit.node),
m.viewInferType(m.getDocTypeNames(status, unit.key)),