ucaslcl commited on
Commit
74d5f2d
1 Parent(s): 41bb3d3

Delete conversation.py

Browse files
Files changed (1) hide show
  1. conversation.py +0 -455
conversation.py DELETED
@@ -1,455 +0,0 @@
1
- import dataclasses
2
- from enum import auto, Enum
3
- from typing import List, Tuple
4
-
5
-
6
- class SeparatorStyle(Enum):
7
- """Different separator style."""
8
- SINGLE = auto()
9
- TWO = auto()
10
- MPT = auto()
11
-
12
-
13
-
14
- # simple_conv_multimodal = Conversation(
15
- # system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology."
16
- # "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
17
- # "Follow the instructions carefully and explain your answers in detail.",
18
- # # system="",
19
- # roles=("Human", "Assistant"),
20
- # messages=(
21
- # ("Human", "Hi!"),
22
- # ("Assistant", "Hi there! How can I help you today?\n")
23
- # ),
24
- # offset=2,
25
- # sep_style=SeparatorStyle.SINGLE,
26
- # sep="###",
27
- # )
28
-
29
- # conv_mpt = Conversation(
30
- # system="""<|im_start|>system
31
- # - You are a helpful language and vision assistant.
32
- # - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
33
- # - You should follow the instructions carefully and explain your answers in detail.""",
34
- # roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
35
- # version="mpt",
36
- # messages=(),
37
- # offset=0,
38
- # sep_style=SeparatorStyle.MPT,
39
- # sep="<|im_end|>",
40
- # )
41
-
42
- @dataclasses.dataclass
43
- class Conversation:
44
- """A class that keeps all conversation history."""
45
- system: str
46
- roles: List[str]
47
- messages: List[List[str]]
48
- offset: int
49
- sep_style: SeparatorStyle = SeparatorStyle.SINGLE
50
- sep: str = "<|im_end|>"
51
- sep2: str = None
52
- version: str = "Unknown"
53
-
54
- skip_next: bool = False
55
-
56
- def get_prompt(self):
57
- if self.sep_style == SeparatorStyle.SINGLE:
58
- ret = self.system + self.sep + '\n'
59
- for role, message in self.messages:
60
- if message:
61
- if type(message) is tuple:
62
- message, _, _ = message
63
- ret += role + ": " + message + self.sep
64
- else:
65
- ret += role + ":"
66
- return ret
67
- elif self.sep_style == SeparatorStyle.TWO:
68
- seps = [self.sep, self.sep2]
69
- ret = self.system + seps[0]
70
- for i, (role, message) in enumerate(self.messages):
71
- if message:
72
- if type(message) is tuple:
73
- message, _, _ = message
74
- ret += role + ": " + message + seps[i % 2]
75
- else:
76
- ret += role + ":"
77
- return ret
78
- if self.sep_style == SeparatorStyle.MPT:
79
- if self.system:
80
- ret = self.system + self.sep
81
- else:
82
- ret = ''
83
- for role, message in self.messages:
84
- if message:
85
- if type(message) is tuple:
86
- message, _, _ = message
87
- ret += role + message + self.sep
88
- else:
89
- ret += role
90
- return ret
91
- else:
92
- raise ValueError(f"Invalid style: {self.sep_style}")
93
- # if self.sep_style == SeparatorStyle.MPT:
94
- # if self.system:
95
- # ret = self.system + self.sep
96
- # else:
97
- # ret = ''
98
- # for role, message in self.messages:
99
- # if message:
100
- # if type(message) is tuple:
101
- # message, _, _ = message
102
- # ret += role + message + self.sep
103
- # # if 'user' in role:
104
- # # ret += role + message + self.sep + "\n"
105
- # # else:
106
- # # ret += role + message + self.sep
107
- # else:
108
- # ret += role
109
- # return ret
110
- # else:
111
- # raise ValueError(f"Invalid style: {self.sep_style}")
112
-
113
- def append_message(self, role, message):
114
- self.messages.append([role, message])
115
-
116
- def get_images(self, return_pil=False):
117
- images = []
118
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
119
- if i % 2 == 0:
120
- if type(msg) is tuple:
121
- import base64
122
- from io import BytesIO
123
- from PIL import Image
124
- msg, image, image_process_mode = msg
125
- if image_process_mode == "Pad":
126
- def expand2square(pil_img, background_color=(122, 116, 104)):
127
- width, height = pil_img.size
128
- if width == height:
129
- return pil_img
130
- elif width > height:
131
- result = Image.new(pil_img.mode, (width, width), background_color)
132
- # result.paste(pil_img, (0, (width - height) // 2))
133
- result.paste(pil_img)
134
- return result
135
- else:
136
- result = Image.new(pil_img.mode, (height, height), background_color)
137
- # result.paste(pil_img, ((height - width) // 2, 0))
138
- result.paste(pil_img)
139
- return result
140
- image = expand2square(image)
141
- elif image_process_mode == "Crop":
142
- max_hw, min_hw = max(image.size), min(image.size)
143
- aspect_ratio = max_hw / min_hw
144
- max_len, min_len = 800, 400
145
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
146
- longest_edge = int(shortest_edge * aspect_ratio)
147
- W, H = image.size
148
- if H > W:
149
- H, W = longest_edge, shortest_edge
150
- else:
151
- H, W = shortest_edge, longest_edge
152
- image = image.resize((W, H))
153
- elif image_process_mode == "Resize":
154
- image = image.resize((224, 224))
155
- else:
156
- raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
157
-
158
- if return_pil:
159
- images.append(image)
160
- else:
161
- buffered = BytesIO()
162
- image.convert('RGB').save(buffered, format="JPEG")
163
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
164
- images.append(img_b64_str)
165
- return images
166
-
167
- def to_gradio_chatbot(self):
168
- ret = []
169
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
170
- if i % 2 == 0:
171
- if type(msg) is tuple:
172
- import base64
173
- from io import BytesIO
174
- msg, image, image_process_mode = msg
175
- max_hw, min_hw = max(image.size), min(image.size)
176
- aspect_ratio = max_hw / min_hw
177
- max_len, min_len = 800, 400
178
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
179
- longest_edge = int(shortest_edge * aspect_ratio)
180
- W, H = image.size
181
- if H > W:
182
- H, W = longest_edge, shortest_edge
183
- else:
184
- H, W = shortest_edge, longest_edge
185
- image = image.resize((W, H))
186
- # image = image.resize((224, 224))
187
- buffered = BytesIO()
188
- image.save(buffered, format="JPEG")
189
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
190
- img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
191
- msg = msg.replace('<image>', img_str)
192
- ret.append([msg, None])
193
- else:
194
- ret[-1][-1] = msg
195
- return ret
196
-
197
- def copy(self):
198
- return Conversation(
199
- system=self.system,
200
- roles=self.roles,
201
- messages=[[x, y] for x, y in self.messages],
202
- offset=self.offset,
203
- sep_style=self.sep_style,
204
- sep=self.sep,
205
- sep2=self.sep2)
206
-
207
- def dict(self):
208
- if len(self.get_images()) > 0:
209
- return {
210
- "system": self.system,
211
- "roles": self.roles,
212
- "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
213
- "offset": self.offset,
214
- "sep": self.sep,
215
- "sep2": self.sep2,
216
- }
217
- return {
218
- "system": self.system,
219
- "roles": self.roles,
220
- "messages": self.messages,
221
- "offset": self.offset,
222
- "sep": self.sep,
223
- "sep2": self.sep2,
224
- }
225
-
226
-
227
- conv_v1 = Conversation(
228
- system="A chat between a curious human and an artificial intelligence assistant. "
229
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
230
- roles=("Human", "Assistant"),
231
- messages=(
232
- ("Human", "Give three tips for staying healthy."),
233
- ("Assistant",
234
- "Sure, here are three tips for staying healthy:\n"
235
- "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
236
- "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
237
- "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
238
- "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
239
- "activities at least two days per week.\n"
240
- "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
241
- "vegetables, whole grains, lean proteins, and healthy fats can help support "
242
- "your overall health. Try to limit your intake of processed and high-sugar foods, "
243
- "and aim to drink plenty of water throughout the day.\n"
244
- "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
245
- "and mental health. Adults should aim for seven to nine hours of sleep per night. "
246
- "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
247
- "help improve the quality of your sleep.")
248
- ),
249
- offset=2,
250
- sep_style=SeparatorStyle.SINGLE,
251
- sep="###",
252
- )
253
-
254
- conv_v1_2 = Conversation(
255
- system="A chat between a curious human and an artificial intelligence assistant. "
256
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
257
- roles=("Human", "Assistant"),
258
- messages=(
259
- ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
260
- ("Assistant",
261
- "Renewable energy sources are those that can be replenished naturally in a relatively "
262
- "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
263
- "Non-renewable energy sources, on the other hand, are finite and will eventually be "
264
- "depleted, such as coal, oil, and natural gas. Here are some key differences between "
265
- "renewable and non-renewable energy sources:\n"
266
- "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
267
- "energy sources are finite and will eventually run out.\n"
268
- "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
269
- "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
270
- "and other negative effects.\n"
271
- "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
272
- "have lower operational costs than non-renewable sources.\n"
273
- "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
274
- "locations than non-renewable sources.\n"
275
- "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
276
- "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
277
- "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
278
- "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
279
- ),
280
- offset=2,
281
- sep_style=SeparatorStyle.SINGLE,
282
- sep="###",
283
- )
284
-
285
- conv_vicuna_v1_1 = Conversation(
286
- system="A chat between a curious user and an artificial intelligence assistant. "
287
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
288
- roles=("USER", "ASSISTANT"),
289
- version="v1",
290
- messages=(),
291
- offset=0,
292
- sep_style=SeparatorStyle.TWO,
293
- sep=" ",
294
- sep2="</s>",
295
- )
296
-
297
- # conv_mpt = Conversation(
298
- # system="""<|im_start|>system
299
- # - You are designed by Megvii(旷视), and your name is GOT.
300
- # - 你叫GOT, 你来自旷视, 你是旷视开发的。
301
- # - 你擅长分析表格,仔细读图表中的内容,然后给出你的答案。""",
302
- # roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
303
- # version="mpt",
304
- # messages=(),
305
- # offset=0,
306
- # sep_style=SeparatorStyle.MPT,
307
- # sep="<|im_end|>",
308
- # )
309
-
310
- conv_mpt = Conversation(
311
- system="""<|im_start|>system
312
- You should follow the instructions carefully and explain your answers in detail.""",
313
- # system = None,
314
- roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
315
- version="mpt",
316
- messages=(),
317
- offset=0,
318
- sep_style=SeparatorStyle.MPT,
319
- sep="<|im_end|>",
320
- )
321
-
322
- conv_mpt_eval = Conversation(
323
- system="",
324
- roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
325
- version="mpt",
326
- messages=(),
327
- offset=0,
328
- sep_style=SeparatorStyle.MPT,
329
- sep="<|im_end|>",
330
- )
331
-
332
- conv_mpt_text = Conversation(
333
- system="""<|im_start|>system
334
- - You are a helpful assistant chatbot trained by MosaicML.
335
- - You answer questions.
336
- - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
337
- - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
338
- roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
339
- version="mpt",
340
- messages=(),
341
- offset=0,
342
- sep_style=SeparatorStyle.MPT,
343
- sep="<|im_end|>",
344
- )
345
-
346
- conv_bair_v1 = Conversation(
347
- system="BEGINNING OF CONVERSATION:",
348
- roles=("USER", "GPT"),
349
- messages=(),
350
- offset=0,
351
- sep_style=SeparatorStyle.TWO,
352
- sep=" ",
353
- sep2="</s>",
354
- )
355
-
356
- # simple_conv = Conversation(
357
- # system="You are GOT, a large language model trained by Foundation Model Group, Megvii Technology, based on LLaMA architecture."
358
- # "You are designed to assist human with a variety of tasks using natural language."
359
- # "Follow the instructions carefully.",
360
- # roles=("Human", "Assistant"),
361
- # messages=(
362
- # ("Human", "Hi!"),
363
- # ("Assistant", "Hi there! How can I help you today?\n")
364
- # ),
365
- # offset=2,
366
- # sep_style=SeparatorStyle.SINGLE,
367
- # sep="###",
368
- # )
369
-
370
-
371
- simple_conv = Conversation(
372
- system="",
373
- roles=("Human", "Assistant"),
374
- messages=(
375
- ),
376
- offset=0,
377
- sep_style=SeparatorStyle.SINGLE,
378
- sep="###",
379
- )
380
-
381
- simple_conv_multimodal = Conversation(
382
- system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology."
383
- "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
384
- "Follow the instructions carefully and explain your answers in detail.",
385
- # system="",
386
- roles=("Human", "Assistant"),
387
- messages=(
388
- ("Human", "Hi!"),
389
- ("Assistant", "Hi there! How can I help you today?\n")
390
- ),
391
- offset=2,
392
- sep_style=SeparatorStyle.SINGLE,
393
- sep="###",
394
- )
395
-
396
- simple_conv_mpt_multimodal = Conversation(
397
- system="""<|im_start|>system
398
- - You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology.
399
- - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
400
- - You should follow the instructions carefully and explain your answers in detail.""",
401
- roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
402
- version="mpt",
403
- messages=(),
404
- offset=0,
405
- sep_style=SeparatorStyle.MPT,
406
- sep="<|im_end|>",
407
- )
408
-
409
- simple_conv_legacy = Conversation(
410
- system="You are GOT, a large language model trained by Foundation Model Group, Megvii Technology."
411
- "You are designed to assist human with a variety of tasks using natural language."
412
- "Follow the instructions carefully.",
413
- roles=("Human", "Assistant"),
414
- messages=(
415
- ("Human", "Hi!\n\n### Response:"),
416
- ("Assistant", "Hi there! How can I help you today?\n")
417
- ),
418
- offset=2,
419
- sep_style=SeparatorStyle.SINGLE,
420
- sep="###",
421
- )
422
-
423
- conv_llava_v1 = Conversation(
424
- system="You are GOT, a large language and vision assistant trained by Foundation Model Group, Megvii Technology."
425
- "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
426
- "Follow the instructions carefully and explain your answers in detail.",
427
- roles=("USER", "ASSISTANT"),
428
- version="v1",
429
- messages=(),
430
- offset=0,
431
- sep_style=SeparatorStyle.TWO,
432
- sep=" ",
433
- sep2="</s>",
434
- )
435
-
436
- default_conversation = conv_mpt
437
- conv_templates = {
438
- "default": simple_conv_multimodal,
439
- "simple": simple_conv,
440
- "simple_legacy": simple_conv_legacy,
441
- "multimodal": simple_conv,
442
- "mpt_multimodal": simple_conv_mpt_multimodal,
443
- "llava_v1": conv_llava_v1,
444
- "mpt_eval": conv_mpt_eval,
445
- # fastchat
446
- "v1": conv_vicuna_v1_1,
447
- "bair_v1": conv_bair_v1,
448
- "vicuna_v1_1": conv_vicuna_v1_1,
449
- "mpt": conv_mpt,
450
- "mpt_text": conv_mpt_text,
451
- }
452
-
453
-
454
- if __name__ == "__main__":
455
- print(default_conversation.get_prompt())