summaryrefslogtreecommitdiff
path: root/script/core/command/autoRequire.lua
blob: c0deecfc73c879b71f801d2b9baa7738277d3ea0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
local files  = require 'files'
local furi   = require 'file-uri'
local config = require 'config'
local rpath  = require 'workspace.require-path'
local client = require 'client'
local lang   = require 'language'
local guide  = require 'parser.guide'

local function inComment(state, pos)
    for _, comm in ipairs(state.comms) do
        if comm.start <= pos and comm.finish >= pos then
            return true
        end
        if comm.start > pos then
            break
        end
    end
    return false
end

local function findInsertRow(uri)
    local text  = files.getText(uri)
    local state = files.getState(uri)
    local lines = state.lines
    local fmt   = {
        pair = false,
        quot = '"',
        col  = nil,
    }
    local row
    for i = 0, #lines do
        if inComment(state, guide.positionOf(i, 0)) then
            goto CONTINUE
        end
        local ln = lines[i]
        local lnText = text:match('[^\r\n]*', ln)
        if not lnText:find('require', 1, true) then
            if row then
                break
            end
            if  not lnText:match '^local%s'
            and not lnText:match '^%s*$'
            and not lnText:match '^%-%-' then
                break
            end
        else
            row = i + 1
            local lpPos = lnText:find '%('
            if lpPos then
                fmt.pair = true
            else
                fmt.pair = false
            end
            local quot = lnText:match [=[(['"])]=]
            fmt.quot = quot or fmt.quot
            local eqPos = lnText:find '='
            if eqPos then
                fmt.col = eqPos
            end
        end
        ::CONTINUE::
    end
    return row or 0, fmt
end

---@async
local function askAutoRequire(uri, visiblePaths)
    local selects = {}
    local nameMap = {}
    for _, visible in ipairs(visiblePaths) do
        local expect = visible.expect
        local select = lang.script(expect)
        if not nameMap[select] then
            nameMap[select] = expect
            selects[#selects+1] = select
        end
    end
    local disable = lang.script.COMPLETION_DISABLE_AUTO_REQUIRE
    selects[#selects+1] = disable

    local result = client.awaitRequestMessage('Info'
        , lang.script.COMPLETION_ASK_AUTO_REQUIRE
        , selects
    )
    if not result then
        return
    end
    if result == disable then
        client.setConfig {
            {
                key    = 'Lua.completion.autoRequire',
                action = 'set',
                value  = false,
                uri    = uri,
            }
        }
        return
    end
    return nameMap[result]
end

local function applyAutoRequire(uri, row, name, result, fmt)
    local quotedResult = ('%q'):format(result)
    if fmt.quot == "'" then
        quotedResult = ([['%s']]):format(quotedResult:sub(2, -2)
            :gsub([[']], [[\']])
            :gsub([[\"]], [["]])
        )
    end
    if fmt.pair then
        quotedResult = ('(%s)'):format(quotedResult)
    else
        quotedResult = (' %s'):format(quotedResult)
    end
    local sp = ' '
    local text = ('local %s'):format(name)
    if fmt.col and fmt.col > #text then
        sp = (' '):rep(fmt.col - #text - 1)
    end
    text = ('local %s%s= require%s\n'):format(name, sp, quotedResult)
    client.editText(uri, {
        {
            start  = guide.positionOf(row, 0),
            finish = guide.positionOf(row, 0),
            text   = text,
        }
    })
end

---@async
return function (data)
    local uri    = data.uri
    local target = data.target
    local name   = data.name
    local state  = files.getState(uri)
    if not state then
        return
    end

    local path = furi.decode(target)
    local visiblePaths = rpath.getVisiblePath(uri, path)
    if not visiblePaths or #visiblePaths == 0 then
        return
    end
    table.sort(visiblePaths, function (a, b)
        return #a.expect < #b.expect
    end)

    local result = askAutoRequire(uri, visiblePaths)
    if not result then
        return
    end

    local offset, fmt = findInsertRow(uri)
    applyAutoRequire(uri, offset, name, result, fmt)
end