|
| 1 | +use std::{ |
| 2 | + collections::{HashMap, HashSet}, |
| 3 | + fs::File, |
| 4 | + io::Read, |
| 5 | + sync::Arc, |
| 6 | +}; |
| 7 | + |
| 8 | +use anyhow::{bail, Result}; |
| 9 | +use bytes::Bytes; |
| 10 | +use indicatif::{ProgressBar, ProgressStyle}; |
| 11 | +use reqwest::{Client, Response, Url}; |
| 12 | +use sha2::{Digest, Sha256}; |
| 13 | +use tokio::task::JoinSet; |
| 14 | + |
| 15 | +use composefs::{ |
| 16 | + fsverity::FsVerityHashValue, |
| 17 | + repository::Repository, |
| 18 | + splitstream::{DigestMapEntry, SplitStreamReader}, |
| 19 | + util::Sha256Digest, |
| 20 | +}; |
| 21 | + |
| 22 | +struct Downloader<ObjectID: FsVerityHashValue> { |
| 23 | + client: Client, |
| 24 | + repo: Arc<Repository<ObjectID>>, |
| 25 | + url: Url, |
| 26 | +} |
| 27 | + |
| 28 | +impl<ObjectID: FsVerityHashValue> Downloader<ObjectID> { |
| 29 | + fn is_symlink(response: &Response) -> bool { |
| 30 | + let Some(content_type_header) = response.headers().get("Content-Type") else { |
| 31 | + return false; |
| 32 | + }; |
| 33 | + |
| 34 | + let Ok(content_type) = content_type_header.to_str() else { |
| 35 | + return false; |
| 36 | + }; |
| 37 | + |
| 38 | + ["text/x-symlink-target"].contains(&content_type) |
| 39 | + } |
| 40 | + |
| 41 | + async fn fetch(&self, dir: &str, name: &str) -> Result<(Bytes, bool)> { |
| 42 | + let object_url = self.url.join(dir)?.join(name)?; |
| 43 | + let request = self.client.get(object_url.clone()).build()?; |
| 44 | + let response = self.client.execute(request).await?; |
| 45 | + response.error_for_status_ref()?; |
| 46 | + let is_symlink = Self::is_symlink(&response); |
| 47 | + Ok((response.bytes().await?, is_symlink)) |
| 48 | + } |
| 49 | + |
| 50 | + async fn ensure_object(&self, id: &ObjectID) -> Result<bool> { |
| 51 | + if self.repo.open_object(id).is_err() { |
| 52 | + let (data, _is_symlink) = self.fetch("objects/", &id.to_object_pathname()).await?; |
| 53 | + let actual_id = self.repo.ensure_object_async(data.into()).await?; |
| 54 | + if actual_id != *id { |
| 55 | + bail!("Downloaded {id:?} but it has fs-verity {actual_id:?}"); |
| 56 | + } |
| 57 | + Ok(true) |
| 58 | + } else { |
| 59 | + Ok(false) |
| 60 | + } |
| 61 | + } |
| 62 | + |
| 63 | + fn open_splitstream(&self, id: &ObjectID) -> Result<SplitStreamReader<File, ObjectID>> { |
| 64 | + SplitStreamReader::new(File::from(self.repo.open_object(id)?)) |
| 65 | + } |
| 66 | + |
| 67 | + fn read_object(&self, id: &ObjectID) -> Result<Vec<u8>> { |
| 68 | + let mut data = vec![]; |
| 69 | + File::from(self.repo.open_object(id)?).read_to_end(&mut data)?; |
| 70 | + Ok(data) |
| 71 | + } |
| 72 | + |
| 73 | + async fn ensure_stream(self: &Arc<Self>, name: &str) -> Result<(Sha256Digest, ObjectID)> { |
| 74 | + let progress = ProgressBar::new(2); // the first object gets "ensured" twice |
| 75 | + progress.set_style( |
| 76 | + ProgressStyle::with_template( |
| 77 | + "[eta {eta}] {bar:40.cyan/blue} Fetching {pos} / {len} splitstreams", |
| 78 | + ) |
| 79 | + .unwrap() |
| 80 | + .progress_chars("##-"), |
| 81 | + ); |
| 82 | + |
| 83 | + // Ideally we'll get a symlink, but we might get the data directly |
| 84 | + let (data, is_symlink) = self.fetch("streams/", name).await?; |
| 85 | + let my_id = if is_symlink { |
| 86 | + ObjectID::from_object_pathname(&data)? |
| 87 | + } else { |
| 88 | + self.repo.ensure_object(&data)? |
| 89 | + }; |
| 90 | + progress.inc(1); |
| 91 | + |
| 92 | + let mut objects_todo = HashSet::new(); |
| 93 | + |
| 94 | + // TODO: if 'name' looks sha256ish then we ought to use it instead of None? |
| 95 | + let mut splitstreams = HashMap::from([(my_id.clone(), None)]); |
| 96 | + let mut splitstreams_todo = vec![my_id.clone()]; |
| 97 | + |
| 98 | + // Recursively fetch all splitstreams |
| 99 | + // TODO: make this parallel, at least the ensure_object() part... |
| 100 | + while let Some(id) = splitstreams_todo.pop() { |
| 101 | + // this is the slow part (downloads, writing to disk, etc.) |
| 102 | + if self.ensure_object(&id).await? { |
| 103 | + progress.inc(1); |
| 104 | + } else { |
| 105 | + progress.dec_length(1); |
| 106 | + } |
| 107 | + |
| 108 | + // this part is fast: it only touches the header |
| 109 | + let mut reader = self.open_splitstream(&id)?; |
| 110 | + for DigestMapEntry { verity, body } in &reader.refs.map { |
| 111 | + match splitstreams.insert(verity.clone(), Some(*body)) { |
| 112 | + // This is the (normal) case if we encounter a splitstream we didn't see yet... |
| 113 | + None => { |
| 114 | + splitstreams_todo.push(verity.clone()); |
| 115 | + progress.inc_length(1); |
| 116 | + } |
| 117 | + |
| 118 | + // This is the case where we've already been asked to fetch this stream. We'll |
| 119 | + // verify the SHA-256 content hashes later (after we get all the objects) so we |
| 120 | + // need to make sure that all referents of this stream agree on what that is. |
| 121 | + Some(Some(previous)) => { |
| 122 | + if previous != *body { |
| 123 | + bail!( |
| 124 | + "Splitstream with verity {verity:?} has different body hashes {} and {}", |
| 125 | + hex::encode(previous), |
| 126 | + hex::encode(body) |
| 127 | + ); |
| 128 | + } |
| 129 | + } |
| 130 | + |
| 131 | + // This case should really be absolutely impossible: the only None value we |
| 132 | + // record is for the original stream, and if we somehow managed to get back |
| 133 | + // there via object IDs (which we check on download) then it means someone |
| 134 | + // managed to construct two self-referential content-addressed objects... |
| 135 | + Some(None) => bail!("Splitstream attempts to include itself recursively"), |
| 136 | + } |
| 137 | + } |
| 138 | + |
| 139 | + // This part is medium-fast: it needs to iterate the entire stream |
| 140 | + reader.get_object_refs(|id| { |
| 141 | + if !splitstreams.contains_key(id) { |
| 142 | + objects_todo.insert(id.clone()); |
| 143 | + } |
| 144 | + })?; |
| 145 | + } |
| 146 | + |
| 147 | + progress.finish(); |
| 148 | + |
| 149 | + let progress = ProgressBar::new(objects_todo.len() as u64); |
| 150 | + progress.set_style( |
| 151 | + ProgressStyle::with_template( |
| 152 | + "[eta {eta}] {bar:40.cyan/blue} Fetching {pos} / {len} objects", |
| 153 | + ) |
| 154 | + .unwrap() |
| 155 | + .progress_chars("##-"), |
| 156 | + ); |
| 157 | + |
| 158 | + // Fetch all the objects |
| 159 | + let mut set = JoinSet::<Result<bool>>::new(); |
| 160 | + let mut iter = objects_todo.into_iter(); |
| 161 | + |
| 162 | + // Queue up 100 initial requests |
| 163 | + // See SETTINGS_MAX_CONCURRENT_STREAMS in RFC 7540 |
| 164 | + // We might actually want to increase this... |
| 165 | + for id in iter.by_ref().take(100) { |
| 166 | + let self_ = Arc::clone(self); |
| 167 | + set.spawn(async move { self_.ensure_object(&id).await }); |
| 168 | + } |
| 169 | + |
| 170 | + // Collect results for tasks that finish. For each finished task, add another (if there |
| 171 | + // are any). |
| 172 | + while let Some(result) = set.join_next().await { |
| 173 | + if result?? { |
| 174 | + // a download |
| 175 | + progress.inc(1); |
| 176 | + } else { |
| 177 | + // a not-download |
| 178 | + progress.dec_length(1); |
| 179 | + } |
| 180 | + |
| 181 | + if let Some(id) = iter.next() { |
| 182 | + let self_ = Arc::clone(self); |
| 183 | + set.spawn(async move { self_.ensure_object(&id).await }); |
| 184 | + } |
| 185 | + } |
| 186 | + |
| 187 | + progress.finish(); |
| 188 | + |
| 189 | + // Now that we have all of the objects, we can verify that the merged-content of each |
| 190 | + // splitstream corresponds to its claimed body content checksum, if any... |
| 191 | + let progress = ProgressBar::new(splitstreams.len() as u64); |
| 192 | + progress.set_style( |
| 193 | + ProgressStyle::with_template( |
| 194 | + "[eta {eta}] {bar:40.cyan/blue} Verifying {pos} / {len} splitstreams", |
| 195 | + ) |
| 196 | + .unwrap() |
| 197 | + .progress_chars("##-"), |
| 198 | + ); |
| 199 | + |
| 200 | + let mut my_sha256 = None; |
| 201 | + // TODO: This can definitely happen in parallel... |
| 202 | + for (id, expected_checksum) in splitstreams { |
| 203 | + let mut reader = self.open_splitstream(&id)?; |
| 204 | + let mut context = Sha256::new(); |
| 205 | + reader.cat(&mut context, |id| self.read_object(id))?; |
| 206 | + let measured_checksum: Sha256Digest = context.finalize().into(); |
| 207 | + |
| 208 | + if let Some(expected) = expected_checksum { |
| 209 | + if measured_checksum != expected { |
| 210 | + bail!( |
| 211 | + "Splitstream id {id:?} should have checksum {} but is actually {}", |
| 212 | + hex::encode(expected), |
| 213 | + hex::encode(measured_checksum) |
| 214 | + ); |
| 215 | + } |
| 216 | + } |
| 217 | + |
| 218 | + if id == my_id { |
| 219 | + my_sha256 = Some(measured_checksum); |
| 220 | + } |
| 221 | + |
| 222 | + progress.inc(1); |
| 223 | + } |
| 224 | + |
| 225 | + progress.finish(); |
| 226 | + |
| 227 | + // We've definitely set this by now: `my_id` is in `splitstreams`. |
| 228 | + let my_sha256 = my_sha256.unwrap(); |
| 229 | + |
| 230 | + Ok((my_sha256, my_id)) |
| 231 | + } |
| 232 | +} |
| 233 | + |
| 234 | +pub async fn download<ObjectID: FsVerityHashValue>( |
| 235 | + url: &str, |
| 236 | + name: &str, |
| 237 | + repo: Arc<Repository<ObjectID>>, |
| 238 | +) -> Result<(Sha256Digest, ObjectID)> { |
| 239 | + let downloader = Arc::new(Downloader { |
| 240 | + client: Client::new(), |
| 241 | + repo, |
| 242 | + url: Url::parse(url)?, |
| 243 | + }); |
| 244 | + |
| 245 | + downloader.ensure_stream(name).await |
| 246 | +} |
0 commit comments