11import pytest
22import os
33
4- from taskingai .retrieval import Record , TokenTextSplitter
4+ from taskingai .retrieval import Record , TokenTextSplitter , TextSplitter
55from taskingai .retrieval import list_collections , create_collection , get_collection , update_collection , delete_collection , list_records , create_record , get_record , update_record , delete_record , query_chunks , create_chunk , update_chunk , get_chunk , delete_chunk , list_chunks
66from taskingai .file import upload_file
77from test .config import Config
@@ -109,11 +109,18 @@ class TestRecord:
109109
110110 text_splitter_list = [
111111 {
112- "type" : "token" , # "type": "token
112+ "type" : "token" ,
113113 "chunk_size" : 100 ,
114114 "chunk_overlap" : 10
115115 },
116- TokenTextSplitter (chunk_size = 200 , chunk_overlap = 20 )
116+ TokenTextSplitter (chunk_size = 200 , chunk_overlap = 20 ),
117+ {
118+ "type" : "separator" ,
119+ "chunk_size" : 100 ,
120+ "chunk_overlap" : 10 ,
121+ "separators" : ["." , "!" , "?" ]
122+ },
123+ TextSplitter (type = "separator" , chunk_size = 200 , chunk_overlap = 20 , separators = ["." , "!" , "?" ])
117124 ]
118125 upload_file_data_list = []
119126
@@ -129,10 +136,10 @@ class TestRecord:
129136 upload_file_data_list .append (upload_file_dict )
130137
131138 @pytest .mark .run (order = 31 )
132- def test_create_record_by_text (self , collection_id ):
139+ @pytest .mark .parametrize ("text_splitter" , text_splitter_list )
140+ def test_create_record_by_text (self , collection_id , text_splitter ):
133141
134142 # Create a text record.
135- text_splitter = TokenTextSplitter (chunk_size = 200 , chunk_overlap = 20 )
136143 text = "Machine learning is a subfield of artificial intelligence (AI) that involves the development of algorithms that allow computers to learn from and make decisions or predictions based on data."
137144 create_record_data = {
138145 "type" : "text" ,
@@ -145,17 +152,9 @@ def test_create_record_by_text(self, collection_id):
145152 "key2" : "value2"
146153 }
147154 }
148- for x in range (2 ):
149- if x == 0 :
150- create_record_data .update (
151- {"text_splitter" : {
152- "type" : "token" ,
153- "chunk_size" : 100 ,
154- "chunk_overlap" : 10
155- }})
156- res = create_record (** create_record_data )
157- res_dict = vars (res )
158- assume_record_result (create_record_data , res_dict )
155+ res = create_record (** create_record_data )
156+ res_dict = vars (res )
157+ assume_record_result (create_record_data , res_dict )
159158
160159 @pytest .mark .run (order = 31 )
161160 def test_create_record_by_web (self , collection_id ):
@@ -345,12 +344,13 @@ def test_query_chunks(self, collection_id):
345344
346345 query_text = "Machine learning"
347346 top_k = 1
348- res = query_chunks (collection_id = collection_id , query_text = query_text , top_k = top_k , max_tokens = 20000 )
347+ res = query_chunks (collection_id = collection_id , query_text = query_text , top_k = top_k , max_tokens = 20000 , score_threshold = 0.04 )
349348 pytest .assume (len (res ) == top_k )
350349 for chunk in res :
351350 chunk_dict = vars (chunk )
352351 assume_query_chunk_result (query_text , chunk_dict )
353352 pytest .assume (chunk_dict .keys () == self .chunk_keys )
353+ pytest .assume (chunk_dict ["score" ] >= 0.04 )
354354
355355 @pytest .mark .run (order = 42 )
356356 def test_create_chunk (self , collection_id ):
0 commit comments