summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--script/core/definition.lua6
-rw-r--r--script/parser/luadoc.lua162
-rw-r--r--script/vm/compiler.lua45
-rw-r--r--script/vm/generic.lua2
-rw-r--r--test/definition/luadoc.lua42
5 files changed, 153 insertions, 104 deletions
diff --git a/script/core/definition.lua b/script/core/definition.lua
index 687c071e..a01a6a25 100644
--- a/script/core/definition.lua
+++ b/script/core/definition.lua
@@ -170,10 +170,8 @@ return function (uri, offset)
goto CONTINUE
end
end
- if src.type == 'doc.type.name' then
- if src.typeGeneric then
- goto CONTINUE
- end
+ if src.type == 'doc.generic.name' then
+ goto CONTINUE
end
if src.type == 'doc.param' then
goto CONTINUE
diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua
index d516643b..62dff4bd 100644
--- a/script/parser/luadoc.lua
+++ b/script/parser/luadoc.lua
@@ -243,6 +243,69 @@ local function parseIndexField(tp, parent)
end
end
+local function parseTable(parent)
+ if not checkToken('symbol', '{', 1) then
+ return nil
+ end
+ nextToken()
+ local typeUnit = {
+ type = 'doc.type.table',
+ start = getStart(),
+ parent = parent,
+ fields = {},
+ }
+
+ while true do
+ if checkToken('symbol', '}', 1) then
+ nextToken()
+ break
+ end
+ local field = {
+ type = 'doc.type.field',
+ parent = typeUnit,
+ }
+
+ do
+ field.name = parseName('doc.field.name', field)
+ or parseIndexField('doc.field.name', field)
+ if not field.name then
+ pushWarning {
+ type = 'LUADOC_MISS_FIELD_NAME',
+ start = getFinish(),
+ finish = getFinish(),
+ }
+ break
+ end
+ if not field.start then
+ field.start = field.name.start
+ end
+ if checkToken('symbol', '?', 1) then
+ nextToken()
+ field.optional = true
+ end
+ field.finish = getFinish()
+ if not nextSymbolOrError(':') then
+ break
+ end
+ field.extends = parseType(field)
+ if not field.extends then
+ break
+ end
+ field.finish = getFinish()
+ end
+
+ typeUnit.fields[#typeUnit.fields+1] = field
+ if checkToken('symbol', ',', 1) then
+ nextToken()
+ else
+ nextSymbolOrError('}')
+ break
+ end
+ end
+ typeUnit.finish = getFinish()
+ return typeUnit
+end
+
local function parseClass(parent)
local result = {
type = 'doc.class',
@@ -269,6 +332,7 @@ local function parseClass(parent)
while true do
local extend = parseName('doc.extends.name', result)
+ or parseTable(result)
if not extend then
pushWarning {
type = 'LUADOC_MISS_CLASS_EXTENDS_NAME',
@@ -351,9 +415,14 @@ local function parseDots(tp, parent)
return dots
end
-local function parseTypeUnitFunction()
+local function parseTypeUnitFunction(parent)
+ if not checkToken('name', 'fun', 1) then
+ return nil
+ end
+ nextToken()
local typeUnit = {
type = 'doc.type.function',
+ parent = parent,
start = getStart(),
args = {},
returns = {},
@@ -424,74 +493,17 @@ local function parseTypeUnitFunction()
return typeUnit
end
-local function parseTypeUnitLiteralTable()
- local typeUnit = {
- type = 'doc.type.table',
- start = getStart(),
- fields = {},
- }
-
- while true do
- if checkToken('symbol', '}', 1) then
- nextToken()
- break
- end
- local field = {
- type = 'doc.type.field',
- parent = typeUnit,
- }
-
- do
- field.name = parseName('doc.field.name', field)
- or parseIndexField('doc.field.name', field)
- if not field.name then
- pushWarning {
- type = 'LUADOC_MISS_FIELD_NAME',
- start = getFinish(),
- finish = getFinish(),
- }
- break
- end
- if not field.start then
- field.start = field.name.start
- end
- if checkToken('symbol', '?', 1) then
- nextToken()
- field.optional = true
- end
- field.finish = getFinish()
- if not nextSymbolOrError(':') then
- break
- end
- field.extends = parseType(field)
- if not field.extends then
- break
- end
- field.finish = getFinish()
- end
-
- typeUnit.fields[#typeUnit.fields+1] = field
- if checkToken('symbol', ',', 1) then
- nextToken()
- else
- nextSymbolOrError('}')
- break
- end
- end
- typeUnit.finish = getFinish()
- return typeUnit
-end
-
local parseTypeUnit
-local function parseDocFunction(parent, content)
+local function parseFunction(parent)
+ local _, content = peekToken()
if content == 'async' then
+ nextToken()
local pos = getStart()
local tp, cont = peekToken()
if tp == 'name' then
if cont == 'fun' then
- nextToken()
- local func = parseTypeUnit(parent, cont)
+ local func = parseTypeUnit(parent)
if func then
func.async = true
func.asyncPos = pos
@@ -501,29 +513,26 @@ local function parseDocFunction(parent, content)
end
end
if content == 'fun' then
- return parseTypeUnitFunction()
+ return parseTypeUnitFunction(parent)
end
end
-function parseTypeUnit(parent, content)
- local result = parseDocFunction(parent, content)
- if not result then
- if content == '{' then
- result = parseTypeUnitLiteralTable()
- end
- end
+function parseTypeUnit(parent)
+ local result = parseFunction(parent)
+ or parseTable(parent)
if not result then
+ local _, token = nextToken()
result = {
type = 'doc.type.name',
start = getStart(),
finish = getFinish(),
- [1] = content,
+ parent = parent,
+ [1] = token,
}
end
if not result then
return nil
end
- result.parent = parent
while true do
local newResult = parseTypeUnitArray(parent, result)
if not newResult then
@@ -594,8 +603,7 @@ function parseType(parent)
end
if tp == 'name' then
- nextToken()
- local typeUnit = parseTypeUnit(result, content)
+ local typeUnit = parseTypeUnit(result)
if not typeUnit then
break
end
@@ -621,8 +629,7 @@ function parseType(parent)
result.start = typeEnum.start
end
elseif tp == 'symbol' and content == '{' then
- nextToken()
- local typeUnit = parseTypeUnit(result, content)
+ local typeUnit = parseTypeUnit(result)
if not typeUnit then
break
end
@@ -933,11 +940,10 @@ local function parseOverload()
}
return nil
end
- nextToken()
local result = {
type = 'doc.overload',
}
- result.overload = parseDocFunction(result, name)
+ result.overload = parseFunction(result)
if not result.overload then
return nil
end
diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua
index 75575fc4..59ef9cbe 100644
--- a/script/vm/compiler.lua
+++ b/script/vm/compiler.lua
@@ -335,7 +335,7 @@ local function selectNode(source, list, index)
if exp.type == '...' then
-- TODO
end
- return m.compileNode(exp)
+ return nodeMgr.setNode(source, m.compileNode(exp))
end
local compilerMap = util.switch()
@@ -396,6 +396,10 @@ local compilerMap = util.switch()
nodeMgr.setNode(source, m.compileNode(setfield.node))
end
end
+ -- for x in ... do
+ if source.parent.type == 'in' then
+ m.compileNode(source.parent)
+ end
end)
: case 'getlocal'
: call(function (source)
@@ -459,7 +463,7 @@ local compilerMap = util.switch()
end
if func.returns and not hasMarkDoc then
for _, rtn in ipairs(func.returns) do
- nodeMgr.setNode(source, selectNode(source, rtn, index))
+ selectNode(source, rtn, index)
end
end
end)
@@ -467,7 +471,26 @@ local compilerMap = util.switch()
: call(function (source)
local vararg = source.vararg
if vararg.type == 'call' then
- nodeMgr.setNode(source, getReturn(vararg.node, source.sindex, source, vararg.args))
+ getReturn(vararg.node, source.sindex, source, vararg.args)
+ end
+ end)
+ : case 'in'
+ : call(function (source)
+ if not source._iterator then
+ -- for k, v in pairs(t) do
+ --> for k, v in iterator, status, initValue do
+ --> local k, v = iterator(status, initValue)
+ source._iterator = {}
+ source._iterArgs = {{}, {}}
+ -- iterator
+ selectNode(source._iterator, source.exps, 1)
+ -- status
+ selectNode(source._iterArgs[1], source.exps, 2)
+ -- initValue
+ selectNode(source._iterArgs[2], source.exps, 3)
+ end
+ for i, loc in ipairs(source.keys) do
+ getReturn(source._iterator, i, loc, source._iterArgs)
end
end)
: case 'doc.type'
@@ -557,6 +580,22 @@ local function compileByGlobal(source)
end
end
end
+ if source._globalNode.cate == 'type' then
+ for _, set in ipairs(source._globalNode:getSets()) do
+ if set.type == 'doc.class' then
+ if set.extends then
+ for _, ext in ipairs(set.extends) do
+ if ext.type == 'doc.type.table' then
+ nodeMgr.setNode(source, m.compileNode(ext))
+ end
+ end
+ end
+ end
+ if set.type == 'doc.alias' then
+ nodeMgr.setNode(source, m.compileNode(set.extends))
+ end
+ end
+ end
return
end
end
diff --git a/script/vm/generic.lua b/script/vm/generic.lua
index c3e1ecbc..6d5a19da 100644
--- a/script/vm/generic.lua
+++ b/script/vm/generic.lua
@@ -80,7 +80,7 @@ local function cloneObject(node, resolved)
end
return newDocFunc
end
- return nil
+ return node
end
---@param argNodes vm.node[]
diff --git a/test/definition/luadoc.lua b/test/definition/luadoc.lua
index 592cbef7..ff54f0ed 100644
--- a/test/definition/luadoc.lua
+++ b/test/definition/luadoc.lua
@@ -143,6 +143,20 @@ function f(<?...?>) end
]]
TEST [[
+---@alias A <!fun()!>
+
+---@type A
+local <!<?x?>!>
+]]
+
+TEST [[
+---@class A: <!{}!>
+
+---@type A
+local <!<?x?>!>
+]]
+
+TEST [[
---@overload <!fun(y: boolean)!>
---@param x number
---@param y boolean
@@ -602,39 +616,31 @@ end
]]
TEST [[
----@class C
-local <!v!>
+---@alias C <!fun()!>
----@type C
-local <!v1!>
+---@type C[]
+local v1
---@generic V, T
---@param t T
----@return fun(t: V): V
+---@return fun(t: V[]): V
---@return T
local function iterator(t) end
-for <!v!> in iterator(<!v1!>) do
+for <!v!> in iterator(v1) do
print(<?v?>)
end
]]
TEST [[
----@class C
-local <!v!>
+---@class TT<V>: { x: V }
----@type C[]
-local v1
+---@type TT<A>
+local t
----@generic V, T
----@param t T
----@return fun(t: V[]): V
----@return T
-local function iterator(t) end
+---@class A: <!{}!>
-for <!v!> in iterator(v1) do
- print(<?v?>)
-end
+print(t.<?x?>)
]]
TEST [[