-
Notifications
You must be signed in to change notification settings - Fork 0
/
cs.py
549 lines (512 loc) · 23.4 KB
/
cs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
import os
import platform
from transformers import AutoTokenizer, AutoModel
import torch
from docx import Document
import re
print(torch.cuda.is_available())
print(torch.__version__)
torch.cuda.empty_cache()
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# MODEL_PATH = os.environ.get('MODEL_PATH', '/HOME/scw6c94/run/ac/llm/ChatGLM3/basic_demo/32k')
MODEL_PATH = os.environ.get('MODEL_PATH', '/home/liuqt/long/32k')
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True,device_map='cuda:0').eval()
# add .quantize(bits=4, device="cuda").cuda() before .eval() to use int4 model
# must use cuda to load int4 model
os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False
welcome_prompt = "欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
def build_prompt(history):
prompt = welcome_prompt
for query, response in history:
prompt += f"\n\n用户:{query}"
prompt += f"\n\nChatGLM3-6B:{response}"
return prompt
def query_zy_class(aname,history,past_key_values,stop_stream):
f = open(
# r"
r'/home/liuqt/long/zy_output_files/'+aname+'.txt',
"r", encoding='utf-8') # 设置文件对象
str_zy = f.read() # 将txt文件的所有内容读入到字符串str中
f.close() # 将文件关闭
# f = open(
# r"D:\PyProject\docx_input\prompt\log_alg.txt",
# "r", encoding='utf-8') # 设置文件对象
# str2 = f.read() # 将txt文件的所有内容读入到字符串str中
# f.close()
query = "\n用户:" + str_zy + '阅读该文章摘要,说明是什么类型的论文,回答格式为:[算法类]或者[系统类]或者[理论推导类]'
print("\nChatGLM:", end="")
current_length = 0
for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,
temperature=0.01,
past_key_values=past_key_values,
return_past_key_values=True):
if stop_stream:
stop_stream = False
break
else:
print(response[current_length:], end="", flush=True)
article_class=response
current_length = len(response)
print("")
return article_class
def query_menu_struct_alg(aname, history, past_key_values, stop_stream):
f = open(
# r"
r'/home/liuqt/long/menu_out_files/'+aname+'.txt',
"r", encoding='utf-8') # 设置文件对象
str = f.read()
mlist=eval(str)# 将txt文件的所有内容读入到字符串str中
f.close() # 将文件关闭
print (mlist)
f = open(
# r"
r'/home/liuqt/long/file/txt2/'+aname+'.txt',
"r", encoding='gbk') # 设置文件对象
str_txt = f.read()
# 将txt文件的所有内容读入到字符串str中
f.close()
# f = open(
# r"D:\PyProject\docx_input\prompt\log_alg.txt",
# "r", encoding='utf-8') # 设置文件对象
# str2 = f.read() # 将txt文件的所有内容读入到字符串str中
# f.close()
qlist = []
# past_key_values, history = None, []
# global stop_stream
# query = "\n用户:" + str + '阅读该文目录列表,说明本文所做的实验以及对应的章节,回答格式为:一系列的[xxxxx,B]其中B为所给出目录中的章节,例如[实体链接模型实验,实体链接算法研究],[关系抽取模型实验,关系抽取算法研究],[基于知识库推理的联合优化算法实验,基于知识库推理的联合优化算法设计与实现]'
#
# print("\nChatGLM:", end="")
# current_length = 0
# for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,
# temperature=0.01,
# past_key_values=past_key_values,
# return_past_key_values=True):
# if stop_stream:
# stop_stream = False
# break
# else:
# print(response[current_length:], end="", flush=True)
#
# current_length = len(response)
# print("")
# #得到实验和对应的章节
# query = "\n用户:结合历史信息,说明本文所做的实验对应的所有章节,仅回答对应章节名称,回答格式为:[xxxxx]其中为所给出目录中的章节,例如[实体链接算法研究],[关系抽取算法研究],[基于知识库推理的联合优化算法设计与实现],[实体链接模型实验],[关系抽取模型实验],[基于知识库推理的联合优化算法实验],[实体链接任务的样本分布偏差讨论]"
#
#
# print("\nChatGLM:", end="")
# current_length = 0
# for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,
# temperature=0.01,
# past_key_values=past_key_values,
# return_past_key_values=True):
# if stop_stream:
# stop_stream = False
# break
# else:
# print(response[current_length:], end="", flush=True)
#
# current_length = len(response)
# print("")
# #得到章节
# rep=response
# tlist =re.findall('\[(.*?)]',rep)
# tlist2=[]
# for i in tlist:
# if ',' in i:
# x=list(i.split(','))
# tlist2.append(x[0].replace(' ',''))
# tlist2.append(x[1].replace(' ',''))
# else:
# tlist2.append(i)
# tlist2=list(set(tlist2))
# tlist_t=[]
# for i in tlist2:
# if i in str:
# print(i+'in menu')
# tlist_t.append(i)
# else:
# print(i+'not in menu')
# #章节验证
# tlist_t.append('致谢')
# #添加文末
# t_pnum_list=[]
# for i in tlist_t:
# for j in mlist:
# if i==j[0] or i in j[0]:
# t_pnum_list.append([i,j[2]])
# #生成章节段落序号
# print(t_pnum_list)
#所有对应实验的paragraphs序号
docx_path='/home/liuqt/long/file/docx/'+aname+'.docx'
docx = Document(docx_path)
#读入原文
# print(file_path)
# print(len(doc.paragraphs))
#按txt内容中出现位置划分
# bl=blockdivide(t_pnum_list,docx.paragraphs)
# bl2=blockdivide2(t_pnum_list,str_txt)
# bl3=blockdivide3(t_pnum_list,str_txt)
return [str_txt]
def query_experiment_alg(string, history, past_key_values, stop_stream, aname, save_path):
# past_key_values, history = None, []
# global stop_stream
query = string + '阅读此部分文章,说明其完成了哪些实验(不超过10个),回答格式为一系列的:[xxxx],例如[实体链接模型实验];回答实验结果是什么,给出文中的量化数据和来源。'
print("\nChatGLM:", end="")
current_length = 0
for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,
temperature=0.01,
past_key_values=past_key_values,
return_past_key_values=True):
if stop_stream:
stop_stream = False
break
else:
print(response[current_length:], end="", flush=True)
current_length = len(response)
print("")
#内容query结束
#指标query开始
query = '实验指标结果是什么,以[xxx,xxx]的格式回答。'
print("\nChatGLM:", end="")
current_length = 0
for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,
temperature=0.01,
past_key_values=past_key_values,
return_past_key_values=True):
if stop_stream:
stop_stream = False
break
else:
print(response[current_length:], end="", flush=True)
with open(save_path+aname+'[指标].txt', 'a', encoding='utf-8') as f:
print(response[current_length:],end='', file=f)
current_length = len(response)
print("")
#总结query
query = '根据以上信息对进行归纳,回答时要涵盖量化数据:针对[]问题,提出了[], 在[]上进行了[], []实验指标为[],结果表明[]。'
print("\nChatGLM:", end="")
current_length = 0
for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,
temperature=0.01,
past_key_values=past_key_values,
return_past_key_values=True):
if stop_stream:
stop_stream = False
break
else:
print(response[current_length:], end="", flush=True)
with open(save_path+aname+'[归纳].txt', 'a', encoding='utf-8') as f:
print(response[current_length:],end='', file=f)
current_length = len(response)
print("")
return
def query_menu_struct_sys(aname, history, past_key_values, stop_stream):
# f = open(r'/home/liuqt/long/menu_out_files/'+aname+'.txt',"r", encoding='utf-8')
# # 设置文件对象
# str = f.read()
# mlist=eval(str)# 将txt文件的所有内容读入到字符串str中
# f.close() # 将文件关闭
# #读入目录
# print (mlist)
#
#
f = open(r'/home/liuqt/long/file/txt2/'+aname+'.txt',"r", encoding='gbk') # 设置文件对象
str_txt = f.read()
# 将txt文件的所有内容读入到字符串str中
f.close()
# #读入文章txt版本
#
# # past_key_values, history = None, []
# # global stop_stream
# #query
# query = "\n用户:" + str + '阅读该文目录列表,说明本文所做系统的模块以及对应的章节,回答格式为:一系列的[xxxxx,B]其中B为所给出目录中的章节,例如[xxx,xxxxx],[xxx,xxxxx]'
# print(query)
# print("\nChatGLM:", end="")
# current_length = 0
# for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,
# temperature=0.01,
# past_key_values=past_key_values,
# return_past_key_values=True):
# if stop_stream:
# stop_stream = False
# break
# else:
# print(response[current_length:], end="", flush=True)
#
# current_length = len(response)
# print("")
# #得到实验和对应的章节
# query = "\n用户:结合历史信息,说明本文涉及到的章节名称,只回答对应章节名称,回答格式例如[系统性能需求分析],[质量模块需求分析],一个中括号中只能有一个章节名称"
# print(query)
#
# print("\nChatGLM:", end="")
# current_length = 0
# for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,
# temperature=0.01,
# past_key_values=past_key_values,
# return_past_key_values=True):
# if stop_stream:
# stop_stream = False
# break
# else:
# print(response[current_length:], end="", flush=True)
#
# current_length = len(response)
# print("")
# #得到章节
# rep=response
# tlist =re.findall('\[(.*?)]',rep)
# tlist_t=[]
# for i in tlist:
# if i in str:
# print(i+'in menu')
# tlist_t.append(i)
# else:
# print(i+'not in menu')
# #章节验证
# tlist_t.append('致谢')
# print('tlist_t')
# print(tlist_t)
# print('mlist')
# print(mlist)
# #添加文末
# t_pnum_list=[]
# for i in tlist_t:
# for j in mlist:
# #可能是i==j[0]
# if i in j[0]:
# t_pnum_list.append([i,j[2]])
# #生成章节段落序号
# print(t_pnum_list)
#所有对应实验的paragraphs序号
docx_path='/home/liuqt/long/file/docx/'+aname+'.docx'
docx = Document(docx_path)
#读入原文
# print(file_path)
# print(len(doc.paragraphs))
#按txt内容中出现位置划分
def blockdivide2(hd,str_txt):
# print(a.getname())
blocks = []
hdnum = []
# menuitem===>text ,heading x, pi
for i in hd:
hdnum.append(str_txt.find(i[0])) # hd[2]--->pi
print(hdnum.sort())
for num in range(0, len(hdnum) - 1):
blocklist = []
if num < len(hdnum) - 1:
# for i in str_txt[hdnum[num]:hdnum[num + 1]]:
# # print(i.text)
blocklist.append(str_txt[hdnum[num]:hdnum[num + 1]])
# else:
# for i in a.getparagraphs()[hdnum[num]:len(a.getparagraphs())]:
# # print(i.text)
# blocklist.append(i.text)
# print("----------------------------------------------------------------------")
blocks.append([hdnum[num], hdnum[num + 1], blocklist])
return blocks
#按docx内段落划分
# def blockdivide(hd,p):
# # print(a.getname())
# blocks = []
# hdnum = []
# # menuitem===>text ,heading x, pi
# for i in hd:
# hdnum.append(i[1]) # hd[2]--->pi
# hdnum.sort()
# for num in range(0, len(hdnum)-1):
# blocklist = []
# if num < len(hdnum) - 1:
# for i in p[hdnum[num]:hdnum[num + 1]]:
# # print(i.text)
# blocklist.append(i.text)
# # else:
# # for i in a.getparagraphs()[hdnum[num]:len(a.getparagraphs())]:
# # # print(i.text)
# # blocklist.append(i.text)
# # print("----------------------------------------------------------------------")
# blocks.append([hdnum[num], hdnum[num + 1], blocklist])
# return blocks
# bl=blockdivide(t_pnum_list,docx.paragraphs)
# bl2=blockdivide2(t_pnum_list,str_txt)
# bl3=blockdivide3(t_pnum_list,str_txt)
return [str_txt]
def query_experiment_sys(string, history, past_key_values, stop_stream, aname, save_path):
# past_key_values, history = None, []
# global stop_stream
query = string + '阅读此部分文章,回答本文所实现的系统或软件有哪些模块(不超过10个),回答格式为[xxx],[xxx],[xxx],例如[用户管理模块],[题目管理模块]。'
print("\nChatGLM:", end="")
current_length = 0
for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,
temperature=0.01,
past_key_values=past_key_values,
return_past_key_values=True):
if stop_stream:
stop_stream = False
break
else:
print(response[current_length:], end="", flush=True)
rep=response
current_length = len(response)
print("")
#内容query结束
#指标query开始
query = '回答每一个模块的需求是什么,架构是什么,功能是什么,该模块测试方式是什么,如果有对应指标,给出所在文中的依据,否则回答[未找到具体指标],回答格式为:[xxx]。'
print(query)
print("\nChatGLM:", end="")
current_length = 0
for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,
temperature=0.01,
past_key_values=past_key_values,
return_past_key_values=True):
if stop_stream:
stop_stream = False
break
else:
print(response[current_length:], end="", flush=True)
with open(save_path+aname+'[指标].txt', 'a', encoding='utf-8') as f:
print(response[current_length:],end='', file=f)
f.close()
current_length = len(response)
print("")
#总结query
return
def query_summary_sys(aname, history, past_key_values, stop_stream,save_path):
query = '根据以上信息对每一个模块进行归纳,回答时尽可能涵盖量化数据,回答格式为:针对[]需求,提出了[]模块, 进行了[], 结果表明[]。'
print("\nChatGLM:", end="")
current_length = 0
for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,
temperature=0.01,
past_key_values=past_key_values,
return_past_key_values=True):
if stop_stream:
stop_stream = False
break
else:
print(response[current_length:], end="", flush=True)
with open(save_path + aname + '[归纳].txt', 'a', encoding='utf-8') as f:
print(response[current_length:], end='', file=f)
f.close()
current_length = len(response)
print("")
def blockdivide2(hd,str_txt):
# print(a.getname())
blocks = []
hdnum = []
# menuitem===>text ,heading x, pi
for i in hd:
print(i,str_txt.find(i[0]))
hdnum.append(str_txt.find(i[0]))
# hd[2]--->pi
print(hdnum.sort())
for num in range(0, len(hdnum) - 1):
blocklist = []
if num < len(hdnum) - 1:
# for i in str_txt[hdnum[num]:hdnum[num + 1]]:
# # print(i.text)
blocklist.append(str_txt[hdnum[num]:hdnum[num + 1]])
# else:
# for i in a.getparagraphs()[hdnum[num]:len(a.getparagraphs())]:
# # print(i.text)
# blocklist.append(i.text)
# print("----------------------------------------------------------------------")
blocks.append([hdnum[num], hdnum[num + 1], blocklist])
return blocks
def blockdivide3(hd, str_txt):
# return [str_txt[0:int(len(str_txt)/2)],str_txt[int(len(str_txt)/2):]]
return [str_txt]
#按docx内段落划分
def blockdivide(hd,p):
# print(a.getname())
blocks = []
hdnum = []
# menuitem===>text ,heading x, pi
for i in hd:
hdnum.append(i[1]) # hd[2]--->pi
hdnum.sort()
for num in range(0, len(hdnum)-1):
blocklist = []
if num < len(hdnum) - 1:
for i in p[hdnum[num]:hdnum[num + 1]]:
# print(i.text)
blocklist.append(i.text)
# else:
# for i in a.getparagraphs()[hdnum[num]:len(a.getparagraphs())]:
# # print(i.text)
# blocklist.append(i.text)
# print("----------------------------------------------------------------------")
blocks.append([hdnum[num], hdnum[num + 1], blocklist])
return blocks
def main(args):
for i in args[0]:
past_key_values, history = None, []
global stop_stream
save_path2='/home/liuqt/long/file/logsave/'
save_path = 'D:\PyProject\docx_input\\file\logsave\\'
# aname='17373118_刘阳_面向国产ARM平台的浮点计算分析工具(最终版)'
aname=i
# aname='17373086_黎明_C语言程序自动评测系统设计与实现'
articleclass=query_zy_class(aname,history,past_key_values,stop_stream)
if '算法类' in articleclass:
with open(save_path2 + aname + '[归纳].txt', 'w', encoding='utf-8') as f:
print('文章归纳', end='\n', file=f)
f.close()
with open(save_path2 + aname + '[指标].txt', 'w', encoding='utf-8') as f:
print('文章指标', end='\n', file=f)
f.close()
bl=query_menu_struct_alg(aname, history, past_key_values, stop_stream)
query_experiment_alg(bl[0], history, past_key_values, stop_stream, aname, save_path2)
# for i in bl:
# print(i[2])
# string=';'.join(i[2])
# query_experiment_alg(string, history, past_key_values, stop_stream, aname, save_path2)
elif '系统类' in articleclass:
bl = query_menu_struct_sys(aname, history, past_key_values, stop_stream)
# for i in bl:
# print(i[2])
# string = ';'.join(i[2])
#clear
with open(save_path2 + aname + '[归纳].txt', 'w', encoding='utf-8') as f:
print('文章归纳', end='\n', file=f)
f.close()
with open(save_path2 + aname + '[指标].txt', 'w', encoding='utf-8') as f:
print('文章指标', end='\n', file=f)
f.close()
query_experiment_sys(bl[0], history, past_key_values, stop_stream, aname, save_path2)
query_summary_sys(aname,history,past_key_values,stop_stream,save_path2)
else:
print('理论推导类')
# while True:
# query = input("\n用户:")
# if query.strip() == "stop":
# break
# if query.strip() == "clear":
# past_key_values, history = None, []
# os.system(clear_command)
# print(welcome_prompt)
# continue
# print("\nChatGLM:", end="")
# current_length = 0
# for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,
# temperature=0.01,
# past_key_values=past_key_values,
# return_past_key_values=True):
# if stop_stream:
# stop_stream = False
# break
# else:
# print(response[current_length:], end="", flush=True)
# current_length = len(response)
# print("")
if __name__ == "__main__":
files=os.listdir('file/txt2')
filenames=[]
for i in files:
filenames.append(i.split('.')[0])
args=[filenames]
main(args)